diff --git a/bin/filter-errors-after-date b/bin/filter-errors-after-date new file mode 100755 index 0000000000..5a0c46408e --- /dev/null +++ b/bin/filter-errors-after-date @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +# Used to filter errors to only show lines committed on or after a specific date +# Can be chained with filter-errors-for-user + +import sys +import re +import subprocess +from datetime import datetime + + +_blame = {} + + +def _is_after_date(file, line_no, cutoff_date): + if file not in _blame: + _blame[file] = _get_git_blame_dates_for_file(file) + line_date = _blame[file].get(line_no) + if not line_date: + return False + return line_date >= cutoff_date + + +def _get_git_blame_dates_for_file(file_name): + try: + result = subprocess.run( + ["git", "blame", "--date=short", file_name], + capture_output=True, + text=True, + check=True, + ) + + blame_map = {} + # Each line looks like: ^abc123 (Author Name 2024-01-01 1) code + blame_pattern = re.compile(r"^[^\(]+\([^\)]+(\d{4}-\d{2}-\d{2})") + + for i, line in enumerate(result.stdout.split("\n")): + if not line: + continue + match = blame_pattern.match(line) + if match: + date_str = match.group(1) + blame_map[str(i + 1)] = date_str + + return blame_map + except subprocess.CalledProcessError: + return {} + + +def main(): + if len(sys.argv) != 2: + print("Usage: filter-errors-after-date ", file=sys.stderr) + print(" Example: filter-errors-after-date 2025-10-04", file=sys.stderr) + sys.exit(1) + + cutoff_date = sys.argv[1] + + try: + datetime.strptime(cutoff_date, "%Y-%m-%d") + except ValueError: + print(f"Error: Invalid date format '{cutoff_date}'. Use YYYY-MM-DD", file=sys.stderr) + sys.exit(1) + + for line in sys.stdin.readlines(): + split = re.findall(r"^([^:]+):(\d+):(.*)", line) + if not split or len(split[0]) != 3: + continue + + file, line_no = split[0][:2] + if not file.startswith("dimos/"): + continue + + if _is_after_date(file, line_no, cutoff_date): + print(":".join(split[0])) + + +if __name__ == "__main__": + main() diff --git a/bin/filter-errors-for-user b/bin/filter-errors-for-user new file mode 100755 index 0000000000..78247a9bb2 --- /dev/null +++ b/bin/filter-errors-for-user @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 + +# Used when running `./bin/mypy-strict --for-me` + +import sys +import re +import subprocess + + +_blame = {} + + +def _is_for_user(file, line_no, user_email): + if file not in _blame: + _blame[file] = _get_git_blame_for_file(file) + return _blame[file][line_no] == user_email + + +def _get_git_blame_for_file(file_name): + try: + result = subprocess.run( + ["git", "blame", "--show-email", "-e", file_name], + capture_output=True, + text=True, + check=True, + ) + + blame_map = {} + # Each line looks like: ^abc123 ( 2024-01-01 12:00:00 +0000 1) code + blame_pattern = re.compile(r"^[^\(]+\(<([^>]+)>") + + for i, line in enumerate(result.stdout.split("\n")): + if not line: + continue + match = blame_pattern.match(line) + if match: + email = match.group(1) + blame_map[str(i + 1)] = email + + return blame_map + except subprocess.CalledProcessError: + return {} + + +def main(): + if len(sys.argv) != 2: + print("Usage: filter-errors-for-user ", file=sys.stderr) + sys.exit(1) + + user_email = sys.argv[1] + + for line in sys.stdin.readlines(): + split = re.findall(r"^([^:]+):(\d+):(.*)", line) + if not split or len(split[0]) != 3: + continue + file, line_no = split[0][:2] + if not file.startswith("dimos/"): + continue + if _is_for_user(file, line_no, user_email): + print(":".join(split[0])) + + +if __name__ == "__main__": + main() diff --git a/bin/mypy-strict b/bin/mypy-strict new file mode 100755 index 0000000000..05001bf100 --- /dev/null +++ b/bin/mypy-strict @@ -0,0 +1,98 @@ +#!/bin/bash +# +# Run mypy with strict settings on the dimos codebase. +# +# Usage: +# ./bin/mypy-strict # Run mypy and show all errors +# ./bin/mypy-strict --user me # Filter for your git user.email +# ./bin/mypy-strict --after cutoff # Filter for lines committed on or after 2025-10-08 +# ./bin/mypy-strict --after 2025-11-11 # Filter for lines committed on or after specific date +# ./bin/mypy-strict --user me --after cutoff # Chain filters +# + +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" + +cd "$ROOT" + +. .venv/bin/activate + +run_mypy() { + export MYPYPATH=/opt/ros/jazzy/lib/python3.12/site-packages + + mypy_args=( + --config-file mypy_strict.ini + --show-error-codes + --hide-error-context + --no-pretty + dimos + ) + mypy "${mypy_args[@]}" +} + +main() { + local user_email="none" + local after_date="" + + # Parse arguments + while [[ $# -gt 0 ]]; do + case "$1" in + --user) + if [[ $# -lt 2 ]]; then + echo "Error: --user requires an argument" >&2 + exit 1 + fi + case "$2" in + me) + user_email="$(git config user.email || echo none)" + ;; + all) + user_email="none" + ;; + *) + user_email="$2" + ;; + esac + shift 2 + ;; + --after) + if [[ $# -lt 2 ]]; then + echo "Error: --after requires an argument" >&2 + exit 1 + fi + case "$2" in + cutoff) + after_date="2025-10-10" + ;; + start) + after_date="" + ;; + *) + after_date="$2" + ;; + esac + shift 2 + ;; + *) + echo "Error: Unknown argument '$1'" >&2 + exit 1 + ;; + esac + done + + # Build filter pipeline + local pipeline="run_mypy" + + if [[ -n "$after_date" ]]; then + pipeline="$pipeline | ./bin/filter-errors-after-date '$after_date'" + fi + + if [[ "$user_email" != "none" ]]; then + pipeline="$pipeline | ./bin/filter-errors-for-user '$user_email'" + fi + + eval "$pipeline" +} + +main "$@" diff --git a/bin/ty-check b/bin/ty-check deleted file mode 100755 index 1d819e7cbb..0000000000 --- a/bin/ty-check +++ /dev/null @@ -1,133 +0,0 @@ -#!/bin/bash - -set -euo pipefail - -files=( - dimos/agents2/conftest.py - dimos/agents2/constants.py - dimos/agents2/constants.py - dimos/agents2/skills/conftest.py - dimos/agents2/skills/google_maps_skill_container.py - dimos/agents2/skills/google_maps_skill_container.py - dimos/agents2/skills/gps_nav_skill.py - dimos/agents2/skills/gps_nav_skill.py - dimos/agents2/skills/navigation.py - dimos/agents2/skills/osm.py - dimos/agents2/skills/osm.py - dimos/agents2/skills/test_google_maps_skill_container.py - dimos/agents2/skills/test_google_maps_skill_container.py - dimos/agents2/skills/test_gps_nav_skills.py - dimos/agents2/skills/test_gps_nav_skills.py - #dimos/agents2/skills/test_navigation.py - dimos/agents2/temp/run_unitree_agents2.py - dimos/agents2/test_agent_fake.py - dimos/agents2/test_agent_fake.py - dimos/agents/memory/spatial_vector_db.py - #dimos/agents/test_agent_message_streams.py - #dimos/agents/test_agent_tools.py - #dimos/agents/test_base_agent_text.py - #dimos/agents/test_conversation_history.py - dimos/conftest.py - dimos/constants.py - dimos/constants.py - #dimos/core/stream.py - #dimos/core/test_core.py - #dimos/core/test_stream.py - #dimos/hardware/gstreamer_camera.py - #dimos/hardware/gstreamer_camera.py - #dimos/hardware/gstreamer_camera_test_script.py - #dimos/hardware/gstreamer_camera_test_script.py - #dimos/hardware/gstreamer_sender.py - #dimos/hardware/gstreamer_sender.py - #dimos/manipulation/visual_servoing/detection3d.py - #dimos/manipulation/visual_servoing/manipulation_module.py - #dimos/manipulation/visual_servoing/pbvs.py - dimos/mapping/google_maps/conftest.py - dimos/mapping/google_maps/conftest.py - dimos/mapping/google_maps/google_maps.py - dimos/mapping/google_maps/google_maps.py - dimos/mapping/google_maps/test_google_maps.py - dimos/mapping/google_maps/test_google_maps.py - dimos/mapping/google_maps/types.py - dimos/mapping/google_maps/types.py - dimos/mapping/__init__.py - dimos/mapping/__init__.py - dimos/mapping/osm/current_location_map.py - dimos/mapping/osm/current_location_map.py - dimos/mapping/osm/demo_osm.py - dimos/mapping/osm/demo_osm.py - dimos/mapping/osm/__init__.py - dimos/mapping/osm/__init__.py - dimos/mapping/osm/osm.py - dimos/mapping/osm/osm.py - dimos/mapping/osm/query.py - dimos/mapping/osm/query.py - dimos/mapping/osm/test_osm.py - dimos/mapping/osm/test_osm.py - dimos/mapping/types.py - dimos/mapping/types.py - dimos/mapping/utils/distance.py - dimos/mapping/utils/distance.py - dimos/models/qwen/video_query.py - dimos/models/vl/base.py - dimos/models/vl/base.py - dimos/models/vl/qwen.py - dimos/models/vl/qwen.py - #dimos/msgs/geometry_msgs/Vector3.py - dimos/msgs/vision_msgs/BoundingBox2DArray.py - dimos/msgs/vision_msgs/BoundingBox2DArray.py - dimos/msgs/vision_msgs/BoundingBox3DArray.py - dimos/msgs/vision_msgs/BoundingBox3DArray.py - dimos/msgs/vision_msgs/Detection2DArray.py - dimos/msgs/vision_msgs/Detection2DArray.py - dimos/msgs/vision_msgs/Detection3DArray.py - dimos/msgs/vision_msgs/Detection3DArray.py - dimos/msgs/vision_msgs/__init__.py - dimos/msgs/vision_msgs/__init__.py - #dimos/navigation/bt_navigator/goal_validator.py - #dimos/navigation/bt_navigator/navigator.py - #dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py - #dimos/navigation/local_planner/local_planner.py - #dimos/navigation/local_planner/test_base_local_planner.py - dimos/perception/detection2d/test_module.py - #dimos/perception/object_tracker.py - #dimos/perception/spatial_perception.py - dimos/protocol/pubsub/lcmpubsub.py - #dimos/protocol/pubsub/test_lcmpubsub.py - #dimos/protocol/rpc/off_test_pubsubrpc.py - #dimos/protocol/rpc/spec.py - dimos/protocol/rpc/test_lcmrpc_timeout.py - #dimos/protocol/service/lcmservice.py - #dimos/protocol/skill/test_coordinator.py - dimos/protocol/skill/test_utils.py - dimos/protocol/skill/test_utils.py - dimos/protocol/skill/utils.py - dimos/protocol/skill/utils.py - dimos/robot/foxglove_bridge.py - #dimos/robot/unitree_webrtc/depth_module.py - #dimos/robot/unitree_webrtc/modular/ivan_unitree.py - dimos/robot/unitree_webrtc/mujoco_connection.py - dimos/robot/unitree_webrtc/run_agents2.py - dimos/robot/unitree_webrtc/run.py - #dimos/robot/unitree_webrtc/unitree_b1/test_connection.py - dimos/robot/utils/robot_debugger.py - dimos/robot/utils/robot_debugger.py - dimos/simulation/mujoco/depth_camera.py - dimos/simulation/mujoco/depth_camera.py - #dimos/simulation/mujoco/model.py - #dimos/simulation/mujoco/mujoco.py - #dimos/simulation/mujoco/policy.py - #dimos/simulation/mujoco/policy.py - dimos/simulation/mujoco/types.py - dimos/simulation/mujoco/types.py - #dimos/skills/navigation.py - #dimos/skills/skills.py - dimos/types/robot_location.py - dimos/utils/deprecation.py - dimos/utils/deprecation.py - dimos/utils/generic.py - dimos/utils/generic.py -) - -uvx ty check "${files[@]}" -# uvx ruff check --fix --extend-select ANN "${files[@]}" diff --git a/dimos/agents/modules/agent_pool.py b/dimos/agents/modules/agent_pool.py index 0d08bd14b7..c5b466159f 100644 --- a/dimos/agents/modules/agent_pool.py +++ b/dimos/agents/modules/agent_pool.py @@ -14,17 +14,14 @@ """Agent pool module for managing multiple agents.""" -import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable from reactivex.subject import Subject from dimos.core import Module, In, Out, rpc from dimos.agents.modules.base_agent import BaseAgentModule from dimos.agents.modules.unified_agent import UnifiedAgentModule -from dimos.skills.skills import SkillLibrary from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.agents.modules.agent_pool") @@ -64,7 +61,6 @@ def __init__(self, agents_config: Dict[str, Dict[str, Any]], default_agent: str self._config = agents_config self._default_agent = default_agent or next(iter(agents_config.keys())) self._agents = {} - self._disposables = CompositeDisposable() # Response routing self._response_subject = Subject() @@ -72,6 +68,7 @@ def __init__(self, agents_config: Dict[str, Dict[str, Any]], default_agent: str @rpc def start(self): """Deploy and start all agents.""" + super().start() logger.info(f"Starting agent pool with {len(self._config)} agents") # Deploy agents based on config @@ -117,11 +114,9 @@ def stop(self): except Exception as e: logger.error(f"Error stopping agent {agent_id}: {e}") - # Dispose subscriptions - self._disposables.dispose() - # Clear agents self._agents.clear() + super().stop() @rpc def add_agent(self, agent_id: str, config: Dict[str, Any]): diff --git a/dimos/agents/modules/base.py b/dimos/agents/modules/base.py index 97e9e91cb0..ef778e2da4 100644 --- a/dimos/agents/modules/base.py +++ b/dimos/agents/modules/base.py @@ -516,7 +516,7 @@ async def aquery(self, message: Union[str, AgentMessage]) -> AgentResponse: return await self._process_query_async(agent_msg) - def dispose(self): + def base_agent_dispose(self) -> None: """Dispose of all resources and close gateway.""" self.response_subject.on_completed() if self._executor: diff --git a/dimos/agents/modules/base_agent.py b/dimos/agents/modules/base_agent.py index f65c6379a9..3c83214f6c 100644 --- a/dimos/agents/modules/base_agent.py +++ b/dimos/agents/modules/base_agent.py @@ -109,6 +109,7 @@ def __init__( @rpc def start(self): """Start the agent module and connect streams.""" + super().start() logger.info(f"Starting agent module with model: {self.model}") # Primary AgentMessage input @@ -141,9 +142,10 @@ def stop(self): self._module_disposables.clear() # Dispose BaseAgent resources - self.dispose() + self.base_agent_dispose() logger.info("Agent module stopped") + super().stop() @rpc def clear_history(self): diff --git a/dimos/agents/modules/simple_vision_agent.py b/dimos/agents/modules/simple_vision_agent.py index c052a047db..9bb6fb9894 100644 --- a/dimos/agents/modules/simple_vision_agent.py +++ b/dimos/agents/modules/simple_vision_agent.py @@ -17,7 +17,6 @@ import asyncio import base64 import io -import logging import threading from typing import Optional @@ -28,8 +27,9 @@ from dimos.msgs.sensor_msgs import Image from dimos.utils.logging_config import setup_logger from dimos.agents.modules.gateway import UnifiedGatewayClient +from reactivex.disposable import Disposable -logger = setup_logger("dimos.agents.modules.simple_vision_agent") +logger = setup_logger(__file__) class SimpleVisionAgentModule(Module): @@ -74,6 +74,8 @@ def __init__( @rpc def start(self): """Initialize and start the agent.""" + super().start() + logger.info(f"Starting simple vision agent with model: {self.model}") # Initialize gateway @@ -81,20 +83,23 @@ def start(self): # Subscribe to inputs if self.query_in: - self.query_in.subscribe(self._handle_query) + unsub = self.query_in.subscribe(self._handle_query) + self._disposables.add(Disposable(unsub)) if self.image_in: - self.image_in.subscribe(self._handle_image) + unsub = self.image_in.subscribe(self._handle_image) + self._disposables.add(Disposable(unsub)) logger.info("Simple vision agent started") @rpc def stop(self): - """Stop the agent.""" logger.info("Stopping simple vision agent") if self.gateway: self.gateway.close() + super().stop() + def _handle_image(self, image: Image): """Handle incoming image.""" logger.info( diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index b7f4c37d81..94f418acc2 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -189,27 +189,20 @@ def __init__( model_provider=self.config.provider, model=self.config.model ) - def __enter__(self) -> "Agent": - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - return False - @rpc def get_agent_id(self) -> str: return self._agent_id @rpc def start(self): + super().start() self.coordinator.start() @rpc def stop(self): - self._close_module() self.coordinator.stop() self._agent_stopped = True + super().stop() def clear_history(self): self._history.clear() diff --git a/dimos/agents2/cli/human.py b/dimos/agents2/cli/human.py index 587f7aed55..5a20abb388 100644 --- a/dimos/agents2/cli/human.py +++ b/dimos/agents2/cli/human.py @@ -15,7 +15,8 @@ import queue from dimos.agents2 import Output, Reducer, Stream, skill -from dimos.core import Module, pLCMTransport +from dimos.core import Module, pLCMTransport, rpc +from reactivex.disposable import Disposable class HumanInput(Module): @@ -30,6 +31,15 @@ def human(self): transport = pLCMTransport("/human_input") msg_queue = queue.Queue() - transport.subscribe(msg_queue.put) + unsub = transport.subscribe(msg_queue.put) + self._disposables.add(Disposable(unsub)) for message in iter(msg_queue.get, None): yield message + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() diff --git a/dimos/agents2/skills/conftest.py b/dimos/agents2/skills/conftest.py index 63f64ca5ee..7ea89e320a 100644 --- a/dimos/agents2/skills/conftest.py +++ b/dimos/agents2/skills/conftest.py @@ -66,21 +66,27 @@ def fake_gps_position_stream(): @pytest.fixture def navigation_skill_container(fake_robot, fake_video_stream): - with NavigationSkillContainer(fake_robot, fake_video_stream) as container: - yield container + container = NavigationSkillContainer(fake_robot, fake_video_stream) + container.start() + yield container + container.stop() @pytest.fixture def gps_nav_skill_container(fake_gps_robot, fake_gps_position_stream): - with GpsNavSkillContainer(fake_gps_robot, fake_gps_position_stream) as container: - yield container + container = GpsNavSkillContainer(fake_gps_robot, fake_gps_position_stream) + container.start() + yield container + container.stop() @pytest.fixture def google_maps_skill_container(fake_gps_robot, fake_gps_position_stream, mocker): - with GoogleMapsSkillContainer(fake_gps_robot, fake_gps_position_stream) as container: - container._client = mocker.MagicMock() - yield container + container = GoogleMapsSkillContainer(fake_gps_robot, fake_gps_position_stream) + container.start() + container._client = mocker.MagicMock() + yield container + container.stop() @pytest.fixture diff --git a/dimos/agents2/skills/google_maps_skill_container.py b/dimos/agents2/skills/google_maps_skill_container.py index 167782fd74..ddf64cbef0 100644 --- a/dimos/agents2/skills/google_maps_skill_container.py +++ b/dimos/agents2/skills/google_maps_skill_container.py @@ -16,6 +16,7 @@ from typing import Any, Optional, Union from reactivex import Observable +from dimos.core.resource import Resource from dimos.mapping.google_maps.google_maps import GoogleMaps from dimos.mapping.osm.current_location_map import CurrentLocationMap from dimos.mapping.types import LatLon @@ -28,7 +29,7 @@ logger = setup_logger(__file__) -class GoogleMapsSkillContainer(SkillContainer): +class GoogleMapsSkillContainer(SkillContainer, Resource): _robot: Robot _disposables: CompositeDisposable _latest_location: Optional[LatLon] @@ -45,15 +46,13 @@ def __init__(self, robot: Robot, position_stream: Observable[LatLon]): self._client = GoogleMaps() self._started = False - def __enter__(self) -> "GoogleMapsSkillContainer": + def start(self) -> None: self._started = True self._disposables.add(self._position_stream.subscribe(self._on_gps_location)) - return self - def __exit__(self, exc_type, exc_val, exc_tb): + def stop(self) -> None: self._disposables.dispose() - self.stop() - return False + super().stop() def _on_gps_location(self, location: LatLon) -> None: self._latest_location = location diff --git a/dimos/agents2/skills/gps_nav_skill.py b/dimos/agents2/skills/gps_nav_skill.py index dd29e7189d..dedda933ca 100644 --- a/dimos/agents2/skills/gps_nav_skill.py +++ b/dimos/agents2/skills/gps_nav_skill.py @@ -16,6 +16,7 @@ from typing import Optional from reactivex import Observable +from dimos.core.resource import Resource from dimos.mapping.google_maps.google_maps import GoogleMaps from dimos.mapping.osm.current_location_map import CurrentLocationMap from dimos.mapping.types import LatLon @@ -30,7 +31,7 @@ logger = setup_logger(__file__) -class GpsNavSkillContainer(SkillContainer): +class GpsNavSkillContainer(SkillContainer, Resource): _robot: Robot _disposables: CompositeDisposable _latest_location: Optional[LatLon] @@ -49,15 +50,13 @@ def __init__(self, robot: Robot, position_stream: Observable[LatLon]): self._started = False self._max_valid_distance = 50000 - def __enter__(self) -> "GpsNavSkillContainer": + def start(self) -> None: self._started = True self._disposables.add(self._position_stream.subscribe(self._on_gps_location)) - return self - def __exit__(self, exc_type, exc_val, exc_tb): + def stop(self) -> None: self._disposables.dispose() - self.stop() - return False + super().stop() def _on_gps_location(self, location: LatLon) -> None: self._latest_location = location diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 7a2dd65edd..18558515e6 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -15,10 +15,10 @@ import time from typing import Any, Optional -import cv2 from reactivex import Observable from reactivex.disposable import CompositeDisposable, Disposable +from dimos.core.resource import Resource from dimos.models.qwen.video_query import BBox from dimos.models.vl.qwen import QwenVlModel from dimos.msgs.geometry_msgs import PoseStamped @@ -31,12 +31,11 @@ from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler from dimos.navigation.bt_navigator.navigator import NavigatorState -from reactivex.disposable import Disposable, CompositeDisposable logger = setup_logger(__file__) -class NavigationSkillContainer(SkillContainer): +class NavigationSkillContainer(SkillContainer, Resource): _robot: UnitreeRobot _disposables: CompositeDisposable _latest_image: Optional[Image] @@ -53,16 +52,14 @@ def __init__(self, robot: UnitreeRobot, video_stream: Observable[Image]): self._started = False self._vl_model = QwenVlModel() - def __enter__(self) -> "NavigationSkillContainer": + def start(self) -> None: unsub = self._video_stream.subscribe(self._on_video) self._disposables.add(Disposable(unsub) if callable(unsub) else unsub) self._started = True - return self - def __exit__(self, exc_type, exc_val, exc_tb): + def stop(self) -> None: self._disposables.dispose() - self.stop() - return False + super().stop() def _on_video(self, image: Image) -> None: self._latest_image = image diff --git a/dimos/agents2/spec.py b/dimos/agents2/spec.py index 1d1707c4e2..889092bad3 100644 --- a/dimos/agents2/spec.py +++ b/dimos/agents2/spec.py @@ -160,13 +160,11 @@ def publish(self, msg: AnyMessage): if self.transport: self.transport.publish(self.config.agent_topic, msg) - @rpc - @abstractmethod - def start(self): ... + def start(self) -> None: + super().start() - @rpc - @abstractmethod - def stop(self): ... + def stop(self) -> None: + super().stop() @rpc @abstractmethod diff --git a/dimos/agents2/temp/run_unitree_agents2.py b/dimos/agents2/temp/run_unitree_agents2.py index 8f0b9ccdea..29b9d4c978 100644 --- a/dimos/agents2/temp/run_unitree_agents2.py +++ b/dimos/agents2/temp/run_unitree_agents2.py @@ -170,7 +170,7 @@ def shutdown(self): if self.robot: try: - # WebRTC robot doesn't have a stop method + self.robot.stop() logger.info("Robot connection closed") except Exception as e: logger.error(f"Error stopping robot: {e}") diff --git a/dimos/agents2/temp/test_unitree_agent_query.py b/dimos/agents2/temp/test_unitree_agent_query.py index 81cf263739..bd2843ac19 100644 --- a/dimos/agents2/temp/test_unitree_agent_query.py +++ b/dimos/agents2/temp/test_unitree_agent_query.py @@ -75,7 +75,6 @@ async def test_async_query(): else: logger.warning("Future not completed yet") - # Clean up agent.stop() return future @@ -132,8 +131,6 @@ def run_loop(): traceback.print_exc() - # Clean up properly - # First stop the agent (this should stop its internal loop if any) agent.stop() # Then stop the manually created event loop thread if we created one diff --git a/dimos/agents2/temp/webcam_agent.py b/dimos/agents2/temp/webcam_agent.py index fed01ed96f..17a68a55ad 100644 --- a/dimos/agents2/temp/webcam_agent.py +++ b/dimos/agents2/temp/webcam_agent.py @@ -18,32 +18,23 @@ This is the migrated version using the new LangChain-based agent system. """ -import asyncio # Needed for event loop management in setup_agent -import os -import sys import time -from pathlib import Path from threading import Thread import reactivex as rx import reactivex.operators as ops -from dotenv import load_dotenv from dimos.agents2 import Agent, Output, Reducer, Stream, skill from dimos.agents2.cli.human import HumanInput from dimos.agents2.spec import Model, Provider -from dimos.core import LCMTransport, Module, pLCMTransport, start +from dimos.core import LCMTransport, Module, start, rpc from dimos.hardware.camera import zed from dimos.hardware.camera.module import CameraModule from dimos.hardware.camera.webcam import Webcam from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -# from dimos.hardware.webcam import ColorCameraModule, Webcam from dimos.msgs.sensor_msgs import CameraInfo, Image from dimos.protocol.skill.test_coordinator import SkillContainerTest -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer -from dimos.utils.logging_config import setup_logger from dimos.web.robot_web_interface import RobotWebInterface @@ -61,7 +52,10 @@ def __init__(self): self.agent_response = rx.subject.Subject() self.human_query = rx.subject.Subject() + @rpc def start(self): + super().start() + text_streams = { "agent_responses": self.agent_response, } @@ -72,15 +66,18 @@ def start(self): audio_subject=rx.subject.Subject(), ) - self.web_interface.query_stream.subscribe(self.human_query.on_next) + unsub = self.web_interface.query_stream.subscribe(self.human_query.on_next) + self._disposables.add(unsub) self.thread = Thread(target=self.web_interface.run, daemon=True) self.thread.start() + @rpc def stop(self): if self.web_interface: self.web_interface.stop() if self.thread: + # TODO, you can't just wait for a server to close, you have to signal it to end. self.thread.join(timeout=1.0) super().stop() @@ -148,6 +145,8 @@ def main(): while True: time.sleep(1) + # webcam.stop() + if __name__ == "__main__": main() diff --git a/dimos/agents2/test_mock_agent.py b/dimos/agents2/test_mock_agent.py index 3609803f11..5ade99f9ab 100644 --- a/dimos/agents2/test_mock_agent.py +++ b/dimos/agents2/test_mock_agent.py @@ -14,18 +14,16 @@ """Test agent with FakeChatModel for unit testing.""" -import os import time import pytest from dimos_lcm.sensor_msgs import CameraInfo -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolCall +from langchain_core.messages import AIMessage, HumanMessage from dimos.agents2.agent import Agent from dimos.agents2.testing import MockModel from dimos.core import LCMTransport, start -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 from dimos.msgs.sensor_msgs import Image from dimos.protocol.skill.test_coordinator import SkillContainerTest from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule diff --git a/dimos/conftest.py b/dimos/conftest.py index e2a8a3ec36..7e52a6191f 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -52,7 +52,7 @@ def monitor_threads(request): if not new_leaks: return - thread_names = [t.name for f in new_leaks] + thread_names = [t.name for t in new_leaks] pytest.fail( f"Non-closed threads before or during this test. The thread names: {thread_names}. " diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 8b00b7da8f..a6dc1aed0c 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -22,10 +22,31 @@ from dimos.protocol.rpc.spec import RPCSpec from dimos.protocol.tf import LCMTF, TF, PubSubTF, TFConfig, TFSpec -__all__ = ["TF", "LCMTF", "PubSubTF", "TFSpec", "TFConfig"] - - -def patch_actor(actor, cls): ... +__all__ = [ + "DimosCluster", + "In", + "LCMRPC", + "LCMTF", + "LCMTransport", + "Module", + "ModuleBase", + "ModuleConfig", + "Out", + "PubSubTF", + "RPCSpec", + "RemoteIn", + "RemoteOut", + "SHMTransport", + "TF", + "TFConfig", + "TFSpec", + "Transport", + "ZenohTransport", + "pLCMTransport", + "pSHMTransport", + "rpc", + "start", +] class RPCClient: @@ -185,20 +206,20 @@ def check_worker_memory(): def close_all(): import time - # Get the event loop before shutting down - loop = dask_client.loop - # Close cluster and client ActorRegistry.clear() - local_cluster.close() - dask_client.close() - # Stop the Tornado IOLoop to clean up IO loop and Profile threads - if loop and hasattr(loop, "add_callback") and hasattr(loop, "stop"): - try: - loop.add_callback(loop.stop) - except Exception: - pass + # Close client first to signal workers to shut down gracefully + try: + dask_client.close(timeout=2) + except Exception: + pass + + # Then close the cluster + try: + local_cluster.close(timeout=2) + except Exception: + pass # Shutdown the Dask offload thread pool try: diff --git a/dimos/core/dimos.py b/dimos/core/dimos.py new file mode 100644 index 0000000000..d286284fec --- /dev/null +++ b/dimos/core/dimos.py @@ -0,0 +1,56 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Type, TypeVar + +from dimos import core +from dimos.core import DimosCluster, Module +from dimos.core.resource import Resource + +T = TypeVar("T", bound="Module") + + +class Dimos(Resource): + _client: Optional[DimosCluster] = None + _n: Optional[int] = None + _memory_limit: str = "auto" + _deployed_modules: dict[Type[Module], Module] = {} + + def __init__(self, n: Optional[int] = None, memory_limit: str = "auto"): + self._n = n + self._memory_limit = memory_limit + + def start(self) -> None: + self._client = core.start(self._n, self._memory_limit) + + def stop(self) -> None: + for module in reversed(self._deployed_modules.values()): + module.stop() + + self._client.close_all() + + def deploy(self, module_class: Type[T], *args, **kwargs) -> T: + if not self._client: + raise ValueError("Not started") + + module = self._client.deploy(module_class, *args, **kwargs) + self._deployed_modules[module_class] = module + return module + + def start_all_modules(self) -> None: + for module in self._deployed_modules.values(): + module.start() + + def get_instance(self, module: Type[T]) -> T | None: + return self._deployed_modules.get(module) diff --git a/dimos/core/module.py b/dimos/core/module.py index 0385aad041..5cea554072 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -29,6 +29,7 @@ from dimos.core import colors from dimos.core.core import T, rpc +from dimos.core.resource import Resource from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec from dimos.protocol.service import Configurable @@ -70,7 +71,7 @@ class ModuleConfig: tf_transport: type[TFSpec] = LCMTF -class ModuleBase(Configurable[ModuleConfig], SkillContainer): +class ModuleBase(Configurable[ModuleConfig], SkillContainer, Resource): _rpc: Optional[RPCSpec] = None _tf: Optional[TFSpec] = None _loop: Optional[asyncio.AbstractEventLoop] = None @@ -94,6 +95,15 @@ def __init__(self, *args, **kwargs): except ValueError: ... + @rpc + def start(self) -> None: + pass + + @rpc + def stop(self) -> None: + self._close_module() + super().stop() + def _close_module(self): self._close_rpc() if hasattr(self, "_loop") and self._loop_thread: diff --git a/dimos/core/resource.py b/dimos/core/resource.py new file mode 100644 index 0000000000..3d69f50bb4 --- /dev/null +++ b/dimos/core/resource.py @@ -0,0 +1,23 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class Resource(ABC): + @abstractmethod + def start(self) -> None: ... + + @abstractmethod + def stop(self) -> None: ... diff --git a/dimos/core/test_modules.py b/dimos/core/test_modules.py new file mode 100644 index 0000000000..cd2687c41f --- /dev/null +++ b/dimos/core/test_modules.py @@ -0,0 +1,267 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that all Module subclasses implement required resource management methods.""" + +import ast +import inspect +from pathlib import Path +from typing import Dict, List, Set, Tuple + +import pytest + +from dimos.core.module import Module + + +class ModuleVisitor(ast.NodeVisitor): + """AST visitor to find classes and their base classes.""" + + def __init__(self, filepath: str): + self.filepath = filepath + self.classes: List[ + Tuple[str, List[str], Set[str]] + ] = [] # (class_name, base_classes, methods) + + def visit_ClassDef(self, node: ast.ClassDef): + """Visit a class definition.""" + # Get base class names + base_classes = [] + for base in node.bases: + if isinstance(base, ast.Name): + base_classes.append(base.id) + elif isinstance(base, ast.Attribute): + # Handle cases like dimos.core.Module + parts = [] + current = base + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + base_classes.append(".".join(reversed(parts))) + + # Get method names defined in this class + methods = set() + for item in node.body: + if isinstance(item, ast.FunctionDef): + methods.add(item.name) + + self.classes.append((node.name, base_classes, methods)) + self.generic_visit(node) + + +def get_import_aliases(tree: ast.AST) -> Dict[str, str]: + """Extract import aliases from the AST.""" + aliases = {} + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + key = alias.asname if alias.asname else alias.name + aliases[key] = alias.name + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + for alias in node.names: + key = alias.asname if alias.asname else alias.name + full_name = f"{module}.{alias.name}" if module else alias.name + aliases[key] = full_name + + return aliases + + +def is_module_subclass( + base_classes: List[str], aliases: Dict[str, str], class_hierarchy: Dict[str, List[str]] = None +) -> bool: + """Check if any base class is or resolves to dimos.core.Module or its variants (recursively).""" + target_classes = { + "Module", + "ModuleBase", + "DaskModule", + "dimos.core.Module", + "dimos.core.ModuleBase", + "dimos.core.DaskModule", + "dimos.core.module.Module", + "dimos.core.module.ModuleBase", + "dimos.core.module.DaskModule", + } + + def check_base(base: str, visited: Set[str] = None) -> bool: + if visited is None: + visited = set() + + # Avoid infinite recursion + if base in visited: + return False + visited.add(base) + + # Check direct match + if base in target_classes: + return True + + # Check if it's an alias + if base in aliases: + resolved = aliases[base] + if resolved in target_classes: + return True + # Continue checking with resolved name + base = resolved + + # If we have a class hierarchy, recursively check parent classes + if class_hierarchy and base in class_hierarchy: + for parent_base in class_hierarchy[base]: + if check_base(parent_base, visited): + return True + + return False + + for base in base_classes: + if check_base(base): + return True + + return False + + +def scan_file( + filepath: Path, class_hierarchy: Dict[str, List[str]] = None +) -> List[Tuple[str, str, bool, bool, Set[str]]]: + """ + Scan a Python file for Module subclasses. + + Returns: + List of (class_name, filepath, has_start, has_stop, forbidden_methods) + """ + forbidden_method_names = {"acquire", "release", "open", "close", "shutdown", "clean", "cleanup"} + + try: + with open(filepath, "r", encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=str(filepath)) + aliases = get_import_aliases(tree) + + visitor = ModuleVisitor(str(filepath)) + visitor.visit(tree) + + results = [] + for class_name, base_classes, methods in visitor.classes: + if is_module_subclass(base_classes, aliases, class_hierarchy): + has_start = "start" in methods + has_stop = "stop" in methods + forbidden_found = methods & forbidden_method_names + results.append((class_name, str(filepath), has_start, has_stop, forbidden_found)) + + return results + + except (SyntaxError, UnicodeDecodeError): + # Skip files that can't be parsed + return [] + + +def build_class_hierarchy(root_path: Path) -> Dict[str, List[str]]: + """Build a complete class hierarchy by scanning all Python files.""" + hierarchy = {} + + for filepath in root_path.rglob("*.py"): + # Skip __pycache__ and other irrelevant directories + if "__pycache__" in filepath.parts or ".venv" in filepath.parts: + continue + + try: + with open(filepath, "r", encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=str(filepath)) + visitor = ModuleVisitor(str(filepath)) + visitor.visit(tree) + + for class_name, base_classes, _ in visitor.classes: + hierarchy[class_name] = base_classes + + except (SyntaxError, UnicodeDecodeError): + # Skip files that can't be parsed + continue + + return hierarchy + + +def scan_directory(root_path: Path) -> List[Tuple[str, str, bool, bool, Set[str]]]: + """Scan all Python files in the directory tree.""" + # First, build the complete class hierarchy + class_hierarchy = build_class_hierarchy(root_path) + + # Then scan for Module subclasses using the complete hierarchy + results = [] + + for filepath in root_path.rglob("*.py"): + # Skip __pycache__ and other irrelevant directories + if "__pycache__" in filepath.parts or ".venv" in filepath.parts: + continue + + file_results = scan_file(filepath, class_hierarchy) + results.extend(file_results) + + return results + + +def get_all_module_subclasses(): + """Find all Module subclasses in the dimos codebase.""" + # Get the dimos package directory + dimos_file = inspect.getfile(Module) + dimos_path = Path(dimos_file).parent.parent # Go up from dimos/core/module.py to dimos/ + + results = scan_directory(dimos_path) + + # Filter out test modules and base classes + filtered_results = [] + for class_name, filepath, has_start, has_stop, forbidden_methods in results: + # Skip base module classes themselves + if class_name in ("Module", "ModuleBase", "DaskModule"): + continue + + # Skip test-only modules (those defined in test_ files) + if "test_" in Path(filepath).name: + continue + + filtered_results.append((class_name, filepath, has_start, has_stop, forbidden_methods)) + + return filtered_results + + +@pytest.mark.parametrize( + "class_name,filepath,has_start,has_stop,forbidden_methods", + get_all_module_subclasses(), + ids=lambda val: val[0] if isinstance(val, str) else str(val), +) +def test_module_has_start_and_stop(class_name, filepath, has_start, has_stop, forbidden_methods): + """Test that Module subclasses implement start and stop methods and don't use forbidden methods.""" + # Get relative path for better error messages + try: + rel_path = Path(filepath).relative_to(Path.cwd()) + except ValueError: + rel_path = filepath + + errors = [] + + # Check for missing required methods + if not has_start: + errors.append("missing required method: start") + if not has_stop: + errors.append("missing required method: stop") + + # Check for forbidden methods + if forbidden_methods: + forbidden_list = ", ".join(sorted(forbidden_methods)) + errors.append(f"has forbidden method(s): {forbidden_list}") + + assert not errors, f"{class_name} in {rel_path} has issues:\n - " + "\n - ".join(errors) diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index 8a2101a9c7..59fa806716 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -188,6 +188,8 @@ def test_subscription(dimos, subscriber_class): # ensuring no new messages have passed through assert total_msg_n == subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() + robot.stop() + @pytest.mark.module def test_get_next(dimos): @@ -215,6 +217,7 @@ def test_get_next(dimos): assert subscriber.active_subscribers() == 0 assert next_odom != odom + robot.stop() @pytest.mark.module @@ -249,3 +252,5 @@ def test_hot_getter(dimos): assert isinstance(next_odom, Odometry) assert next_odom != odom subscriber.stop_hot_getter() + + robot.stop() diff --git a/dimos/core/testing.py b/dimos/core/testing.py index da8ff5b0c4..e17b25f41e 100644 --- a/dimos/core/testing.py +++ b/dimos/core/testing.py @@ -17,15 +17,7 @@ import pytest -from dimos.core import ( - In, - LCMTransport, - Module, - Out, - RemoteOut, - rpc, - start, -) +from dimos.core import In, Module, Out, start, rpc from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry @@ -55,11 +47,22 @@ def __init__(self): self._stop_event = Event() self._thread = None + @rpc def start(self): + super().start() + self._thread = Thread(target=self.odomloop) self._thread.start() self.mov.subscribe(self.mov_callback) + @rpc + def stop(self) -> None: + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) + + super().stop() + def odomloop(self): odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) @@ -78,8 +81,3 @@ def odomloop(self): lidarmsg.pubtime = time.perf_counter() self.lidar.publish(lidarmsg) time.sleep(0.1) - - def stop(self): - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=1.0) # Wait up to 1 second for clean shutdown diff --git a/dimos/hardware/camera/webcam.py b/dimos/hardware/camera/webcam.py index 87ba492d6e..7f9c9940a7 100644 --- a/dimos/hardware/camera/webcam.py +++ b/dimos/hardware/camera/webcam.py @@ -12,24 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import queue import threading import time from dataclasses import dataclass, field from functools import cache -from typing import Any, Callable, Generic, Literal, Optional, Protocol, TypeVar +from typing import Literal, Optional import cv2 from dimos_lcm.sensor_msgs import CameraInfo from reactivex import create from reactivex.observable import Observable -from dimos.hardware.camera.spec import ( - CameraConfig, - CameraHardware, -) from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.hardware.camera.spec import CameraConfig, CameraHardware from dimos.utils.reactive import backpressure diff --git a/dimos/hardware/camera/zed/camera.py b/dimos/hardware/camera/zed/camera.py index 6822c19afb..e9f029c845 100644 --- a/dimos/hardware/camera/zed/camera.py +++ b/dimos/hardware/camera/zed/camera.py @@ -610,6 +610,8 @@ def start(self): logger.warning("ZED module already running") return + super().start() + try: # Initialize ZED camera self.zed_camera = ZEDCamera( @@ -671,7 +673,7 @@ def stop(self): self.zed_camera.close() self.zed_camera = None - logger.info("ZED module stopped") + super().stop() def _capture_and_publish(self): """Capture frame and publish all data.""" @@ -868,7 +870,3 @@ def get_pose(self) -> Optional[Dict[str, Any]]: if self.zed_camera and self.enable_tracking: return self.zed_camera.get_pose() return None - - def cleanup(self): - """Clean up resources on module destruction.""" - self.stop() diff --git a/dimos/hardware/fake_zed_module.py b/dimos/hardware/fake_zed_module.py index 4b7c4c2fdf..b0a246ef12 100644 --- a/dimos/hardware/fake_zed_module.py +++ b/dimos/hardware/fake_zed_module.py @@ -57,7 +57,6 @@ def __init__(self, recording_path: str, frame_id: str = "zed_camera", **kwargs): self.recording_path = recording_path self.frame_id = frame_id self._running = False - self._subscriptions = [] # Initialize TF publisher self.tf = TF() @@ -200,6 +199,8 @@ def camera_info_autocast(x): @rpc def start(self): """Start replaying recorded data.""" + super().start() + if self._running: logger.warning("FakeZEDModule already running") return @@ -211,46 +212,55 @@ def start(self): # Subscribe to all streams and publish try: # Color image stream - sub = self._get_color_stream().subscribe( + unsub = self._get_color_stream().subscribe( lambda msg: self.color_image.publish(msg) if self._running else None ) - self._subscriptions.append(sub) + self._disposables.add(unsub) logger.info("Started color image replay stream") except Exception as e: logger.warning(f"Color image stream not available: {e}") try: # Depth image stream - sub = self._get_depth_stream().subscribe( + unsub = self._get_depth_stream().subscribe( lambda msg: self.depth_image.publish(msg) if self._running else None ) - self._subscriptions.append(sub) + self._disposables.add(unsub) logger.info("Started depth image replay stream") except Exception as e: logger.warning(f"Depth image stream not available: {e}") try: # Pose stream - sub = self._get_pose_stream().subscribe( + unsub = self._get_pose_stream().subscribe( lambda msg: self._publish_pose(msg) if self._running else None ) - self._subscriptions.append(sub) + self._disposables.add(unsub) logger.info("Started pose replay stream") except Exception as e: logger.warning(f"Pose stream not available: {e}") try: # Camera info stream - sub = self._get_camera_info_stream().subscribe( + unsub = self._get_camera_info_stream().subscribe( lambda msg: self.camera_info.publish(msg) if self._running else None ) - self._subscriptions.append(sub) + self._disposables.add(unsub) logger.info("Started camera info replay stream") except Exception as e: logger.warning(f"Camera info stream not available: {e}") logger.info("FakeZEDModule replay started") + @rpc + def stop(self) -> None: + if not self._running: + return + + self._running = False + + super().stop() + def _publish_pose(self, msg): """Publish pose and TF transform.""" if msg: @@ -268,20 +278,3 @@ def _publish_pose(self, msg): ts=time.time(), ) self.tf.publish(transform) - - @rpc - def stop(self): - """Stop replaying data.""" - if not self._running: - return - - self._running = False - - # Dispose of all subscriptions - for sub in self._subscriptions: - if sub: - sub.dispose() - - self._subscriptions = [] - - logger.info("FakeZEDModule stopped") diff --git a/dimos/hardware/gstreamer_camera.py b/dimos/hardware/gstreamer_camera.py index b4d378ba29..32c2e8304b 100644 --- a/dimos/hardware/gstreamer_camera.py +++ b/dimos/hardware/gstreamer_camera.py @@ -88,24 +88,13 @@ def start(self): logger.warning("GStreamer camera module is already running") return + super().start() + self.should_reconnect = True self._connect() - def _connect(self): - if not self.should_reconnect: - return - - try: - self._create_pipeline() - self._start_pipeline() - self.running = True - logger.info(f"GStreamer TCP camera module connected to {self.host}:{self.port}") - except Exception as e: - logger.error(f"Failed to connect to {self.host}:{self.port}: {e}") - self._schedule_reconnect() - @rpc - def stop(self): + def stop(self) -> None: self.should_reconnect = False self._cleanup_reconnect_timer() @@ -124,7 +113,20 @@ def stop(self): if self.main_loop_thread and self.main_loop_thread != threading.current_thread(): self.main_loop_thread.join(timeout=2.0) - logger.info("GStreamer camera module stopped") + super().stop() + + def _connect(self) -> None: + if not self.should_reconnect: + return + + try: + self._create_pipeline() + self._start_pipeline() + self.running = True + logger.info(f"GStreamer TCP camera module connected to {self.host}:{self.port}") + except Exception as e: + logger.error(f"Failed to connect to {self.host}:{self.port}: {e}") + self._schedule_reconnect() def _cleanup_reconnect_timer(self): if self.reconnect_timer_id: @@ -305,6 +307,3 @@ def _on_new_sample(self, appsink): buffer.unmap(map_info) return Gst.FlowReturn.OK - - def __del__(self): - self.stop() diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 774f70b1c6..71ce4bf04f 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -14,14 +14,11 @@ # dimos/hardware/piper_arm.py -from typing import ( - Optional, - Tuple, -) +from reactivex.disposable import Disposable +from typing import Tuple from piper_sdk import * # from the official Piper SDK import numpy as np import time -import subprocess import kinpy as kp import sys import termios @@ -40,7 +37,7 @@ from dimos.core import In, Module, rpc from dimos_lcm.geometry_msgs import Pose, Vector3, Twist -logger = setup_logger("dimos.hardware.piper_arm") +logger = setup_logger(__file__) class PiperArm: @@ -362,10 +359,14 @@ def __init__(self, arm, period=0.01, *args, **kwargs): self.period = period self.latest_cmd = None self.last_cmd_time = None + self._thread = None @rpc def start(self): - self.cmd_vel.subscribe(self.handle_cmd_vel) + super().start() + + unsub = self.cmd_vel.subscribe(self.handle_cmd_vel) + self._disposables.add(Disposable(unsub)) def control_loop(): while True: @@ -434,8 +435,15 @@ def control_loop(): ) time.sleep(self.period) - thread = threading.Thread(target=control_loop, daemon=True) - thread.start() + self._thread = threading.Thread(target=control_loop, daemon=True) + self._thread.start() + + @rpc + def stop(self) -> None: + if self._thread: + # TODO: trigger the thread to stop + self._thread.join(2) + super().stop() def handle_cmd_vel(self, cmd_vel: Twist): self.latest_cmd = cmd_vel @@ -456,6 +464,8 @@ def run_velocity_controller(): while True: time.sleep(1) + # velocity_controller.stop() + if __name__ == "__main__": arm = PiperArm() diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index a3fe0a17f9..9d2d77a0fa 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -26,6 +26,7 @@ import numpy as np +from reactivex.disposable import Disposable from dimos.core import Module, In, Out, rpc from dimos.msgs.sensor_msgs import Image from dimos.msgs.geometry_msgs import Vector3, Pose, Quaternion @@ -221,10 +222,15 @@ def __init__( @rpc def start(self): """Start the manipulation module.""" - # Subscribe to camera data - self.rgb_image.subscribe(self._on_rgb_image) - self.depth_image.subscribe(self._on_depth_image) - self.camera_info.subscribe(self._on_camera_info) + + unsub = self.rgb_image.subscribe(self._on_rgb_image) + self._disposables.add(Disposable(unsub)) + + unsub = self.depth_image.subscribe(self._on_depth_image) + self._disposables.add(Disposable(unsub)) + + unsub = self.camera_info.subscribe(self._on_camera_info) + self._disposables.add(Disposable(unsub)) logger.info("Manipulation module started") @@ -237,6 +243,11 @@ def stop(self): self.task_thread.join(timeout=5.0) self.reset_to_idle() + + if self.detector and hasattr(self.detector, "cleanup"): + self.detector.cleanup() + self.arm.disable() + logger.info("Manipulation module stopped") def _on_rgb_image(self, msg: Image): @@ -935,10 +946,3 @@ def get_place_target_pose(self) -> Optional[Pose]: ) return place_pose - - @rpc - def cleanup(self): - """Clean up resources on module destruction.""" - if self.detector and hasattr(self.detector, "cleanup"): - self.detector.cleanup() - self.arm.disable() diff --git a/dimos/mapping/osm/demo_osm.py b/dimos/mapping/osm/demo_osm.py index 638b48d8cd..c791008510 100644 --- a/dimos/mapping/osm/demo_osm.py +++ b/dimos/mapping/osm/demo_osm.py @@ -22,13 +22,12 @@ from dimos.agents2.cli.human import HumanInput from dimos.agents2.constants import AGENT_SYSTEM_PROMPT_PATH from dimos.agents2.skills.osm import OsmSkillContainer +from dimos.core.resource import Resource from dimos.mapping.types import LatLon from dimos.robot.robot import Robot from dimos.robot.utils.robot_debugger import RobotDebugger from dimos.utils.logging_config import setup_logger -from contextlib import ExitStack - logger = setup_logger(__file__) load_dotenv() @@ -41,29 +40,33 @@ class FakeRobot(Robot): pass -class UnitreeAgents2Runner: +class UnitreeAgents2Runner(Resource): def __init__(self): self._robot = None self._agent = None - self._exit_stack = ExitStack() + self._robot_debugger = None + self._osm_skill_container = None - def __enter__(self): + def start(self) -> None: self._robot = FakeRobot() self._agent = Agent(system_prompt=SYSTEM_PROMPT) - self._agent.register_skills( - self._exit_stack.enter_context(OsmSkillContainer(self._robot, _get_fake_location())) - ) + self._osm_skill_container = OsmSkillContainer(self._robot, _get_fake_location()) + self._osm_skill_container.__enter__() + self._agent.register_skills(self._osm_skill_container) self._agent.register_skills(HumanInput()) self._agent.run_implicit_skill("human") - self._exit_stack.enter_context(self._agent) + self._agent.start() self._agent.loop_thread() - self._exit_stack.enter_context(RobotDebugger(self._robot)) - - return self + self._robot_debugger = RobotDebugger(self._robot) + self._robot_debugger.start() - def __exit__(self, exc_type, exc_val, exc_tb): - self._exit_stack.close() - return False + def stop(self) -> None: + if self._robot_debugger: + self._robot_debugger.stop() + if self._osm_skill_container: + self._osm_skill_container.__exit__(None, None, None) + if self._agent: + self._agent.stop() def run(self): while True: @@ -74,8 +77,10 @@ def run(self): def main(): - with UnitreeAgents2Runner() as runner: - runner.run() + runner = UnitreeAgents2Runner() + runner.start() + runner.run() + runner.stop() def _get_fake_location() -> Observable[LatLon]: diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py index aaafc32ac7..f498f2ec3f 100644 --- a/dimos/navigation/bbox_navigation.py +++ b/dimos/navigation/bbox_navigation.py @@ -18,6 +18,7 @@ from dimos_lcm.sensor_msgs import CameraInfo from dimos.utils.logging_config import setup_logger import logging +from reactivex.disposable import Disposable logger = setup_logger(__name__, level=logging.DEBUG) @@ -36,10 +37,17 @@ def __init__(self, goal_distance: float = 1.0): @rpc def start(self): - self.camera_info.subscribe( + unsub = self.camera_info.subscribe( lambda msg: setattr(self, "camera_intrinsics", [msg.K[0], msg.K[4], msg.K[2], msg.K[5]]) ) - self.detection2d.subscribe(self._on_detection) + self._disposables.add(Disposable(unsub)) + + unsub = self.detection2d.subscribe(self._on_detection) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + super().stop() def _on_detection(self, det: Detection2DArray): if det.detections_length == 0 or not self.camera_intrinsics: diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py index 9bde5d9ca5..33d516106f 100644 --- a/dimos/navigation/bt_navigator/navigator.py +++ b/dimos/navigation/bt_navigator/navigator.py @@ -29,6 +29,7 @@ from dimos_lcm.std_msgs import String from dimos.navigation.bt_navigator.goal_validator import find_safe_goal from dimos.navigation.bt_navigator.recovery_server import RecoveryServer +from reactivex.disposable import Disposable from dimos.protocol.tf import TF from dimos.utils.logging_config import setup_logger from dimos_lcm.std_msgs import Bool @@ -65,7 +66,7 @@ class BehaviorTreeNavigator(Module): global_costmap: In[OccupancyGrid] = None # LCM outputs - goal: Out[PoseStamped] = None + target: Out[PoseStamped] = None goal_reached: Out[Bool] = None navigation_state: Out[String] = None @@ -122,11 +123,17 @@ def __init__( @rpc def start(self): - """Start the navigator module.""" + super().start() + # Subscribe to inputs - self.odom.subscribe(self._on_odom) - self.goal_request.subscribe(self._on_goal_request) - self.global_costmap.subscribe(self._on_costmap) + unsub = self.odom.subscribe(self._on_odom) + self._disposables.add(Disposable(unsub)) + + unsub = self.goal_request.subscribe(self._on_goal_request) + self._disposables.add(Disposable(unsub)) + + unsub = self.global_costmap.subscribe(self._on_costmap) + self._disposables.add(Disposable(unsub)) # Start control thread self.stop_event.clear() @@ -135,6 +142,18 @@ def start(self): logger.info("Navigator started") + @rpc + def stop(self) -> None: + """Clean up resources including stopping the control thread.""" + + self.stop_navigation() + + self.stop_event.set() + if self.control_thread and self.control_thread.is_alive(): + self.control_thread.join(timeout=2.0) + + super().stop() + @rpc def cancel_goal(self) -> bool: """ @@ -143,22 +162,9 @@ def cancel_goal(self) -> bool: Returns: True if goal was cancelled, False if no goal was active """ - self.stop() + self.stop_navigation() return True - @rpc - def cleanup(self): - """Clean up resources including stopping the control thread.""" - # First stop navigation - self.stop() - - # Then clean up the control thread - self.stop_event.set() - if self.control_thread and self.control_thread.is_alive(): - self.control_thread.join(timeout=2.0) - - logger.info("Navigator cleanup complete") - @rpc def set_goal(self, goal: PoseStamped) -> bool: """ @@ -292,7 +298,7 @@ def _control_loop(self): frame_id=goal.frame_id, ts=goal.ts, ) - self.goal.publish(safe_goal) + self.target.publish(safe_goal) self.current_goal = safe_goal else: logger.warning("Could not find safe goal position, cancelling goal") @@ -303,7 +309,7 @@ def _control_loop(self): reached_msg = Bool() reached_msg.data = True self.goal_reached.publish(reached_msg) - self.stop() + self.stop_navigation() self._goal_reached = True logger.info("Goal reached, resetting local planner") @@ -322,7 +328,7 @@ def is_goal_reached(self) -> bool: """ return self._goal_reached - def stop(self): + def stop_navigation(self) -> None: """Stop navigation and return to IDLE state.""" with self.goal_lock: self.current_goal = None diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index 4daee48002..83f16051a0 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -37,7 +37,7 @@ def explorer(): yield explorer # Cleanup after test try: - explorer.cleanup() + explorer.stop() except: pass @@ -161,6 +161,8 @@ def test_frontier_detection_with_office_lidar(explorer, quick_costmap): else: print("No frontiers detected - map may be fully explored or parameters too restrictive") + explorer.stop() # TODO: this should be a in try-finally + def test_exploration_goal_selection(explorer): """Test the complete exploration goal selection pipeline.""" @@ -193,6 +195,8 @@ def test_exploration_goal_selection(explorer): else: print("No exploration goal selected - map may be fully explored") + explorer.stop() # TODO: this should be a in try-finally + def test_exploration_session_reset(explorer): """Test exploration session reset functionality.""" @@ -222,6 +226,7 @@ def test_exploration_session_reset(explorer): assert explorer.no_gain_counter == 0, "No-gain counter should be reset" print("Exploration session reset successfully") + explorer.stop() # TODO: this should be a in try-finally def test_frontier_ranking(explorer): @@ -267,6 +272,8 @@ def test_frontier_ranking(explorer): else: print("No frontiers found for ranking test") + explorer.stop() # TODO: this should be a in try-finally + def test_exploration_with_no_gain_detection(): """Test information gain detection and exploration termination.""" @@ -301,10 +308,8 @@ def test_exploration_with_no_gain_detection(): # Should have stopped due to no information gain assert goal is None, "Exploration should stop after no-gain threshold" assert explorer.no_gain_counter == 0, "Counter should reset after stopping" - - print("No-gain detection test passed") finally: - explorer.cleanup() + explorer.stop() @pytest.mark.vis @@ -388,10 +393,9 @@ def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]: # Display the image base_image.show(title="Frontier Detection - Office Lidar") - print("Visualization displayed. Close the image window to continue.") finally: - explorer.cleanup() + explorer.stop() def test_performance_timing(): @@ -442,7 +446,7 @@ def test_performance_timing(): print(f" Goal selection: {goal_time:.4f}s") print(f" Frontiers found: {len(frontiers)}") finally: - explorer.cleanup() + explorer.stop() # Check that larger maps take more time (expected behavior) # But verify times are reasonable diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 437aab3e3b..5acbf7b5bf 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -33,6 +33,7 @@ from dimos.utils.logging_config import setup_logger from dimos_lcm.std_msgs import Bool from dimos.utils.transform_utils import get_distance +from reactivex.disposable import Disposable logger = setup_logger("dimos.robot.unitree.frontier_exploration") @@ -90,8 +91,8 @@ class WavefrontFrontierExplorer(Module): """ # LCM inputs - costmap: In[OccupancyGrid] = None - odometry: In[PoseStamped] = None + global_costmap: In[OccupancyGrid] = None + odom: In[PoseStamped] = None goal_reached: In[Bool] = None explore_cmd: In[Bool] = None stop_explore_cmd: In[Bool] = None @@ -152,29 +153,30 @@ def __init__( @rpc def start(self): - """Start the frontier exploration module.""" - # Subscribe to inputs - self.costmap.subscribe(self._on_costmap) - self.odometry.subscribe(self._on_odometry) + super().start() + + unsub = self.global_costmap.subscribe(self._on_costmap) + self._disposables.add(Disposable(unsub)) + + unsub = self.odom.subscribe(self._on_odometry) + self._disposables.add(Disposable(unsub)) - # Subscribe to goal_reached if available if self.goal_reached.transport is not None: - self.goal_reached.subscribe(self._on_goal_reached) + unsub = self.goal_reached.subscribe(self._on_goal_reached) + self._disposables.add(Disposable(unsub)) - # Subscribe to exploration commands if self.explore_cmd.transport is not None: - self.explore_cmd.subscribe(self._on_explore_cmd) - if self.stop_explore_cmd.transport is not None: - self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) + unsub = self.explore_cmd.subscribe(self._on_explore_cmd) + self._disposables.add(Disposable(unsub)) - logger.info("WavefrontFrontierExplorer started") + if self.stop_explore_cmd.transport is not None: + unsub = self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) + self._disposables.add(Disposable(unsub)) @rpc - def cleanup(self): - """Clean up resources.""" + def stop(self) -> None: self.stop_exploration() - self._close_module() - logger.info("WavefrontFrontierExplorer cleanup complete") + super().stop() def _on_costmap(self, msg: OccupancyGrid): """Handle incoming costmap messages.""" diff --git a/dimos/navigation/global_planner/__init__.py b/dimos/navigation/global_planner/__init__.py index 4b158f73a1..0496f586b9 100644 --- a/dimos/navigation/global_planner/__init__.py +++ b/dimos/navigation/global_planner/__init__.py @@ -1,2 +1,2 @@ -from dimos.navigation.global_planner.planner import AstarPlanner, Planner +from dimos.navigation.global_planner.planner import AstarPlanner from dimos.navigation.global_planner.algo import astar diff --git a/dimos/navigation/global_planner/planner.py b/dimos/navigation/global_planner/planner.py index 984873f67a..08a00596aa 100644 --- a/dimos/navigation/global_planner/planner.py +++ b/dimos/navigation/global_planner/planner.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod -from dataclasses import dataclass from typing import Optional from dimos.core import In, Module, Out, rpc @@ -22,8 +20,9 @@ from dimos.navigation.global_planner.algo import astar from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion +from reactivex.disposable import Disposable -logger = setup_logger("dimos.robot.unitree.global_planner") +logger = setup_logger(__file__) import math from dimos.msgs.geometry_msgs import Quaternion, Vector3 @@ -140,16 +139,7 @@ def resample_path(path: Path, spacing: float) -> Path: return Path(frame_id=path.frame_id, poses=resampled) -@dataclass -class Planner(Module): - target: In[PoseStamped] = None - path: Out[Path] = None - - def __init__(self): - Module.__init__(self) - - -class AstarPlanner(Planner): +class AstarPlanner(Module): # LCM inputs target: In[PoseStamped] = None global_costmap: In[OccupancyGrid] = None @@ -167,13 +157,23 @@ def __init__(self): @rpc def start(self): - # Subscribe to inputs - self.target.subscribe(self._on_target) - self.global_costmap.subscribe(self._on_costmap) - self.odom.subscribe(self._on_odom) + super().start() + + unsub = self.target.subscribe(self._on_target) + self._disposables.add(Disposable(unsub)) + + unsub = self.global_costmap.subscribe(self._on_costmap) + self._disposables.add(Disposable(unsub)) + + unsub = self.odom.subscribe(self._on_odom) + self._disposables.add(Disposable(unsub)) logger.info("A* planner started") + @rpc + def stop(self) -> None: + super().stop() + def _on_costmap(self, msg: OccupancyGrid): """Handle incoming costmap messages.""" self.latest_costmap = msg diff --git a/dimos/navigation/local_planner/local_planner.py b/dimos/navigation/local_planner/local_planner.py index ec415e6e51..ac1a6ea744 100644 --- a/dimos/navigation/local_planner/local_planner.py +++ b/dimos/navigation/local_planner/local_planner.py @@ -29,8 +29,9 @@ from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import get_distance, quaternion_to_euler, normalize_angle +from reactivex.disposable import Disposable -logger = setup_logger("dimos.robot.local_planner") +logger = setup_logger(__file__) class BaseLocalPlanner(Module): @@ -89,13 +90,21 @@ def __init__( @rpc def start(self): - """Start the local planner module.""" - # Subscribe to inputs - self.local_costmap.subscribe(self._on_costmap) - self.odom.subscribe(self._on_odom) - self.path.subscribe(self._on_path) + super().start() - logger.info("Local planner module started") + unsub = self.local_costmap.subscribe(self._on_costmap) + self._disposables.add(Disposable(unsub)) + + unsub = self.odom.subscribe(self._on_odom) + self._disposables.add(Disposable(unsub)) + + unsub = self.path.subscribe(self._on_path) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + self.cancel_planning() + super().stop() def _on_costmap(self, msg: OccupancyGrid): self.latest_costmap = msg @@ -186,11 +195,11 @@ def reset(self): self.latest_path = None self.latest_odom = None self.latest_costmap = None - self.stop() + self.cancel_planning() logger.info("Local planner reset") @rpc - def stop(self): + def cancel_planning(self) -> None: """Stop the local planner and any running threads.""" if self.planning_thread and self.planning_thread.is_alive(): self.stop_planning.set() @@ -198,10 +207,3 @@ def stop(self): self.planning_thread = None stop_cmd = Twist() self.cmd_vel.publish(stop_cmd) - - logger.info("Local planner stopped") - - @rpc - def stop_planner_module(self): - self.stop() - self._close_module() diff --git a/dimos/perception/detection2d/module2D.py b/dimos/perception/detection2d/module2D.py index d11875315f..9eaa2f5aa1 100644 --- a/dimos/perception/detection2d/module2D.py +++ b/dimos/perception/detection2d/module2D.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from dataclasses import dataclass from typing import Any, Callable, Optional @@ -27,11 +28,10 @@ from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection2d.type import ImageDetections2D from dimos.perception.detection2d.detectors import Detector, Yolo2DDetector from dimos.perception.detection2d.detectors.person.yolo import YoloPersonDetector -from dimos.perception.detection2d.type import ( - ImageDetections2D, -) +from dimos.perception.detection2d.type import ImageDetections2D from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.reactive import backpressure @@ -90,13 +90,16 @@ def detection_stream_2d(self) -> Observable[ImageDetections2D]: @rpc def start(self): - self.detection_stream_2d().subscribe( + super().start() + unsub = self.detection_stream_2d().subscribe( lambda det: self.detections.publish(det.to_ros_detection2d_array()) ) + self._disposables.add(unsub) - self.detection_stream_2d().subscribe( + unsub = self.detection_stream_2d().subscribe( lambda det: self.annotations.publish(det.to_foxglove_annotations()) ) + self._disposables.add(unsub) def publish_cropped_images(detections: ImageDetections2D): for index, detection in enumerate(detections[:3]): @@ -106,4 +109,5 @@ def publish_cropped_images(detections: ImageDetections2D): self.detection_stream_2d().subscribe(publish_cropped_images) @rpc - def stop(self): ... + def stop(self) -> None: + super().stop() diff --git a/dimos/perception/detection2d/type/test_detection3dpc.py b/dimos/perception/detection2d/type/test_detection3dpc.py index a25e27d458..374385903d 100644 --- a/dimos/perception/detection2d/type/test_detection3dpc.py +++ b/dimos/perception/detection2d/type/test_detection3dpc.py @@ -58,7 +58,7 @@ def test_detection3dpc(detection3dpc): # def test_point_cloud_properties(detection3dpc): """Test point cloud data and boundaries.""" pc_points = detection3dpc.pointcloud.points() - assert len(pc_points) in [69, 70] + assert len(pc_points) in [68, 69, 70] assert detection3dpc.pointcloud.frame_id == "world", ( f"Expected frame_id 'world', got '{detection3dpc.pointcloud.frame_id}'" ) diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index 7fd5872314..d59165cb06 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -22,7 +22,8 @@ from dimos.msgs.std_msgs import Header from dimos.msgs.sensor_msgs import Image, ImageFormat from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray -from dimos.msgs.geometry_msgs import Vector3, Quaternion, Transform, Pose, PoseStamped +from reactivex.disposable import Disposable +from dimos.msgs.geometry_msgs import Vector3, Quaternion, Transform, Pose from dimos.protocol.tf import TF from dimos.utils.logging_config import setup_logger @@ -101,7 +102,6 @@ def __init__( self._latest_rgb_frame: Optional[np.ndarray] = None self._latest_depth_frame: Optional[np.ndarray] = None self._latest_camera_info: Optional[CameraInfo] = None - self._aligned_frames_subscription = None # Tracking thread control self.tracking_thread: Optional[threading.Thread] = None @@ -119,7 +119,7 @@ def __init__( @rpc def start(self): - """Start the object tracking module and subscribe to LCM streams.""" + super().start() # Subscribe to aligned rgb and depth streams def on_aligned_frames(frames_tuple): @@ -140,7 +140,8 @@ def on_aligned_frames(frames_tuple): buffer_size=2.0, # 2 second buffer match_tolerance=0.5, # 500ms tolerance ) - self._aligned_frames_subscription = aligned_frames.subscribe(on_aligned_frames) + unsub = aligned_frames.subscribe(on_aligned_frames) + self._disposables.add(unsub) # Subscribe to camera info stream separately (doesn't need alignment) def on_camera_info(camera_info_msg: CameraInfo): @@ -154,9 +155,19 @@ def on_camera_info(camera_info_msg: CameraInfo): camera_info_msg.K[5], ] - self.camera_info.subscribe(on_camera_info) + unsub = self.camera_info.subscribe(on_camera_info) + self._disposables.add(Disposable(unsub)) - logger.info("ObjectTracking module started with aligned frame subscription") + @rpc + def stop(self) -> None: + self.stop_track() + + self.stop_tracking.set() + + if self.tracking_thread and self.tracking_thread.is_alive(): + self.tracking_thread.join(timeout=2.0) + + super().stop() @rpc def track( @@ -611,18 +622,3 @@ def _get_depth_from_bbox(self, bbox: List[int], depth_frame: np.ndarray) -> Opti return depth_25th_percentile return None - - @rpc - def cleanup(self): - """Clean up resources.""" - self.stop_track() - - # Ensure thread is stopped - if self.tracking_thread and self.tracking_thread.is_alive(): - self.stop_tracking.set() - self.tracking_thread.join(timeout=2.0) - - # Unsubscribe from aligned frames - if self._aligned_frames_subscription: - self._aligned_frames_subscription.dispose() - self._aligned_frames_subscription = None diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index 481b69e1ac..84b823ce5e 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -24,6 +24,7 @@ from dimos.msgs.sensor_msgs import Image, ImageFormat from dimos.msgs.vision_msgs import Detection2DArray from dimos.utils.logging_config import setup_logger +from reactivex.disposable import Disposable # Import LCM messages from dimos_lcm.vision_msgs import ( @@ -86,7 +87,7 @@ def __init__( @rpc def start(self): - """Start the object tracking module and subscribe to video stream.""" + super().start() def on_frame(frame_msg: Image): arrival_time = time.perf_counter() @@ -94,9 +95,19 @@ def on_frame(frame_msg: Image): self._latest_rgb_frame = frame_msg.data self._frame_arrival_time = arrival_time - self.color_image.subscribe(on_frame) + unsub = self.color_image.subscribe(on_frame) + self._disposables.add(Disposable(unsub)) logger.info("ObjectTracker2D module started") + @rpc + def stop(self) -> None: + self.stop_track() + if self.tracking_thread and self.tracking_thread.is_alive(): + self.stop_tracking_event.set() + self.tracking_thread.join(timeout=2.0) + + super().stop() + @rpc def track(self, bbox: List[float]) -> Dict: """ @@ -286,11 +297,3 @@ def _draw_visualization(self, image: np.ndarray, bbox: List[int]) -> np.ndarray: cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(viz_image, "TRACKING", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) return viz_image - - @rpc - def cleanup(self): - """Clean up resources.""" - self.stop_track() - if self.tracking_thread and self.tracking_thread.is_alive(): - self.stop_tracking_event.set() - self.tracking_thread.join(timeout=2.0) diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py index a5dc96bae9..20b5705c05 100644 --- a/dimos/perception/object_tracker_3d.py +++ b/dimos/perception/object_tracker_3d.py @@ -61,7 +61,6 @@ def __init__(self, **kwargs): self.camera_intrinsics = None self._latest_depth_frame: Optional[np.ndarray] = None self._latest_camera_info: Optional[CameraInfo] = None - self._aligned_frames_subscription = None # TF publisher for tracked object self.tf = TF() @@ -71,7 +70,7 @@ def __init__(self, **kwargs): @rpc def start(self): - """Start the 3D tracking module with depth stream alignment.""" + super().start() # Subscribe to aligned RGB and depth streams def on_aligned_frames(frames_tuple): @@ -92,7 +91,8 @@ def on_aligned_frames(frames_tuple): buffer_size=2.0, # 2 second buffer match_tolerance=0.5, # 500ms tolerance ) - self._aligned_frames_subscription = aligned_frames.subscribe(on_aligned_frames) + unsub = aligned_frames.subscribe(on_aligned_frames) + self._disposables.add(unsub) # Subscribe to camera info def on_camera_info(camera_info_msg: CameraInfo): @@ -109,6 +109,10 @@ def on_camera_info(camera_info_msg: CameraInfo): logger.info("ObjectTracker3D module started with aligned frame subscription") + @rpc + def stop(self) -> None: + super().stop() + def _process_tracking(self): """Override to add 3D detection creation after 2D tracking.""" # Call parent 2D tracking @@ -298,12 +302,3 @@ def _draw_reid_overlay(self, image: np.ndarray) -> np.ndarray: cv2.putText(viz_image, text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) return viz_image - - @rpc - def cleanup(self): - """Clean up resources.""" - super().cleanup() - - if self._aligned_frames_subscription: - self._aligned_frames_subscription.dispose() - self._aligned_frames_subscription = None diff --git a/dimos/perception/person_tracker.py b/dimos/perception/person_tracker.py index fd63cc1794..d5d3e2be09 100644 --- a/dimos/perception/person_tracker.py +++ b/dimos/perception/person_tracker.py @@ -16,6 +16,7 @@ from dimos.perception.detection2d.utils import filter_detections from dimos.perception.common.ibvs import PersonDistanceEstimator from reactivex import Observable, interval +from reactivex.disposable import Disposable from reactivex import operators as ops import numpy as np import cv2 @@ -94,6 +95,8 @@ def __init__( def start(self): """Start the person tracking module and subscribe to LCM streams.""" + super().start() + # Subscribe to video stream def set_video(image_msg: Image): if hasattr(image_msg, "data"): @@ -101,13 +104,19 @@ def set_video(image_msg: Image): else: logger.warning("Received image message without data attribute") - self.video.subscribe(set_video) + unsub = self.video.subscribe(set_video) + self._disposables.add(Disposable(unsub)) # Start periodic processing - interval(self._process_interval).subscribe(lambda _: self._process_frame()) + unsub = interval(self._process_interval).subscribe(lambda _: self._process_frame()) + self._disposables.add(unsub) logger.info("PersonTracking module started and subscribed to LCM streams") + @rpc + def stop(self) -> None: + super().stop() + def _process_frame(self): """Process the latest frame if available.""" if self._latest_frame is None: @@ -250,9 +259,3 @@ def create_stream(self, video_stream: Observable) -> Observable: """ return video_stream.pipe(ops.map(self._process_tracking)) - - @rpc - def cleanup(self): - """Clean up resources.""" - # CUDA cleanup is now handled by WorkerPlugin in dimos.core - pass diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 8205eaba0a..7d93e2e174 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -38,7 +38,7 @@ from dimos.types.vector import Vector from dimos.types.robot_location import RobotLocation -logger = setup_logger("dimos.perception.spatial_memory") +logger = setup_logger(__file__) class SpatialMemory(Module): @@ -52,7 +52,7 @@ class SpatialMemory(Module): """ # LCM inputs - video: In[Image] = None + color_image: In[Image] = None odom: In[PoseStamped] = None def __init__( @@ -176,7 +176,7 @@ def __init__( @rpc def start(self): - """Start the spatial memory module and subscribe to LCM streams.""" + super().start() # Subscribe to LCM streams def set_video(image_msg: Image): @@ -191,8 +191,9 @@ def set_video(image_msg: Image): def set_odom(odom_msg: PoseStamped): self._latest_odom = odom_msg - unsub = self.video.subscribe(set_video) + unsub = self.color_image.subscribe(set_video) self._disposables.add(Disposable(unsub)) + unsub = self.odom.subscribe(set_odom) self._disposables.add(Disposable(unsub)) @@ -200,11 +201,17 @@ def set_odom(odom_msg: PoseStamped): unsub = interval(self._process_interval).subscribe(lambda _: self._process_frame()) self._disposables.add(Disposable(unsub)) - logger.info("SpatialMemory module started and subscribed to LCM streams") - @rpc def stop(self): - self._close_module() + self.stop_continuous_processing() + + # Save data before shutdown + self.save() + + if self._visual_memory: + self._visual_memory.clear() + + super().stop() def _process_frame(self): """Process the latest frame with pose data if available.""" @@ -628,18 +635,6 @@ def get_stats(self) -> Dict[str, int]: """ return {"frame_count": self.frame_count, "stored_frame_count": self.stored_frame_count} - def cleanup(self): - """Clean up resources.""" - # Stop any ongoing processing - self.stop_continuous_processing() - - # Save data if possible - self.save() - - # Log cleanup - if self.vector_db: - logger.info(f"Cleaning up SpatialMemory, stored {self.stored_frame_count} frames") - @rpc def tag_location(self, robot_location: RobotLocation) -> bool: try: diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index c8cf8de26b..cde2b7d45c 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -55,7 +55,7 @@ def spatial_memory(self, temp_dir): ) yield memory # Clean up - memory.cleanup() + memory.stop() def test_spatial_memory_initialization(self, spatial_memory): """Test SpatialMemory initializes correctly with CLIP model.""" diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index dc7d067891..238c1f6545 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -26,7 +26,6 @@ from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf, check_system -from dimos.protocol.service.spec import Service from dimos.utils.deprecation import deprecated from dimos.utils.logging_config import setup_logger diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py index c79b8d57ba..5406e2151f 100644 --- a/dimos/protocol/service/spec.py +++ b/dimos/protocol/service/spec.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from abc import ABC from typing import Generic, Type, TypeVar # Generic type for service configuration @@ -27,8 +27,8 @@ def __init__(self, **kwargs) -> None: class Service(Configurable[ConfigT], ABC): - @abstractmethod - def start(self) -> None: ... + def start(self) -> None: + super().start() - @abstractmethod - def stop(self) -> None: ... + def stop(self) -> None: + super().stop() diff --git a/dimos/protocol/skill/coordinator.py b/dimos/protocol/skill/coordinator.py index cfc889fabc..23d9025a1a 100644 --- a/dimos/protocol/skill/coordinator.py +++ b/dimos/protocol/skill/coordinator.py @@ -34,8 +34,10 @@ from dimos.protocol.skill.type import MsgType, Output, Reducer, Return, SkillMsg, Stream from dimos.protocol.skill.utils import interpret_tool_call_args from dimos.utils.logging_config import setup_logger +from dimos.core.module import Module + -logger = setup_logger("dimos.protocol.skill.coordinator") +logger = setup_logger(__file__) @dataclass @@ -257,9 +259,6 @@ def __str__(self): return capture.get().strip() -from dimos.core.module import Module - - # This class is responsible for managing the lifecycle of skills, # handling skill calls, and coordinating communication between the agent and skills. # @@ -320,6 +319,7 @@ def _ensure_updates_available(self) -> asyncio.Event: @rpc def start(self) -> None: + super().start() self.skill_transport.start() self._transport_unsub_fn = self.skill_transport.subscribe(self.handle_message) @@ -337,6 +337,8 @@ def stop(self) -> None: for container in self._dynamic_containers: container.stop() + super().stop() + def len(self) -> int: return len(self._skills) diff --git a/dimos/protocol/skill/skill.py b/dimos/protocol/skill/skill.py index 9d63689527..6a7d35bcb9 100644 --- a/dimos/protocol/skill/skill.py +++ b/dimos/protocol/skill/skill.py @@ -166,21 +166,9 @@ def stop(self): self._skill_thread_pool.shutdown(wait=True) self._skill_thread_pool = None - # If this container is also a Module, close the module properly - if hasattr(self, "_close_module"): - self._close_module() - elif hasattr(self, "_close_rpc"): - self._close_rpc() - - if hasattr(self, "_loop") and hasattr(self, "_loop_thread") and self._loop_thread: - if self._loop_thread.is_alive(): - self._loop.call_soon_threadsafe(self._loop.stop) - self._loop_thread.join(timeout=2) - self._loop = None - self._loop_thread = None - - if hasattr(self, "_disposables"): - self._disposables.dispose() + # Continue the MRO chain if there's a parent stop() method + if hasattr(super(), "stop"): + super().stop() # TODO: figure out standard args/kwargs passing format, # use same interface as skill coordinator call_skill method diff --git a/dimos/protocol/skill/test_coordinator.py b/dimos/protocol/skill/test_coordinator.py index 3dad606227..65b45c50fa 100644 --- a/dimos/protocol/skill/test_coordinator.py +++ b/dimos/protocol/skill/test_coordinator.py @@ -18,7 +18,7 @@ import pytest -from dimos.core import Module +from dimos.core import Module, rpc from dimos.msgs.sensor_msgs import Image from dimos.protocol.skill.coordinator import SkillCoordinator from dimos.protocol.skill.skill import skill @@ -27,6 +27,14 @@ class SkillContainerTest(Module): + @rpc + def start(self): + super().start() + + @rpc + def stop(self): + super().stop() + @skill() def add(self, x: int, y: int) -> int: """adds x and y.""" @@ -145,5 +153,5 @@ async def test_coordinator_generator(): print("coordinator loop finished") print(skillCoordinator) - container._close_module() + container.stop() skillCoordinator.stop() diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py index bf7f74c321..e12877bdec 100644 --- a/dimos/protocol/tf/tflcmcpp.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -15,7 +15,7 @@ from typing import Optional, Union from datetime import datetime from dimos_lcm import tf -from dimos.protocol.service.lcmservice import LCMConfig, LCMService, Service +from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.protocol.tf.tf import TFSpec, TFConfig from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 diff --git a/dimos/robot/agilex/piper_arm.py b/dimos/robot/agilex/piper_arm.py index bb22db46b5..7dbb2fcbfc 100644 --- a/dimos/robot/agilex/piper_arm.py +++ b/dimos/robot/agilex/piper_arm.py @@ -150,7 +150,6 @@ def stop(self): try: if self.manipulation_interface: self.manipulation_interface.stop() - self.manipulation_interface.cleanup() if self.stereo_camera: self.stereo_camera.stop() diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 30fa248784..18211f65c2 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -28,10 +28,11 @@ class FoxgloveBridge(Module): def __init__(self, *args, shm_channels=None, **kwargs): super().__init__(*args, **kwargs) self.shm_channels = shm_channels or [] - self.start() @rpc def start(self): + super().start() + def run_bridge(): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) @@ -55,3 +56,5 @@ def stop(self): if self._loop and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join(timeout=2) + + super().stop() diff --git a/dimos/robot/nav_bot.py b/dimos/robot/nav_bot.py index c8359c9a1c..e65ed8214b 100644 --- a/dimos/robot/nav_bot.py +++ b/dimos/robot/nav_bot.py @@ -23,6 +23,7 @@ from dimos import core from dimos.core import Module, In, Out, rpc +from dimos.core.resource import Resource from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped, Transform, Vector3 from dimos.msgs.nav_msgs import Odometry from dimos.msgs.sensor_msgs import PointCloud2, Joy, Image @@ -39,6 +40,7 @@ from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos.utils.logging_config import setup_logger from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from reactivex.disposable import Disposable logger = setup_logger("dimos.robot.unitree_webrtc.nav_bot", level=logging.INFO) @@ -65,10 +67,14 @@ def __init__(self, *args, **kwargs): @rpc def start(self): - """Start the navigation module.""" + super().start() if self.goal_reached: - self.goal_reached.subscribe(self._on_goal_reached) - logger.info("NavigationModule started") + unsub = self.goal_reached.subscribe(self._on_goal_reached) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + super().stop() def _on_goal_reached(self, msg: Bool): """Handle goal reached status messages.""" @@ -137,13 +143,13 @@ def go_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: return self.goal_reach time.sleep(0.1) - self.stop() + self.stop_navigation() logger.warning(f"Navigation timed out after {timeout} seconds") return False @rpc - def stop(self) -> bool: + def stop_navigation(self) -> bool: """ Cancel current navigation by publishing to cancel_goal. @@ -180,8 +186,13 @@ def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs): @rpc def start(self): - self.odom.subscribe(self._publish_odom_pose) - logger.info("TopicRemapModule started") + super().start() + unsub = self.odom.subscribe(self._publish_odom_pose) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + super().stop() def _publish_odom_pose(self, msg: Odometry): pose_msg = PoseStamped( @@ -225,7 +236,7 @@ def _publish_odom_pose(self, msg: Odometry): self.tf.publish(sensor_to_base_link_tf, map_to_world_tf) -class NavBot: +class NavBot(Resource): """ NavBot class for navigation-related functionality. Manages ROS bridge and topic remapping for navigation. @@ -251,6 +262,8 @@ def __init__(self, dimos=None, sensor_to_base_link_transform=[0.0, 0.0, 0.0, 0.0 self.lcm = LCM() def start(self): + super().start() + if self.topic_remap_module: self.topic_remap_module.start() logger.info("Topic remap module started") @@ -258,6 +271,18 @@ def start(self): if self.ros_bridge: logger.info("ROS bridge started") + def stop(self) -> None: + logger.info("Shutting down navigation modules...") + + if self.ros_bridge is not None: + try: + self.ros_bridge.shutdown() + logger.info("ROS bridge shut down successfully") + except Exception as e: + logger.error(f"Error shutting down ROS bridge: {e}") + + super().stop() + def deploy_navigation_modules(self, bridge_name="nav_bot_ros_bridge"): # Deploy topic remap module logger.info("Deploying topic remap module...") @@ -396,13 +421,3 @@ def cancel_navigation(self) -> bool: self.lcm.publish(cancel_topic, cancel_msg) return True - - def shutdown(self): - logger.info("Shutting down navigation modules...") - - if self.ros_bridge is not None: - try: - self.ros_bridge.shutdown() - logger.info("ROS bridge shut down successfully") - except Exception as e: - logger.error(f"Error shutting down ROS bridge: {e}") diff --git a/dimos/robot/ros_bridge.py b/dimos/robot/ros_bridge.py index 7e845e08d0..d77d5eb1fb 100644 --- a/dimos/robot/ros_bridge.py +++ b/dimos/robot/ros_bridge.py @@ -14,8 +14,7 @@ import logging import threading -import time -from typing import Dict, Any, Type, Literal, Optional +from typing import Dict, Any, Type, Optional from enum import Enum try: @@ -32,6 +31,7 @@ QoSHistoryPolicy = None QoSDurabilityPolicy = None +from dimos.core.resource import Resource from dimos.protocol.pubsub.lcmpubsub import LCM, Topic from dimos.utils.logging_config import setup_logger @@ -45,7 +45,7 @@ class BridgeDirection(Enum): DIMOS_TO_ROS = "dimos_to_ros" -class ROSBridge: +class ROSBridge(Resource): """Unidirectional bridge between ROS and DIMOS for message passing.""" def __init__(self, node_name: str = "dimos_ros_bridge"): @@ -65,7 +65,7 @@ def __init__(self, node_name: str = "dimos_ros_bridge"): self._executor.add_node(self.node) self._spin_thread = threading.Thread(target=self._ros_spin, daemon=True) - self._spin_thread.start() + self._spin_thread.start() # TODO: don't forget to shut it down self._bridges: Dict[str, Dict[str, Any]] = {} @@ -78,6 +78,19 @@ def __init__(self, node_name: str = "dimos_ros_bridge"): logger.info(f"ROSBridge initialized with node name: {node_name}") + def start(self) -> None: + pass + + def stop(self) -> None: + """Shutdown the bridge and clean up resources.""" + self._executor.shutdown() + self.node.destroy_node() + + if rclpy.ok(): + rclpy.shutdown() + + logger.info("ROSBridge shutdown complete") + def _ros_spin(self): """Background thread for spinning ROS executor.""" try: @@ -190,13 +203,3 @@ def _dimos_to_ros(self, dimos_msg: Any, ros_publisher, _topic_name: str) -> None """ ros_msg = dimos_msg.to_ros_msg() ros_publisher.publish(ros_msg) - - def shutdown(self): - """Shutdown the bridge and clean up resources.""" - self._executor.shutdown() - self.node.destroy_node() - - if rclpy.ok(): - rclpy.shutdown() - - logger.info("ROSBridge shutdown complete") diff --git a/dimos/robot/test_ros_bridge.py b/dimos/robot/test_ros_bridge.py index c7dda5fd88..a4c0c16ed7 100644 --- a/dimos/robot/test_ros_bridge.py +++ b/dimos/robot/test_ros_bridge.py @@ -71,7 +71,7 @@ def setUp(self): def tearDown(self): """Clean up test fixtures.""" self.test_node.destroy_node() - self.bridge.shutdown() + self.bridge.stop() if rclpy.ok(): rclpy.try_shutdown() diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index 75d3bdd13d..f7cf683a2a 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -31,6 +31,7 @@ from reactivex.subject import Subject from dimos.core import In, Module, Out, rpc +from dimos.core.resource import Resource from dimos.msgs.geometry_msgs import Pose, Transform, Twist, Vector3 from dimos.msgs.sensor_msgs import Image from dimos.robot.connection_interface import ConnectionInterface @@ -70,7 +71,7 @@ def to_ndarray(self, format=None): return self.data -class UnitreeWebRTCConnection: +class UnitreeWebRTCConnection(Resource): def __init__(self, ip: str, mode: str = "ai"): self.ip = ip self.mode = mode @@ -111,6 +112,32 @@ def start_background_loop(): self.thread.start() self.connection_ready.wait() + def start(self) -> None: + pass + + def stop(self) -> None: + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if self.task: + self.task.cancel() + + async def async_disconnect() -> None: + try: + await self.conn.disconnect() + except Exception: + pass + + if self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + self.loop.call_soon_threadsafe(self.loop.stop) + + if self.thread.is_alive(): + self.thread.join(timeout=2.0) + def move(self, twist: Twist, duration: float = 0.0) -> bool: """Send movement command to the robot using Twist commands. diff --git a/dimos/robot/unitree_webrtc/connectionModule.py b/dimos/robot/unitree_webrtc/connectionModule.py deleted file mode 100644 index 6a6f9085cd..0000000000 --- a/dimos/robot/unitree_webrtc/connectionModule.py +++ /dev/null @@ -1,256 +0,0 @@ -#!/usr/bin/env python3 - -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import logging -import math -import time -import warnings -from typing import Optional - -import reactivex as rx -from dimos_lcm.sensor_msgs import CameraInfo -from reactivex import operators as ops -from reactivex.subject import Subject - -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs.Image import Image, sharpness_window -from dimos.msgs.std_msgs import Header -from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.utils.data import get_data -from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import TimedSensorReplay - -logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) - -# Suppress verbose loggers -logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) -logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) -logging.getLogger("websockets.server").setLevel(logging.ERROR) -logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) -logging.getLogger("asyncio").setLevel(logging.ERROR) -logging.getLogger("root").setLevel(logging.WARNING) - -# Suppress warnings -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") - -image_resize_factor = 4 -originalwidth, originalheight = (1280, 720) - - -class FakeRTC: - """Fake WebRTC connection for testing with recorded data.""" - - def __init__(self, *args, **kwargs): - get_data("unitree_office_walk") # Preload data for testing - - def connect(self): - pass - - def standup(self): - print("standup suppressed") - - def liedown(self): - print("liedown suppressed") - - @functools.cache - def lidar_stream(self): - print("lidar stream start") - lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) - return lidar_store.stream() - - @functools.cache - def odom_stream(self): - print("odom stream start") - odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) - return odom_store.stream() - - @functools.cache - def video_stream(self): - print("video stream start") - video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) - return video_store.stream() - - def move(self, vector: Vector3, duration: float = 0.0): - pass - - def publish_request(self, topic: str, data: dict): - """Fake publish request for testing.""" - return {"status": "ok", "message": "Fake publish"} - - -class ConnectionModule(Module): - """Module that handles robot sensor data and movement commands.""" - - movecmd: In[Vector3] = None - odom: Out[PoseStamped] = None - lidar: Out[LidarMessage] = None - video: Out[Image] = None - ip: str - connection_type: str = "webrtc" - camera_info: Out[CameraInfo] = None - _odom: PoseStamped = None - _lidar: LidarMessage = None - - def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwargs): - self.ip = ip - self.connection_type = connection_type - self.connection = None - Module.__init__(self, *args, **kwargs) - - @rpc - def start(self): - """Start the connection and subscribe to sensor streams.""" - match self.connection_type: - case "webrtc": - self.connection = UnitreeWebRTCConnection(self.ip) - case "fake": - self.connection = FakeRTC(self.ip) - case "mujoco": - from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection - - self.connection = MujocoConnection() - self.connection.start() - case _: - raise ValueError(f"Unknown connection type: {self.connection_type}") - - def image_pub(img): - self.video.publish(img) - - # Connect sensor streams to outputs - self.connection.lidar_stream().subscribe(self.lidar.publish) - self.connection.odom_stream().subscribe(self._publish_tf) - - def attach_frame_id(image: Image) -> Image: - image.frame_id = "camera_optical" - - return image.resize( - int(originalwidth / image_resize_factor), int(originalheight / image_resize_factor) - ) - - # sharpness_window( - # 10, self.connection.video_stream().pipe(ops.map(attach_frame_id)) - # ).subscribe(image_pub) - self.connection.video_stream().pipe(ops.map(attach_frame_id)).subscribe(image_pub) - self.camera_info_stream().subscribe(self.camera_info.publish) - self.movecmd.subscribe(self.move) - - @functools.cache - def camera_info_stream(self) -> Subject[CameraInfo]: - fx, fy, cx, cy = list( - map(lambda x: x / image_resize_factor, [819.553492, 820.646595, 625.284099, 336.808987]) - ) - - # width, height = (1280, 720) - width, height = tuple( - map(lambda x: int(x / image_resize_factor), [originalwidth, originalheight]) - ) - print("WIIDHT", width, height) - # Camera matrix K (3x3) - K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] - - # No distortion coefficients for now - D = [0.0, 0.0, 0.0, 0.0, 0.0] - - # Identity rotation matrix - R = [1, 0, 0, 0, 1, 0, 0, 0, 1] - - # Projection matrix P (3x4) - P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] - - base_msg = { - "D_length": len(D), - "height": height, - "width": width, - "distortion_model": "plumb_bob", - "D": D, - "K": K, - "R": R, - "P": P, - "binning_x": 0, - "binning_y": 0, - } - - return rx.interval(1).pipe( - ops.map( - lambda x: CameraInfo( - **base_msg, - header=Header("camera_optical"), - ) - ) - ) - - def _publish_tf(self, msg): - self.tf.publish(Transform.from_pose("base_link", msg)) - - camera_link = Transform( - translation=Vector3(0.3, 0.0, 0.0), - rotation=Quaternion.from_euler(Vector3([0, 0, 0])), - frame_id="base_link", - child_frame_id="camera_link", - ts=time.time(), - ) - - camera_optical = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), - frame_id="camera_link", - child_frame_id="camera_optical", - ts=camera_link.ts, - ) - - self.tf.publish(camera_link, camera_optical) - - @rpc - def get_odom(self) -> Optional[PoseStamped]: - """Get the robot's odometry. - - Returns: - The robot's odometry - """ - return self._odom - - @rpc - def move(self, vector: Vector3, duration: float = 0.0): - """Send movement command to robot.""" - self.connection.move(vector, duration) - - @rpc - def standup(self): - """Make the robot stand up.""" - return self.connection.standup() - - @rpc - def liedown(self): - """Make the robot lie down.""" - return self.connection.liedown() - - @rpc - def publish_request(self, topic: str, data: dict): - """Publish a request to the WebRTC connection. - Args: - topic: The RTC topic to publish to - data: The data dictionary to publish - Returns: - The result of the publish request - """ - return self.connection.publish_request(topic, data) diff --git a/dimos/robot/unitree_webrtc/depth_module.py b/dimos/robot/unitree_webrtc/depth_module.py index 96f6fd74f6..b5b3b12738 100644 --- a/dimos/robot/unitree_webrtc/depth_module.py +++ b/dimos/robot/unitree_webrtc/depth_module.py @@ -18,7 +18,6 @@ import threading from typing import Optional -import cv2 import numpy as np from dimos.core import Module, In, Out, rpc @@ -82,7 +81,8 @@ def __init__( @rpc def start(self): - """Start the camera module.""" + super().start() + if self._running: logger.warning("Camera module already running") return @@ -101,7 +101,6 @@ def start(self): @rpc def stop(self): - """Stop the camera module.""" if not self._running: return @@ -112,7 +111,7 @@ def stop(self): if self._processing_thread and self._processing_thread.is_alive(): self._processing_thread.join(timeout=2.0) - logger.info("Depth module stopped") + super().stop() def _on_camera_info(self, msg: CameraInfo): """Process camera info to extract intrinsics.""" @@ -233,9 +232,3 @@ def _publish_depth(self): except Exception as e: logger.error(f"Error publishing depth data: {e}", exc_info=True) - - def cleanup(self): - """Clean up resources on module destruction.""" - self.stop() - if self.metric3d: - self.metric3d.cleanup() diff --git a/dimos/robot/unitree_webrtc/g1_joystick_module.py b/dimos/robot/unitree_webrtc/g1_joystick_module.py index e98b663f53..156a0891a2 100644 --- a/dimos/robot/unitree_webrtc/g1_joystick_module.py +++ b/dimos/robot/unitree_webrtc/g1_joystick_module.py @@ -42,6 +42,8 @@ def __init__(self, *args, **kwargs): @rpc def start(self): """Initialize pygame and start control loop.""" + super().start() + try: import pygame except ImportError: @@ -58,6 +60,21 @@ def start(self): return True + @rpc + def stop(self) -> None: + super().stop() + + self.running = False + self.pygame_ready = False + + stop_twist = Twist() + stop_twist.linear = Vector3(0, 0, 0) + stop_twist.angular = Vector3(0, 0, 0) + + self._thread.join(2) + + self.twist_out.publish(stop_twist) + def _pygame_loop(self): """Main pygame event loop - ALL pygame operations happen here.""" import pygame @@ -162,18 +179,3 @@ def _update_display(self, twist): y_pos += 25 pygame.display.flip() - - @rpc - def stop(self): - """Stop the joystick module.""" - self.running = False - stop_twist = Twist() - stop_twist.linear = Vector3(0, 0, 0) - stop_twist.angular = Vector3(0, 0, 0) - self.twist_out.publish(stop_twist) - return True - - def cleanup(self): - """Clean up pygame resources.""" - self.running = False - self.pygame_ready = False diff --git a/dimos/robot/unitree_webrtc/modular/connection_module.py b/dimos/robot/unitree_webrtc/modular/connection_module.py index 6e13ed938e..30413bf182 100644 --- a/dimos/robot/unitree_webrtc/modular/connection_module.py +++ b/dimos/robot/unitree_webrtc/modular/connection_module.py @@ -30,17 +30,15 @@ from reactivex.observable import Observable from dimos.agents2 import Agent, Output, Reducer, Stream, skill -from dimos.core import DimosCluster, In, LCMTransport, Module, ModuleConfig, Out, rpc +from dimos.core import DimosCluster, LCMTransport, Module, ModuleConfig, Out, rpc, In from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 from dimos.msgs.sensor_msgs.Image import Image, sharpness_window from dimos.msgs.std_msgs import Header -from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger -from dimos.utils.reactive import backpressure from dimos.utils.testing import TimedSensorReplay, TimedSensorStorage logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) @@ -171,6 +169,9 @@ def record(self, recording_name: str): @rpc def start(self): """Start the connection and subscribe to sensor streams.""" + + super().start() + match self.connection_type: case "webrtc": self.connection = UnitreeWebRTCConnection(**self.connection_config) @@ -183,12 +184,15 @@ def start(self): self.connection.start() case _: raise ValueError(f"Unknown connection type: {self.connection_type}") - self.connection.odom_stream().subscribe( + + unsub = self.connection.odom_stream().subscribe( lambda odom: self._publish_tf(odom) and self.odom.publish(odom) ) + self._disposables.add(unsub) # Connect sensor streams to outputs - self.connection.lidar_stream().subscribe(self.lidar.publish) + unsub = self.connection.lidar_stream().subscribe(self.lidar.publish) + self._disposables.add(unsub) # self.connection.lidar_stream().subscribe(lambda lidar: print("LIDAR", lidar.ts)) # self.connection.video_stream().subscribe(lambda video: print("IMAGE", video.ts)) @@ -199,11 +203,19 @@ def resize(image: Image) -> Image: int(originalwidth / image_resize_factor), int(originalheight / image_resize_factor) ) - self.connection.video_stream().subscribe(self.video.publish) - # sharpness_window(15.0, self.connection.video_stream()).subscribe(self.video.publish) - # self.connection.video_stream().pipe(ops.map(resize)).subscribe(self.video.publish) - self.camera_info_stream().subscribe(self.camera_info.publish) - self.movecmd.subscribe(self.connection.move) + unsub = self.connection.video_stream().subscribe(self.video.publish) + self._disposables.add(unsub) + unsub = self.camera_info_stream().subscribe(self.camera_info.publish) + self._disposables.add(unsub) + unsub = self.movecmd.subscribe(self.connection.move) + self._disposables.add(unsub) + + @rpc + def stop(self) -> None: + if self.connection: + self.connection.stop() + + super().stop() @classmethod def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: diff --git a/dimos/robot/unitree_webrtc/mujoco_connection.py b/dimos/robot/unitree_webrtc/mujoco_connection.py index 885ab28a76..64bfaf2b8e 100644 --- a/dimos/robot/unitree_webrtc/mujoco_connection.py +++ b/dimos/robot/unitree_webrtc/mujoco_connection.py @@ -50,11 +50,45 @@ def __init__(self, *args, **kwargs): self._is_cleaned_up = False # Register cleanup on exit - atexit.register(self.cleanup) + atexit.register(self.stop) - def start(self): + def start(self) -> None: self.mujoco_thread.start() + def stop(self) -> None: + """Clean up all resources. Can be called multiple times safely.""" + if self._is_cleaned_up: + return + + self._is_cleaned_up = True + + # Stop all stream threads + for stop_event in self._stop_events: + stop_event.set() + + # Wait for threads to finish + for thread in self._stream_threads: + if thread.is_alive(): + thread.join(timeout=2.0) + if thread.is_alive(): + logger.warning(f"Stream thread {thread.name} did not stop gracefully") + + # Clean up the MuJoCo thread + if hasattr(self, "mujoco_thread") and self.mujoco_thread: + self.mujoco_thread.cleanup() + + # Clear references + self._stream_threads.clear() + self._stop_events.clear() + + # Clear cached methods to prevent memory leaks + if hasattr(self, "lidar_stream"): + self.lidar_stream.cache_clear() + if hasattr(self, "odom_stream"): + self.odom_stream.cache_clear() + if hasattr(self, "video_stream"): + self.video_stream.cache_clear() + def standup(self): print("standup supressed") @@ -202,49 +236,3 @@ def move(self, twist: Twist, duration: float = 0.0): def publish_request(self, topic: str, data: dict): pass - - def stop(self): - """Stop the MuJoCo connection gracefully.""" - self.cleanup() - - def cleanup(self): - """Clean up all resources. Can be called multiple times safely.""" - if self._is_cleaned_up: - return - - logger.debug("Cleaning up MuJoCo connection resources") - self._is_cleaned_up = True - - # Stop all stream threads - for stop_event in self._stop_events: - stop_event.set() - - # Wait for threads to finish - for thread in self._stream_threads: - if thread.is_alive(): - thread.join(timeout=2.0) - if thread.is_alive(): - logger.warning(f"Stream thread {thread.name} did not stop gracefully") - - # Clean up the MuJoCo thread - if hasattr(self, "mujoco_thread") and self.mujoco_thread: - self.mujoco_thread.cleanup() - - # Clear references - self._stream_threads.clear() - self._stop_events.clear() - - # Clear cached methods to prevent memory leaks - if hasattr(self, "lidar_stream"): - self.lidar_stream.cache_clear() - if hasattr(self, "odom_stream"): - self.odom_stream.cache_clear() - if hasattr(self, "video_stream"): - self.video_stream.cache_clear() - - def __del__(self): - """Destructor to ensure cleanup on object deletion.""" - try: - self.cleanup() - except Exception: - pass diff --git a/dimos/robot/unitree_webrtc/run_agents2.py b/dimos/robot/unitree_webrtc/run_agents2.py index e19b7e2692..e779c26bb6 100755 --- a/dimos/robot/unitree_webrtc/run_agents2.py +++ b/dimos/robot/unitree_webrtc/run_agents2.py @@ -21,6 +21,7 @@ from dimos.agents2 import Agent from dimos.agents2.cli.human import HumanInput from dimos.agents2.constants import AGENT_SYSTEM_PROMPT_PATH +from dimos.core.resource import Resource from dimos.robot.robot import UnitreeRobot from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer @@ -28,9 +29,7 @@ from dimos.robot.utils.robot_debugger import RobotDebugger from dimos.utils.logging_config import setup_logger -from contextlib import ExitStack - -logger = setup_logger("dimos.robot.unitree_webrtc.run_agents2") +logger = setup_logger(__file__) load_dotenv() @@ -38,24 +37,22 @@ SYSTEM_PROMPT = f.read() -class UnitreeAgents2Runner: +class UnitreeAgents2Runner(Resource): _robot: Optional[UnitreeRobot] _agent: Optional[Agent] - _exit_stack: ExitStack + _robot_debugger: Optional[RobotDebugger] + _navigation_skill: Optional[NavigationSkillContainer] def __init__(self): self._robot: UnitreeRobot = None self._agent = None - self._exit_stack = ExitStack() - - def __enter__(self): - logger.info("Initializing Unitree Go2 robot...") + self._robot_debugger = None + self._navigation_skill = None - self._robot = self._exit_stack.enter_context( - UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - connection_type=os.getenv("CONNECTION_TYPE", "webrtc"), - ) + def start(self) -> None: + self._robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + connection_type=os.getenv("CONNECTION_TYPE", "webrtc"), ) time.sleep(3) @@ -64,26 +61,18 @@ def __enter__(self): self.setup_agent() - self._exit_stack.enter_context(RobotDebugger(self._robot)) - - logger.info("=" * 60) - logger.info("Unitree Go2 Agent Ready (agents2 framework)!") - logger.info("You can:") - logger.info(" - Type commands in the human CLI") - logger.info(" - Ask the robot to navigate to locations") - logger.info(" - Ask the robot to observe and describe surroundings") - logger.info(" - Ask the robot to follow people or explore areas") - logger.info(" - Ask the robot to perform actions (sit, stand, dance, etc.)") - logger.info(" - Ask the robot to speak text") - logger.info("=" * 60) + self._robot_debugger = RobotDebugger(self._robot) + self._robot_debugger.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - logger.info("Shutting down...") - self._exit_stack.close() - logger.info("Shutdown complete") - return False + def stop(self) -> None: + if self._navigation_skill: + self._navigation_skill.stop() + if self._robot_debugger: + self._robot_debugger.stop() + if self._agent: + self._agent.stop() + if self._robot: + self._robot.stop() def setup_agent(self) -> None: if not self._robot: @@ -92,15 +81,15 @@ def setup_agent(self) -> None: logger.info("Setting up agent with skills...") self._agent = Agent(system_prompt=SYSTEM_PROMPT) + self._navigation_skill = NavigationSkillContainer( + robot=self._robot, + video_stream=self._robot.connection.video, + ) + self._navigation_skill.start() skill_containers = [ UnitreeSkillContainer(robot=self._robot), - self._exit_stack.enter_context( - NavigationSkillContainer( - robot=self._robot, - video_stream=self._robot.connection.video, - ) - ), + self._navigation_skill, HumanInput(), ] @@ -110,7 +99,7 @@ def setup_agent(self) -> None: self._agent.run_implicit_skill("human") - self._exit_stack.enter_context(self._agent) + self._agent.start() # Log available skills tools = self._agent.get_tools() @@ -129,8 +118,10 @@ def run(self): def main(): - with UnitreeAgents2Runner() as runner: - runner.run() + runner = UnitreeAgents2Runner() + runner.start() + runner.run() + runner.stop() if __name__ == "__main__": diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index c0a0338de7..52e2c62260 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -13,14 +13,12 @@ # limitations under the License. import time -from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import numpy as np import open3d as o3d -import reactivex.operators as ops from reactivex import interval -from reactivex.observable import Observable +from reactivex.disposable import Disposable from dimos.core import In, Module, Out, rpc from dimos.msgs.nav_msgs import OccupancyGrid @@ -54,7 +52,10 @@ def __init__( @rpc def start(self): - self.lidar.subscribe(self.add_frame) + super().start() + + unsub = self.lidar.subscribe(self.add_frame) + self._disposables.add(Disposable(unsub)) def publish(_): self.global_map.publish(self.to_lidar_message()) @@ -71,7 +72,12 @@ def publish(_): self.global_costmap.publish(occupancygrid) if self.global_publish_interval is not None: - interval(self.global_publish_interval).subscribe(publish) + unsub = interval(self.global_publish_interval).subscribe(publish) + self._disposables.add(unsub) + + @rpc + def stop(self) -> None: + super().stop() def to_PointCloud2(self) -> PointCloud2: return PointCloud2( @@ -91,6 +97,11 @@ def to_lidar_message(self) -> LidarMessage: def add_frame(self, frame: LidarMessage) -> "Map": """Voxelise *frame* and splice it into the running map.""" new_pct = frame.pointcloud.voxel_down_sample(voxel_size=self.voxel_size) + + # Skip for empty pointclouds. + if len(new_pct.points) == 0: + return self + self.pointcloud = splice_cylinder(self.pointcloud, new_pct, shrink=0.5) local_costmap = OccupancyGrid.from_pointcloud( frame, diff --git a/dimos/robot/unitree_webrtc/unitree_b1/connection.py b/dimos/robot/unitree_webrtc/unitree_b1/connection.py index 73d24bdc3c..a458858040 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/connection.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/connection.py @@ -21,7 +21,6 @@ import socket import threading import time -from typing import Optional from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped @@ -30,6 +29,7 @@ from dimos.utils.logging_config import setup_logger from .b1_command import B1Command +from reactivex.disposable import Disposable # Setup logger with DEBUG level for troubleshooting logger = setup_logger("dimos.robot.unitree_webrtc.unitree_b1.connection", level=logging.DEBUG) @@ -90,6 +90,8 @@ def __init__( def start(self): """Start the connection and subscribe to command streams.""" + super().start() + # Setup UDP socket (unless in test mode) if not self.test_mode: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -99,11 +101,14 @@ def start(self): # Subscribe to input streams if self.cmd_vel: - self.cmd_vel.subscribe(self.handle_twist_stamped) + unsub = self.cmd_vel.subscribe(self.handle_twist_stamped) + self._disposables.add(Disposable(unsub)) if self.mode_cmd: - self.mode_cmd.subscribe(self.handle_mode) + unsub = self.mode_cmd.subscribe(self.handle_mode) + self._disposables.add(Disposable(unsub)) if self.odom_in: - self.odom_in.subscribe(self._publish_odom_pose) + unsub = self.odom_in.subscribe(self._publish_odom_pose) + self._disposables.add(Disposable(unsub)) # Start threads self.running = True @@ -117,11 +122,10 @@ def start(self): self.watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=True) self.watchdog_thread.start() - return True - @rpc def stop(self): """Stop the connection and send stop commands.""" + self.set_mode(RobotMode.IDLE) # IDLE with self.cmd_lock: self._current_cmd = B1Command(mode=RobotMode.IDLE) # Zero all velocities @@ -146,7 +150,7 @@ def stop(self): self.socket.close() self.socket = None - return True + super().stop() def handle_twist_stamped(self, twist_stamped: TwistStamped): """Handle timestamped Twist message and convert to B1Command. @@ -350,10 +354,6 @@ def move(self, twist_stamped: TwistStamped, duration: float = 0.0): self.handle_twist_stamped(twist_stamped) return True - def cleanup(self): - """Clean up resources when module is destroyed.""" - self.stop() - class MockB1ConnectionModule(B1ConnectionModule): """Test connection module that prints commands instead of sending UDP.""" @@ -389,3 +389,11 @@ def _send_loop(self): self.packet_count += 1 time.sleep(0.020) + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() diff --git a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py index 34fb5d79c8..9edc27f3c3 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py @@ -48,6 +48,9 @@ def __init__(self, *args, **kwargs): @rpc def start(self): """Initialize pygame and start control loop.""" + + super().start() + try: import pygame except ImportError: @@ -64,6 +67,27 @@ def start(self): return True + @rpc + def stop(self) -> None: + """Stop the joystick module.""" + + self.running = False + self.pygame_ready = False + + # Send stop command + stop_twist = Twist() + stop_twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=stop_twist.linear, + angular=stop_twist.angular, + ) + self.twist_out.publish(stop_twist_stamped) + + self._thread.join(2) + + super().stop() + def _pygame_loop(self): """Main pygame event loop - ALL pygame operations happen here.""" import pygame @@ -255,23 +279,3 @@ def _update_display(self, twist): y_pos += 25 pygame.display.flip() - - @rpc - def stop(self): - """Stop the joystick module.""" - self.running = False - # Send stop command - stop_twist = Twist() - stop_twist_stamped = TwistStamped( - ts=time.time(), - frame_id="base_link", - linear=stop_twist.linear, - angular=stop_twist.angular, - ) - self.twist_out.publish(stop_twist_stamped) - return True - - def cleanup(self): - """Clean up pygame resources.""" - self.running = False - self.pygame_ready = False diff --git a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py index bef0fafbfa..78d22c37e3 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py @@ -25,7 +25,9 @@ from typing import Optional from dimos import core -from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped +from dimos.core.dimos import Dimos +from dimos.core.resource import Resource +from dimos.msgs.geometry_msgs import TwistStamped, PoseStamped from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.std_msgs import Int32 from dimos.msgs.tf2_msgs.TFMessage import TFMessage @@ -56,7 +58,7 @@ logger = setup_logger("dimos.robot.unitree_webrtc.unitree_b1", level=logging.INFO) -class UnitreeB1(Robot): +class UnitreeB1(Robot, Resource): """Unitree B1 quadruped robot with UDP control. Simplified architecture: @@ -97,6 +99,7 @@ def __init__( self.connection = None self.joystick = None self.ros_bridge = None + self._dimos = Dimos(n=2) os.makedirs(self.output_dir, exist_ok=True) logger.info(f"Robot outputs will be saved to: {self.output_dir}") @@ -105,13 +108,13 @@ def start(self): """Start the B1 robot - initialize DimOS, deploy modules, and start them.""" logger.info("Initializing DimOS...") - self.dimos = core.start(2) + self._dimos.start() logger.info("Deploying connection module...") if self.test_mode: - self.connection = self.dimos.deploy(MockB1ConnectionModule, self.ip, self.port) + self.connection = self._dimos.deploy(MockB1ConnectionModule, self.ip, self.port) else: - self.connection = self.dimos.deploy(B1ConnectionModule, self.ip, self.port) + self.connection = self._dimos.deploy(B1ConnectionModule, self.ip, self.port) # Configure LCM transports for connection (matching G1 pattern) self.connection.cmd_vel.transport = core.LCMTransport("/cmd_vel", TwistStamped) @@ -123,18 +126,16 @@ def start(self): if self.enable_joystick: from dimos.robot.unitree_webrtc.unitree_b1.joystick_module import JoystickModule - self.joystick = self.dimos.deploy(JoystickModule) + self.joystick = self._dimos.deploy(JoystickModule) self.joystick.twist_out.transport = core.LCMTransport("/cmd_vel", TwistStamped) self.joystick.mode_out.transport = core.LCMTransport("/b1/mode", Int32) logger.info("Joystick module deployed - pygame window will open") - self.connection.start() + self._dimos.start_all_modules() + self.connection.idle() # Start in IDLE mode for safety logger.info("B1 started in IDLE mode (safety)") - if self.joystick: - self.joystick.start() - # Deploy ROS bridge if enabled (matching G1 pattern) if self.enable_ros_bridge: self._deploy_ros_bridge() @@ -145,6 +146,11 @@ def start(self): if self.enable_ros_bridge: logger.info("ROS bridge enabled for external control") + def stop(self) -> None: + self._dimos.stop() + if self.ros_bridge: + self.ros_bridge.stop() + def _deploy_ros_bridge(self): """Deploy and configure ROS bridge (matching G1 implementation).""" self.ros_bridge = ROSBridge("b1_ros_bridge") @@ -164,6 +170,8 @@ def _deploy_ros_bridge(self): "/tf", TFMessage, ROSTFMessage, direction=BridgeDirection.ROS_TO_DIMOS ) + self.ros_bridge.start() + logger.info("ROS bridge deployed: /cmd_vel, /state_estimation, /tf (ROS → DIMOS)") # Robot control methods (standard interface) @@ -177,11 +185,6 @@ def move(self, twist_stamped: TwistStamped, duration: float = 0.0): if self.connection: self.connection.move(twist_stamped, duration) - def stop(self): - """Stop all robot movement.""" - if self.connection: - self.connection.stop() - def stand(self): """Put robot in stand mode.""" if self.connection: @@ -200,35 +203,6 @@ def idle(self): self.connection.idle() logger.info("B1 switched to IDLE mode") - def shutdown(self): - """Shutdown the robot and clean up resources.""" - logger.info("Shutting down UnitreeB1...") - - # Stop robot movement - self.stop() - - # Shutdown ROS bridge if it exists - if self.ros_bridge is not None: - try: - self.ros_bridge.shutdown() - logger.info("ROS bridge shut down successfully") - except Exception as e: - logger.error(f"Error shutting down ROS bridge: {e}") - - # Clean up connection module - if self.connection: - self.connection.cleanup() - - logger.info("UnitreeB1 shutdown complete") - - def cleanup(self): - """Clean up robot resources (calls shutdown for consistency).""" - self.shutdown() - - def __del__(self): - """Destructor to ensure cleanup.""" - self.shutdown() - def main(): """Main entry point for testing B1 robot.""" @@ -296,7 +270,7 @@ def main(): except KeyboardInterrupt: print("\nShutting down...") finally: - robot.cleanup() + robot.stop() if __name__ == "__main__": diff --git a/dimos/robot/unitree_webrtc/unitree_g1.py b/dimos/robot/unitree_webrtc/unitree_g1.py index 08a23bc2dc..b39ddd6db6 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1.py +++ b/dimos/robot/unitree_webrtc/unitree_g1.py @@ -21,16 +21,6 @@ import logging import os import time -from typing import Optional - -from dimos_lcm.foxglove_msgs import SceneUpdate -from geometry_msgs.msg import PoseStamped as ROSPoseStamped -from geometry_msgs.msg import TwistStamped as ROSTwistStamped -from nav_msgs.msg import Odometry as ROSOdometry -from sensor_msgs.msg import Image as ROSImage -from sensor_msgs.msg import Joy as ROSJoy -from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 -from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos import core from dimos.agents2 import Agent @@ -38,6 +28,8 @@ from dimos.agents2.skills.ros_navigation import RosNavigation from dimos.agents2.spec import Model, Provider from dimos.core import In, Module, Out, rpc +from dimos.core.dimos import Dimos +from dimos.core.resource import Resource from dimos.hardware.camera import zed from dimos.hardware.camera.module import CameraModule from dimos.hardware.camera.webcam import Webcam @@ -55,7 +47,6 @@ from dimos.msgs.std_msgs.Bool import Bool from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection2d import Detection3DModule from dimos.perception.detection2d.moduleDB import ObjectDBModule from dimos.perception.spatial_perception import SpatialMemory from dimos.protocol import pubsub @@ -71,6 +62,14 @@ from dimos.types.robot_capabilities import RobotCapability from dimos.utils.logging_config import setup_logger from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos_lcm.foxglove_msgs import SceneUpdate +from geometry_msgs.msg import TwistStamped as ROSTwistStamped +from nav_msgs.msg import Odometry as ROSOdometry +from reactivex.disposable import Disposable +from sensor_msgs.msg import Joy as ROSJoy +from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 +from tf2_msgs.msg import TFMessage as ROSTFMessage +from typing import Optional logger = setup_logger("dimos.robot.unitree_webrtc.unitree_g1", level=logging.INFO) @@ -101,10 +100,21 @@ def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwa @rpc def start(self): """Start the connection and subscribe to sensor streams.""" + + super().start() + # Use the exact same UnitreeWebRTCConnection as Go2 self.connection = UnitreeWebRTCConnection(self.ip) - self.movecmd.subscribe(self.move) - self.odom_in.subscribe(self._publish_odom_pose) + self.connection.start() + unsub = self.movecmd.subscribe(self.move) + self._disposables.add(Disposable(unsub)) + unsub = self.odom_in.subscribe(self._publish_odom_pose) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + self.connection.stop() + super().stop() def _publish_odom_pose(self, msg: Odometry): self.odom_pose.publish( @@ -128,7 +138,7 @@ def publish_request(self, topic: str, data: dict): return self.connection.publish_request(topic, data) -class UnitreeG1(Robot): +class UnitreeG1(Robot, Resource): """Unitree G1 humanoid robot.""" def __init__( @@ -183,7 +193,7 @@ def __init__( self.capabilities = [RobotCapability.LOCOMOTION] # Module references - self.dimos = None + self._dimos = Dimos(n=4) self.connection = None self.websocket_vis = None self.foxglove_bridge = None @@ -242,8 +252,8 @@ def _deploy_detection(self, goto): self.detection = detection def start(self): - """Start the robot system with all modules.""" - self.dimos = core.start(8) # 2 workers for connection and visualization + self.lcm.start() + self._dimos.start() if self.enable_connection: self._deploy_connection() @@ -308,20 +318,13 @@ def start(self): logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") self._start_modules() - def stop(self): + def stop(self) -> None: + self._dimos.stop() self.lcm.stop() - def __enter__(self) -> "UnitreeG1": - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - return False - def _deploy_connection(self): """Deploy and configure the connection module.""" - self.connection = self.dimos.deploy(G1ConnectionModule, self.ip) + self.connection = self._dimos.deploy(G1ConnectionModule, self.ip) # Configure LCM transports self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", TwistStamped) @@ -332,7 +335,7 @@ def _deploy_camera(self): """Deploy and configure a standard webcam module.""" logger.info("Deploying standard webcam module...") - self.camera = self.dimos.deploy( + self.camera = self._dimos.deploy( CameraModule, transform=Transform( translation=Vector3(0.05, 0.0, 0.0), @@ -355,7 +358,7 @@ def _deploy_camera(self): def _deploy_visualization(self): """Deploy and configure visualization modules.""" # Deploy WebSocket visualization module - self.websocket_vis = self.dimos.deploy(WebsocketVisModule, port=self.websocket_port) + self.websocket_vis = self._dimos.deploy(WebsocketVisModule, port=self.websocket_port) self.websocket_vis.movecmd_stamped.transport = core.LCMTransport("/cmd_vel", TwistStamped) # Note: robot_pose connection removed since odom was removed from G1ConnectionModule @@ -367,6 +370,7 @@ def _deploy_visualization(self): "/zed/depth_image#sensor_msgs.Image", ] ) + self.foxglove_bridge.start() def _deploy_perception(self): self.spatial_memory_module = self.dimos.deploy( @@ -387,7 +391,7 @@ def _deploy_joystick(self): from dimos.robot.unitree_webrtc.g1_joystick_module import G1JoystickModule logger.info("Deploying G1 joystick module...") - self.joystick = self.dimos.deploy(G1JoystickModule) + self.joystick = self._dimos.deploy(G1JoystickModule) self.joystick.twist_out.transport = core.LCMTransport("/cmd_vel", Twist) logger.info("Joystick module deployed - pygame window will open") @@ -436,25 +440,15 @@ def _deploy_ros_bridge(self): remap_topic="/map", ) + self.ros_bridge.start() + logger.info( "ROS bridge deployed: /cmd_vel, /state_estimation, /tf, /registered_scan (ROS → DIMOS)" ) def _start_modules(self): """Start all deployed modules.""" - if self.connection: - self.connection.start() - # self.websocket_vis.start() - self.foxglove_bridge.start() - - # if self.joystick: - # self.joystick.start() - - self.camera.start() - self.detection.start() - - if self.enable_perception: - self.spatial_memory_module.start() + self._dimos.start_all_modules() # Initialize skills after connection is established if self.skill_library is not None: @@ -475,27 +469,6 @@ def get_odom(self) -> PoseStamped: # Note: odom functionality removed from G1ConnectionModule return None - def shutdown(self): - """Shutdown the robot and clean up resources.""" - logger.info("Shutting down UnitreeG1...") - - # Shutdown ROS bridge if it exists - if self.ros_bridge is not None: - try: - self.ros_bridge.shutdown() - logger.info("ROS bridge shut down successfully") - except Exception as e: - logger.error(f"Error shutting down ROS bridge: {e}") - - # Stop other modules if needed - if self.websocket_vis: - try: - self.websocket_vis.stop() - except Exception as e: - logger.error(f"Error stopping websocket vis: {e}") - - logger.info("UnitreeG1 shutdown complete") - @property def spatial_memory(self) -> Optional[SpatialMemory]: return self.spatial_memory_module @@ -566,7 +539,7 @@ def main(): time.sleep(1) except KeyboardInterrupt: logger.info("Shutting down...") - robot.shutdown() + robot.stop() if __name__ == "__main__": diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 3c05062149..5276586a43 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -23,10 +23,13 @@ from typing import Optional from reactivex import Observable +from reactivex.disposable import CompositeDisposable from dimos import core from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core import In, Module, Out, rpc +from dimos.core.dimos import Dimos +from dimos.core.resource import Resource from dimos.mapping.types import LatLon from dimos.msgs.std_msgs import Header from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3, Quaternion @@ -67,7 +70,7 @@ from dimos.types.robot_capabilities import RobotCapability -logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) +logger = setup_logger(__file__, level=logging.INFO) # Suppress verbose loggers logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) @@ -82,13 +85,16 @@ warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") -class FakeRTC: +class FakeRTC(Resource): """Fake WebRTC connection for testing with recorded data.""" def __init__(self, *args, **kwargs): get_data("unitree_office_walk") # Preload data for testing - def connect(self): + def start(self) -> None: + pass + + def stop(self) -> None: pass def standup(self): @@ -128,11 +134,11 @@ def publish_request(self, topic: str, data: dict): class ConnectionModule(Module): """Module that handles robot sensor data, movement commands, and camera information.""" - movecmd: In[Twist] = None + cmd_vel: In[Twist] = None odom: Out[PoseStamped] = None gps_location: Out[LatLon] = None lidar: Out[LidarMessage] = None - video: Out[Image] = None + color_image: Out[Image] = None camera_info: Out[CameraInfo] = None camera_pose: Out[PoseStamped] = None ip: str @@ -180,8 +186,10 @@ def __init__( Module.__init__(self, *args, **kwargs) @rpc - def start(self): + def start(self) -> None: """Start the connection and subscribe to sensor streams.""" + super().start() + match self.connection_type: case "webrtc": self.connection = UnitreeWebRTCConnection(self.ip) @@ -191,17 +199,33 @@ def start(self): from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection self.connection = MujocoConnection() - self.connection.start() case _: raise ValueError(f"Unknown connection type: {self.connection_type}") + self.connection.start() + # Connect sensor streams to outputs - self.connection.lidar_stream().subscribe(self.lidar.publish) - self.connection.odom_stream().subscribe(self._publish_tf) + unsub = self.connection.lidar_stream().subscribe(self.lidar.publish) + self._disposables.add(unsub) + + unsub = self.connection.odom_stream().subscribe(self._publish_tf) + self._disposables.add(unsub) + if self.connection_type == "mujoco": - self.connection.gps_stream().subscribe(self._publish_gps_location) - self.connection.video_stream().subscribe(self._on_video) - self.movecmd.subscribe(self.move) + unsub = self.connection.gps_stream().subscribe(self._publish_gps_location) + self._disposables.add(unsub) + + unsub = self.connection.video_stream().subscribe(self._on_video) + self._disposables.add(unsub) + + unsub = self.cmd_vel.subscribe(self.move) + self._disposables.add(unsub) + + @rpc + def stop(self) -> None: + if self.connection: + self.connection.stop() + super().stop() def _on_video(self, msg: Image): """Handle incoming video frames and publish synchronized camera data.""" @@ -209,10 +233,10 @@ def _on_video(self, msg: Image): if self.rectify_image: rectified_msg = rectify_image(msg, self.camera_matrix, self.dist_coeffs) self._last_image = rectified_msg - self.video.publish(rectified_msg) + self.color_image.publish(rectified_msg) else: self._last_image = msg - self.video.publish(msg) + self.color_image.publish(msg) # Publish camera info and pose synchronized with video timestamp = msg.ts if msg.ts else time.time() @@ -301,9 +325,12 @@ def publish_request(self, topic: str, data: dict): return self.connection.publish_request(topic, data) -class UnitreeGo2(UnitreeRobot): +class UnitreeGo2(UnitreeRobot, Resource): """Full Unitree Go2 robot with navigation and perception capabilities.""" + _dimos: Dimos + _disposables: CompositeDisposable = CompositeDisposable() + def __init__( self, ip: Optional[str], @@ -322,6 +349,7 @@ def __init__( connection_type: webrtc, fake, or mujoco """ super().__init__() + self._dimos = Dimos(n=8, memory_limit="8GiB") self.ip = ip self.connection_type = connection_type or "webrtc" if ip is None and self.connection_type == "webrtc": @@ -338,7 +366,6 @@ def __init__( # Set capabilities self.capabilities = [RobotCapability.LOCOMOTION, RobotCapability.VISION] - self.dimos = None self.connection = None self.mapper = None self.global_planner = None @@ -353,14 +380,6 @@ def __init__( self._setup_directories() - def __enter__(self) -> "UnitreeGo2": - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # self.stop() - return False - def _setup_directories(self): """Setup directories for spatial memory storage.""" os.makedirs(self.output_dir, exist_ok=True) @@ -381,8 +400,8 @@ def _setup_directories(self): os.makedirs(self.db_path, exist_ok=True) def start(self): - """Start the robot system with all modules.""" - self.dimos = core.start(8, memory_limit="8GiB") + self.lcm.start() + self._dimos.start() self._deploy_connection() self._deploy_mapping() @@ -393,47 +412,35 @@ def start(self): self._deploy_camera() self._start_modules() - - self.lcm.start() - logger.info("UnitreeGo2 initialized and started") - logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") - - def stop(self): - # self.connection.stop() - # self.mapper.stop() - # self.global_planner.stop() - # self.local_planner.stop() - # self.navigator.stop() - # self.frontier_explorer.stop() - # self.websocket_vis.stop() - # self.foxglove_bridge.stop() - self.spatial_memory_module.stop() - # self.object_tracker.stop() - self.utilization_module.stop() - self.dimos.close_all() + + def stop(self) -> None: + if self.foxglove_bridge: + self.foxglove_bridge.stop() + self._disposables.dispose() + self._dimos.stop() self.lcm.stop() def _deploy_connection(self): """Deploy and configure the connection module.""" - self.connection = self.dimos.deploy( + self.connection = self._dimos.deploy( ConnectionModule, self.ip, connection_type=self.connection_type ) self.connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) self.connection.gps_location.transport = core.pLCMTransport("/gps_location") - self.connection.video.transport = core.pSHMTransport( + self.connection.color_image.transport = core.pSHMTransport( "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE ) - self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", Twist) + self.connection.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) self.connection.camera_info.transport = core.LCMTransport("/go2/camera_info", CameraInfo) self.connection.camera_pose.transport = core.LCMTransport("/go2/camera_pose", PoseStamped) def _deploy_mapping(self): """Deploy and configure the mapping module.""" min_height = 0.3 if self.connection_type == "mujoco" else 0.15 - self.mapper = self.dimos.deploy( + self.mapper = self._dimos.deploy( Map, voxel_size=0.5, global_publish_interval=2.5, min_height=min_height ) @@ -445,16 +452,16 @@ def _deploy_mapping(self): def _deploy_navigation(self): """Deploy and configure navigation modules.""" - self.global_planner = self.dimos.deploy(AstarPlanner) - self.local_planner = self.dimos.deploy(HolonomicLocalPlanner) - self.navigator = self.dimos.deploy( + self.global_planner = self._dimos.deploy(AstarPlanner) + self.local_planner = self._dimos.deploy(HolonomicLocalPlanner) + self.navigator = self._dimos.deploy( BehaviorTreeNavigator, reset_local_planner=self.local_planner.reset, check_goal_reached=self.local_planner.is_goal_reached, ) - self.frontier_explorer = self.dimos.deploy(WavefrontFrontierExplorer) + self.frontier_explorer = self._dimos.deploy(WavefrontFrontierExplorer) - self.navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) + self.navigator.target.transport = core.LCMTransport("/navigation_goal", PoseStamped) self.navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) self.navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) self.navigator.navigation_state.transport = core.LCMTransport("/navigation_state", String) @@ -472,7 +479,7 @@ def _deploy_navigation(self): "/stop_explore_cmd", Bool ) - self.global_planner.target.connect(self.navigator.goal) + self.global_planner.target.connect(self.navigator.target) self.global_planner.global_costmap.connect(self.mapper.global_costmap) self.global_planner.odom.connect(self.connection.odom) @@ -481,23 +488,23 @@ def _deploy_navigation(self): self.local_planner.local_costmap.connect(self.mapper.local_costmap) self.local_planner.odom.connect(self.connection.odom) - self.connection.movecmd.connect(self.local_planner.cmd_vel) + self.connection.cmd_vel.connect(self.local_planner.cmd_vel) self.navigator.odom.connect(self.connection.odom) - self.frontier_explorer.costmap.connect(self.mapper.global_costmap) - self.frontier_explorer.odometry.connect(self.connection.odom) + self.frontier_explorer.global_costmap.connect(self.mapper.global_costmap) + self.frontier_explorer.odom.connect(self.connection.odom) def _deploy_visualization(self): """Deploy and configure visualization modules.""" - self.websocket_vis = self.dimos.deploy(WebsocketVisModule, port=self.websocket_port) - self.websocket_vis.click_goal.transport = core.LCMTransport("/goal_request", PoseStamped) + self.websocket_vis = self._dimos.deploy(WebsocketVisModule, port=self.websocket_port) + self.websocket_vis.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) self.websocket_vis.gps_goal.transport = core.pLCMTransport("/gps_goal") self.websocket_vis.explore_cmd.transport = core.LCMTransport("/explore_cmd", Bool) self.websocket_vis.stop_explore_cmd.transport = core.LCMTransport("/stop_explore_cmd", Bool) - self.websocket_vis.movecmd.transport = core.LCMTransport("/cmd_vel", Twist) + self.websocket_vis.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) - self.websocket_vis.robot_pose.connect(self.connection.odom) + self.websocket_vis.odom.connect(self.connection.odom) self.websocket_vis.gps_location.connect(self.connection.gps_location) self.websocket_vis.path.connect(self.global_planner.path) self.websocket_vis.global_costmap.connect(self.mapper.global_costmap) @@ -507,6 +514,7 @@ def _set_goal(goal: LatLon): self.set_gps_travel_goal_points([goal]) unsub = self.websocket_vis.gps_goal.transport.pure_observable().subscribe(_set_goal) + self._disposables.add(unsub) def _deploy_foxglove_bridge(self): self.foxglove_bridge = FoxgloveBridge( @@ -515,11 +523,12 @@ def _deploy_foxglove_bridge(self): "/go2/tracked_overlay#sensor_msgs.Image", ] ) + self.foxglove_bridge.start() def _deploy_perception(self): """Deploy and configure perception modules.""" # Deploy spatial memory - self.spatial_memory_module = self.dimos.deploy( + self.spatial_memory_module = self._dimos.deploy( SpatialMemory, collection_name=self.spatial_memory_collection, db_path=self.db_path, @@ -527,7 +536,7 @@ def _deploy_perception(self): output_dir=self.spatial_memory_dir, ) - self.spatial_memory_module.video.transport = core.pSHMTransport( + self.spatial_memory_module.color_image.transport = core.pSHMTransport( "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE ) self.spatial_memory_module.odom.transport = core.LCMTransport( @@ -537,15 +546,15 @@ def _deploy_perception(self): logger.info("Spatial memory module deployed and connected") # Deploy 2D object tracker - self.object_tracker = self.dimos.deploy( + self.object_tracker = self._dimos.deploy( ObjectTracker2D, frame_id="camera_link", ) # Deploy bbox navigation module - self.bbox_navigator = self.dimos.deploy(BBoxNavigationModule, goal_distance=1.0) + self.bbox_navigator = self._dimos.deploy(BBoxNavigationModule, goal_distance=1.0) - self.utilization_module = self.dimos.deploy(UtilizationModule) + self.utilization_module = self._dimos.deploy(UtilizationModule) # Set up transports for object tracker self.object_tracker.detection2darray.transport = core.LCMTransport( @@ -564,7 +573,7 @@ def _deploy_camera(self): """Deploy and configure the camera module.""" # Connect object tracker inputs if self.object_tracker: - self.object_tracker.color_image.connect(self.connection.video) + self.object_tracker.color_image.connect(self.connection.color_image) logger.info("Object tracker connected to camera") # Connect bbox navigator inputs @@ -576,18 +585,7 @@ def _deploy_camera(self): def _start_modules(self): """Start all deployed modules in the correct order.""" - self.connection.start() - self.mapper.start() - self.global_planner.start() - self.local_planner.start() - self.navigator.start() - self.frontier_explorer.start() - # self.websocket_vis.start() - self.foxglove_bridge.start() - self.spatial_memory_module.start() - self.object_tracker.start() - self.bbox_navigator.start() - self.utilization_module.start() + self._dimos.start_all_modules() # Initialize skills after connection is established if self.skill_library is not None: @@ -680,12 +678,6 @@ def spatial_memory(self) -> Optional[SpatialMemory]: def gps_position_stream(self) -> Observable[LatLon]: return self.connection.gps_location.transport.pure_observable() - def set_gps_travel_goal_points(self, points: list[LatLon]) -> None: - logger.info(f"Travelling to: {points}") - # self.connection.... (actually set the goal) - print("websocketvis", self.websocket_vis) - self.websocket_vis.set_gps_travel_goal_points(points) - def get_odom(self) -> PoseStamped: """Get the robot's odometry. @@ -709,7 +701,7 @@ def main(): while True: time.sleep(0.1) except KeyboardInterrupt: - logger.info("Shutting down...") + pass finally: robot.stop() diff --git a/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py b/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py index 785f55e025..cf2136dde6 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py @@ -14,6 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +# $$$$$$$$\ $$$$$$\ $$$$$$$\ $$$$$$\ +# \__$$ __|$$ __$$\ $$ __$$\ $$ __$$\ +# $$ | $$ / $$ |$$ | $$ |$$ / $$ | +# $$ | $$ | $$ |$$ | $$ |$$ | $$ | +# $$ | $$ | $$ |$$ | $$ |$$ | $$ | +# $$ | $$ | $$ |$$ | $$ |$$ | $$ | +# $$ | $$$$$$ |$$$$$$$ | $$$$$$ | +# \__| \______/ \_______/ \______/ +# DOES anyone use this? The imports are broken which tells me it's unused. import functools import logging @@ -29,11 +38,13 @@ from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer from dimos.navigation.global_planner import AstarPlanner from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos.perception.common.utils import load_camera_info, load_camera_info_opencv, rectify_image from dimos.protocol import pubsub from dimos.protocol.pubsub.lcmpubsub import LCM from dimos.protocol.tf import TF @@ -162,6 +173,7 @@ def __init__( @rpc def start(self): + super().start() """Start the connection and subscribe to sensor streams.""" match self.connection_type: case "webrtc": @@ -177,10 +189,23 @@ def start(self): raise ValueError(f"Unknown connection type: {self.connection_type}") # Connect sensor streams to outputs - self.connection.lidar_stream().subscribe(self.lidar.publish) - self.connection.odom_stream().subscribe(self._publish_tf) - self.connection.video_stream().subscribe(self._on_video) - self.movecmd.subscribe(self.move) + unsub = self.connection.lidar_stream().subscribe(self.lidar.publish) + self._disposables.add(unsub) + + unsub = self.connection.odom_stream().subscribe(self._publish_tf) + self._disposables.add(unsub) + + unsub = self.connection.video_stream().subscribe(self._on_video) + self._disposables.add(unsub) + + unsub = self.movecmd.subscribe(self.move) + self._disposables.add(unsub) + + @rpc + def stop(self) -> None: + if self.connection: + self.connection.stop() + super().stop() def _on_video(self, msg: Image): """Handle incoming video frames and publish synchronized camera data.""" @@ -312,6 +337,7 @@ def __init__( self.navigator = None self.frontier_explorer = None self.websocket_vis = None + self.foxglove_bridge = None def start(self): """Start the robot system with navigation modules only.""" @@ -321,8 +347,7 @@ def start(self): self._deploy_mapping() self._deploy_navigation() - foxglove_bridge = self.dimos.deploy(FoxgloveBridge) - foxglove_bridge.start() + self.foxglove_bridge = self.dimos.deploy(FoxgloveBridge) self._start_modules() @@ -410,6 +435,7 @@ def _start_modules(self): self.local_planner.start() self.navigator.start() self.frontier_explorer.start() + self.foxglove_bridge.start() def move(self, twist: Twist, duration: float = 0.0): """Send movement command to robot.""" diff --git a/dimos/robot/unitree_webrtc/unitree_skill_container.py b/dimos/robot/unitree_webrtc/unitree_skill_container.py index 4db720be81..61df7be2d7 100644 --- a/dimos/robot/unitree_webrtc/unitree_skill_container.py +++ b/dimos/robot/unitree_webrtc/unitree_skill_container.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Optional from dimos.core import Module +from dimos.core.core import rpc from dimos.msgs.geometry_msgs import Twist, Vector3 from dimos.protocol.skill.skill import skill from dimos.protocol.skill.type import Reducer, Stream @@ -52,6 +53,15 @@ def __init__(self, robot: Optional[UnitreeGo2] = None): # Dynamically generate skills from UNITREE_WEBRTC_CONTROLS self._generate_unitree_skills() + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + # TODO: Do I need to clean up dynamic skills? + super().stop() + def _generate_unitree_skills(self): """Dynamically generate skills from the UNITREE_WEBRTC_CONTROLS list.""" logger.info(f"Generating {len(UNITREE_WEBRTC_CONTROLS)} dynamic Unitree skills") diff --git a/dimos/robot/utils/robot_debugger.py b/dimos/robot/utils/robot_debugger.py index 5ab33487fc..74c174f9cd 100644 --- a/dimos/robot/utils/robot_debugger.py +++ b/dimos/robot/utils/robot_debugger.py @@ -14,17 +14,18 @@ import os +from dimos.core.resource import Resource from dimos.utils.logging_config import setup_logger logger = setup_logger(__file__) -class RobotDebugger: +class RobotDebugger(Resource): def __init__(self, robot): self._robot = robot self._threaded_server = None - def __enter__(self): + def start(self) -> None: if not os.getenv("ROBOT_DEBUGGER"): return @@ -52,9 +53,7 @@ def exposed_robot(self): }, ) self._threaded_server.start() - return self - def __exit__(self, exc_type, exc_val, exc_tb): + def stop(self) -> None: if self._threaded_server: self._threaded_server.close() - return False diff --git a/dimos/utils/cli/agentspy/agentspy.py b/dimos/utils/cli/agentspy/agentspy.py index de784f4719..a3fc70f0b0 100644 --- a/dimos/utils/cli/agentspy/agentspy.py +++ b/dimos/utils/cli/agentspy/agentspy.py @@ -17,7 +17,7 @@ import time from collections import deque from dataclasses import dataclass -from typing import Any, Deque, Dict, List, Optional, Union +from typing import Any, Deque, List, Optional, Union from langchain_core.messages import ( AIMessage, @@ -25,18 +25,11 @@ SystemMessage, ToolMessage, ) -from rich.console import Console -from rich.table import Table -from rich.text import Text from textual.app import App, ComposeResult from textual.binding import Binding -from textual.containers import Container, ScrollableContainer -from textual.reactive import reactive from textual.widgets import Footer, RichLog -from dimos.protocol.pubsub import lcm from dimos.protocol.pubsub.lcmpubsub import PickleLCM -from dimos.utils.logging_config import setup_logger # Type alias for all message types we might receive AnyMessage = Union[SystemMessage, ToolMessage, AIMessage, HumanMessage] diff --git a/dimos/utils/cli/skillspy/demo_skillspy.py b/dimos/utils/cli/skillspy/demo_skillspy.py index 3ec3829794..20c5417a2e 100644 --- a/dimos/utils/cli/skillspy/demo_skillspy.py +++ b/dimos/utils/cli/skillspy/demo_skillspy.py @@ -103,6 +103,8 @@ def skill_runner(): except KeyboardInterrupt: print("\nDemo stopped.") + agent_interface.stop() + if __name__ == "__main__": run_demo_skills() diff --git a/dimos/utils/cli/skillspy/skillspy.py b/dimos/utils/cli/skillspy/skillspy.py index 8255f72587..68253aa848 100644 --- a/dimos/utils/cli/skillspy/skillspy.py +++ b/dimos/utils/cli/skillspy/skillspy.py @@ -28,7 +28,6 @@ from dimos.protocol.skill.comms import SkillMsg from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateEnum -from dimos.protocol.skill.type import MsgType class AgentSpy: diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index 3bd84bb845..d5b9bd4364 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -14,6 +14,9 @@ import os import json +import uuid +import string +import hashlib from typing import Any, Optional @@ -48,3 +51,21 @@ def extract_json_from_llm_response(response: str) -> Any: pass return None + + +def short_id(from_string: str | None = None) -> str: + alphabet = string.digits + string.ascii_letters + base = len(alphabet) + + if from_string is None: + num = uuid.uuid4().int + else: + hash_bytes = hashlib.sha1(from_string.encode()).digest()[:16] + num = int.from_bytes(hash_bytes, "big") + + chars = [] + while num: + num, rem = divmod(num, base) + chars.append(alphabet[rem]) + + return "".join(reversed(chars))[:18] diff --git a/dimos/utils/monitoring.py b/dimos/utils/monitoring.py index 9c562dc0e7..c13c274cac 100644 --- a/dimos/utils/monitoring.py +++ b/dimos/utils/monitoring.py @@ -172,6 +172,8 @@ def __init__(self): @rpc def start(self): + super().start() + if self._utilization_thread: self._utilization_thread.start() @@ -180,7 +182,7 @@ def stop(self): if self._utilization_thread: self._utilization_thread.stop() self._utilization_thread.join(timeout=2) - self._close_module() + super().stop() def _can_use_py_spy(): diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index c31d5220a8..f8440d4b20 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -37,6 +37,7 @@ from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.utils.logging_config import setup_logger +from reactivex.disposable import Disposable logger = setup_logger("dimos.web.websocket_vis") @@ -61,17 +62,17 @@ class WebsocketVisModule(Module): """ # LCM inputs - robot_pose: In[PoseStamped] = None + odom: In[PoseStamped] = None gps_location: In[LatLon] = None path: In[Path] = None global_costmap: In[OccupancyGrid] = None # LCM outputs - click_goal: Out[PoseStamped] = None + goal_request: Out[PoseStamped] = None gps_goal: Out[LatLon] = None explore_cmd: Out[Bool] = None stop_explore_cmd: Out[Bool] = None - movecmd: Out[Twist] = None + cmd_vel: Out[Twist] = None movecmd_stamped: Out[TwistStamped] = None def __init__(self, port: int = 7779, **kwargs): @@ -83,19 +84,20 @@ def __init__(self, port: int = 7779, **kwargs): super().__init__(**kwargs) self.port = port - self.server_thread: Optional[threading.Thread] = None + self._uvicorn_server_thread: Optional[threading.Thread] = None self.sio: Optional[socketio.AsyncServer] = None self.app = None self._broadcast_loop = None self._broadcast_thread = None + self._uvicorn_server: Optional[uvicorn.Server] = None self.vis_state = {} self.state_lock = threading.Lock() logger.info(f"WebSocket visualization module initialized on port {port}") - def _start_broadcast_loop(self): - def run_loop(): + def _start_broadcast_loop(self) -> None: + def websocket_vis_loop() -> None: self._broadcast_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._broadcast_loop) try: @@ -105,37 +107,58 @@ def run_loop(): finally: self._broadcast_loop.close() - self._broadcast_thread = threading.Thread(target=run_loop, daemon=True) + self._broadcast_thread = threading.Thread(target=websocket_vis_loop, daemon=True) self._broadcast_thread.start() @rpc def start(self): + super().start() + self._create_server() + self._start_broadcast_loop() - self.server_thread = threading.Thread(target=self._run_server, daemon=True) - self.server_thread.start() + self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) + self._uvicorn_server_thread.start() + + if self.odom.connection is not None: + unsub = self.odom.subscribe(self._on_robot_pose) + self._disposables.add(Disposable(unsub)) - # Only subscribe to connected topics - if self.robot_pose.connection is not None: - self.robot_pose.subscribe(self._on_robot_pose) if self.gps_location.connection is not None: - self.gps_location.subscribe(self._on_gps_location) + unsub = self.gps_location.subscribe(self._on_gps_location) + self._disposables.add(Disposable(unsub)) + if self.path.connection is not None: - self.path.subscribe(self._on_path) - if self.global_costmap.connection is not None: - self.global_costmap.subscribe(self._on_global_costmap) + unsub = self.path.subscribe(self._on_path) + self._disposables.add(Disposable(unsub)) - logger.info(f"WebSocket server started on http://localhost:{self.port}") + if self.global_costmap.connection is not None: + unsub = self.global_costmap.subscribe(self._on_global_costmap) + self._disposables.add(Disposable(unsub)) @rpc def stop(self): - """Stop the WebSocket server.""" + if self._uvicorn_server: + self._uvicorn_server.should_exit = True + + if self.sio and self._broadcast_loop and not self._broadcast_loop.is_closed(): + + async def _disconnect_all(): + await self.sio.disconnect() + + asyncio.run_coroutine_threadsafe(_disconnect_all(), self._broadcast_loop) + if self._broadcast_loop and not self._broadcast_loop.is_closed(): self._broadcast_loop.call_soon_threadsafe(self._broadcast_loop.stop) + if self._broadcast_thread and self._broadcast_thread.is_alive(): self._broadcast_thread.join(timeout=1.0) - logger.info("WebSocket visualization module stopped") + + if self._uvicorn_server_thread and self._uvicorn_server_thread.is_alive(): + self._uvicorn_server_thread.join(timeout=2.0) + + super().stop() @rpc def set_gps_travel_goal_points(self, points: list[LatLon]) -> None: @@ -169,7 +192,7 @@ async def click(sid, position): orientation=(0, 0, 0, 1), # Default orientation frame_id="world", ) - self.click_goal.publish(goal) + self.goal_request.publish(goal) logger.info(f"Click goal published: ({goal.position.x:.2f}, {goal.position.y:.2f})") @self.sio.event @@ -190,14 +213,14 @@ async def stop_explore(sid): @self.sio.event async def move_command(sid, data): # Publish Twist if transport is configured - if self.movecmd and self.movecmd.transport: + if self.cmd_vel and self.cmd_vel.transport: twist = Twist( linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), angular=Vector3( data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] ), ) - self.movecmd.publish(twist) + self.cmd_vel.publish(twist) # Publish TwistStamped if transport is configured if self.movecmd_stamped and self.movecmd_stamped.transport: @@ -211,13 +234,15 @@ async def move_command(sid, data): ) self.movecmd_stamped.publish(twist_stamped) - def _run_server(self): - uvicorn.run( + def _run_uvicorn_server(self) -> None: + config = uvicorn.Config( self.app, host="0.0.0.0", port=self.port, log_level="error", # Reduce verbosity ) + self._uvicorn_server = uvicorn.Server(config) + self._uvicorn_server.run() def _on_robot_pose(self, msg: PoseStamped): pose_data = {"type": "vector", "c": [msg.position.x, msg.position.y, msg.position.z]} diff --git a/mypy_strict.ini b/mypy_strict.ini new file mode 100644 index 0000000000..ed49020e9b --- /dev/null +++ b/mypy_strict.ini @@ -0,0 +1,30 @@ +[mypy] +python_version = 3.10 +strict = True +exclude = ^dimos/models/Detic(/|$)|.*/test_.|.*/conftest.py* + +# Enable all optional error checks individually (redundant with strict=True, but explicit) +warn_unused_configs = True +warn_unused_ignores = True +warn_redundant_casts = True +warn_no_return = True +warn_return_any = True +warn_unreachable = True +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +disallow_untyped_decorators = True +disallow_any_generics = True +no_implicit_optional = True +check_untyped_defs = True +strict_optional = True +ignore_missing_imports = False +show_error_context = True +show_column_numbers = True +pretty = True +color_output = True +error_summary = True + +# Performance and caching +incremental = True +cache_dir = .mypy_cache_strict diff --git a/pyproject.toml b/pyproject.toml index 7979293411..0e4f8e0fe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ cuda = [ dev = [ "ruff==0.11.10", - "mypy==1.15.0", + "mypy==1.18.2", "pre_commit==4.2.0", "pytest==8.3.5", "pytest-asyncio==0.26.0", diff --git a/tests/test_object_tracking_module.py b/tests/test_object_tracking_module.py index 374315c184..2fd1038c89 100755 --- a/tests/test_object_tracking_module.py +++ b/tests/test_object_tracking_module.py @@ -276,7 +276,7 @@ async def test_object_tracking_module(): cv2.destroyAllWindows() if tracker: - tracker.cleanup() + tracker.stop() if zed: zed.stop() if foxglove_bridge: diff --git a/tests/test_spatial_memory.py b/tests/test_spatial_memory.py index 4a7b72701a..16b1449509 100644 --- a/tests/test_spatial_memory.py +++ b/tests/test_spatial_memory.py @@ -198,9 +198,7 @@ def on_stored_frame(result): saved_path = spatial_memory.vector_db.visual_memory.save("visual_memory.pkl") print(f"Saved {spatial_memory.vector_db.visual_memory.count()} images to disk at {saved_path}") - # Final cleanup - print("Performing final cleanup...") - spatial_memory.cleanup() + spatial_memory.stop() print("Test completed successfully")