From 9e17432c706623fb8f9bcd050f7febc5e9575e60 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Wed, 4 Mar 2026 02:37:26 +0200 Subject: [PATCH] Revert "Config adjustments (#1369)" This reverts commit 790397cebbacdf0bb1d8a4e850069598b04dc754. --- dimos/agents/agent.py | 7 +- dimos/agents/agent_test_runner.py | 20 ++---- dimos/agents/mcp/mcp_client.py | 7 +- dimos/agents/mcp/mcp_server.py | 20 ++++-- dimos/agents/mcp/test_mcp_client.py | 12 ++-- .../skills/google_maps_skill_container.py | 5 +- dimos/agents/skills/gps_nav_skill.py | 3 + dimos/agents/skills/navigation.py | 5 +- dimos/agents/skills/person_follow.py | 41 +++++------ .../test_google_maps_skill_container.py | 10 ++- dimos/agents/skills/test_gps_nav_skills.py | 8 ++- dimos/agents/skills/test_navigation.py | 15 +++- .../skills/test_unitree_skill_container.py | 6 +- dimos/agents/test_agent.py | 12 ++-- dimos/control/coordinator.py | 1 + dimos/core/blueprints.py | 53 +++++++------- dimos/core/docker_runner.py | 3 +- dimos/core/introspection/blueprint/dot.py | 10 +-- dimos/core/module.py | 54 ++++++-------- dimos/core/module_coordinator.py | 23 +++--- dimos/core/native_module.py | 42 +++++------ dimos/core/test_blueprints.py | 9 ++- dimos/core/test_core.py | 3 + dimos/core/test_native_module.py | 2 + dimos/core/test_stream.py | 10 ++- dimos/core/test_worker.py | 15 ++-- dimos/core/testing.py | 6 +- dimos/core/worker.py | 27 ++++--- dimos/core/worker_manager.py | 25 ++++--- dimos/hardware/sensors/camera/module.py | 12 +++- dimos/hardware/sensors/camera/spec.py | 8 +-- dimos/hardware/sensors/camera/zed/__init__.py | 8 +-- dimos/hardware/sensors/camera/zed/test_zed.py | 7 +- dimos/hardware/sensors/fake_zed_module.py | 11 +-- .../hardware/sensors/lidar/fastlio2/module.py | 6 +- dimos/hardware/sensors/lidar/livox/module.py | 10 +-- dimos/manipulation/manipulation_module.py | 15 ++-- dimos/manipulation/pick_and_place_module.py | 15 ++-- dimos/manipulation/planning/spec/config.py | 25 +++---- dimos/mapping/costmapper.py | 8 ++- dimos/mapping/osm/current_location_map.py | 6 +- dimos/mapping/osm/query.py | 7 +- dimos/mapping/voxels.py | 8 ++- dimos/models/base.py | 7 +- dimos/models/embedding/base.py | 2 + dimos/models/embedding/clip.py | 2 + dimos/models/embedding/mobileclip.py | 2 + dimos/models/embedding/treid.py | 2 + dimos/models/vl/base.py | 23 +++--- dimos/models/vl/moondream.py | 9 ++- dimos/models/vl/moondream_hosted.py | 13 ++-- dimos/models/vl/openai.py | 5 +- dimos/models/vl/qwen.py | 5 +- .../test_wavefront_frontier_goal_selector.py | 2 +- .../wavefront_frontier_goal_selector.py | 71 ++++++++++--------- dimos/navigation/visual/query.py | 3 +- dimos/perception/detection/conftest.py | 7 +- dimos/perception/detection/module2D.py | 25 ++++--- .../temporal_memory/entity_graph_db.py | 2 +- .../temporal_memory/temporal_memory.py | 17 +++-- .../temporal_memory/temporal_memory_deploy.py | 4 +- .../temporal_utils/graph_utils.py | 2 +- dimos/perception/object_tracker.py | 23 +++--- dimos/perception/object_tracker_2d.py | 9 +-- dimos/protocol/pubsub/bridge.py | 6 +- dimos/protocol/pubsub/impl/lcmpubsub.py | 14 ++-- dimos/protocol/pubsub/impl/redispubsub.py | 9 +-- dimos/protocol/service/__init__.py | 7 +- dimos/protocol/service/ddsservice.py | 11 ++- dimos/protocol/service/lcmservice.py | 37 +++++----- dimos/protocol/service/spec.py | 15 ++-- dimos/protocol/service/test_lcmservice.py | 58 +++++++-------- dimos/protocol/tf/tf.py | 28 ++++---- dimos/protocol/tf/tflcmcpp.py | 9 +-- dimos/robot/drone/connection_module.py | 40 ++++++----- dimos/robot/foxglove_bridge.py | 32 +++++---- dimos/robot/unitree/b1/connection.py | 31 +++----- dimos/robot/unitree/b1/unitree_b1.py | 4 +- dimos/robot/unitree/g1/connection.py | 2 +- dimos/robot/unitree/go2/connection.py | 2 +- dimos/simulation/manipulators/sim_module.py | 8 ++- .../manipulators/test_sim_module.py | 3 +- dimos/utils/cli/lcmspy/lcmspy.py | 14 ++-- dimos/visualization/rerun/bridge.py | 11 +-- pyproject.toml | 8 +-- 85 files changed, 609 insertions(+), 575 deletions(-) diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 76e3f0c30f..37e1a4757c 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import json from queue import Empty, Queue from threading import Event, RLock, Thread @@ -27,7 +28,6 @@ from dimos.agents.system_prompt import SYSTEM_PROMPT from dimos.agents.utils import pretty_print_langchain_message from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig, SkillInfo from dimos.core.rpc_client import RpcCall, RPCClient from dimos.core.stream import In, Out @@ -38,6 +38,7 @@ from langchain_core.language_models import BaseChatModel +@dataclass class AgentConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" @@ -57,8 +58,8 @@ class Agent(Module[AgentConfig]): _thread: Thread _stop_event: Event - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._lock = RLock() self._state_graph = None self._message_queue = Queue() diff --git a/dimos/agents/agent_test_runner.py b/dimos/agents/agent_test_runner.py index 2562cc7688..7d7fbab03d 100644 --- a/dimos/agents/agent_test_runner.py +++ b/dimos/agents/agent_test_runner.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable from threading import Event, Thread -from typing import Any from langchain_core.messages import AIMessage from langchain_core.messages.base import BaseMessage @@ -22,27 +20,21 @@ from dimos.agents.agent import AgentSpec from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, ModuleConfig +from dimos.core.module import Module from dimos.core.rpc_client import RPCClient from dimos.core.stream import In, Out -class Config(ModuleConfig): - messages: Iterable[BaseMessage] - - -class AgentTestRunner(Module[Config]): - default_config = Config - +class AgentTestRunner(Module): agent_spec: AgentSpec agent: In[BaseMessage] agent_idle: In[bool] finished: Out[bool] added: Out[bool] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, messages: list[BaseMessage]) -> None: + super().__init__() + self._messages = messages self._idle_event = Event() self._subscription_ready = Event() self._thread = Thread(target=self._thread_loop, daemon=True) @@ -79,7 +71,7 @@ def _thread_loop(self) -> None: if not self._subscription_ready.wait(5): raise TimeoutError("Timed out waiting for subscription to be ready.") - for message in self.config.messages: + for message in self._messages: self._idle_event.clear() self.agent_spec.add_message(message) if not self._idle_event.wait(60): diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index 02c4672b47..7c5eda5302 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from queue import Empty, Queue from threading import Event, RLock, Thread import time @@ -29,7 +30,6 @@ from dimos.agents.system_prompt import SYSTEM_PROMPT from dimos.agents.utils import pretty_print_langchain_message from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.rpc_client import RPCClient from dimos.core.stream import In, Out @@ -39,6 +39,7 @@ logger = setup_logger() +@dataclass class McpClientConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" @@ -61,8 +62,8 @@ class McpClient(Module[McpClientConfig]): _http_client: httpx.Client _seq_ids: SequentialIds - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._lock = RLock() self._state_graph = None self._message_queue = Queue() diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index 89a7843b15..87c27302db 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -14,26 +14,30 @@ from __future__ import annotations import asyncio -import concurrent.futures import json from typing import TYPE_CHECKING, Any from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from starlette.requests import Request from starlette.responses import Response import uvicorn +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +from starlette.requests import Request # noqa: TC002 + from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.rpc_client import RpcCall, RPCClient -from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.core.module import SkillInfo + import concurrent.futures -logger = setup_logger() + from dimos.core.module import SkillInfo app = FastAPI() @@ -155,8 +159,10 @@ async def mcp_endpoint(request: Request) -> Response: class McpServer(Module): - _uvicorn_server: uvicorn.Server | None = None - _serve_future: concurrent.futures.Future[None] | None = None + def __init__(self) -> None: + super().__init__() + self._uvicorn_server: uvicorn.Server | None = None + self._serve_future: concurrent.futures.Future[None] | None = None @rpc def start(self) -> None: diff --git a/dimos/agents/mcp/test_mcp_client.py b/dimos/agents/mcp/test_mcp_client.py index 94209f5bf5..16427103e4 100644 --- a/dimos/agents/mcp/test_mcp_client.py +++ b/dimos/agents/mcp/test_mcp_client.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any from langchain_core.messages import HumanMessage import pytest from dimos.agents.annotation import skill -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.msgs.sensor_msgs import Image from dimos.utils.data import get_data @@ -42,8 +40,10 @@ def test_can_call_tool(agent_setup): class UserRegistration(Module): - _first_call = True - _use_upper = False + def __init__(self): + super().__init__() + self._first_call = True + self._use_upper = False @skill def register_user(self, name: str) -> str: @@ -79,8 +79,8 @@ def test_can_call_again_on_error(agent_setup): class MultipleTools(Module): - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): - super().__init__(global_config, **kwargs) + def __init__(self): + super().__init__() self._people = {"Ben": "office", "Bob": "garage"} @skill diff --git a/dimos/agents/skills/google_maps_skill_container.py b/dimos/agents/skills/google_maps_skill_container.py index 039b126500..33b2ee9f10 100644 --- a/dimos/agents/skills/google_maps_skill_container.py +++ b/dimos/agents/skills/google_maps_skill_container.py @@ -17,7 +17,6 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.core.stream import In from dimos.mapping.google_maps.google_maps import GoogleMaps @@ -33,8 +32,8 @@ class GoogleMapsSkillContainer(Module): gps_location: In[LatLon] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self) -> None: + super().__init__() self._client = GoogleMaps() self._started = True self._max_valid_distance = 20000 # meters diff --git a/dimos/agents/skills/gps_nav_skill.py b/dimos/agents/skills/gps_nav_skill.py index 63cf4a3dd3..721119f6e6 100644 --- a/dimos/agents/skills/gps_nav_skill.py +++ b/dimos/agents/skills/gps_nav_skill.py @@ -34,6 +34,9 @@ class GpsNavSkillContainer(Module): gps_location: In[LatLon] gps_goal: Out[LatLon] + def __init__(self) -> None: + super().__init__() + @rpc def start(self) -> None: super().start() diff --git a/dimos/agents/skills/navigation.py b/dimos/agents/skills/navigation.py index 838bbd6d92..b02ff3a446 100644 --- a/dimos/agents/skills/navigation.py +++ b/dimos/agents/skills/navigation.py @@ -19,7 +19,6 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.core.stream import In from dimos.models.qwen.bbox import BBox @@ -56,8 +55,8 @@ class NavigationSkillContainer(Module): color_image: In[Image] odom: In[PoseStamped] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self) -> None: + super().__init__() self._skill_started = False # Here to prevent unwanted imports in the file. diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index c491ee9137..4bb42b2970 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -14,7 +14,7 @@ from threading import Event, RLock, Thread import time -from typing import Any +from typing import TYPE_CHECKING from langchain_core.messages import HumanMessage import numpy as np @@ -23,12 +23,10 @@ from dimos.agents.agent import AgentSpec from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, ModuleConfig +from dimos.core.global_config import GlobalConfig +from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.models.qwen.bbox import BBox -from dimos.models.segmentation.edge_tam import EdgeTAMProcessor -from dimos.models.vl.base import VlModel from dimos.models.vl.qwen import QwenVlModel from dimos.msgs.geometry_msgs import Twist from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 @@ -37,15 +35,14 @@ from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D from dimos.utils.logging_config import setup_logger -logger = setup_logger() - +if TYPE_CHECKING: + from dimos.models.segmentation.edge_tam import EdgeTAMProcessor + from dimos.models.vl.base import VlModel -class Config(ModuleConfig): - camera_info: CameraInfo - use_3d_navigation: bool = False +logger = setup_logger() -class PersonFollowSkillContainer(Module[Config]): +class PersonFollowSkillContainer(Module): """Skill container for following a person. This skill uses: @@ -55,8 +52,6 @@ class PersonFollowSkillContainer(Module[Config]): - Does not do obstacle avoidance; assumes a clear path. """ - default_config = Config - color_image: In[Image] global_map: In[PointCloud2] cmd_vel: Out[Twist] @@ -65,24 +60,30 @@ class PersonFollowSkillContainer(Module[Config]): _frequency: float = 20.0 # Hz - control loop frequency _max_lost_frames: int = 15 # number of frames to wait before declaring person lost - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__( + self, + camera_info: CameraInfo, + cfg: GlobalConfig, + use_3d_navigation: bool = False, + ) -> None: + super().__init__() + self._global_config: GlobalConfig = cfg + self._use_3d_navigation: bool = use_3d_navigation self._latest_image: Image | None = None self._latest_pointcloud: PointCloud2 | None = None - # Use VlModel to keep usage in this class generic - self._vl_model: VlModel[Any] = QwenVlModel() + self._vl_model: VlModel = QwenVlModel() self._tracker: EdgeTAMProcessor | None = None self._thread: Thread | None = None self._should_stop: Event = Event() self._lock = RLock() # Use MuJoCo camera intrinsics in simulation mode - camera_info = self.config.camera_info if self._global_config.simulation: from dimos.robot.unitree.mujoco_connection import MujocoConnection camera_info = MujocoConnection.camera_info_static + self._camera_info = camera_info self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation) self._detection_navigation = DetectionNavigation(self.tf, camera_info) @@ -90,7 +91,7 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) - def start(self) -> None: super().start() self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) - if self.config.use_3d_navigation: + if self._use_3d_navigation: self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud))) @rpc @@ -229,7 +230,7 @@ def _follow_loop(self, tracker: "EdgeTAMProcessor", query: str) -> None: lost_count = 0 best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume()) - if self.config.use_3d_navigation: + if self._use_3d_navigation: with self._lock: pointcloud = self._latest_pointcloud if pointcloud is None: diff --git a/dimos/agents/skills/test_google_maps_skill_container.py b/dimos/agents/skills/test_google_maps_skill_container.py index cf2dd1c62e..1d8e4549b0 100644 --- a/dimos/agents/skills/test_google_maps_skill_container.py +++ b/dimos/agents/skills/test_google_maps_skill_container.py @@ -13,13 +13,11 @@ # limitations under the License. import re -from typing import Any from langchain_core.messages import HumanMessage import pytest from dimos.agents.skills.google_maps_skill_container import GoogleMapsSkillContainer -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.core.stream import Out from dimos.mapping.google_maps.types import Coordinates, LocationContext, Position @@ -41,8 +39,8 @@ def get_location_context(self, location, radius=200): class MockedWhereAmISkill(GoogleMapsSkillContainer): - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): - Module.__init__(self, global_config, **kwargs) # Skip GoogleMapsSkillContainer's __init__. + def __init__(self): + Module.__init__(self) # Skip GoogleMapsSkillContainer's __init__. self._client = FakeLocationClient() self._latest_location = LatLon(lat=37.782654, lon=-122.413273) self._started = True @@ -64,8 +62,8 @@ def get_position(self, query, location): class MockedPositionSkill(GoogleMapsSkillContainer): - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): - Module.__init__(self, global_config, **kwargs) + def __init__(self): + Module.__init__(self) self._client = FakePositionClient() self._latest_location = LatLon(lat=37.782654, lon=-122.413273) self._started = True diff --git a/dimos/agents/skills/test_gps_nav_skills.py b/dimos/agents/skills/test_gps_nav_skills.py index 4060b1814e..d701d469ca 100644 --- a/dimos/agents/skills/test_gps_nav_skills.py +++ b/dimos/agents/skills/test_gps_nav_skills.py @@ -28,9 +28,11 @@ class FakeGPS(Module): class MockedGpsNavSkill(GpsNavSkillContainer): - _latest_location = LatLon(lat=37.782654, lon=-122.413273) - _started = True - _max_valid_distance = 50000 + def __init__(self): + Module.__init__(self) + self._latest_location = LatLon(lat=37.782654, lon=-122.413273) + self._started = True + self._max_valid_distance = 50000 @pytest.mark.slow diff --git a/dimos/agents/skills/test_navigation.py b/dimos/agents/skills/test_navigation.py index e31fae93b5..a7505b23c7 100644 --- a/dimos/agents/skills/test_navigation.py +++ b/dimos/agents/skills/test_navigation.py @@ -31,17 +31,23 @@ class FakeOdom(Module): class MockedStopNavSkill(NavigationSkillContainer): - _skill_started = True rpc_calls: list[str] = [] + def __init__(self): + Module.__init__(self) + self._skill_started = True + def _cancel_goal_and_stop(self): pass class MockedExploreNavSkill(NavigationSkillContainer): - _skill_started = True rpc_calls: list[str] = [] + def __init__(self): + Module.__init__(self) + self._skill_started = True + def _start_exploration(self, timeout): return "Exploration completed successfuly" @@ -50,9 +56,12 @@ def _cancel_goal_and_stop(self): class MockedSemanticNavSkill(NavigationSkillContainer): - _skill_started = True rpc_calls: list[str] = [] + def __init__(self): + Module.__init__(self) + self._skill_started = True + def _navigate_by_tagged_location(self, query): return None diff --git a/dimos/agents/skills/test_unitree_skill_container.py b/dimos/agents/skills/test_unitree_skill_container.py index 8b95711e35..dde7239bbd 100644 --- a/dimos/agents/skills/test_unitree_skill_container.py +++ b/dimos/agents/skills/test_unitree_skill_container.py @@ -13,20 +13,18 @@ # limitations under the License. import difflib -from typing import Any from langchain_core.messages import HumanMessage import pytest -from dimos.core.global_config import GlobalConfig, global_config from dimos.robot.unitree.unitree_skill_container import _UNITREE_COMMANDS, UnitreeSkillContainer class MockedUnitreeSkill(UnitreeSkillContainer): rpc_calls: list[str] = [] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): - super().__init__(global_config, **kwargs) + def __init__(self): + super().__init__() # Provide a fake RPC so the real execute_sport_command runs end-to-end. self._bound_rpc_calls["GO2Connection.publish_request"] = lambda *args, **kwargs: None diff --git a/dimos/agents/test_agent.py b/dimos/agents/test_agent.py index b5cb743ef5..2464e622ca 100644 --- a/dimos/agents/test_agent.py +++ b/dimos/agents/test_agent.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any from langchain_core.messages import HumanMessage import pytest from dimos.agents.annotation import skill -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.msgs.sensor_msgs import Image from dimos.utils.data import get_data @@ -42,8 +40,10 @@ def test_can_call_tool(agent_setup): class UserRegistration(Module): - _first_call = True - _use_upper = False + def __init__(self): + super().__init__() + self._first_call = True + self._use_upper = False @skill def register_user(self, name: str) -> str: @@ -81,8 +81,8 @@ def test_can_call_again_on_error(agent_setup): class MultipleTools(Module): - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): - super().__init__(global_config, **kwargs) + def __init__(self): + super().__init__() self._people = {"Ben": "office", "Bob": "garage"} @skill diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 73e036e873..21d4c9d06c 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -104,6 +104,7 @@ class TaskConfig: gripper_closed_pos: float = 0.0 +@dataclass class ControlCoordinatorConfig(ModuleConfig): """Configuration for the ControlCoordinator. diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index ae668c814c..41b554a5b4 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -17,6 +17,7 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass, field from functools import cached_property, reduce +import inspect import operator import sys from types import MappingProxyType @@ -26,7 +27,7 @@ from dimos.protocol.service.system_configurator.base import SystemConfigurator from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, ModuleBase, ModuleSpec, is_module_type +from dimos.core.module import Module, is_module_type from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport @@ -34,11 +35,6 @@ from dimos.utils.generic import short_id from dimos.utils.logging_config import setup_logger -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing import Any as Self - logger = setup_logger() @@ -52,18 +48,21 @@ class StreamRef: @dataclass(frozen=True) class ModuleRef: name: str - spec: type[Spec] | type[ModuleBase] + spec: type[Spec] | type[Module] @dataclass(frozen=True) class _BlueprintAtom: - kwargs: dict[str, Any] - module: type[ModuleBase[Any]] + module: type[Module] streams: tuple[StreamRef, ...] module_refs: tuple[ModuleRef, ...] + args: tuple[Any, ...] + kwargs: dict[str, Any] @classmethod - def create(cls, module: type[ModuleBase[Any]], kwargs: dict[str, Any]) -> Self: + def create( + cls, module: type[Module], args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> "_BlueprintAtom": streams: list[StreamRef] = [] module_refs: list[ModuleRef] = [] @@ -104,6 +103,7 @@ def create(cls, module: type[ModuleBase[Any]], kwargs: dict[str, Any]) -> Self: module=module, streams=tuple(streams), module_refs=tuple(module_refs), + args=args, kwargs=kwargs, ) @@ -115,15 +115,15 @@ class Blueprint: default_factory=lambda: MappingProxyType({}) ) global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) - remapping_map: Mapping[tuple[type[ModuleBase], str], str | type[ModuleBase] | type[Spec]] = ( - field(default_factory=lambda: MappingProxyType({})) + remapping_map: Mapping[tuple[type[Module], str], str | type[Module] | type[Spec]] = field( + default_factory=lambda: MappingProxyType({}) ) requirement_checks: tuple[Callable[[], str | None], ...] = field(default_factory=tuple) configurator_checks: "tuple[SystemConfigurator, ...]" = field(default_factory=tuple) @classmethod - def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint": - blueprint = _BlueprintAtom.create(module, kwargs) + def create(cls, module: type[Module], *args: Any, **kwargs: Any) -> "Blueprint": + blueprint = _BlueprintAtom.create(module, args, kwargs) return cls(blueprints=(blueprint,)) def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint": @@ -147,10 +147,7 @@ def global_config(self, **kwargs: Any) -> "Blueprint": ) def remappings( - self, - remappings: list[ - tuple[type[ModuleBase[Any]], str, str | type[ModuleBase[Any]] | type[Spec]] - ], + self, remappings: list[tuple[type[Module], str, str | type[Module] | type[Spec]]] ) -> "Blueprint": remappings_dict = dict(self.remapping_map) for module, old, new in remappings: @@ -188,8 +185,8 @@ def configurators(self, *checks: "SystemConfigurator") -> "Blueprint": def _check_ambiguity( self, requested_method_name: str, - interface_methods: Mapping[str, list[tuple[type[ModuleBase], Callable[..., Any]]]], - requesting_module: type[ModuleBase], + interface_methods: Mapping[str, list[tuple[type[Module], Callable[..., Any]]]], + requesting_module: type[Module], ) -> None: if ( requested_method_name in interface_methods @@ -298,9 +295,13 @@ def _verify_no_name_conflicts(self) -> None: def _deploy_all_modules( self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig ) -> None: - module_specs: list[ModuleSpec] = [] + module_specs: list[tuple[type[Module], tuple[Any, ...], dict[str, Any]]] = [] for blueprint in self.blueprints: - module_specs.append((blueprint.module, global_config, blueprint.kwargs)) + kwargs = {**blueprint.kwargs} + sig = inspect.signature(blueprint.module.__init__) + if "cfg" in sig.parameters: + kwargs["cfg"] = global_config + module_specs.append((blueprint.module, blueprint.args, kwargs)) module_coordinator.deploy_parallel(module_specs) @@ -420,12 +421,12 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: rpc_methods_dot = {} # Track interface methods to detect ambiguity. - interface_methods: defaultdict[str, list[tuple[type[ModuleBase], Callable[..., Any]]]] = ( + interface_methods: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( defaultdict(list) ) # interface_name_method -> [(module_class, method)] - interface_methods_dot: defaultdict[ - str, list[tuple[type[ModuleBase], Callable[..., Any]]] - ] = defaultdict(list) # interface_name.method -> [(module_class, method)] + interface_methods_dot: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( + defaultdict(list) + ) # interface_name.method -> [(module_class, method)] for blueprint in self.blueprints: for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined] diff --git a/dimos/core/docker_runner.py b/dimos/core/docker_runner.py index 99833a9b97..ee56163ca6 100644 --- a/dimos/core/docker_runner.py +++ b/dimos/core/docker_runner.py @@ -15,7 +15,7 @@ import argparse from contextlib import suppress -from dataclasses import field +from dataclasses import dataclass, field import importlib import json import os @@ -46,6 +46,7 @@ LOG_TAIL_LINES = 200 # Number of log lines to include in error messages +@dataclass(kw_only=True) class DockerModuleConfig(ModuleConfig): """ Configuration for running a DimOS module inside Docker. diff --git a/dimos/core/introspection/blueprint/dot.py b/dimos/core/introspection/blueprint/dot.py index 74ee9406a9..ea66401033 100644 --- a/dimos/core/introspection/blueprint/dot.py +++ b/dimos/core/introspection/blueprint/dot.py @@ -31,7 +31,7 @@ color_for_string, sanitize_id, ) -from dimos.core.module import ModuleBase +from dimos.core.module import Module from dimos.utils.cli import theme @@ -82,11 +82,11 @@ def render( ignored_modules = DEFAULT_IGNORED_MODULES # Collect all outputs: (name, type) -> list of producer modules - producers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list) + producers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list) # Collect all inputs: (name, type) -> list of consumer modules - consumers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list) + consumers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list) # Module name -> module class (for getting package info) - module_classes: dict[str, type[ModuleBase]] = {} + module_classes: dict[str, type[Module]] = {} for bp in blueprint_set.blueprints: module_classes[bp.module.__name__] = bp.module @@ -117,7 +117,7 @@ def render( active_channels[key] = color_for_string(TYPE_COLORS, label) # Group modules by package - def get_group(mod_class: type[ModuleBase]) -> str: + def get_group(mod_class: type[Module]) -> str: module_path = mod_class.__module__ parts = module_path.split(".") if len(parts) >= 2 and parts[0] == "dimos": diff --git a/dimos/core/module.py b/dimos/core/module.py index daf0f0d7fd..48a99a79a3 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -17,43 +17,38 @@ from functools import partial import inspect import json -import sys import threading from typing import ( TYPE_CHECKING, Any, - Protocol, get_args, get_origin, get_type_hints, overload, ) +from typing_extensions import TypeVar as TypeVarExtension + +if TYPE_CHECKING: + from dimos.core.introspection.module import ModuleInfo + from dimos.core.rpc_client import RPCClient + +from typing import TypeVar + from langchain_core.tools import tool from reactivex.disposable import CompositeDisposable from dimos.core.core import T, rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.introspection.module import extract_module_info, render_module_io from dimos.core.resource import Resource from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec -from dimos.protocol.service import BaseConfig, Configurable +from dimos.protocol.service import Configurable # type: ignore[attr-defined] from dimos.protocol.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty -if TYPE_CHECKING: - from dimos.core.blueprints import Blueprint - from dimos.core.introspection.module import ModuleInfo - from dimos.core.rpc_client import RPCClient - -if sys.version_info >= (3, 13): - from typing import TypeVar -else: - from typing_extensions import TypeVar - @dataclass(frozen=True) class SkillInfo: @@ -75,26 +70,20 @@ def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: return loop, thr -class ModuleConfig(BaseConfig): +@dataclass +class ModuleConfig: rpc_transport: type[RPCSpec] = LCMRPC - tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] + tf_transport: type[TFSpec] = LCMTF frame_id_prefix: str | None = None frame_id: str | None = None -ModuleConfigT = TypeVar("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) - - -class _BlueprintPartial(Protocol): - def __call__(self, **kwargs: Any) -> "Blueprint": ... +ModuleConfigT = TypeVarExtension("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) class ModuleBase(Configurable[ModuleConfigT], Resource): - # This won't type check against the TypeVar, but we need it as the default. - default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] - _rpc: RPCSpec | None = None - _tf: TFSpec[Any] | None = None + _tf: TFSpec | None = None _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None _disposables: CompositeDisposable @@ -104,9 +93,10 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): rpc_calls: list[str] = [] - def __init__(self, config_args: dict[str, Any], global_config: GlobalConfig): - super().__init__(**config_args) - self._global_config = global_config + default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) self._module_closed_lock = threading.Lock() self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() @@ -348,7 +338,7 @@ def __get__( module_info = _module_info_descriptor() @classproperty - def blueprint(self) -> _BlueprintPartial: + def blueprint(self): # type: ignore[no-untyped-def] # Here to prevent circular imports. from dimos.core.blueprints import Blueprint @@ -419,7 +409,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not hasattr(cls, name) or getattr(cls, name) is None: setattr(cls, name, None) - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] self.ref = None # type: ignore[assignment] try: @@ -437,7 +427,7 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any): inner, *_ = get_args(ann) or (Any,) stream = In(inner, name, self) # type: ignore[assignment] setattr(self, name, stream) - super().__init__(config_args=kwargs, global_config=global_config) + super().__init__(*args, **kwargs) def __str__(self) -> str: return f"{self.__class__.__name__}" @@ -475,7 +465,7 @@ def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): # type: input_stream.connection = remote_stream -ModuleSpec = tuple[type[ModuleBase], GlobalConfig, dict[str, Any]] +ModuleT = TypeVar("ModuleT", bound="Module[Any]") def is_module_type(value: Any) -> bool: diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 4ec6bc5725..86afb9ebc4 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.resource import Resource from dimos.core.worker_manager import WorkerManager from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: + from dimos.core.module import Module, ModuleT from dimos.core.resource_monitor.monitor import StatsMonitor from dimos.core.rpc_client import ModuleProxy @@ -36,7 +36,7 @@ class ModuleCoordinator(Resource): # type: ignore[misc] _global_config: GlobalConfig _n: int | None = None _memory_limit: str = "auto" - _deployed_modules: dict[type[ModuleBase], ModuleProxy] + _deployed_modules: dict[type[Module], ModuleProxy] _stats_monitor: StatsMonitor | None = None def __init__( @@ -75,20 +75,17 @@ def stop(self) -> None: self._client.close_all() # type: ignore[union-attr] - def deploy( - self, - module_class: type[ModuleBase[Any]], - global_config: GlobalConfig = global_config, - **kwargs: Any, - ) -> ModuleProxy: + def deploy(self, module_class: type[ModuleT], *args, **kwargs) -> ModuleProxy: # type: ignore[no-untyped-def] if not self._client: raise ValueError("Trying to dimos.deploy before the client has started") - module = self._client.deploy(module_class, global_config, kwargs) - self._deployed_modules[module_class] = module # type: ignore[assignment] - return module # type: ignore[return-value] + module: ModuleProxy = self._client.deploy(module_class, *args, **kwargs) # type: ignore[union-attr, attr-defined, assignment] + self._deployed_modules[module_class] = module + return module - def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]: + def deploy_parallel( + self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[str, Any]]] + ) -> list[ModuleProxy]: if not self._client: raise ValueError("Not started") @@ -111,7 +108,7 @@ def start_all_modules(self) -> None: if hasattr(module, "on_system_modules"): module.on_system_modules(module_list) - def get_instance(self, module: type[ModuleBase]) -> ModuleProxy: + def get_instance(self, module: type[ModuleT]) -> ModuleProxy: return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return] def loop(self) -> None: diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index bec23e42e1..6a93e6453a 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -40,6 +40,7 @@ class MyCppModule(NativeModule): from __future__ import annotations +from dataclasses import dataclass, field, fields import enum import inspect import json @@ -47,22 +48,13 @@ class MyCppModule(NativeModule): from pathlib import Path import signal import subprocess -import sys import threading from typing import IO, Any -from pydantic import Field - from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger -if sys.version_info < (3, 13): - from typing_extensions import TypeVar -else: - from typing import TypeVar - logger = setup_logger() @@ -71,14 +63,15 @@ class LogFormat(enum.Enum): JSON = "json" +@dataclass(kw_only=True) class NativeModuleConfig(ModuleConfig): """Configuration for a native (C/C++) subprocess module.""" executable: str build_command: str | None = None cwd: str | None = None - extra_args: list[str] = Field(default_factory=list) - extra_env: dict[str, str] = Field(default_factory=dict) + extra_args: list[str] = field(default_factory=list) + extra_env: dict[str, str] = field(default_factory=dict) shutdown_timeout: float = 10.0 log_format: LogFormat = LogFormat.TEXT @@ -92,29 +85,26 @@ def to_cli_args(self) -> list[str]: or its parents) and converts them to ``["--name", str(value)]`` pairs. Skips fields whose values are ``None`` and fields in ``cli_exclude``. """ - ignore_fields = {f for f in NativeModuleConfig.model_fields} + ignore_fields = {f.name for f in fields(NativeModuleConfig)} args: list[str] = [] - for f in self.__class__.model_fields: - if f in ignore_fields: + for f in fields(self): + if f.name in ignore_fields: continue - if f in self.cli_exclude: + if f.name in self.cli_exclude: continue - val = getattr(self, f) + val = getattr(self, f.name) if val is None: continue if isinstance(val, bool): - args.extend([f"--{f}", str(val).lower()]) + args.extend([f"--{f.name}", str(val).lower()]) elif isinstance(val, list): - args.extend([f"--{f}", ",".join(str(v) for v in val)]) + args.extend([f"--{f.name}", ",".join(str(v) for v in val)]) else: - args.extend([f"--{f}", str(val)]) + args.extend([f"--{f.name}", str(val)]) return args -_NativeConfig = TypeVar("_NativeConfig", bound=NativeModuleConfig, default=NativeModuleConfig) - - -class NativeModule(Module[_NativeConfig]): +class NativeModule(Module[NativeModuleConfig]): """Module that wraps a native executable as a managed subprocess. Subclass this, declare In/Out ports, and set ``default_config`` to a @@ -128,13 +118,13 @@ class NativeModule(Module[_NativeConfig]): LCM topics directly. On ``stop()``, the process receives SIGTERM. """ - default_config: type[_NativeConfig] = NativeModuleConfig # type: ignore[assignment] + default_config: type[NativeModuleConfig] = NativeModuleConfig _process: subprocess.Popen[bytes] | None = None _watchdog: threading.Thread | None = None _stopping: bool = False - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._resolve_paths() @rpc diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index f91591d919..fd18fe72d8 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -114,13 +114,14 @@ class ModuleC(Module): def test_get_connection_set() -> None: - assert _BlueprintAtom.create(CatModule, kwargs={"k": "v"}) == _BlueprintAtom( + assert _BlueprintAtom.create(CatModule, args=("arg1",), kwargs={"k": "v"}) == _BlueprintAtom( module=CatModule, streams=( StreamRef(name="pet_cat", type=Petting, direction="in"), StreamRef(name="scratches", type=Scratch, direction="out"), ), module_refs=(), + args=("arg1",), kwargs={"k": "v"}, ) @@ -137,6 +138,7 @@ def test_autoconnect() -> None: StreamRef(name="data2", type=Data2, direction="out"), ), module_refs=(), + args=(), kwargs={}, ), _BlueprintAtom( @@ -147,6 +149,7 @@ def test_autoconnect() -> None: StreamRef(name="data3", type=Data3, direction="out"), ), module_refs=(), + args=(), kwargs={}, ), ) @@ -343,11 +346,11 @@ def test_future_annotations_support() -> None: """ # Test that streams are properly extracted from modules with future annotations - out_blueprint = _BlueprintAtom.create(FutureModuleOut, kwargs={}) + out_blueprint = _BlueprintAtom.create(FutureModuleOut, args=(), kwargs={}) assert len(out_blueprint.streams) == 1 assert out_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="out") - in_blueprint = _BlueprintAtom.create(FutureModuleIn, kwargs={}) + in_blueprint = _BlueprintAtom.create(FutureModuleIn, args=(), kwargs={}) assert len(in_blueprint.streams) == 1 assert in_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="in") diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 3bd1383761..197539ef67 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -39,6 +39,9 @@ class Navigation(Module): @rpc def navigate_to(self, target: Vector3) -> bool: ... + def __init__(self) -> None: + super().__init__() + @rpc def start(self) -> None: def _odom(msg) -> None: diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index 4e4ef1d8f7..0df78ac23f 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -18,6 +18,7 @@ The echo script writes received CLI args to a temp file for assertions. """ +from dataclasses import dataclass import json from pathlib import Path import time @@ -58,6 +59,7 @@ def read_json_file(path: str) -> dict[str, str]: return result +@dataclass(kw_only=True) class StubNativeConfig(NativeModuleConfig): executable: str = _ECHO log_format: LogFormat = LogFormat.TEXT diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index 005d5f3bff..a7c949b33a 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -15,12 +15,10 @@ from collections.abc import Callable import threading import time -from typing import Any import pytest from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.core.stream import In from dimos.core.testing import MockRobotClient @@ -30,15 +28,15 @@ class SubscriberBase(Module): - sub1_msgs: list[Odometry] - sub2_msgs: list[Odometry] + sub1_msgs: list[Odometry] = None + sub2_msgs: list[Odometry] = None - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + def __init__(self) -> None: self.sub1_msgs = [] self.sub2_msgs = [] self._sub1_received = threading.Event() self._sub2_received = threading.Event() - super().__init__(global_config, **kwargs) + super().__init__() def _sub1_callback(self, msg) -> None: self.sub1_msgs.append(msg) diff --git a/dimos/core/test_worker.py b/dimos/core/test_worker.py index 9a5ac83009..a5217f2dd6 100644 --- a/dimos/core/test_worker.py +++ b/dimos/core/test_worker.py @@ -17,7 +17,6 @@ import pytest from dimos.core.core import rpc -from dimos.core.global_config import global_config from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.core.worker_manager import WorkerManager @@ -100,7 +99,7 @@ def _create(n_workers): @pytest.mark.slow def test_worker_manager_basic(create_worker_manager): worker_manager = create_worker_manager(n_workers=2) - module = worker_manager.deploy(SimpleModule, global_config, {}) + module = worker_manager.deploy(SimpleModule) module.start() result = module.increment() @@ -118,8 +117,8 @@ def test_worker_manager_basic(create_worker_manager): @pytest.mark.slow def test_worker_manager_multiple_different_modules(create_worker_manager): worker_manager = create_worker_manager(n_workers=2) - module1 = worker_manager.deploy(SimpleModule, global_config, {}) - module2 = worker_manager.deploy(AnotherModule, global_config, {}) + module1 = worker_manager.deploy(SimpleModule) + module2 = worker_manager.deploy(AnotherModule) module1.start() module2.start() @@ -176,8 +175,8 @@ def test_collect_stats(create_worker_manager): from dimos.core.resource_monitor.monitor import StatsMonitor manager = create_worker_manager(n_workers=2) - module1 = manager.deploy(SimpleModule, global_config, {}) - module2 = manager.deploy(AnotherModule, global_config, {}) + module1 = manager.deploy(SimpleModule) + module2 = manager.deploy(AnotherModule) module1.start() module2.start() @@ -220,8 +219,8 @@ def log_stats(self, coordinator, workers): @pytest.mark.slow def test_worker_pool_modules_share_workers(create_worker_manager): manager = create_worker_manager(n_workers=1) - module1 = manager.deploy(SimpleModule, global_config, {}) - module2 = manager.deploy(AnotherModule, global_config, {}) + module1 = manager.deploy(SimpleModule) + module2 = manager.deploy(AnotherModule) module1.start() module2.start() diff --git a/dimos/core/testing.py b/dimos/core/testing.py index 4884272fc6..6431c09dbd 100644 --- a/dimos/core/testing.py +++ b/dimos/core/testing.py @@ -14,10 +14,8 @@ from threading import Event, Thread import time -from typing import Any from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs import Vector3 @@ -34,8 +32,8 @@ class MockRobotClient(Module): mov_msg_count = 0 - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self) -> None: + super().__init__() self._stop_event = Event() self._thread = None diff --git a/dimos/core/worker.py b/dimos/core/worker.py index e32c70b9ab..b0dd802841 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -14,17 +14,17 @@ from __future__ import annotations import multiprocessing as mp -from multiprocessing.connection import Connection import threading import traceback from typing import TYPE_CHECKING, Any -from dimos.core.global_config import GlobalConfig, global_config from dimos.utils.logging_config import setup_logger from dimos.utils.sequential_ids import SequentialIds if TYPE_CHECKING: - from dimos.core.module import ModuleBase + from multiprocessing.connection import Connection + + from dimos.core.module import ModuleT logger = setup_logger() @@ -72,7 +72,7 @@ class Actor: def __init__( self, conn: Connection | None, - module_class: type[ModuleBase], + module_class: type[ModuleT], worker_id: int, module_id: int = 0, lock: threading.Lock | None = None, @@ -140,6 +140,8 @@ def reset_forkserver_context() -> None: class Worker: + """Generic worker process that can host multiple modules.""" + def __init__(self) -> None: self._lock = threading.Lock() self._modules: dict[int, Actor] = {} @@ -186,9 +188,9 @@ def start_process(self) -> None: def deploy_module( self, - module_class: type[ModuleBase], - global_config: GlobalConfig = global_config, - kwargs: dict[str, Any] | None = None, + module_class: type[ModuleT], + args: tuple[Any, ...] = (), + kwargs: dict[Any, Any] | None = None, ) -> Actor: if self._conn is None: raise RuntimeError("Worker process not started") @@ -201,7 +203,7 @@ def deploy_module( "type": "deploy_module", "module_id": module_id, "module_class": module_class, - "global_config": global_config, + "args": args, "kwargs": kwargs, } with self._lock: @@ -254,7 +256,10 @@ def shutdown(self) -> None: self._process = None -def _worker_entrypoint(conn: Connection, worker_id: int) -> None: +def _worker_entrypoint( + conn: Connection, + worker_id: int, +) -> None: instances: dict[int, Any] = {} try: @@ -304,10 +309,10 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) -> if req_type == "deploy_module": module_class = request["module_class"] - request["global_config"] + args = request.get("args", ()) kwargs = request.get("kwargs", {}) module_id = request["module_id"] - instance = module_class(global_config, **kwargs) + instance = module_class(*args, **kwargs) instances[module_id] = instance response["result"] = module_id diff --git a/dimos/core/worker_manager.py b/dimos/core/worker_manager.py index cbd55c9b47..4dbb51eb54 100644 --- a/dimos/core/worker_manager.py +++ b/dimos/core/worker_manager.py @@ -15,14 +15,15 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import TYPE_CHECKING, Any -from dimos.core.global_config import GlobalConfig -from dimos.core.module import ModuleBase from dimos.core.rpc_client import RPCClient from dimos.core.worker import Worker from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from dimos.core.module import ModuleT + logger = setup_logger() @@ -46,9 +47,7 @@ def start(self) -> None: def _select_worker(self) -> Worker: return min(self._workers, key=lambda w: w.module_count) - def deploy( - self, module_class: type[ModuleBase], global_config: GlobalConfig, kwargs: dict[str, Any] - ) -> RPCClient: + def deploy(self, module_class: type[ModuleT], *args: Any, **kwargs: Any) -> RPCClient: if self._closed: raise RuntimeError("WorkerManager is closed") @@ -57,11 +56,11 @@ def deploy( self.start() worker = self._select_worker() - actor = worker.deploy_module(module_class, global_config, kwargs=kwargs) + actor = worker.deploy_module(module_class, args=args, kwargs=kwargs) return RPCClient(actor, module_class) def deploy_parallel( - self, module_specs: list[tuple[type[ModuleBase], GlobalConfig, dict[str, Any]]] + self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[Any, Any]]] ) -> list[RPCClient]: if self._closed: raise RuntimeError("WorkerManager is closed") @@ -73,17 +72,17 @@ def deploy_parallel( # Pre-assign workers sequentially (so least-loaded accounting is # correct), then deploy concurrently via threads. The per-worker lock # serializes deploys that land on the same worker process. - assignments: list[tuple[Worker, type[ModuleBase], GlobalConfig, dict[str, Any]]] = [] - for module_class, global_config, kwargs in module_specs: + assignments: list[tuple[Worker, type[ModuleT], tuple[Any, ...], dict[Any, Any]]] = [] + for module_class, args, kwargs in module_specs: worker = self._select_worker() worker.reserve_slot() - assignments.append((worker, module_class, global_config, kwargs)) + assignments.append((worker, module_class, args, kwargs)) def _deploy( - item: tuple[Worker, type[ModuleBase], GlobalConfig, dict[str, Any]], + item: tuple[Worker, type[ModuleT], tuple[Any, ...], dict[Any, Any]], ) -> RPCClient: worker, module_class, args, kwargs = item - actor = worker.deploy_module(module_class, global_config=global_config, kwargs=kwargs) + actor = worker.deploy_module(module_class, args=args, kwargs=kwargs) return RPCClient(actor, module_class) with ThreadPoolExecutor(max_workers=len(assignments)) as pool: diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 7c7edaaf5d..11821d4724 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -22,6 +22,7 @@ from dimos.agents.annotation import skill from dimos.core.blueprints import autoconnect from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -54,9 +55,16 @@ class CameraModule(Module[CameraModuleConfig], perception.Camera): color_image: Out[Image] camera_info: Out[CameraInfo] - default_config = CameraModuleConfig hardware: CameraHardware[Any] - _latest_image: Image | None = None + + config: CameraModuleConfig + default_config = CameraModuleConfig + _global_config: GlobalConfig + + def __init__(self, *args: Any, cfg: GlobalConfig = global_config, **kwargs: Any) -> None: + self._global_config = cfg + self._latest_image: Image | None = None + super().__init__(*args, **kwargs) @rpc def start(self) -> None: diff --git a/dimos/hardware/sensors/camera/spec.py b/dimos/hardware/sensors/camera/spec.py index c913e4bfea..23fd1a076e 100644 --- a/dimos/hardware/sensors/camera/spec.py +++ b/dimos/hardware/sensors/camera/spec.py @@ -13,19 +13,19 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import TypeVar +from typing import Generic, Protocol, TypeVar from reactivex.observable import Observable from dimos.msgs.geometry_msgs import Quaternion, Transform from dimos.msgs.sensor_msgs import CameraInfo from dimos.msgs.sensor_msgs.Image import Image -from dimos.protocol.service.spec import BaseConfig, Configurable +from dimos.protocol.service import Configurable # type: ignore[attr-defined] OPTICAL_ROTATION = Quaternion(-0.5, 0.5, -0.5, 0.5) -class CameraConfig(BaseConfig): +class CameraConfig(Protocol): frame_id_prefix: str | None width: int height: int @@ -35,7 +35,7 @@ class CameraConfig(BaseConfig): CameraConfigT = TypeVar("CameraConfigT", bound=CameraConfig) -class CameraHardware(ABC, Configurable[CameraConfigT]): +class CameraHardware(ABC, Configurable[CameraConfigT], Generic[CameraConfigT]): @abstractmethod def image_stream(self) -> Observable[Image]: pass diff --git a/dimos/hardware/sensors/camera/zed/__init__.py b/dimos/hardware/sensors/camera/zed/__init__.py index 6e3b905e90..f8e73273bf 100644 --- a/dimos/hardware/sensors/camera/zed/__init__.py +++ b/dimos/hardware/sensors/camera/zed/__init__.py @@ -18,15 +18,15 @@ from dimos.msgs.sensor_msgs.CameraInfo import CalibrationProvider +# Check if ZED SDK is available try: - import pyzed.sl # noqa: F401 + import pyzed.sl as sl # noqa: F401 - # This awkwardness is needed as pytest implicitly imports this to collect - # the test in this directory. HAS_ZED_SDK = True except ImportError: HAS_ZED_SDK = False +# Only import ZED classes if SDK is available if HAS_ZED_SDK: from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule, zed_camera else: @@ -43,7 +43,7 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." ) - def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[misc,no-redef] + def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[no-redef] raise ModuleNotFoundError( "ZED SDK not installed. Please install pyzed package to use ZED camera functionality.", name="pyzed", diff --git a/dimos/hardware/sensors/camera/zed/test_zed.py b/dimos/hardware/sensors/camera/zed/test_zed.py index 2716e809a5..2d912553c6 100644 --- a/dimos/hardware/sensors/camera/zed/test_zed.py +++ b/dimos/hardware/sensors/camera/zed/test_zed.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - -from dimos.hardware.sensors.camera import zed from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo -@pytest.mark.skipif(not zed.HAS_ZED_SDK, reason="ZED SDK not installed") def test_zed_import_and_calibration_access() -> None: """Test that zed module can be imported and calibrations accessed.""" + # Import zed module from camera + from dimos.hardware.sensors.camera import zed + # Test that CameraInfo is accessible assert hasattr(zed, "CameraInfo") diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py index f119179705..ec5613077d 100644 --- a/dimos/hardware/sensors/fake_zed_module.py +++ b/dimos/hardware/sensors/fake_zed_module.py @@ -17,6 +17,7 @@ FakeZEDModule - Replays recorded ZED data for testing without hardware. """ +from dataclasses import dataclass import functools import logging @@ -24,7 +25,6 @@ import numpy as np from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.msgs.geometry_msgs import PoseStamped @@ -37,8 +37,8 @@ logger = setup_logger(level=logging.INFO) +@dataclass class FakeZEDModuleConfig(ModuleConfig): - recording_path: str frame_id: str = "zed_camera" @@ -54,17 +54,18 @@ class FakeZEDModule(Module[FakeZEDModuleConfig]): pose: Out[PoseStamped] default_config = FakeZEDModuleConfig + config: FakeZEDModuleConfig - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: object) -> None: + def __init__(self, recording_path: str, **kwargs: object) -> None: """ Initialize FakeZEDModule with recording path. Args: recording_path: Path to recorded data directory """ - super().__init__(global_config, **kwargs) + super().__init__(**kwargs) - self.recording_path = self.config.recording_path + self.recording_path = recording_path self._running = False # Initialize TF publisher diff --git a/dimos/hardware/sensors/lidar/fastlio2/module.py b/dimos/hardware/sensors/lidar/fastlio2/module.py index 6ccb7aa764..fb894ddce5 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/module.py +++ b/dimos/hardware/sensors/lidar/fastlio2/module.py @@ -35,7 +35,7 @@ from typing import TYPE_CHECKING from dimos.core.native_module import NativeModule, NativeModuleConfig -from dimos.core.stream import Out +from dimos.core.stream import Out # noqa: TC001 from dimos.hardware.sensors.lidar.livox.ports import ( SDK_CMD_DATA_PORT, SDK_HOST_CMD_DATA_PORT, @@ -48,8 +48,8 @@ SDK_POINT_DATA_PORT, SDK_PUSH_MSG_PORT, ) -from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.nav_msgs.Odometry import Odometry # noqa: TC001 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # noqa: TC001 from dimos.spec import mapping, perception _CONFIG_DIR = Path(__file__).parent / "config" diff --git a/dimos/hardware/sensors/lidar/livox/module.py b/dimos/hardware/sensors/lidar/livox/module.py index 999cdd9aa1..2e470b21ef 100644 --- a/dimos/hardware/sensors/lidar/livox/module.py +++ b/dimos/hardware/sensors/lidar/livox/module.py @@ -26,10 +26,11 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING from dimos.core.native_module import NativeModule, NativeModuleConfig -from dimos.core.stream import Out +from dimos.core.stream import Out # noqa: TC001 from dimos.hardware.sensors.lidar.livox.ports import ( SDK_CMD_DATA_PORT, SDK_HOST_CMD_DATA_PORT, @@ -42,11 +43,12 @@ SDK_POINT_DATA_PORT, SDK_PUSH_MSG_PORT, ) -from dimos.msgs.sensor_msgs.Imu import Imu -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.sensor_msgs.Imu import Imu # noqa: TC001 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # noqa: TC001 from dimos.spec import perception +@dataclass(kw_only=True) class Mid360Config(NativeModuleConfig): """Config for the C++ Mid-360 native module.""" @@ -74,7 +76,7 @@ class Mid360Config(NativeModuleConfig): host_log_data_port: int = SDK_HOST_LOG_DATA_PORT -class Mid360(NativeModule[Mid360Config], perception.Lidar, perception.IMU): +class Mid360(NativeModule, perception.Lidar, perception.IMU): """Livox Mid-360 LiDAR module backed by a native C++ binary. Ports: diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index 6c6320b301..40dd6734c5 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -24,7 +24,7 @@ from __future__ import annotations -from collections.abc import Iterable +from dataclasses import dataclass, field from enum import Enum import threading import time @@ -32,7 +32,6 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In from dimos.manipulation.planning import ( @@ -83,17 +82,18 @@ class ManipulationState(Enum): FAULT = 4 +@dataclass class ManipulationModuleConfig(ModuleConfig): """Configuration for ManipulationModule.""" - robots: Iterable[RobotModelConfig] = () + robots: list[RobotModelConfig] = field(default_factory=list) planning_timeout: float = 10.0 enable_viz: bool = False planner_name: str = "rrt_connect" # "rrt_connect" kinematics_name: str = "jacobian" # "jacobian" or "drake_optimization" -class ManipulationModule(Module[ManipulationModuleConfig]): +class ManipulationModule(Module): """Base motion planning module with ControlCoordinator execution. - @rpc: Low-level building blocks (plan, execute, gripper) @@ -104,11 +104,14 @@ class ManipulationModule(Module[ManipulationModuleConfig]): default_config = ManipulationModuleConfig + # Type annotation for the config attribute (mypy uses this) + config: ManipulationModuleConfig + # Input: Joint state from coordinator (for world sync) joint_state: In[JointState] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) # State machine self._state = ManipulationState.IDLE diff --git a/dimos/manipulation/pick_and_place_module.py b/dimos/manipulation/pick_and_place_module.py index 251e42e1c0..84ede61793 100644 --- a/dimos/manipulation/pick_and_place_module.py +++ b/dimos/manipulation/pick_and_place_module.py @@ -22,6 +22,7 @@ from __future__ import annotations +from dataclasses import dataclass, field import math from pathlib import Path import time @@ -31,8 +32,7 @@ from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core.core import rpc from dimos.core.docker_runner import DockerModule as DockerRunner -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.stream import In +from dimos.core.stream import In # noqa: TC001 from dimos.manipulation.grasping.graspgen_module import GraspGenModule from dimos.manipulation.manipulation_module import ( ManipulationModule, @@ -40,7 +40,7 @@ ) from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 from dimos.perception.detection.type.detection3d.object import ( - Object as DetObject, + Object as DetObject, # noqa: TC001 ) from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger @@ -56,6 +56,7 @@ _GRASPGEN_VIZ_CONTAINER_PATH = f"{_GRASPGEN_VIZ_CONTAINER_DIR}/visualization.json" +@dataclass class PickAndPlaceModuleConfig(ManipulationModuleConfig): """Configuration for PickAndPlaceModule (adds GraspGen settings).""" @@ -67,8 +68,8 @@ class PickAndPlaceModuleConfig(ManipulationModuleConfig): graspgen_grasp_threshold: float = -1.0 graspgen_filter_collisions: bool = False graspgen_save_visualization_data: bool = False - graspgen_visualization_output_path: Path = ( - Path.home() / ".dimos" / "graspgen" / "visualization.json" + graspgen_visualization_output_path: Path = field( + default_factory=lambda: Path.home() / ".dimos" / "graspgen" / "visualization.json" ) @@ -89,8 +90,8 @@ class PickAndPlaceModule(ManipulationModule): # Input: Objects from perception (for obstacle integration) objects: In[list[DetObject]] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) # GraspGen Docker runner (lazy initialized on first generate_grasps call) self._graspgen: DockerRunner | None = None diff --git a/dimos/manipulation/planning/spec/config.py b/dimos/manipulation/planning/spec/config.py index e379fc1eb5..dc302689ea 100644 --- a/dimos/manipulation/planning/spec/config.py +++ b/dimos/manipulation/planning/spec/config.py @@ -16,16 +16,17 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence -from pathlib import Path +from dataclasses import dataclass, field +from typing import TYPE_CHECKING -from pydantic import Field +if TYPE_CHECKING: + from pathlib import Path -from dimos.core.module import ModuleConfig -from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs import PoseStamped -class RobotModelConfig(ModuleConfig): +@dataclass +class RobotModelConfig: """Configuration for adding a robot to the world. Attributes: @@ -59,24 +60,24 @@ class RobotModelConfig(ModuleConfig): joint_names: list[str] end_effector_link: str base_link: str = "base_link" - package_paths: dict[str, Path] = Field(default_factory=dict) + package_paths: dict[str, Path] = field(default_factory=dict) joint_limits_lower: list[float] | None = None joint_limits_upper: list[float] | None = None velocity_limits: list[float] | None = None auto_convert_meshes: bool = False - xacro_args: dict[str, str] = Field(default_factory=dict) - collision_exclusion_pairs: Iterable[tuple[str, str]] = () + xacro_args: dict[str, str] = field(default_factory=dict) + collision_exclusion_pairs: list[tuple[str, str]] = field(default_factory=list) # Motion constraints for trajectory generation max_velocity: float = 1.0 max_acceleration: float = 2.0 # Coordinator integration - joint_name_mapping: dict[str, str] = Field(default_factory=dict) + joint_name_mapping: dict[str, str] = field(default_factory=dict) coordinator_task_name: str | None = None gripper_hardware_id: str | None = None # TF publishing for extra links (e.g., camera mount) - tf_extra_links: Sequence[str] = () + tf_extra_links: list[str] = field(default_factory=list) # Home/observe joint configuration for go_home skill - home_joints: Iterable[float] | None = None + home_joints: list[float] | None = None # Pre-grasp offset distance in meters (along approach direction) pre_grasp_offset: float = 0.10 diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py index 40c4fc742f..fa0ce826f2 100644 --- a/dimos/mapping/costmapper.py +++ b/dimos/mapping/costmapper.py @@ -39,14 +39,16 @@ class Config(ModuleConfig): config: OccupancyConfig = field(default_factory=HeightCostConfig) -class CostMapper(Module[Config]): +class CostMapper(Module): default_config = Config + config: Config global_map: In[PointCloud2] global_costmap: Out[OccupancyGrid] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: object) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, cfg: GlobalConfig = global_config, **kwargs: object) -> None: + super().__init__(**kwargs) + self._global_config = cfg @rpc def start(self) -> None: diff --git a/dimos/mapping/osm/current_location_map.py b/dimos/mapping/osm/current_location_map.py index 832116e25c..ef0a832cd6 100644 --- a/dimos/mapping/osm/current_location_map.py +++ b/dimos/mapping/osm/current_location_map.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any - from PIL import Image as PILImage, ImageDraw from dimos.mapping.osm.osm import MapImage, get_osm_map @@ -26,11 +24,11 @@ class CurrentLocationMap: - _vl_model: VlModel[Any] + _vl_model: VlModel _position: LatLon | None _map_image: MapImage | None - def __init__(self, vl_model: VlModel[Any]) -> None: + def __init__(self, vl_model: VlModel) -> None: self._vl_model = vl_model self._position = None self._map_image = None diff --git a/dimos/mapping/osm/query.py b/dimos/mapping/osm/query.py index 17fbfe3d4b..410f879c20 100644 --- a/dimos/mapping/osm/query.py +++ b/dimos/mapping/osm/query.py @@ -13,7 +13,6 @@ # limitations under the License. import re -from typing import Any from dimos.mapping.osm.osm import MapImage from dimos.mapping.types import LatLon @@ -26,9 +25,7 @@ logger = setup_logger() -def query_for_one_position( - vl_model: VlModel[Any], map_image: MapImage, query: str -) -> LatLon | None: +def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) -> LatLon | None: full_query = f"{_PROLOGUE} {query} {_JSON} If there's a match return the x, y coordinates from the image. Example: `[123, 321]`. If there's no match return `null`." response = vl_model.query(map_image.image, full_query) coords = tuple(map(int, re.findall(r"\d+", response))) @@ -38,7 +35,7 @@ def query_for_one_position( def query_for_one_position_and_context( - vl_model: VlModel[Any], map_image: MapImage, query: str, robot_position: LatLon + vl_model: VlModel, map_image: MapImage, query: str, robot_position: LatLon ) -> tuple[LatLon, str] | None: example = '{"coordinates": [123, 321], "description": "A Starbucks on 27th Street"}' x, y = map_image.latlon_to_pixel(robot_position) diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index 89b4e68be1..124073cf49 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import time -from typing import Any import numpy as np import open3d as o3d # type: ignore[import-untyped] @@ -34,6 +34,7 @@ logger = setup_logger() +@dataclass class Config(ModuleConfig): frame_id: str = "world" # -1 never publishes, 0 publishes on every frame, >0 publishes at interval in seconds @@ -51,8 +52,9 @@ class VoxelGridMapper(Module): lidar: In[PointCloud2] global_map: Out[PointCloud2] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, cfg: GlobalConfig = global_config, **kwargs: object) -> None: + super().__init__(**kwargs) + self._global_config = cfg dev = ( o3c.Device(self.config.device) diff --git a/dimos/models/base.py b/dimos/models/base.py index d03ce5c539..2269a6d0b8 100644 --- a/dimos/models/base.py +++ b/dimos/models/base.py @@ -16,19 +16,21 @@ from __future__ import annotations +from dataclasses import dataclass from functools import cached_property from typing import Annotated, Any import torch from dimos.core.resource import Resource -from dimos.protocol.service import BaseConfig, Configurable +from dimos.protocol.service import Configurable # type: ignore[attr-defined] # Device string type - 'cuda', 'cpu', 'cuda:0', 'cuda:1', etc. DeviceType = Annotated[str, "Device identifier (e.g., 'cuda', 'cpu', 'cuda:0')"] -class LocalModelConfig(BaseConfig): +@dataclass +class LocalModelConfig: device: DeviceType = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.float32 warmup: bool = False @@ -125,6 +127,7 @@ def _ensure_cuda_initialized(self) -> None: pass +@dataclass class HuggingFaceModelConfig(LocalModelConfig): model_name: str = "" trust_remote_code: bool = True diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index da1b1601ba..c6b78fcf2c 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -29,12 +29,14 @@ from dimos.msgs.sensor_msgs import Image +@dataclass class EmbeddingModelConfig(LocalModelConfig): """Base config for embedding models.""" normalize: bool = True +@dataclass class HuggingFaceEmbeddingModelConfig(HuggingFaceModelConfig): """Base config for HuggingFace-based embedding models.""" diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index e3a61e9570..1b8d3e68bb 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from functools import cached_property from PIL import Image as PILImage @@ -24,6 +25,7 @@ from dimos.msgs.sensor_msgs import Image +@dataclass class CLIPModelConfig(HuggingFaceEmbeddingModelConfig): model_name: str = "openai/clip-vit-base-patch32" dtype: torch.dtype = torch.float32 diff --git a/dimos/models/embedding/mobileclip.py b/dimos/models/embedding/mobileclip.py index 8ad37936be..c02361b367 100644 --- a/dimos/models/embedding/mobileclip.py +++ b/dimos/models/embedding/mobileclip.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from functools import cached_property from typing import Any @@ -26,6 +27,7 @@ from dimos.utils.data import get_data +@dataclass class MobileCLIPModelConfig(EmbeddingModelConfig): model_name: str = "MobileCLIP2-S4" diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py index 69cc1aae13..85e32cd39b 100644 --- a/dimos/models/embedding/treid.py +++ b/dimos/models/embedding/treid.py @@ -16,6 +16,7 @@ warnings.filterwarnings("ignore", message="Cython evaluation.*unavailable", category=UserWarning) +from dataclasses import dataclass from functools import cached_property import torch @@ -31,6 +32,7 @@ # osnet models downloaded from https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html # into dimos/data/models_torchreid/ # feel free to add more +@dataclass class TorchReIDModelConfig(EmbeddingModelConfig): model_name: str = "osnet_x1_0" diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index 84079a0835..237feb1d1b 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -1,23 +1,21 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass import json import logging -import sys +from typing import TYPE_CHECKING import warnings from dimos.core.resource import Resource from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D -from dimos.protocol.service.spec import BaseConfig, Configurable +from dimos.protocol.service import Configurable # type: ignore[attr-defined] from dimos.utils.data import get_data from dimos.utils.decorators import retry from dimos.utils.llm_utils import extract_json -if sys.version_info < (3, 13): - from typing_extensions import TypeVar -else: - from typing import TypeVar +if TYPE_CHECKING: + from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D logger = logging.getLogger(__name__) @@ -161,17 +159,15 @@ def vlm_point_to_detection2d_point( ) -class VlModelConfig(BaseConfig): +@dataclass +class VlModelConfig: """Configuration for VlModel.""" auto_resize: tuple[int, int] | None = None """Optional (width, height) tuple. If set, images are resized to fit.""" -_VlConfig = TypeVar("_VlConfig", bound=VlModelConfig) - - -class VlModel(Captioner, Resource, Configurable[_VlConfig]): +class VlModel(Captioner, Resource, Configurable[VlModelConfig]): """Vision-language model that can answer questions about images. Inherits from Captioner, providing a default caption() implementation @@ -180,7 +176,8 @@ class VlModel(Captioner, Resource, Configurable[_VlConfig]): Implements Resource interface for lifecycle management. """ - default_config: type[_VlConfig] = VlModelConfig # type: ignore[assignment] + default_config = VlModelConfig + config: VlModelConfig def _prepare_image(self, image: Image) -> tuple[Image, float]: """Prepare image for inference, applying any configured transformations. diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index c444d8b9ed..f31611e867 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from functools import cached_property from typing import Any import warnings @@ -8,7 +9,7 @@ from transformers import AutoModelForCausalLM # type: ignore[import-untyped] from dimos.models.base import HuggingFaceModel, HuggingFaceModelConfig -from dimos.models.vl.base import VlModel, VlModelConfig +from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D @@ -16,7 +17,8 @@ MOONDREAM_DEFAULT_AUTO_RESIZE = (512, 512) -class MoondreamConfig(HuggingFaceModelConfig, VlModelConfig): +@dataclass +class MoondreamConfig(HuggingFaceModelConfig): """Configuration for MoondreamVlModel.""" model_name: str = "vikhyatk/moondream2" @@ -24,9 +26,10 @@ class MoondreamConfig(HuggingFaceModelConfig, VlModelConfig): auto_resize: tuple[int, int] | None = MOONDREAM_DEFAULT_AUTO_RESIZE -class MoondreamVlModel(HuggingFaceModel, VlModel[MoondreamConfig]): +class MoondreamVlModel(HuggingFaceModel, VlModel): _model_class = AutoModelForCausalLM default_config = MoondreamConfig # type: ignore[assignment] + config: MoondreamConfig # type: ignore[assignment] @cached_property def _model(self) -> AutoModelForCausalLM: diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py index 57df91b47e..fc1f8b7a17 100644 --- a/dimos/models/vl/moondream_hosted.py +++ b/dimos/models/vl/moondream_hosted.py @@ -6,21 +6,20 @@ import numpy as np from PIL import Image as PILImage -from dimos.models.vl.base import VlModel, VlModelConfig +from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D -class Config(VlModelConfig): - api_key: str | None = None +class MoondreamHostedVlModel(VlModel): + _api_key: str | None - -class MoondreamHostedVlModel(VlModel[Config]): - default_config = Config + def __init__(self, api_key: str | None = None) -> None: + self._api_key = api_key @cached_property def _client(self) -> md.vl: - api_key = self.config.api_key or os.getenv("MOONDREAM_API_KEY") + api_key = self._api_key or os.getenv("MOONDREAM_API_KEY") if not api_key: raise ValueError( "Moondream API key must be provided or set in MOONDREAM_API_KEY environment variable" diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index ec774189e4..f596f1ee1e 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from functools import cached_property import os from typing import Any @@ -12,13 +13,15 @@ logger = setup_logger() +@dataclass class OpenAIVlModelConfig(VlModelConfig): model_name: str = "gpt-4o-mini" api_key: str | None = None -class OpenAIVlModel(VlModel[OpenAIVlModelConfig]): +class OpenAIVlModel(VlModel): default_config = OpenAIVlModelConfig + config: OpenAIVlModelConfig @cached_property def _client(self) -> OpenAI: diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index 014c6f73a5..93b31bf74c 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from functools import cached_property import os from typing import Any @@ -9,6 +10,7 @@ from dimos.msgs.sensor_msgs import Image +@dataclass class QwenVlModelConfig(VlModelConfig): """Configuration for Qwen VL model.""" @@ -16,8 +18,9 @@ class QwenVlModelConfig(VlModelConfig): api_key: str | None = None -class QwenVlModel(VlModel[QwenVlModelConfig]): +class QwenVlModel(VlModel): default_config = QwenVlModelConfig + config: QwenVlModelConfig @cached_property def _client(self) -> OpenAI: diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index 419986780a..1c8082b414 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -262,7 +262,7 @@ def test_frontier_ranking(explorer) -> None: # Note: Goals might be closer than safe_distance if that's the best available frontier # The safe_distance is used for scoring, not as a hard constraint print( - f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.config.safe_distance}m)" + f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.safe_distance}m)" ) print(f"Frontier ranking test passed - selected goal at ({goal1.x:.2f}, {goal1.y:.2f})") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index b79315808c..6e598e8316 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -23,7 +23,6 @@ from dataclasses import dataclass from enum import IntFlag import threading -from typing import Any from dimos_lcm.std_msgs import Bool import numpy as np @@ -31,8 +30,7 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, ModuleConfig +from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.mapping.occupancy.inflation import simple_inflate from dimos.msgs.geometry_msgs import PoseStamped, Vector3 @@ -80,18 +78,7 @@ def clear(self) -> None: self.points.clear() -class WavefrontConfig(ModuleConfig): - min_frontier_perimeter: float = 0.5 - occupancy_threshold: int = 99 - safe_distance: float = 3.0 - lookahead_distance: float = 5.0 - max_explored_distance: float = 10.0 - info_gain_threshold: float = 0.03 - num_no_gain_attempts: int = 2 - goal_timeout: float = 15.0 - - -class WavefrontFrontierExplorer(Module[WavefrontConfig]): +class WavefrontFrontierExplorer(Module): """ Wavefront frontier exploration algorithm implementation. @@ -106,8 +93,6 @@ class WavefrontFrontierExplorer(Module[WavefrontConfig]): - goal_request: Exploration goals sent to the navigator """ - default_config = WavefrontConfig - # LCM inputs global_costmap: In[OccupancyGrid] odom: In[PoseStamped] @@ -118,7 +103,17 @@ class WavefrontFrontierExplorer(Module[WavefrontConfig]): # LCM outputs goal_request: Out[PoseStamped] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + def __init__( + self, + min_frontier_perimeter: float = 0.5, + occupancy_threshold: int = 99, + safe_distance: float = 3.0, + lookahead_distance: float = 5.0, + max_explored_distance: float = 10.0, + info_gain_threshold: float = 0.03, + num_no_gain_attempts: int = 2, + goal_timeout: float = 15.0, + ) -> None: """ Initialize the frontier explorer. @@ -129,12 +124,20 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) - info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) num_no_gain_attempts: Maximum number of consecutive attempts with no information gain """ - super().__init__(global_config, **kwargs) + super().__init__() + self.min_frontier_perimeter = min_frontier_perimeter + self.occupancy_threshold = occupancy_threshold + self.safe_distance = safe_distance + self.max_explored_distance = max_explored_distance + self.lookahead_distance = lookahead_distance + self.info_gain_threshold = info_gain_threshold + self.num_no_gain_attempts = num_no_gain_attempts self._cache = FrontierCache() self.explored_goals = [] # type: ignore[var-annotated] # list of explored goals self.exploration_direction = Vector3(0.0, 0.0, 0.0) # current exploration direction self.last_costmap = None # store last costmap for information comparison self.no_gain_counter = 0 # track consecutive no-gain attempts + self.goal_timeout = goal_timeout # Latest data self.latest_costmap: OccupancyGrid | None = None @@ -211,7 +214,7 @@ def _count_costmap_information(self, costmap: OccupancyGrid) -> int: Number of cells that are free space or obstacles (not unknown) """ free_count = np.sum(costmap.grid == CostValues.FREE) - obstacle_count = np.sum(costmap.grid >= self.config.occupancy_threshold) + obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) return int(free_count + obstacle_count) def _get_neighbors(self, point: GridPoint, costmap: OccupancyGrid) -> list[GridPoint]: @@ -249,7 +252,7 @@ def _is_frontier_point(self, point: GridPoint, costmap: OccupancyGrid) -> bool: neighbor_cost = costmap.grid[neighbor.y, neighbor.x] # If adjacent to occupied space, not a frontier - if neighbor_cost > self.config.occupancy_threshold: + if neighbor_cost > self.occupancy_threshold: return False # Check if adjacent to free space @@ -373,7 +376,7 @@ def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> list[ # Check if we found a large enough frontier # Convert minimum perimeter to minimum number of cells based on resolution - min_cells = int(self.config.min_frontier_perimeter / costmap.resolution) + min_cells = int(self.min_frontier_perimeter / costmap.resolution) if len(new_frontier) >= min_cells: world_points = [] for point in new_frontier: @@ -486,7 +489,7 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr min_distance = float("inf") search_radius = ( - int(self.config.safe_distance / costmap.resolution) + 5 + int(self.safe_distance / costmap.resolution) + 5 ) # Search a bit beyond minimum # Search in a square around the frontier point @@ -505,14 +508,14 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr continue # Check if this cell is an obstacle - if costmap.grid[check_y, check_x] >= self.config.occupancy_threshold: + if costmap.grid[check_y, check_x] >= self.occupancy_threshold: # Calculate distance in meters distance = np.sqrt(dx**2 + dy**2) * costmap.resolution min_distance = min(min_distance, distance) # If no obstacles found within search radius, return the safe distance # This indicates the frontier is safely away from obstacles - return min_distance if min_distance != float("inf") else self.config.safe_distance + return min_distance if min_distance != float("inf") else self.safe_distance def _compute_comprehensive_frontier_score( self, frontier: Vector3, frontier_size: int, robot_pose: Vector3, costmap: OccupancyGrid @@ -524,25 +527,25 @@ def _compute_comprehensive_frontier_score( # Distance score: prefer moderate distances (not too close, not too far) # Normalized to 0-1 range - distance_score = 1.0 / (1.0 + abs(robot_distance - self.config.lookahead_distance)) + distance_score = 1.0 / (1.0 + abs(robot_distance - self.lookahead_distance)) # 2. Information gain (frontier size) # Normalize by a reasonable max frontier size - max_expected_frontier_size = self.config.min_frontier_perimeter / costmap.resolution * 10 + max_expected_frontier_size = self.min_frontier_perimeter / costmap.resolution * 10 info_gain_score = min(frontier_size / max_expected_frontier_size, 1.0) # 3. Distance to explored goals (bonus for being far from explored areas) # Normalize by a reasonable max distance (e.g., 10 meters) explored_goals_distance = self._compute_distance_to_explored_goals(frontier) - explored_goals_score = min(explored_goals_distance / self.config.max_explored_distance, 1.0) + explored_goals_score = min(explored_goals_distance / self.max_explored_distance, 1.0) # 4. Distance to obstacles (score based on safety) # 0 = too close to obstacles, 1 = at or beyond safe distance obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) - if obstacles_distance >= self.config.safe_distance: + if obstacles_distance >= self.safe_distance: obstacles_score = 1.0 # Fully safe else: - obstacles_score = obstacles_distance / self.config.safe_distance # Linear penalty + obstacles_score = obstacles_distance / self.safe_distance # Linear penalty # 5. Direction momentum (already in 0-1 range from dot product) momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) @@ -625,15 +628,15 @@ def get_exploration_goal(self, robot_pose: Vector3, costmap: OccupancyGrid) -> V # Check if information increase meets minimum percentage threshold if last_info > 0: # Avoid division by zero info_increase_percent = (current_info - last_info) / last_info - if info_increase_percent < self.config.info_gain_threshold: + if info_increase_percent < self.info_gain_threshold: logger.info( - f"Information increase ({info_increase_percent:.2f}) below threshold ({self.config.info_gain_threshold:.2f})" + f"Information increase ({info_increase_percent:.2f}) below threshold ({self.info_gain_threshold:.2f})" ) logger.info( f"Current information: {current_info}, Last information: {last_info}" ) self.no_gain_counter += 1 - if self.no_gain_counter >= self.config.num_no_gain_attempts: + if self.no_gain_counter >= self.num_no_gain_attempts: logger.info( f"No information gain for {self.no_gain_counter} consecutive attempts" ) @@ -794,7 +797,7 @@ def _exploration_loop(self) -> None: # Wait for goal to be reached or timeout logger.info("Waiting for goal to be reached...") - goal_reached = self.goal_reached_event.wait(timeout=self.config.goal_timeout) + goal_reached = self.goal_reached_event.wait(timeout=self.goal_timeout) if goal_reached: logger.info("Goal reached, finding next frontier") diff --git a/dimos/navigation/visual/query.py b/dimos/navigation/visual/query.py index 0c84e8ac34..37b743506a 100644 --- a/dimos/navigation/visual/query.py +++ b/dimos/navigation/visual/query.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any from dimos.models.qwen.bbox import BBox from dimos.models.vl.base import VlModel @@ -21,7 +20,7 @@ def get_object_bbox_from_image( - vl_model: VlModel[Any], image: Image, object_description: str + vl_model: VlModel, image: Image, object_description: str ) -> BBox | None: prompt = ( f"Look at this image and find the '{object_description}'. " diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 8c1a65eb8b..e81ab2ab4a 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -15,7 +15,6 @@ from collections.abc import Callable, Generator import functools from typing import TypedDict -from unittest import mock from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate @@ -205,8 +204,7 @@ def detection3dpc(detections3dpc) -> Detection3DPC: def get_moment_2d(get_moment) -> Generator[Callable[[], Moment2D], None, None]: from dimos.perception.detection.detectors import Yolo2DDetector - c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) - module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) + module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) @functools.lru_cache(maxsize=1) def moment_provider(**kwargs) -> Moment2D: @@ -264,8 +262,7 @@ def object_db_module(get_moment): """Create and populate an ObjectDBModule with detections from multiple frames.""" from dimos.perception.detection.detectors import Yolo2DDetector - c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) - module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) + module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) module3d = Detection3DModule(camera_info=connection._camera_info_static()) moduleDB = ObjectDBModule(camera_info=connection._camera_info_static()) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index 0a07b1238d..f86794a1f7 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Sequence -from typing import Annotated, Any +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any from dimos_lcm.foxglove_msgs.ImageAnnotations import ( ImageAnnotations, ) -from pydantic.experimental.pipeline import validate_as from reactivex import operators as ops from reactivex.observable import Observable from reactivex.subject import Subject @@ -38,21 +38,24 @@ from dimos.utils.reactive import backpressure +@dataclass class Config(ModuleConfig): max_freq: float = 10 detector: Callable[[Any], Detector] | None = Yolo2DDetector publish_detection_images: bool = True - camera_info: CameraInfo - filter: Annotated[ - Sequence[Filter2D], - validate_as(Sequence[Filter2D] | Filter2D).transform( - lambda f: f if isinstance(f, Sequence) else (f,) - ), - ] = () + camera_info: CameraInfo = None # type: ignore[assignment] + filter: list[Filter2D] | Filter2D | None = None + def __post_init__(self) -> None: + if self.filter is None: + self.filter = [] + elif not isinstance(self.filter, list): + self.filter = [self.filter] -class Detection2DModule(Module[Config]): + +class Detection2DModule(Module): default_config = Config + config: Config detector: Detector color_image: In[Image] diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py index 953fd00dac..7109459f40 100644 --- a/dimos/perception/experimental/temporal_memory/entity_graph_db.py +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -931,7 +931,7 @@ def estimate_and_save_distances( self, parsed: dict[str, Any], frame_image: "Image", - vlm: "VlModel[Any]", + vlm: "VlModel", timestamp_s: float, max_distance_pairs: int = 5, ) -> None: diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index 6d66955a61..66b6fce911 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -34,7 +34,6 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In from dimos.models.vl.base import VlModel @@ -70,8 +69,6 @@ class Frame: @dataclass class TemporalMemoryConfig(ModuleConfig): - vlm: VlModel[Any] | None = None - # Frame processing fps: float = 1.0 window_s: float = 2.0 @@ -103,7 +100,7 @@ class TemporalMemoryConfig(ModuleConfig): nearby_distance_meters: float = 5.0 # "Nearby" threshold -class TemporalMemory(Module[TemporalMemoryConfig]): +class TemporalMemory(Module): """ builds temporal understanding of video streams using vlms. @@ -113,12 +110,14 @@ class TemporalMemory(Module[TemporalMemoryConfig]): """ color_image: In[Image] - default_config = TemporalMemoryConfig - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__( + self, vlm: VlModel | None = None, config: TemporalMemoryConfig | None = None + ) -> None: + super().__init__() - self._vlm = self.config.vlm # Can be None for blueprint usage + self._vlm = vlm # Can be None for blueprint usage + self.config: TemporalMemoryConfig = config or TemporalMemoryConfig() # single lock protects all state self._state_lock = threading.Lock() @@ -184,7 +183,7 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) - ) @property - def vlm(self) -> VlModel[Any]: + def vlm(self) -> VlModel: """Get or create VLM instance lazily.""" if self._vlm is None: from dimos.models.vl.openai import OpenAIVlModel diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py index 517726fea4..b4aaa6ac94 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py @@ -17,7 +17,7 @@ """ import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from dimos.core.module_coordinator import ModuleCoordinator from dimos.models.vl.base import VlModel @@ -32,7 +32,7 @@ def deploy( dimos: ModuleCoordinator, camera: CameraSpec, - vlm: VlModel[Any] | None = None, + vlm: VlModel | None = None, config: TemporalMemoryConfig | None = None, ) -> TemporalMemory: """Deploy TemporalMemory with a camera. diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py index 9d30cd3338..8d05f8c1e1 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py @@ -30,7 +30,7 @@ def extract_time_window( question: str, - vlm: "VlModel[Any]", + vlm: "VlModel", latest_frame: "Image | None" = None, ) -> float | None: """Extract time window from question using VLM with example-based learning. diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index 8b0ce52d06..da415ac32a 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import threading import time -from typing import Any import cv2 @@ -29,7 +29,6 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 @@ -52,10 +51,9 @@ logger = setup_logger() +@dataclass class ObjectTrackingConfig(ModuleConfig): frame_id: str = "camera_link" - reid_threshold: int = 10 - reid_fail_tolerance: int = 5 class ObjectTracking(Module[ObjectTrackingConfig]): @@ -72,8 +70,11 @@ class ObjectTracking(Module[ObjectTrackingConfig]): tracked_overlay: Out[Image] # Visualization output default_config = ObjectTrackingConfig + config: ObjectTrackingConfig - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + def __init__( + self, reid_threshold: int = 10, reid_fail_tolerance: int = 5, **kwargs: object + ) -> None: """ Initialize an object tracking module using OpenCV's CSRT tracker with ORB re-ID. @@ -85,9 +86,11 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) - tracking is stopped. """ # Call parent Module init - super().__init__(global_config, **kwargs) + super().__init__(**kwargs) self.camera_intrinsics = None + self.reid_threshold = reid_threshold + self.reid_fail_tolerance = reid_fail_tolerance self.tracker = None self.tracking_bbox = None # Stores (x, y, w, h) for tracker initialization @@ -273,7 +276,7 @@ def reid(self, frame, current_bbox) -> bool: # type: ignore[no-untyped-def] good_matches += 1 self.last_good_matches = good_matches_list # Store good matches for visualization - return good_matches >= self.config.reid_threshold + return good_matches >= self.reid_threshold def _start_tracking_thread(self) -> None: """Start the tracking thread.""" @@ -386,7 +389,7 @@ def _process_tracking(self) -> None: # Determine final success if tracker_succeeded: - if self.reid_fail_count >= self.config.reid_fail_tolerance: + if self.reid_fail_count >= self.reid_fail_tolerance: logger.warning( f"Re-ID failed consecutively {self.reid_fail_count} times. Target lost." ) @@ -586,11 +589,11 @@ def _draw_reid_matches(self, image: NDArray[np.uint8]) -> NDArray[np.uint8]: # f"REID: WARMING UP ({self.tracking_frame_count}/{self.reid_warmup_frames})" ) status_color = (255, 255, 0) # Yellow - elif len(self.last_good_matches) >= self.config.reid_threshold: + elif len(self.last_good_matches) >= self.reid_threshold: status_text = "REID: CONFIRMED" status_color = (0, 255, 0) # Green else: - status_text = f"REID: WEAK ({self.reid_fail_count}/{self.config.reid_fail_tolerance})" + status_text = f"REID: WEAK ({self.reid_fail_count}/{self.reid_fail_tolerance})" status_color = (0, 165, 255) # Orange cv2.putText( diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index 27b7c0e93c..1264b0e92b 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import logging import threading import time -from typing import Any import cv2 @@ -33,7 +33,6 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.sensor_msgs import Image, ImageFormat @@ -44,6 +43,7 @@ logger = setup_logger(level=logging.INFO) +@dataclass class ObjectTracker2DConfig(ModuleConfig): frame_id: str = "camera_link" @@ -57,10 +57,11 @@ class ObjectTracker2D(Module[ObjectTracker2DConfig]): tracked_overlay: Out[Image] # Visualization output default_config = ObjectTracker2DConfig + config: ObjectTracker2DConfig - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + def __init__(self, **kwargs: object) -> None: """Initialize 2D object tracking module using OpenCV's CSRT tracker.""" - super().__init__(global_config, **kwargs) + super().__init__(**kwargs) # Tracker state self.tracker = None diff --git a/dimos/protocol/pubsub/bridge.py b/dimos/protocol/pubsub/bridge.py index 72cbe155d9..f312caed7b 100644 --- a/dimos/protocol/pubsub/bridge.py +++ b/dimos/protocol/pubsub/bridge.py @@ -16,9 +16,10 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING, Generic, Protocol, TypeVar -from dimos.protocol.service.spec import BaseConfig, Service +from dimos.protocol.service.spec import Service if TYPE_CHECKING: from collections.abc import Callable @@ -65,7 +66,8 @@ def pass_msg(msg: MsgFrom, topic: TopicFrom) -> None: return pubsub1.subscribe_all(pass_msg) -class BridgeConfig(BaseConfig, Generic[TopicFrom, TopicTo, MsgFrom, MsgTo]): +@dataclass +class BridgeConfig(Generic[TopicFrom, TopicTo, MsgFrom, MsgTo]): """Configuration for a one-way bridge.""" source: AllPubSub[TopicFrom, MsgFrom] diff --git a/dimos/protocol/pubsub/impl/lcmpubsub.py b/dimos/protocol/pubsub/impl/lcmpubsub.py index 4e792f5965..bf6bbd0dec 100644 --- a/dimos/protocol/pubsub/impl/lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/lcmpubsub.py @@ -14,13 +14,10 @@ from __future__ import annotations -from collections.abc import Callable from dataclasses import dataclass import re -import threading -from typing import Any +from typing import TYPE_CHECKING, Any -from dimos.msgs import DimosMsg from dimos.protocol.pubsub.encoders import ( JpegEncoderMixin, LCMEncoderMixin, @@ -28,9 +25,15 @@ ) from dimos.protocol.pubsub.patterns import Glob from dimos.protocol.pubsub.spec import AllPubSub -from dimos.protocol.service.lcmservice import LCMService, autoconf +from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from collections.abc import Callable + import threading + + from dimos.msgs import DimosMsg + logger = setup_logger() @@ -80,6 +83,7 @@ class LCMPubSubBase(LCMService, AllPubSub[Topic, Any]): RegexSubscribable directly without needing discovery-based fallback. """ + default_config = LCMConfig _stop_event: threading.Event _thread: threading.Thread | None diff --git a/dimos/protocol/pubsub/impl/redispubsub.py b/dimos/protocol/pubsub/impl/redispubsub.py index b299d6b883..6cc089e953 100644 --- a/dimos/protocol/pubsub/impl/redispubsub.py +++ b/dimos/protocol/pubsub/impl/redispubsub.py @@ -14,24 +14,25 @@ from collections import defaultdict from collections.abc import Callable +from dataclasses import dataclass, field import json import threading import time from types import TracebackType from typing import Any -from pydantic import Field import redis # type: ignore[import-not-found] from dimos.protocol.pubsub.spec import PubSub -from dimos.protocol.service.spec import BaseConfig, Service +from dimos.protocol.service.spec import Service -class RedisConfig(BaseConfig): +@dataclass +class RedisConfig: host: str = "localhost" port: int = 6379 db: int = 0 - kwargs: dict[str, Any] = Field(default_factory=dict) + kwargs: dict[str, Any] = field(default_factory=dict) class Redis(PubSub[str, Any], Service[RedisConfig]): diff --git a/dimos/protocol/service/__init__.py b/dimos/protocol/service/__init__.py index ed6caf93c2..fb9df08ca9 100644 --- a/dimos/protocol/service/__init__.py +++ b/dimos/protocol/service/__init__.py @@ -1,9 +1,8 @@ from dimos.protocol.service.lcmservice import LCMService -from dimos.protocol.service.spec import BaseConfig, Configurable, Service +from dimos.protocol.service.spec import Configurable as Configurable, Service as Service -__all__ = ( - "BaseConfig", +__all__ = [ "Configurable", "LCMService", "Service", -) +] diff --git a/dimos/protocol/service/ddsservice.py b/dimos/protocol/service/ddsservice.py index b5562defff..6ed04c07ad 100644 --- a/dimos/protocol/service/ddsservice.py +++ b/dimos/protocol/service/ddsservice.py @@ -14,8 +14,9 @@ from __future__ import annotations +from dataclasses import dataclass import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any try: from cyclonedds.domain import DomainParticipant @@ -25,7 +26,7 @@ DDS_AVAILABLE = False DomainParticipant = None # type: ignore[assignment, misc] -from dimos.protocol.service.spec import BaseConfig, Service +from dimos.protocol.service.spec import Service from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -37,7 +38,8 @@ _participants_lock = threading.Lock() -class DDSConfig(BaseConfig): +@dataclass +class DDSConfig: """Configuration for DDS service.""" domain_id: int = 0 @@ -47,6 +49,9 @@ class DDSConfig(BaseConfig): class DDSService(Service[DDSConfig]): default_config = DDSConfig + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + def start(self) -> None: """Start the DDS service.""" domain_id = self.config.domain_id diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index 8fc982bef2..f414ce9e23 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -15,24 +15,18 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass import os import platform -import sys import threading import traceback -from typing import Any -import lcm as lcm_mod +import lcm -from dimos.protocol.service.spec import BaseConfig, Service +from dimos.protocol.service.spec import Service from dimos.protocol.service.system_configurator import configure_system, lcm_configurators from dimos.utils.logging_config import setup_logger -if sys.version_info < (3, 13): - from typing_extensions import TypeVar -else: - from typing import TypeVar - logger = setup_logger() _DEFAULT_LCM_HOST = "239.255.76.67" @@ -51,38 +45,41 @@ def autoconf(check_only: bool = False) -> None: configure_system(checks, check_only=check_only) -class LCMConfig(BaseConfig): +@dataclass +class LCMConfig: ttl: int = 0 - url: str = _DEFAULT_LCM_URL + url: str | None = None autoconf: bool = True - lcm: lcm_mod.LCM | None = None + lcm: lcm.LCM | None = None + + def __post_init__(self) -> None: + if self.url is None: + self.url = _DEFAULT_LCM_URL -_Config = TypeVar("_Config", bound=LCMConfig, default=LCMConfig) _LCM_LOOP_TIMEOUT = 50 # this class just sets up cpp LCM instance # and runs its handle loop in a thread # higher order stuff is done by pubsub/impl/lcmpubsub.py -class LCMService(Service[_Config]): - default_config = LCMConfig # type: ignore[assignment] - - l: lcm_mod.LCM | None +class LCMService(Service[LCMConfig]): + default_config = LCMConfig + l: lcm.LCM | None _stop_event: threading.Event _l_lock: threading.Lock _thread: threading.Thread | None _call_thread_pool: ThreadPoolExecutor | None = None _call_thread_pool_lock: threading.RLock = threading.RLock() - def __init__(self, **kwargs: Any) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) # we support passing an existing LCM instance if self.config.lcm: self.l = self.config.lcm else: - self.l = lcm_mod.LCM(self.config.url) if self.config.url else lcm_mod.LCM() + self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() self._l_lock = threading.Lock() self._stop_event = threading.Event() @@ -117,7 +114,7 @@ def start(self) -> None: if self.config.lcm: self.l = self.config.lcm else: - self.l = lcm_mod.LCM(self.config.url) if self.config.url else lcm_mod.LCM() + self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() try: autoconf(check_only=not self.config.autoconf) diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py index 4dcb9398b6..c4e6758614 100644 --- a/dimos/protocol/service/spec.py +++ b/dimos/protocol/service/spec.py @@ -13,24 +13,17 @@ # limitations under the License. from abc import ABC -from typing import Any, Generic, TypeVar - -from pydantic import BaseModel - - -class BaseConfig(BaseModel): - model_config = {"arbitrary_types_allowed": True} - +from typing import Generic, TypeVar # Generic type for service configuration -ConfigT = TypeVar("ConfigT", bound=BaseConfig) +ConfigT = TypeVar("ConfigT") class Configurable(Generic[ConfigT]): default_config: type[ConfigT] - def __init__(self, **kwargs: Any) -> None: - self.config = self.default_config(**kwargs) + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + self.config: ConfigT = self.default_config(**kwargs) class Service(Configurable[ConfigT], ABC): diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py index f89d13ea82..fdd2340e54 100644 --- a/dimos/protocol/service/test_lcmservice.py +++ b/dimos/protocol/service/test_lcmservice.py @@ -14,9 +14,7 @@ import threading import time -from unittest.mock import MagicMock, create_autospec, patch - -from lcm import LCM +from unittest.mock import MagicMock, patch from dimos.protocol.pubsub.impl.lcmpubsub import Topic from dimos.protocol.service.lcmservice import ( @@ -101,6 +99,10 @@ def test_custom_url(self) -> None: config = LCMConfig(url=custom_url) assert config.url == custom_url + def test_post_init_sets_default_url_when_none(self) -> None: + config = LCMConfig(url=None) + assert config.url == _DEFAULT_LCM_URL + def test_autoconf_can_be_disabled(self) -> None: config = LCMConfig(autoconf=False) assert config.autoconf is False @@ -126,8 +128,8 @@ def test_str_with_lcm_type(self) -> None: class TestLCMService: def test_init_with_default_config(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -137,8 +139,8 @@ def test_init_with_default_config(self) -> None: def test_init_with_custom_url(self) -> None: custom_url = "udpm://192.168.1.1:7777?ttl=1" - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance # Pass url as kwarg, not config= @@ -146,17 +148,17 @@ def test_init_with_custom_url(self) -> None: mock_lcm_class.assert_called_once_with(custom_url) def test_init_with_existing_lcm_instance(self) -> None: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + mock_lcm_instance = MagicMock() - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: # Pass lcm as kwarg service = LCMService(lcm=mock_lcm_instance) mock_lcm_class.assert_not_called() assert service.l == mock_lcm_instance def test_start_and_stop(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf"): @@ -174,8 +176,8 @@ def test_start_and_stop(self) -> None: assert not service._thread.is_alive() def test_start_calls_configure_system(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf") as mock_configure: @@ -188,8 +190,8 @@ def test_start_calls_configure_system(self) -> None: service.stop() def test_start_with_autoconf_disabled(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf") as mock_configure: @@ -202,8 +204,8 @@ def test_start_with_autoconf_disabled(self) -> None: service.stop() def test_getstate_excludes_unpicklable_attrs(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -217,8 +219,8 @@ def test_getstate_excludes_unpicklable_attrs(self) -> None: assert "_call_thread_pool_lock" not in state def test_setstate_reinitializes_runtime_attrs(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -237,8 +239,8 @@ def test_setstate_reinitializes_runtime_attrs(self) -> None: assert hasattr(new_service._l_lock, "release") def test_start_reinitializes_lcm_after_unpickling(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf"): @@ -258,8 +260,8 @@ def test_start_reinitializes_lcm_after_unpickling(self) -> None: new_service.stop() def test_stop_cleans_up_lcm_instance(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf"): @@ -271,7 +273,7 @@ def test_stop_cleans_up_lcm_instance(self) -> None: assert service.l is None def test_stop_preserves_external_lcm_instance(self) -> None: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + mock_lcm_instance = MagicMock() with patch("dimos.protocol.service.lcmservice.autoconf"): # Pass lcm as kwarg @@ -283,8 +285,8 @@ def test_stop_preserves_external_lcm_instance(self) -> None: assert service.l == mock_lcm_instance def test_get_call_thread_pool_creates_pool(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -302,8 +304,8 @@ def test_get_call_thread_pool_creates_pool(self) -> None: pool.shutdown(wait=False) def test_stop_shuts_down_thread_pool(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: - mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) + with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + mock_lcm_instance = MagicMock() mock_lcm_class.return_value = mock_lcm_instance with patch("dimos.protocol.service.lcmservice.autoconf"): diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 1b5ccadf3c..825e89fc8c 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -16,7 +16,7 @@ from abc import abstractmethod from collections import deque -from dataclasses import field +from dataclasses import dataclass, field from functools import reduce from typing import TypeVar @@ -25,22 +25,23 @@ from dimos.msgs.tf2_msgs import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.spec import PubSub -from dimos.protocol.service.spec import BaseConfig, Service +from dimos.protocol.service.lcmservice import Service # type: ignore[attr-defined] CONFIG = TypeVar("CONFIG") # generic configuration for transform service -class TFConfig(BaseConfig): +@dataclass +class TFConfig: buffer_size: float = 10.0 # seconds rate_limit: float = 10.0 # Hz -_TFConfig = TypeVar("_TFConfig", bound=TFConfig) - - # generic specification for transform service -class TFSpec(Service[_TFConfig]): +class TFSpec(Service[TFConfig]): + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + @abstractmethod def publish(self, *args: Transform) -> None: ... @@ -243,17 +244,15 @@ def __str__(self) -> str: return "\n".join(lines) +@dataclass class PubSubTFConfig(TFConfig): topic: Topic | None = None # Required field but needs default for dataclass inheritance pubsub: type[PubSub] | PubSub | None = None # type: ignore[type-arg] autostart: bool = True -_PubSubConfig = TypeVar("_PubSubConfig", bound=PubSubTFConfig) - - -class PubSubTF(MultiTBuffer, TFSpec[_PubSubConfig]): - default_config: type[_PubSubConfig] = PubSubTFConfig # type: ignore[assignment] +class PubSubTF(MultiTBuffer, TFSpec): + default_config: type[PubSubTFConfig] = PubSubTFConfig def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] TFSpec.__init__(self, **kwargs) @@ -331,14 +330,15 @@ def receive_msg(self, msg: TFMessage, topic: Topic) -> None: self.receive_tfmessage(msg) +@dataclass class LCMPubsubConfig(PubSubTFConfig): topic: Topic = field(default_factory=lambda: Topic("/tf", TFMessage)) pubsub: type[PubSub] | PubSub | None = LCM # type: ignore[type-arg] autostart: bool = True -class LCMTF(PubSubTF[LCMPubsubConfig]): - default_config = LCMPubsubConfig +class LCMTF(PubSubTF): + default_config: type[LCMPubsubConfig] = LCMPubsubConfig TF = LCMTF diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py index bf2885958d..158a68d3d8 100644 --- a/dimos/protocol/tf/tflcmcpp.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -13,18 +13,15 @@ # limitations under the License. from datetime import datetime +from typing import Union from dimos.msgs.geometry_msgs import Transform from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.protocol.tf.tf import TFConfig, TFSpec -class Config(TFConfig, LCMConfig): - """Combined config""" - - # this doesn't work due to tf_lcm_py package -class TFLCM(TFSpec[Config], LCMService[Config]): +class TFLCM(TFSpec, LCMService): """A service for managing and broadcasting transforms using LCM. This is not a separete module, You can include this in your module if you need to access transforms. @@ -37,7 +34,7 @@ class TFLCM(TFSpec[Config], LCMService[Config]): for each module. """ - default_config = Config + default_config = Union[TFConfig, LCMConfig] def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py index 6f6fc4b261..4e5559f220 100644 --- a/dimos/robot/drone/connection_module.py +++ b/dimos/robot/drone/connection_module.py @@ -26,8 +26,7 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, ModuleConfig +from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.mapping.types import LatLon from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 @@ -46,17 +45,9 @@ def _add_disposable(composite: CompositeDisposable, item: Disposable | Any) -> N composite.add(Disposable(item)) -class Config(ModuleConfig): - connection_string: str = "udp:0.0.0.0:14550" - video_port: int = 5600 - outdoor: bool = False - - -class DroneConnectionModule(Module[Config]): +class DroneConnectionModule(Module): """Module that handles drone sensor data and movement commands.""" - default_config = Config - # Inputs movecmd: In[Vector3] movecmd_twist: In[Twist] # Twist commands from tracking/navigation @@ -71,6 +62,9 @@ class DroneConnectionModule(Module[Config]): video: Out[Image] follow_object_cmd: Out[Any] + # Parameters + connection_string: str + # Internal state _odom: PoseStamped | None = None _status: dict[str, Any] = {} @@ -79,7 +73,14 @@ class DroneConnectionModule(Module[Config]): _latest_status: dict[str, Any] | None = None _latest_status_lock: threading.RLock - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + def __init__( + self, + connection_string: str = "udp:0.0.0.0:14550", + video_port: int = 5600, + outdoor: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: """Initialize drone connection module. Args: @@ -87,6 +88,9 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) - video_port: UDP port for video stream outdoor: Use GPS only mode (no velocity integration) """ + self.connection_string = connection_string + self.video_port = video_port + self.outdoor = outdoor self.connection: MavlinkConnection | None = None self.video_stream: DJIDroneVideoStream | None = None self._latest_video_frame = None @@ -95,25 +99,23 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) - self._latest_status_lock = threading.RLock() self._running = False self._telemetry_thread: threading.Thread | None = None - super().__init__(global_config, **kwargs) + Module.__init__(self, *args, **kwargs) @rpc def start(self) -> None: """Start the connection and subscribe to sensor streams.""" # Check for replay mode - if self.config.connection_string == "replay": + if self.connection_string == "replay": from dimos.robot.drone.dji_video_stream import FakeDJIVideoStream from dimos.robot.drone.mavlink_connection import FakeMavlinkConnection self.connection = FakeMavlinkConnection("replay") - self.video_stream = FakeDJIVideoStream(port=self.config.video_port) + self.video_stream = FakeDJIVideoStream(port=self.video_port) else: - self.connection = MavlinkConnection( - self.config.connection_string, outdoor=self.config.outdoor - ) + self.connection = MavlinkConnection(self.connection_string, outdoor=self.outdoor) self.connection.connect() - self.video_stream = DJIDroneVideoStream(port=self.config.video_port) + self.video_stream = DJIDroneVideoStream(port=self.video_port) if not self.connection.connected: logger.error("Failed to connect to drone") diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 734ad85e95..6700c38901 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -13,21 +13,21 @@ # limitations under the License. import asyncio -from collections.abc import Sequence import logging import threading -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from dimos_lcm.foxglove_bridge import ( FoxgloveBridge as LCMFoxgloveBridge, ) from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig +from dimos.core.module import Module from dimos.core.module_coordinator import ModuleCoordinator from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: + from dimos.core.global_config import GlobalConfig from dimos.core.rpc_client import ModuleProxy logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) @@ -36,15 +36,23 @@ logger = setup_logger() -class FoxgloveConfig(ModuleConfig): - shm_channels: Sequence[str] = () - jpeg_shm_channels: Sequence[str] = () - - -class FoxgloveBridge(Module[FoxgloveConfig]): +class FoxgloveBridge(Module): _thread: threading.Thread _loop: asyncio.AbstractEventLoop - default_config = FoxgloveConfig + _global_config: "GlobalConfig | None" = None + + def __init__( + self, + *args: Any, + shm_channels: list[str] | None = None, + jpeg_shm_channels: list[str] | None = None, + global_config: "GlobalConfig | None" = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.shm_channels = shm_channels or [] + self.jpeg_shm_channels = jpeg_shm_channels or [] + self._global_config = global_config @rpc def start(self) -> None: @@ -72,8 +80,8 @@ def run_bridge() -> None: port=8765, debug=False, num_threads=4, - shm_channels=self.config.shm_channels, - jpeg_shm_channels=self.config.jpeg_shm_channels, + shm_channels=self.shm_channels, + jpeg_shm_channels=self.jpeg_shm_channels, ) self._loop.run_until_complete(bridge.run()) except Exception as e: diff --git a/dimos/robot/unitree/b1/connection.py b/dimos/robot/unitree/b1/connection.py index 3c783857be..4279f78399 100644 --- a/dimos/robot/unitree/b1/connection.py +++ b/dimos/robot/unitree/b1/connection.py @@ -21,13 +21,11 @@ import socket import threading import time -from typing import Any from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, ModuleConfig +from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry @@ -50,21 +48,13 @@ class RobotMode: RECOVERY = 6 -class B1ConnectionConfig(ModuleConfig): - ip: str = "192.168.12.1" - port: int = 9090 - test_mode: bool = False - - -class B1ConnectionModule(Module[B1ConnectionConfig]): +class B1ConnectionModule(Module): """UDP connection module for B1 robot with standard Twist interface. Accepts standard ROS Twist messages on /cmd_vel and mode changes on /b1/mode, internally converts to B1Command format, and sends UDP packets at 50Hz. """ - default_config = B1ConnectionConfig - # LCM ports (inter-module communication) cmd_vel: In[TwistStamped] mode_cmd: In[Int32] @@ -77,7 +67,9 @@ class B1ConnectionModule(Module[B1ConnectionConfig]): ros_odom_in: In[Odometry] ros_tf: In[TFMessage] - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: + def __init__( # type: ignore[no-untyped-def] + self, ip: str = "192.168.12.1", port: int = 9090, test_mode: bool = False, *args, **kwargs + ) -> None: """Initialize B1 connection module. Args: @@ -85,11 +77,11 @@ def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) - port: UDP port for joystick server test_mode: If True, print commands instead of sending UDP """ - super().__init__(global_config, **kwargs) + Module.__init__(self, *args, **kwargs) - self.ip = self.config.ip - self.port = self.config.port - self.test_mode = self.config.test_mode + self.ip = ip + self.port = port + self.test_mode = test_mode self.current_mode = RobotMode.IDLE # Start in IDLE mode self._current_cmd = B1Command(mode=RobotMode.IDLE) self.cmd_lock = threading.Lock() # Thread lock for _current_cmd access @@ -391,10 +383,9 @@ def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> bool: class MockB1ConnectionModule(B1ConnectionModule): """Test connection module that prints commands instead of sending UDP.""" - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: # type: ignore[no-untyped-def] + def __init__(self, ip: str = "127.0.0.1", port: int = 9090, *args, **kwargs) -> None: # type: ignore[no-untyped-def] """Initialize test connection without creating socket.""" - kwargs["test_mode"] = True - super().__init__(global_config, **kwargs) + super().__init__(ip, port, test_mode=True, *args, **kwargs) # type: ignore[misc] def _send_loop(self) -> None: """Override to provide better test output with timeout detection.""" diff --git a/dimos/robot/unitree/b1/unitree_b1.py b/dimos/robot/unitree/b1/unitree_b1.py index 6b374d1d5b..2c0c918942 100644 --- a/dimos/robot/unitree/b1/unitree_b1.py +++ b/dimos/robot/unitree/b1/unitree_b1.py @@ -92,9 +92,9 @@ def start(self) -> None: logger.info("Deploying connection module...") if self.test_mode: - self.connection = self._dimos.deploy(MockB1ConnectionModule, ip=self.ip, port=self.port) # type: ignore[assignment] + self.connection = self._dimos.deploy(MockB1ConnectionModule, self.ip, self.port) # type: ignore[assignment] else: - self.connection = self._dimos.deploy(B1ConnectionModule, ip=self.ip, port=self.port) # type: ignore[assignment] + self.connection = self._dimos.deploy(B1ConnectionModule, self.ip, self.port) # type: ignore[assignment] # Configure LCM transports for connection (matching G1 pattern) self.connection.cmd_vel.transport = LCMTransport("/cmd_vel", TwistStamped) # type: ignore[attr-defined] diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index 4b87cc2baa..17a66945f9 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -99,7 +99,7 @@ def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: def deploy(dimos: ModuleCoordinator, ip: str, local_planner: spec.LocalPlanner) -> "ModuleProxy": - connection = dimos.deploy(G1Connection, ip=ip) + connection = dimos.deploy(G1Connection, ip) # type: ignore[attr-defined] connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() return connection diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index 3529a6e833..4c7d4755fb 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -320,7 +320,7 @@ def observe(self) -> Image | None: def deploy(dimos: ModuleCoordinator, ip: str, prefix: str = "") -> "ModuleProxy": from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE - connection = dimos.deploy(GO2Connection, ip=ip) + connection = dimos.deploy(GO2Connection, ip) # type: ignore[attr-defined] connection.pointcloud.transport = pSHMTransport( f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py index 127748efd1..831ea6ee34 100644 --- a/dimos/simulation/manipulators/sim_module.py +++ b/dimos/simulation/manipulators/sim_module.py @@ -15,6 +15,7 @@ """Simulator-agnostic manipulator simulation module.""" from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path import threading import time @@ -23,7 +24,6 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState @@ -31,6 +31,7 @@ from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface +@dataclass(kw_only=True) class SimulationModuleConfig(ModuleConfig): engine: EngineType config_path: Path | Callable[[], Path] @@ -41,6 +42,7 @@ class SimulationModule(Module[SimulationModuleConfig]): """Module wrapper for manipulator simulation across engines.""" default_config = SimulationModuleConfig + config: SimulationModuleConfig joint_state: Out[JointState] robot_state: Out[RobotState] @@ -49,8 +51,8 @@ class SimulationModule(Module[SimulationModuleConfig]): MIN_CONTROL_RATE = 1.0 - def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None: - super().__init__(global_config, **kwargs) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self._backend: SimManipInterface | None = None self._control_rate = 100.0 self._monitor_rate = 100.0 diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py index 72408fefed..334e2ce85f 100644 --- a/dimos/simulation/manipulators/test_sim_module.py +++ b/dimos/simulation/manipulators/test_sim_module.py @@ -17,11 +17,10 @@ import pytest -from dimos.protocol.rpc import RPCSpec from dimos.simulation.manipulators.sim_module import SimulationModule -class _DummyRPC(RPCSpec): +class _DummyRPC: def serve_module_rpc(self, _module) -> None: # type: ignore[no-untyped-def] return None diff --git a/dimos/utils/cli/lcmspy/lcmspy.py b/dimos/utils/cli/lcmspy/lcmspy.py index a0f491e81b..2df800591f 100755 --- a/dimos/utils/cli/lcmspy/lcmspy.py +++ b/dimos/utils/cli/lcmspy/lcmspy.py @@ -13,9 +13,9 @@ # limitations under the License. from collections import deque +from dataclasses import dataclass import threading import time -from typing import Any from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.utils.human import human_bytes @@ -94,19 +94,20 @@ def __str__(self) -> str: return f"topic({self.name})" +@dataclass class LCMSpyConfig(LCMConfig): topic_history_window: float = 60.0 -class LCMSpy(LCMService[LCMSpyConfig], Topic): +class LCMSpy(LCMService, Topic): default_config = LCMSpyConfig topic = dict[str, Topic] graph_log_window: float = 1.0 topic_class: type[Topic] = Topic - def __init__(self, **kwargs: Any) -> None: + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) - Topic.__init__(self, name="total", history_window=self.config.topic_history_window) + Topic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] self.topic = {} # type: ignore[assignment] def start(self) -> None: @@ -143,6 +144,7 @@ def update_graphs(self, step_window: float = 1.0) -> None: self.bandwidth_history.append(kbps) +@dataclass class GraphLCMSpyConfig(LCMSpyConfig): graph_log_window: float = 1.0 @@ -154,9 +156,9 @@ class GraphLCMSpy(LCMSpy, GraphTopic): graph_log_stop_event: threading.Event = threading.Event() topic_class: type[Topic] = GraphTopic - def __init__(self, **kwargs: Any) -> None: + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) - GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) + GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] def start(self) -> None: super().start() diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 23b9cd5e3f..af91f1b8b8 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -16,8 +16,7 @@ from __future__ import annotations -from collections.abc import Callable -from dataclasses import field +from dataclasses import dataclass, field from functools import lru_cache from typing import ( TYPE_CHECKING, @@ -89,12 +88,14 @@ logger = setup_logger() if TYPE_CHECKING: + from collections.abc import Callable + from rerun._baseclasses import Archetype from rerun.blueprint import Blueprint from dimos.protocol.pubsub.spec import SubscribeAllCapable -BlueprintFactory: TypeAlias = Callable[[], "Blueprint"] +BlueprintFactory: TypeAlias = "Callable[[], Blueprint]" # to_rerun() can return a single archetype or a list of (entity_path, archetype) tuples RerunMulti: TypeAlias = "list[tuple[str, Archetype]]" @@ -141,6 +142,7 @@ def _default_blueprint() -> Blueprint: ) +@dataclass class Config(ModuleConfig): """Configuration for RerunBridgeModule.""" @@ -163,7 +165,7 @@ class Config(ModuleConfig): blueprint: BlueprintFactory | None = _default_blueprint -class RerunBridgeModule(Module[Config]): +class RerunBridgeModule(Module): """Bridge that logs messages from pubsubs to Rerun. Spawns its own Rerun viewer and subscribes to all topics on each provided @@ -180,6 +182,7 @@ class RerunBridgeModule(Module[Config]): """ default_config = Config + config: Config @lru_cache(maxsize=256) def _visual_override_for_entity_path( diff --git a/pyproject.toml b/pyproject.toml index a616526166..cb4607ced5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -324,12 +324,8 @@ exclude = [ [tool.ruff.lint] extend-select = ["E", "W", "F", "B", "UP", "N", "I", "C90", "A", "RUF", "TCH"] -ignore = [ - # TODO: All of these should be fixed, but it's easier commit autofixes first - "A001", "A002", "B008", "B017", "B019", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N817", "N999", "RUF003", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007", - # This breaks runtime type checking (both for us, and users introspecting our APIs) - "TC001", "TC002", "TC003" -] +# TODO: All of these should be fixed, but it's easier commit autofixes first +ignore = ["A001", "A002", "B008", "B017", "B019", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N817", "N999", "RUF003", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007"] [tool.ruff.lint.per-file-ignores] "dimos/models/Detic/*" = ["ALL"]