diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 67544f7f29..0e38877c44 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: - --use-current-year - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.1 + rev: v0.14.3 hooks: - id: ruff-format stages: [pre-commit] diff --git a/bin/mypy-strict b/bin/mypy-strict index 05001bf100..660faa1a14 100755 --- a/bin/mypy-strict +++ b/bin/mypy-strict @@ -34,6 +34,7 @@ run_mypy() { main() { local user_email="none" local after_date="" + local in_this_branch="" # Parse arguments while [[ $# -gt 0 ]]; do @@ -74,6 +75,10 @@ main() { esac shift 2 ;; + --in-this-branch) + in_this_branch=true + shift 1 + ;; *) echo "Error: Unknown argument '$1'" >&2 exit 1 @@ -92,6 +97,10 @@ main() { pipeline="$pipeline | ./bin/filter-errors-for-user '$user_email'" fi + if [[ "$in_this_branch" ]]; then + pipeline="$pipeline | grep -Ff <(git diff --name-only dev..HEAD) -" + fi + eval "$pipeline" } diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index 04c08b0434..e58e1aa9a3 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -32,11 +32,11 @@ from dimos.agents2.system_prompt import get_system_prompt from dimos.core import DimosCluster, rpc from dimos.protocol.skill.coordinator import ( - SkillContainer, SkillCoordinator, SkillState, SkillStateDict, ) +from dimos.protocol.skill.skill import SkillContainer from dimos.protocol.skill.type import Output from dimos.utils.logging_config import setup_logger @@ -270,7 +270,6 @@ def _get_state() -> str: # we are getting tools from the coordinator on each turn # since this allows for skillcontainers to dynamically provide new skills tools = self.get_tools() - print("Available tools:", [tool.name for tool in tools]) self._llm = self._llm.bind_tools(tools) # publish to /agent topic for observability diff --git a/dimos/agents2/fixtures/test_pounce.json b/dimos/agents2/fixtures/test_pounce.json new file mode 100644 index 0000000000..99e84d003a --- /dev/null +++ b/dimos/agents2/fixtures/test_pounce.json @@ -0,0 +1,38 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "execute_sport_command", + "args": { + "args": [ + "FrontPounce" + ] + }, + "id": "call_Ukj6bCAnHQLj28RHRp697blZ", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "speak", + "args": { + "args": [ + "I have successfully performed a front pounce." + ] + }, + "id": "call_FR9DtqEvJ9zSY85qVD2UFrll", + "type": "tool_call" + } + ] + }, + { + "content": "I have successfully performed a front pounce.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_show_your_love.json b/dimos/agents2/fixtures/test_show_your_love.json new file mode 100644 index 0000000000..941906e781 --- /dev/null +++ b/dimos/agents2/fixtures/test_show_your_love.json @@ -0,0 +1,38 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "execute_sport_command", + "args": { + "args": [ + "FingerHeart" + ] + }, + "id": "call_VFp6x9F00FBmiiUiemFWewop", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "speak", + "args": { + "args": [ + "Here's a gesture to show you some love!" + ] + }, + "id": "call_WUUmBJ95s9PtVx8YQsmlJ4EU", + "type": "tool_call" + } + ] + }, + { + "content": "Just did a finger heart gesture to show my affection!", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/skills/conftest.py b/dimos/agents2/skills/conftest.py index a8734ca7ed..f7d1500847 100644 --- a/dimos/agents2/skills/conftest.py +++ b/dimos/agents2/skills/conftest.py @@ -15,17 +15,13 @@ from functools import partial import pytest -import reactivex as rx from reactivex.scheduler import ThreadPoolScheduler from dimos.agents2.skills.google_maps_skill_container import GoogleMapsSkillContainer from dimos.agents2.skills.gps_nav_skill import GpsNavSkillContainer from dimos.agents2.skills.navigation import NavigationSkillContainer from dimos.agents2.system_prompt import get_system_prompt -from dimos.mapping.types import LatLon -from dimos.msgs.sensor_msgs import Image -from dimos.robot.robot import GpsRobot -from dimos.utils.data import get_data +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer system_prompt = get_system_prompt() @@ -45,31 +41,6 @@ def cleanup_threadpool_scheduler(monkeypatch): threadpool.scheduler = ThreadPoolScheduler(max_workers=threadpool.get_max_workers()) -# TODO: Delete -@pytest.fixture -def fake_robot(mocker): - return mocker.MagicMock() - - -# TODO: Delete -@pytest.fixture -def fake_gps_robot(mocker): - return mocker.Mock(spec=GpsRobot) - - -@pytest.fixture -def fake_video_stream(): - image_path = get_data("chair-image.png") - image = Image.from_file(str(image_path)) - return rx.of(image) - - -# TODO: Delete -@pytest.fixture -def fake_gps_position_stream(): - return rx.of(LatLon(lat=37.783, lon=-122.413)) - - @pytest.fixture def navigation_skill_container(mocker): container = NavigationSkillContainer() @@ -81,22 +52,35 @@ def navigation_skill_container(mocker): @pytest.fixture -def gps_nav_skill_container(fake_gps_robot, fake_gps_position_stream): - container = GpsNavSkillContainer(fake_gps_robot, fake_gps_position_stream) +def gps_nav_skill_container(mocker): + container = GpsNavSkillContainer() + container.gps_location.connection = mocker.MagicMock() + container.gps_goal = mocker.MagicMock() container.start() yield container container.stop() @pytest.fixture -def google_maps_skill_container(fake_gps_robot, fake_gps_position_stream, mocker): - container = GoogleMapsSkillContainer(fake_gps_robot, fake_gps_position_stream) +def google_maps_skill_container(mocker): + container = GoogleMapsSkillContainer() + container.gps_location.connection = mocker.MagicMock() container.start() container._client = mocker.MagicMock() yield container container.stop() +@pytest.fixture +def unitree_skills(mocker): + container = UnitreeSkillContainer() + container._move = mocker.Mock() + container._publish_request = mocker.Mock() + container.start() + yield container + container.stop() + + @pytest.fixture def create_navigation_agent(navigation_skill_container, create_fake_agent): return partial( @@ -122,3 +106,12 @@ def create_google_maps_agent( system_prompt=system_prompt, skill_containers=[gps_nav_skill_container, google_maps_skill_container], ) + + +@pytest.fixture +def create_unitree_skills_agent(unitree_skills, create_fake_agent): + return partial( + create_fake_agent, + system_prompt=system_prompt, + skill_containers=[unitree_skills], + ) diff --git a/dimos/agents2/skills/demo_google_maps_skill.py b/dimos/agents2/skills/demo_google_maps_skill.py new file mode 100644 index 0000000000..4bee8691a3 --- /dev/null +++ b/dimos/agents2/skills/demo_google_maps_skill.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dotenv import load_dotenv + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.demo_robot import demo_robot +from dimos.agents2.skills.google_maps_skill_container import google_maps_skill +from dimos.agents2.system_prompt import get_system_prompt +from dimos.core.blueprints import autoconnect + +load_dotenv() + + +demo_google_maps_skill = autoconnect( + demo_robot(), + google_maps_skill(), + human_input(), + llm_agent(system_prompt=get_system_prompt()), +) diff --git a/dimos/agents2/skills/demo_gps_nav.py b/dimos/agents2/skills/demo_gps_nav.py new file mode 100644 index 0000000000..55ffd052ff --- /dev/null +++ b/dimos/agents2/skills/demo_gps_nav.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dotenv import load_dotenv + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.demo_robot import demo_robot +from dimos.agents2.skills.gps_nav_skill import gps_nav_skill +from dimos.agents2.system_prompt import get_system_prompt +from dimos.core.blueprints import autoconnect + +load_dotenv() + + +demo_gps_nav_skill = autoconnect( + demo_robot(), + gps_nav_skill(), + human_input(), + llm_agent(system_prompt=get_system_prompt()), +) diff --git a/dimos/agents2/skills/demo_robot.py b/dimos/agents2/skills/demo_robot.py new file mode 100644 index 0000000000..74b5c47bd3 --- /dev/null +++ b/dimos/agents2/skills/demo_robot.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from reactivex import interval + +from dimos.core.module import Module +from dimos.core.stream import Out +from dimos.mapping.types import LatLon + + +class DemoRobot(Module): + gps_location: Out[LatLon] = None + + def start(self) -> None: + super().start() + self._disposables.add(interval(1.0).subscribe(lambda _: self._publish_gps_location())) + + def stop(self) -> None: + super().stop() + + def _publish_gps_location(self) -> None: + self.gps_location.publish(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) + + +demo_robot = DemoRobot.blueprint + + +__all__ = ["DemoRobot", "demo_robot"] diff --git a/dimos/agents2/skills/google_maps_skill_container.py b/dimos/agents2/skills/google_maps_skill_container.py index 433914a5e3..f5c1af428e 100644 --- a/dimos/agents2/skills/google_maps_skill_container.py +++ b/dimos/agents2/skills/google_maps_skill_container.py @@ -15,43 +15,34 @@ import json from typing import Any -from reactivex import Observable -from reactivex.disposable import CompositeDisposable - -from dimos.core.resource import Resource +from dimos.core.core import rpc +from dimos.core.skill_module import SkillModule +from dimos.core.stream import In from dimos.mapping.google_maps.google_maps import GoogleMaps -from dimos.mapping.osm.current_location_map import CurrentLocationMap from dimos.mapping.types import LatLon -from dimos.protocol.skill.skill import SkillContainer, skill -from dimos.robot.robot import Robot +from dimos.protocol.skill.skill import skill from dimos.utils.logging_config import setup_logger logger = setup_logger(__file__) -class GoogleMapsSkillContainer(SkillContainer, Resource): - _robot: Robot - _disposables: CompositeDisposable - _latest_location: LatLon | None - _position_stream: Observable[LatLon] - _current_location_map: CurrentLocationMap - _started: bool +class GoogleMapsSkillContainer(SkillModule): + _latest_location: LatLon | None = None + _client: GoogleMaps + + gps_location: In[LatLon] = None - def __init__(self, robot: Robot, position_stream: Observable[LatLon]) -> None: + def __init__(self) -> None: super().__init__() - self._robot = robot - self._disposables = CompositeDisposable() - self._latest_location = None - self._position_stream = position_stream self._client = GoogleMaps() - self._started = False + @rpc def start(self) -> None: - self._started = True - self._disposables.add(self._position_stream.subscribe(self._on_gps_location)) + super().start() + self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) + @rpc def stop(self) -> None: - self._disposables.dispose() super().stop() def _on_gps_location(self, location: LatLon) -> None: @@ -75,9 +66,6 @@ def where_am_i(self, context_radius: int = 200) -> str: context_radius (int): default 200, how many meters to look around """ - if not self._started: - raise ValueError(f"{self} has not been started.") - location = self._get_latest_location() result = None @@ -105,9 +93,6 @@ def get_gps_position_for_queries(self, *queries: str) -> str: queries (list[str]): The places you want to look up. """ - if not self._started: - raise ValueError(f"{self} has not been started.") - location = self._get_latest_location() results: list[dict[str, Any] | str] = [] @@ -123,3 +108,8 @@ def get_gps_position_for_queries(self, *queries: str) -> str: results.append(f"no result for {query}") return json.dumps(results) + + +google_maps_skill = GoogleMapsSkillContainer.blueprint + +__all__ = ["GoogleMapsSkillContainer", "google_maps_skill"] diff --git a/dimos/agents2/skills/gps_nav_skill.py b/dimos/agents2/skills/gps_nav_skill.py index 80e346790a..fa68b32800 100644 --- a/dimos/agents2/skills/gps_nav_skill.py +++ b/dimos/agents2/skills/gps_nav_skill.py @@ -14,48 +14,43 @@ import json -from reactivex import Observable -from reactivex.disposable import CompositeDisposable - -from dimos.core.resource import Resource -from dimos.mapping.google_maps.google_maps import GoogleMaps -from dimos.mapping.osm.current_location_map import CurrentLocationMap +from dimos.core.core import rpc +from dimos.core.rpc_client import RpcCall +from dimos.core.skill_module import SkillModule +from dimos.core.stream import In, Out from dimos.mapping.types import LatLon from dimos.mapping.utils.distance import distance_in_meters -from dimos.protocol.skill.skill import SkillContainer, skill -from dimos.robot.robot import Robot +from dimos.protocol.skill.skill import skill from dimos.utils.logging_config import setup_logger logger = setup_logger(__file__) -class GpsNavSkillContainer(SkillContainer, Resource): - _robot: Robot - _disposables: CompositeDisposable - _latest_location: LatLon | None - _position_stream: Observable[LatLon] - _current_location_map: CurrentLocationMap - _started: bool - _max_valid_distance: int +class GpsNavSkillContainer(SkillModule): + _latest_location: LatLon | None = None + _max_valid_distance: int = 50000 + _set_gps_travel_goal_points: RpcCall | None = None + + gps_location: In[LatLon] = None + gps_goal: Out[LatLon] = None - def __init__(self, robot: Robot, position_stream: Observable[LatLon]) -> None: + def __init__(self) -> None: super().__init__() - self._robot = robot - self._disposables = CompositeDisposable() - self._latest_location = None - self._position_stream = position_stream - self._client = GoogleMaps() - self._started = False - self._max_valid_distance = 50000 + @rpc def start(self) -> None: - self._started = True - self._disposables.add(self._position_stream.subscribe(self._on_gps_location)) + super().start() + self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) + @rpc def stop(self) -> None: - self._disposables.dispose() super().stop() + @rpc + def set_WebsocketVisModule_set_gps_travel_goal_points(self, callable: RpcCall) -> None: + self._set_gps_travel_goal_points = callable + self._set_gps_travel_goal_points.set_rpc(self.rpc) + def _on_gps_location(self, location: LatLon) -> None: self._latest_location = location @@ -75,18 +70,24 @@ def set_gps_travel_points(self, *points: dict[str, float]) -> str: # then travel to {"lat": 37.7915, "lon": -122.4276} """ - if not self._started: - raise ValueError(f"{self} has not been started.") - new_points = [self._convert_point(x) for x in points] if not all(new_points): parsed = json.dumps([x.__dict__ if x else x for x in new_points]) return f"Not all points were valid. I parsed this: {parsed}" + for new_point in new_points: + distance = distance_in_meters(self._get_latest_location(), new_point) + if distance > self._max_valid_distance: + return f"Point {new_point} is too far ({int(distance)} meters away)." + logger.info(f"Set travel points: {new_points}") - self._robot.set_gps_travel_goal_points(new_points) + if self.gps_goal._transport is not None: + self.gps_goal.publish(new_points) + + if self._set_gps_travel_goal_points: + self._set_gps_travel_goal_points(new_points) return "I've successfully set the travel points." @@ -99,9 +100,10 @@ def _convert_point(self, point: dict[str, float]) -> LatLon | None: if lat is None or lon is None: return None - new_point = LatLon(lat=lat, lon=lon) - distance = distance_in_meters(self._get_latest_location(), new_point) - if distance > self._max_valid_distance: - return None + return LatLon(lat=lat, lon=lon) + + +gps_nav_skill = GpsNavSkillContainer.blueprint + - return new_point +__all__ = ["GpsNavSkillContainer", "gps_nav_skill"] diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 9a7b91d68a..cf3411d497 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -16,7 +16,6 @@ from typing import Any from dimos.core.core import rpc -from dimos.core.rpc_client import RpcCall from dimos.core.skill_module import SkillModule from dimos.core.stream import In from dimos.models.qwen.video_query import BBox @@ -24,7 +23,7 @@ from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.msgs.sensor_msgs import Image -from dimos.navigation.bt_navigator.navigator import NavigatorState +from dimos.navigation.base import NavigationState from dimos.navigation.visual.query import get_object_bbox_from_image from dimos.protocol.skill.skill import skill from dimos.types.robot_location import RobotLocation @@ -39,19 +38,21 @@ class NavigationSkillContainer(SkillModule): _skill_started: bool = False _similarity_threshold: float = 0.23 - _tag_location: RpcCall | None = None - _query_tagged_location: RpcCall | None = None - _query_by_text: RpcCall | None = None - _set_goal: RpcCall | None = None - _get_state: RpcCall | None = None - _is_goal_reached: RpcCall | None = None - _cancel_goal: RpcCall | None = None - _track: RpcCall | None = None - _stop_track: RpcCall | None = None - _is_tracking: RpcCall | None = None - _stop_exploration: RpcCall | None = None - _explore: RpcCall | None = None - _is_exploration_active: RpcCall | None = None + rpc_calls: list[str] = [ + "SpatialMemory.tag_location", + "SpatialMemory.query_tagged_location", + "SpatialMemory.query_by_text", + "NavigationInterface.set_goal", + "NavigationInterface.get_state", + "NavigationInterface.is_goal_reached", + "NavigationInterface.cancel_goal", + "ObjectTracking.track", + "ObjectTracking.stop_track", + "ObjectTracking.is_tracking", + "WavefrontFrontierExplorer.stop_exploration", + "WavefrontFrontierExplorer.explore", + "WavefrontFrontierExplorer.is_exploration_active", + ] color_image: In[Image] = None odom: In[PoseStamped] = None @@ -77,72 +78,6 @@ def _on_color_image(self, image: Image) -> None: def _on_odom(self, odom: PoseStamped) -> None: self._latest_odom = odom - # TODO: This is quite repetitive, maybe I should automate this somehow - @rpc - def set_SpatialMemory_tag_location(self, callable: RpcCall) -> None: - self._tag_location = callable - self._tag_location.set_rpc(self.rpc) - - @rpc - def set_SpatialMemory_query_tagged_location(self, callable: RpcCall) -> None: - self._query_tagged_location = callable - self._query_tagged_location.set_rpc(self.rpc) - - @rpc - def set_SpatialMemory_query_by_text(self, callable: RpcCall) -> None: - self._query_by_text = callable - self._query_by_text.set_rpc(self.rpc) - - @rpc - def set_BehaviorTreeNavigator_set_goal(self, callable: RpcCall) -> None: - self._set_goal = callable - self._set_goal.set_rpc(self.rpc) - - @rpc - def set_BehaviorTreeNavigator_get_state(self, callable: RpcCall) -> None: - self._get_state = callable - self._get_state.set_rpc(self.rpc) - - @rpc - def set_BehaviorTreeNavigator_is_goal_reached(self, callable: RpcCall) -> None: - self._is_goal_reached = callable - self._is_goal_reached.set_rpc(self.rpc) - - @rpc - def set_BehaviorTreeNavigator_cancel_goal(self, callable: RpcCall) -> None: - self._cancel_goal = callable - self._cancel_goal.set_rpc(self.rpc) - - @rpc - def set_ObjectTracking_track(self, callable: RpcCall) -> None: - self._track = callable - self._track.set_rpc(self.rpc) - - @rpc - def set_ObjectTracking_stop_track(self, callable: RpcCall) -> None: - self._stop_track = callable - self._stop_track.set_rpc(self.rpc) - - @rpc - def set_ObjectTracking_is_tracking(self, callable: RpcCall) -> None: - self._is_tracking = callable - self._is_tracking.set_rpc(self.rpc) - - @rpc - def set_WavefrontFrontierExplorer_stop_exploration(self, callable: RpcCall) -> None: - self._stop_exploration = callable - self._stop_exploration.set_rpc(self.rpc) - - @rpc - def set_WavefrontFrontierExplorer_explore(self, callable: RpcCall) -> None: - self._explore = callable - self._explore.set_rpc(self.rpc) - - @rpc - def set_WavefrontFrontierExplorer_is_exploration_active(self, callable: RpcCall) -> None: - self._is_exploration_active = callable - self._is_exploration_active.set_rpc(self.rpc) - @skill() def tag_location(self, location_name: str) -> str: """Tag this location in the spatial memory with a name. @@ -171,20 +106,13 @@ def tag_location(self, location_name: str) -> str: rotation=(rotation.x, rotation.y, rotation.z), ) - if not self._tag_location(location): + tag_location_rpc = self.get_rpc_calls("SpatialMemory.tag_location") + if not tag_location_rpc(location): return f"Error: Failed to store '{location_name}' in the spatial memory" logger.info(f"Tagged {location}") return f"Tagged '{location_name}': ({position.x},{position.y})." - def _navigate_to_object(self, query: str) -> str | None: - position = self.detection_module.nav_vlm(query) - print("Object position from VLM:", position) - if not position: - return None - self.nav.navigate_to(position) - return f"Arrived to object matching '{query}' in view." - @skill() def navigate_with_text(self, query: str) -> str: """Navigate to a location by querying the existing semantic map using natural language. @@ -219,11 +147,13 @@ def navigate_with_text(self, query: str) -> str: return f"No tagged location called '{query}'. No object in view matching '{query}'. No matching location found in semantic map for '{query}'." def _navigate_by_tagged_location(self, query: str) -> str | None: - if not self._query_tagged_location: + try: + query_tagged_location_rpc = self.get_rpc_calls("SpatialMemory.query_tagged_location") + except Exception: logger.warning("SpatialMemory module not connected, cannot query tagged locations") return None - robot_location = self._query_tagged_location(query) + robot_location = query_tagged_location_rpc(query) if not robot_location: return None @@ -244,21 +174,27 @@ def _navigate_by_tagged_location(self, query: str) -> str | None: ) def _navigate_to(self, pose: PoseStamped) -> bool: - if not self._set_goal or not self._get_state or not self._is_goal_reached: - logger.error("BehaviorTreeNavigator module not connected properly") + try: + set_goal_rpc, get_state_rpc, is_goal_reached_rpc = self.get_rpc_calls( + "NavigationInterface.set_goal", + "NavigationInterface.get_state", + "NavigationInterface.is_goal_reached", + ) + except Exception: + logger.error("Navigation module not connected properly") return False logger.info( f"Navigating to pose: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" ) - self._set_goal(pose) + set_goal_rpc(pose) time.sleep(1.0) - while self._get_state() == NavigatorState.FOLLOWING_PATH: + while get_state_rpc() == NavigationState.FOLLOWING_PATH: time.sleep(0.25) time.sleep(1.0) - if not self._is_goal_reached(): + if not is_goal_reached_rpc(): logger.info("Navigation was cancelled or failed") return False else: @@ -275,18 +211,26 @@ def _navigate_to_object(self, query: str) -> str | None: if bbox is None: return None - if not self._track or not self._stop_track or not self._is_tracking: + try: + track_rpc, stop_track_rpc, is_tracking_rpc = self.get_rpc_calls( + "ObjectTracking.track", "ObjectTracking.stop_track", "ObjectTracking.is_tracking" + ) + except Exception: logger.error("ObjectTracking module not connected properly") return None - if not self._get_state or not self._is_goal_reached: - logger.error("BehaviorTreeNavigator module not connected properly") + try: + get_state_rpc, is_goal_reached_rpc = self.get_rpc_calls( + "NavigationInterface.get_state", "NavigationInterface.is_goal_reached" + ) + except Exception: + logger.error("Navigation module not connected properly") return None logger.info(f"Found {query} at {bbox}") # Start tracking - BBoxNavigationModule automatically generates goals - self._track(bbox) + track_rpc(bbox) start_time = time.time() timeout = 30.0 @@ -294,31 +238,31 @@ def _navigate_to_object(self, query: str) -> str | None: while time.time() - start_time < timeout: # Check if navigator finished - if self._get_state() == NavigatorState.IDLE and goal_set: + if get_state_rpc() == NavigationState.IDLE and goal_set: logger.info("Waiting for goal result") time.sleep(1.0) - if not self._is_goal_reached(): + if not is_goal_reached_rpc(): logger.info(f"Goal cancelled, tracking '{query}' failed") - self._stop_track() + stop_track_rpc() return None else: logger.info(f"Reached '{query}'") - self._stop_track() + stop_track_rpc() return f"Successfully arrived at '{query}'" # If goal set and tracking lost, just continue (tracker will resume or timeout) - if goal_set and not self._is_tracking(): + if goal_set and not is_tracking_rpc(): continue # BBoxNavigationModule automatically sends goals when tracker publishes # Just check if we have any detections to mark goal_set - if self._is_tracking(): + if is_tracking_rpc(): goal_set = True time.sleep(0.25) logger.warning(f"Navigation to '{query}' timed out after {timeout}s") - self._stop_track() + stop_track_rpc() return None def _get_bbox_for_current_frame(self, query: str) -> BBox | None: @@ -328,10 +272,12 @@ def _get_bbox_for_current_frame(self, query: str) -> BBox | None: return get_object_bbox_from_image(self._vl_model, self._latest_image, query) def _navigate_using_semantic_map(self, query: str) -> str: - if not self._query_by_text: + try: + query_by_text_rpc = self.get_rpc_calls("SpatialMemory.query_by_text") + except Exception: return "Error: The SpatialMemory module is not connected." - results = self._query_by_text(query) + results = query_by_text_rpc(query) if not results: return f"No matching location found in semantic map for '{query}'" @@ -368,16 +314,20 @@ def stop_movement(self) -> str: return "Stopped" def _cancel_goal_and_stop(self) -> None: - if not self._cancel_goal: - logger.warning("BehaviorTreeNavigator module not connected, cannot cancel goal") + try: + cancel_goal_rpc = self.get_rpc_calls("NavigationInterface.cancel_goal") + except Exception: + logger.warning("Navigation module not connected, cannot cancel goal") return - if not self._stop_exploration: + try: + stop_exploration_rpc = self.get_rpc_calls("WavefrontFrontierExplorer.stop_exploration") + except Exception: logger.warning("FrontierExplorer module not connected, cannot stop exploration") return - self._cancel_goal() - return self._stop_exploration() + cancel_goal_rpc() + return stop_exploration_rpc() @skill() def start_exploration(self, timeout: float = 240.0) -> str: @@ -401,18 +351,23 @@ def start_exploration(self, timeout: float = 240.0) -> str: self._cancel_goal_and_stop() def _start_exploration(self, timeout: float) -> str: - if not self._explore or not self._is_exploration_active: + try: + explore_rpc, is_exploration_active_rpc = self.get_rpc_calls( + "WavefrontFrontierExplorer.explore", + "WavefrontFrontierExplorer.is_exploration_active", + ) + except Exception: return "Error: The WavefrontFrontierExplorer module is not connected." logger.info("Starting autonomous frontier exploration") start_time = time.time() - has_started = self._explore() + has_started = explore_rpc() if not has_started: return "Error: Could not start exploration." - while time.time() - start_time < timeout and self._is_exploration_active(): + while time.time() - start_time < timeout and is_exploration_active_rpc(): time.sleep(0.5) return "Exploration completed successfuly" diff --git a/dimos/agents2/skills/osm.py b/dimos/agents2/skills/osm.py index ae721bea81..d4455f14bd 100644 --- a/dimos/agents2/skills/osm.py +++ b/dimos/agents2/skills/osm.py @@ -28,7 +28,6 @@ class OsmSkill(SkillModule): _latest_location: LatLon | None _current_location_map: CurrentLocationMap - _skill_started: bool gps_location: In[LatLon] = None @@ -36,11 +35,9 @@ def __init__(self) -> None: super().__init__() self._latest_location = None self._current_location_map = CurrentLocationMap(QwenVlModel()) - self._skill_started = False def start(self) -> None: super().start() - self._skill_started = True self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) def stop(self) -> None: @@ -63,9 +60,6 @@ def street_map_query(self, query_sentence: str) -> str: query_sentence (str): The query sentence. """ - if not self._skill_started: - raise ValueError(f"{self} has not been started.") - self._current_location_map.update_position(self._latest_location) location = self._current_location_map.query_for_one_position_and_context( query_sentence, self._latest_location diff --git a/dimos/agents2/skills/ros_navigation.py b/dimos/agents2/skills/ros_navigation.py index 973cdcc10f..f6c257a941 100644 --- a/dimos/agents2/skills/ros_navigation.py +++ b/dimos/agents2/skills/ros_navigation.py @@ -15,10 +15,10 @@ import time from typing import TYPE_CHECKING, Any -from dimos.core.resource import Resource +from dimos.core.skill_module import SkillModule from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.geometry_msgs.Vector3 import make_vector3 -from dimos.protocol.skill.skill import SkillContainer, skill +from dimos.protocol.skill.skill import skill from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion @@ -28,7 +28,8 @@ logger = setup_logger(__file__) -class RosNavigation(SkillContainer, Resource): +# TODO: Remove, deprecated +class RosNavigation(SkillModule): _robot: "UnitreeG1" _started: bool @@ -54,8 +55,6 @@ def navigate_with_text(self, query: str) -> str: query: Text query to search for in the semantic map """ - print("X" * 10000) - if not self._started: raise ValueError(f"{self} has not been started.") @@ -119,3 +118,8 @@ def _get_goal_pose_from_result(self, result: dict[str, Any]) -> PoseStamped | No orientation=euler_to_quaternion(make_vector3(0, 0, theta)), frame_id="map", ) + + +ros_navigation_skill = RosNavigation.blueprint + +__all__ = ["RosNavigation", "ros_navigation_skill"] diff --git a/dimos/agents2/skills/test_google_maps_skill_container.py b/dimos/agents2/skills/test_google_maps_skill_container.py index 27a9dadb8f..4f6b730b5f 100644 --- a/dimos/agents2/skills/test_google_maps_skill_container.py +++ b/dimos/agents2/skills/test_google_maps_skill_container.py @@ -15,9 +15,11 @@ import re from dimos.mapping.google_maps.types import Coordinates, LocationContext, Position +from dimos.mapping.types import LatLon def test_where_am_i(create_google_maps_agent, google_maps_skill_container) -> None: + google_maps_skill_container._latest_location = LatLon(lat=37.782654, lon=-122.413273) google_maps_skill_container._client.get_location_context.return_value = LocationContext( street="Bourbon Street", coordinates=Coordinates(lat=37.782654, lon=-122.413273) ) @@ -31,6 +33,7 @@ def test_where_am_i(create_google_maps_agent, google_maps_skill_container) -> No def test_get_gps_position_for_queries( create_google_maps_agent, google_maps_skill_container ) -> None: + google_maps_skill_container._latest_location = LatLon(lat=37.782654, lon=-122.413273) google_maps_skill_container._client.get_position.side_effect = [ Position(lat=37.782601, lon=-122.413201, description="address 1"), Position(lat=37.782602, lon=-122.413202, description="address 2"), diff --git a/dimos/agents2/skills/test_gps_nav_skills.py b/dimos/agents2/skills/test_gps_nav_skills.py index 9e8090b169..19cc8cb104 100644 --- a/dimos/agents2/skills/test_gps_nav_skills.py +++ b/dimos/agents2/skills/test_gps_nav_skills.py @@ -16,24 +16,40 @@ from dimos.mapping.types import LatLon -def test_set_gps_travel_points(fake_gps_robot, create_gps_nav_agent) -> None: +def test_set_gps_travel_points(create_gps_nav_agent, gps_nav_skill_container, mocker) -> None: + gps_nav_skill_container._latest_location = LatLon(lat=37.782654, lon=-122.413273) + gps_nav_skill_container._set_gps_travel_goal_points = mocker.Mock() agent = create_gps_nav_agent(fixture="test_set_gps_travel_points.json") agent.query("go to lat: 37.782654, lon: -122.413273") - fake_gps_robot.set_gps_travel_goal_points.assert_called_once_with( + gps_nav_skill_container._set_gps_travel_goal_points.assert_called_once_with( + [LatLon(lat=37.782654, lon=-122.413273)] + ) + gps_nav_skill_container.gps_goal.publish.assert_called_once_with( [LatLon(lat=37.782654, lon=-122.413273)] ) -def test_set_gps_travel_points_multiple(fake_gps_robot, create_gps_nav_agent) -> None: +def test_set_gps_travel_points_multiple( + create_gps_nav_agent, gps_nav_skill_container, mocker +) -> None: + gps_nav_skill_container._latest_location = LatLon(lat=37.782654, lon=-122.413273) + gps_nav_skill_container._set_gps_travel_goal_points = mocker.Mock() agent = create_gps_nav_agent(fixture="test_set_gps_travel_points_multiple.json") agent.query( "go to lat: 37.782654, lon: -122.413273, then 37.782660,-122.413260, and then 37.782670,-122.413270" ) - fake_gps_robot.set_gps_travel_goal_points.assert_called_once_with( + gps_nav_skill_container._set_gps_travel_goal_points.assert_called_once_with( + [ + LatLon(lat=37.782654, lon=-122.413273), + LatLon(lat=37.782660, lon=-122.413260), + LatLon(lat=37.782670, lon=-122.413270), + ] + ) + gps_nav_skill_container.gps_goal.publish.assert_called_once_with( [ LatLon(lat=37.782654, lon=-122.413273), LatLon(lat=37.782660, lon=-122.413260), diff --git a/dimos/agents2/skills/test_navigation.py b/dimos/agents2/skills/test_navigation.py index 9d4f3b7eff..93c0a4f5be 100644 --- a/dimos/agents2/skills/test_navigation.py +++ b/dimos/agents2/skills/test_navigation.py @@ -19,25 +19,35 @@ # @pytest.mark.skip def test_stop_movement(create_navigation_agent, navigation_skill_container, mocker) -> None: - navigation_skill_container._cancel_goal = mocker.Mock() - navigation_skill_container._stop_exploration = mocker.Mock() + cancel_goal_mock = mocker.Mock() + stop_exploration_mock = mocker.Mock() + navigation_skill_container._bound_rpc_calls["NavigationInterface.cancel_goal"] = ( + cancel_goal_mock + ) + navigation_skill_container._bound_rpc_calls["WavefrontFrontierExplorer.stop_exploration"] = ( + stop_exploration_mock + ) agent = create_navigation_agent(fixture="test_stop_movement.json") agent.query("stop") - navigation_skill_container._cancel_goal.assert_called_once_with() - navigation_skill_container._stop_exploration.assert_called_once_with() + cancel_goal_mock.assert_called_once_with() + stop_exploration_mock.assert_called_once_with() def test_take_a_look_around(create_navigation_agent, navigation_skill_container, mocker) -> None: - navigation_skill_container._explore = mocker.Mock() - navigation_skill_container._is_exploration_active = mocker.Mock() + explore_mock = mocker.Mock() + is_exploration_active_mock = mocker.Mock() + navigation_skill_container._bound_rpc_calls["WavefrontFrontierExplorer.explore"] = explore_mock + navigation_skill_container._bound_rpc_calls[ + "WavefrontFrontierExplorer.is_exploration_active" + ] = is_exploration_active_mock mocker.patch("dimos.agents2.skills.navigation.time.sleep") agent = create_navigation_agent(fixture="test_take_a_look_around.json") agent.query("take a look around for 10 seconds") - navigation_skill_container._explore.assert_called_once_with() + explore_mock.assert_called_once_with() def test_go_to_semantic_location( @@ -51,11 +61,11 @@ def test_go_to_semantic_location( "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_to_object", return_value=None, ) - mocker.patch( + navigate_to_mock = mocker.patch( "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_to", return_value=True, ) - navigation_skill_container._query_by_text = mocker.Mock( + query_by_text_mock = mocker.Mock( return_value=[ { "distance": 0.5, @@ -69,12 +79,13 @@ def test_go_to_semantic_location( } ] ) + navigation_skill_container._bound_rpc_calls["SpatialMemory.query_by_text"] = query_by_text_mock agent = create_navigation_agent(fixture="test_go_to_semantic_location.json") agent.query("go to the bookshelf") - navigation_skill_container._query_by_text.assert_called_once_with("bookshelf") - navigation_skill_container._navigate_to.assert_called_once_with( + query_by_text_mock.assert_called_once_with("bookshelf") + navigate_to_mock.assert_called_once_with( PoseStamped( position=Vector3(1, 2, 0), orientation=euler_to_quaternion(Vector3(0, 0, 3)), diff --git a/dimos/agents2/skills/test_unitree_skill_container.py b/dimos/agents2/skills/test_unitree_skill_container.py new file mode 100644 index 0000000000..d9570341d8 --- /dev/null +++ b/dimos/agents2/skills/test_unitree_skill_container.py @@ -0,0 +1,42 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_pounce(create_unitree_skills_agent, unitree_skills) -> None: + agent = create_unitree_skills_agent(fixture="test_pounce.json") + + response = agent.query("pounce") + + assert "front pounce" in response.lower() + unitree_skills._publish_request.assert_called_once_with( + "rt/api/sport/request", {"api_id": 1032} + ) + + +def test_show_your_love(create_unitree_skills_agent, unitree_skills) -> None: + agent = create_unitree_skills_agent(fixture="test_show_your_love.json") + + response = agent.query("show your love") + + assert "finger heart" in response.lower() + unitree_skills._publish_request.assert_called_once_with( + "rt/api/sport/request", {"api_id": 1036} + ) + + +def test_did_you_mean(unitree_skills) -> None: + assert ( + unitree_skills.execute_sport_command("Pounce") + == "There's no 'Pounce' command. Did you mean: ['FrontPounce', 'Pose']" + ) diff --git a/dimos/agents2/temp/test_unitree_agent_query.py b/dimos/agents2/temp/test_unitree_agent_query.py deleted file mode 100644 index 4990940e6c..0000000000 --- a/dimos/agents2/temp/test_unitree_agent_query.py +++ /dev/null @@ -1,229 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test script to debug agent query issues. -Shows different ways to call the agent and handle async. -""" - -import asyncio -import os -from pathlib import Path -import sys -import time - -from dotenv import load_dotenv - -# Add parent directories to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) - -from dimos.agents2 import Agent -from dimos.agents2.spec import Model, Provider -from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("test_agent_query") - -# Load environment variables -load_dotenv() - - -async def test_async_query(): - """Test agent query using async/await pattern.""" - print("\n=== Testing Async Query ===\n") - - # Create skill container - container = UnitreeSkillContainer(robot=None) - - # Create agent - agent = Agent( - system_prompt="You are a helpful robot assistant. List 3 skills you can do.", - model=Model.GPT_4O_MINI, - provider=Provider.OPENAI, - ) - - # Register skills and start - agent.register_skills(container) - agent.start() - - # Query asynchronously - logger.info("Sending async query...") - future = agent.query_async("Hello! What skills do you have?") - - # Wait for result - logger.info("Waiting for response...") - await asyncio.sleep(10) # Give it time to process - - # Check if future is done - if hasattr(future, "done") and future.done(): - try: - result = future.result() - logger.info(f"Got result: {result}") - except Exception as e: - logger.error(f"Future failed: {e}") - else: - logger.warning("Future not completed yet") - - agent.stop() - - return future - - -def test_sync_query_with_thread() -> None: - """Test agent query using threading for the event loop.""" - print("\n=== Testing Sync Query with Thread ===\n") - - import threading - - # Create skill container - container = UnitreeSkillContainer(robot=None) - - # Create agent - agent = Agent( - system_prompt="You are a helpful robot assistant. List 3 skills you can do.", - model=Model.GPT_4O_MINI, - provider=Provider.OPENAI, - ) - - # Register skills and start - agent.register_skills(container) - agent.start() - - # Track the thread we might create - loop_thread = None - - # The agent's event loop should be running in the Module's thread - # Let's check if it's running - if agent._loop and agent._loop.is_running(): - logger.info("Agent's event loop is running") - else: - logger.warning("Agent's event loop is NOT running - this is the problem!") - - # Try to run the loop in a thread - def run_loop() -> None: - asyncio.set_event_loop(agent._loop) - agent._loop.run_forever() - - loop_thread = threading.Thread(target=run_loop, daemon=False, name="EventLoopThread") - loop_thread.start() - time.sleep(1) # Give loop time to start - logger.info("Started event loop in thread") - - # Now try the query - try: - logger.info("Sending sync query...") - result = agent.query("Hello! What skills do you have?") - logger.info(f"Got result: {result}") - except Exception as e: - logger.error(f"Query failed: {e}") - import traceback - - traceback.print_exc() - - agent.stop() - - # Then stop the manually created event loop thread if we created one - if loop_thread and loop_thread.is_alive(): - logger.info("Stopping manually created event loop thread...") - # Stop the event loop - if agent._loop and agent._loop.is_running(): - agent._loop.call_soon_threadsafe(agent._loop.stop) - # Wait for thread to finish - loop_thread.join(timeout=5) - if loop_thread.is_alive(): - logger.warning("Thread did not stop cleanly within timeout") - - # Finally close the container - container._close_module() - - -# def test_with_real_module_system(): -# """Test using the real DimOS module system (like in test_agent.py).""" -# print("\n=== Testing with Module System ===\n") - -# from dimos.core import start - -# # Start the DimOS system -# dimos = start(2) - -# # Deploy container and agent as modules -# container = dimos.deploy(UnitreeSkillContainer, robot=None) -# agent = dimos.deploy( -# Agent, -# system_prompt="You are a helpful robot assistant. List 3 skills you can do.", -# model=Model.GPT_4O_MINI, -# provider=Provider.OPENAI, -# ) - -# # Register skills -# agent.register_skills(container) -# agent.start() - -# # Query -# try: -# logger.info("Sending query through module system...") -# future = agent.query_async("Hello! What skills do you have?") - -# # In the module system, the loop should be running -# time.sleep(5) # Wait for processing - -# if hasattr(future, "result"): -# result = future.result(timeout=10) -# logger.info(f"Got result: {result}") -# except Exception as e: -# logger.error(f"Query failed: {e}") - -# # Clean up -# agent.stop() -# dimos.stop() - - -def main() -> None: - """Run tests based on available API key.""" - - if not os.getenv("OPENAI_API_KEY"): - print("ERROR: OPENAI_API_KEY not set") - print("Please set your OpenAI API key to test the agent") - sys.exit(1) - - print("=" * 60) - print("Agent Query Testing") - print("=" * 60) - - # Test 1: Async query - try: - asyncio.run(test_async_query()) - except Exception as e: - logger.error(f"Async test failed: {e}") - - # Test 2: Sync query with threading - try: - test_sync_query_with_thread() - except Exception as e: - logger.error(f"Sync test failed: {e}") - - # Test 3: Module system (optional - more complex) - # try: - # test_with_real_module_system() - # except Exception as e: - # logger.error(f"Module test failed: {e}") - - print("\n" + "=" * 60) - print("Testing complete") - print("=" * 60) - - -if __name__ == "__main__": - main() diff --git a/dimos/agents2/temp/test_unitree_skill_container.py b/dimos/agents2/temp/test_unitree_skill_container.py deleted file mode 100644 index 16502004ff..0000000000 --- a/dimos/agents2/temp/test_unitree_skill_container.py +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test file for UnitreeSkillContainer with agents2 framework. -Tests skill registration and basic functionality. -""" - -from pathlib import Path -import sys -import time - -# Add parent directories to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) - -from dimos.agents2 import Agent -from dimos.agents2.spec import Model, Provider -from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("test_unitree_skills") - - -def test_skill_container_creation(): - """Test that the skill container can be created and skills are registered.""" - print("\n=== Testing UnitreeSkillContainer Creation ===") - - # Create container without robot (for testing) - container = UnitreeSkillContainer(robot=None) - - try: - # Get available skills from the container - skills = container.skills() - - print(f"Number of skills registered: {len(skills)}") - print("\nAvailable skills:") - for name, skill_config in list(skills.items())[:10]: # Show first 10 - print( - f" - {name}: {skill_config.description if hasattr(skill_config, 'description') else 'No description'}" - ) - if len(skills) > 10: - print(f" ... and {len(skills) - 10} more skills") - - return container, skills - finally: - # Ensure proper cleanup - container._close_module() - # Small delay to allow threads to finish cleanup - time.sleep(0.1) - - -def test_agent_with_skills(): - """Test that an agent can be created with the skill container.""" - print("\n=== Testing Agent with Skills ===") - - # Create skill container - container = UnitreeSkillContainer(robot=None) - agent = None - - try: - # Create agent with configuration passed directly - agent = Agent( - system_prompt="You are a helpful robot assistant that can control a Unitree Go2 robot.", - model=Model.GPT_4O_MINI, - provider=Provider.OPENAI, - ) - - # Register skills - agent.register_skills(container) - - print("Agent created and skills registered successfully!") - - # Get tools to verify - tools = agent.get_tools() - print(f"Agent has access to {len(tools)} tools") - - return agent - finally: - # Ensure proper cleanup in order - if agent: - agent.stop() - container._close_module() - # Small delay to allow threads to finish cleanup - time.sleep(0.1) - - -def test_skill_schemas() -> None: - """Test that skill schemas are properly generated for LangChain.""" - print("\n=== Testing Skill Schemas ===") - - container = UnitreeSkillContainer(robot=None) - - try: - skills = container.skills() - - # Check a few key skills (using snake_case names now) - skill_names = ["move", "wait", "stand_up", "sit", "front_flip", "dance1"] - - for name in skill_names: - if name in skills: - skill_config = skills[name] - print(f"\n{name} skill:") - print(f" Config: {skill_config}") - if hasattr(skill_config, "schema"): - print( - f" Schema keys: {skill_config.schema.keys() if skill_config.schema else 'None'}" - ) - else: - print(f"\nWARNING: Skill '{name}' not found!") - finally: - # Ensure proper cleanup - container._close_module() - # Small delay to allow threads to finish cleanup - time.sleep(0.1) diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index fedb05769c..45aa617571 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC from collections import defaultdict from collections.abc import Mapping from dataclasses import dataclass, field @@ -109,16 +110,9 @@ def _all_name_types(self) -> set[tuple[str, type]]: def _is_name_unique(self, name: str) -> bool: return sum(1 for n, _ in self._all_name_types if n == name) == 1 - def build(self, global_config: GlobalConfig | None = None) -> ModuleCoordinator: - if global_config is None: - global_config = GlobalConfig() - global_config = global_config.model_copy(update=self.global_config_overrides) - - module_coordinator = ModuleCoordinator(global_config=global_config) - - module_coordinator.start() - - # Deploy all modules. + def _deploy_all_modules( + self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig + ) -> None: for blueprint in self.blueprints: kwargs = {**blueprint.kwargs} sig = inspect.signature(blueprint.module.__init__) @@ -126,6 +120,7 @@ def build(self, global_config: GlobalConfig | None = None) -> ModuleCoordinator: kwargs["global_config"] = global_config module_coordinator.deploy(blueprint.module, *blueprint.args, **kwargs) + def _connect_transports(self, module_coordinator: ModuleCoordinator) -> None: # Gather all the In/Out connections with remapping applied. connections = defaultdict(list) # Track original name -> remapped name for each module @@ -145,26 +140,81 @@ def build(self, global_config: GlobalConfig | None = None) -> ModuleCoordinator: transport = self._get_transport_for(remapped_name, type) for module, original_name in connections[(remapped_name, type)]: instance = module_coordinator.get_instance(module) - # Use the remote method to set transport on Dask actors instance.set_transport(original_name, transport) + def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: # Gather all RPC methods. rpc_methods = {} + rpc_methods_dot = {} + # Track interface methods to detect ambiguity + interface_methods = defaultdict(list) # interface_name.method -> [(module_class, method)] + for blueprint in self.blueprints: for method_name in blueprint.module.rpcs.keys(): method = getattr(module_coordinator.get_instance(blueprint.module), method_name) + # Register under concrete class name (backward compatibility) rpc_methods[f"{blueprint.module.__name__}_{method_name}"] = method + rpc_methods_dot[f"{blueprint.module.__name__}.{method_name}"] = method + + # Also register under any interface names + for base in blueprint.module.__bases__: + # Check if this base is an abstract interface with the method + if ( + base is not Module + and issubclass(base, ABC) + and hasattr(base, method_name) + and getattr(base, method_name, None) is not None + ): + interface_key = f"{base.__name__}.{method_name}" + interface_methods[interface_key].append((blueprint.module, method)) + + # Check for ambiguity in interface methods and add non-ambiguous ones + for interface_key, implementations in interface_methods.items(): + if len(implementations) == 1: + rpc_methods_dot[interface_key] = implementations[0][1] # Fulfil method requests (so modules can call each other). for blueprint in self.blueprints: + instance = module_coordinator.get_instance(blueprint.module) for method_name in blueprint.module.rpcs.keys(): if not method_name.startswith("set_"): continue linked_name = method_name.removeprefix("set_") if linked_name not in rpc_methods: continue - instance = module_coordinator.get_instance(blueprint.module) getattr(instance, method_name)(rpc_methods[linked_name]) + for requested_method_name in instance.get_rpc_method_names(): + # Check if this is an ambiguous interface method + if ( + requested_method_name in interface_methods + and len(interface_methods[requested_method_name]) > 1 + ): + modules_str = ", ".join( + impl[0].__name__ for impl in interface_methods[requested_method_name] + ) + raise ValueError( + f"Ambiguous RPC method '{requested_method_name}' requested by " + f"{blueprint.module.__name__}. Multiple implementations found: " + f"{modules_str}. Please use a concrete class name instead." + ) + + if requested_method_name not in rpc_methods_dot: + continue + instance.set_rpc_method( + requested_method_name, rpc_methods_dot[requested_method_name] + ) + + def build(self, global_config: GlobalConfig | None = None) -> ModuleCoordinator: + if global_config is None: + global_config = GlobalConfig() + global_config = global_config.model_copy(update=self.global_config_overrides) + + module_coordinator = ModuleCoordinator(global_config=global_config) + module_coordinator.start() + + self._deploy_all_modules(module_coordinator, global_config) + self._connect_transports(module_coordinator) + self._connect_rpc_methods(module_coordinator) module_coordinator.start_all_modules() diff --git a/dimos/core/module.py b/dimos/core/module.py index 6ce8480087..25a39b37ac 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -22,6 +22,7 @@ get_args, get_origin, get_type_hints, + overload, ) from dask.distributed import Actor, get_worker @@ -30,6 +31,7 @@ from dimos.core import colors from dimos.core.core import T, rpc from dimos.core.resource import Resource +from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec from dimos.protocol.service import Configurable @@ -78,6 +80,9 @@ class ModuleBase(Configurable[ModuleConfig], SkillContainer, Resource): _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None _disposables: CompositeDisposable + _bound_rpc_calls: dict[str, RpcCall] = {} + + rpc_calls: list[str] = [] default_config = ModuleConfig @@ -245,6 +250,30 @@ def blueprint(self): return partial(create_module_blueprint, self) + @rpc + def get_rpc_method_names(self) -> list[str]: + return self.rpc_calls + + @rpc + def set_rpc_method(self, method: str, callable: RpcCall) -> None: + callable.set_rpc(self.rpc) + self._bound_rpc_calls[method] = callable + + @overload + def get_rpc_calls(self, method: str) -> RpcCall: ... + + @overload + def get_rpc_calls(self, method1: str, method2: str, *methods: str) -> tuple[RpcCall, ...]: ... + + def get_rpc_calls(self, *methods: str) -> RpcCall | tuple[RpcCall, ...]: + missing = [m for m in methods if m not in self._bound_rpc_calls] + if missing: + raise ValueError( + f"RPC methods not found. Class: {self.__class__.__name__}, RPC methods: {', '.join(missing)}" + ) + result = tuple(self._bound_rpc_calls[m] for m in methods) + return result[0] if len(result) == 1 else result + class DaskModule(ModuleBase): ref: Actor diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 59f541aa58..d910e88d7d 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -185,7 +185,7 @@ def test_build_happy_path() -> None: coordinator.stop() -def test_remapping(): +def test_remapping() -> None: """Test that remapping connections works correctly.""" pubsub.lcm.autoconf() diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 97f09a4182..37a179b875 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -87,7 +87,7 @@ def test_classmethods() -> None: # Check that we have the expected RPC methods assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" assert "start" in class_rpcs, "start should be in rpcs" - assert len(class_rpcs) == 6 + assert len(class_rpcs) == 8 # Check that the values are callable assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" diff --git a/dimos/core/transport.py b/dimos/core/transport.py index 32f75e6c33..5cd05460f0 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -15,7 +15,7 @@ from __future__ import annotations import traceback -from typing import TypeVar +from typing import Any, TypeVar import dimos.core.colors as colors @@ -38,9 +38,9 @@ class PubSubTransport(Transport[T]): - topic: any + topic: Any - def __init__(self, topic: any) -> None: + def __init__(self, topic: Any) -> None: self.topic = topic def __str__(self) -> str: @@ -101,7 +101,7 @@ def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> class JpegLcmTransport(LCMTransport): - def __init__(self, topic: str, type: type, **kwargs): + def __init__(self, topic: str, type: type, **kwargs) -> None: self.lcm = JpegLCM(**kwargs) super().__init__(topic, type) @@ -160,7 +160,7 @@ def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> class JpegShmTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str, quality: int = 75, **kwargs): + def __init__(self, topic: str, quality: int = 75, **kwargs) -> None: super().__init__(topic) self.shm = JpegSharedMemory(quality=quality, **kwargs) self.quality = quality @@ -168,7 +168,7 @@ def __init__(self, topic: str, quality: int = 75, **kwargs): def __reduce__(self): return (JpegShmTransport, (self.topic, self.quality)) - def broadcast(self, _, msg): + def broadcast(self, _, msg) -> None: if not self._started: self.shm.start() self._started = True diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py index a4d0a1decb..f5ef549bac 100644 --- a/dimos/hardware/camera/module.py +++ b/dimos/hardware/camera/module.py @@ -20,7 +20,6 @@ from dimos_lcm.sensor_msgs import CameraInfo import reactivex as rx from reactivex import operators as ops -from reactivex.disposable import Disposable from reactivex.observable import Observable from dimos import spec diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index d27d1df394..e6e9da693d 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -40,7 +40,7 @@ class PiperArm: def __init__(self, arm_name: str = "arm") -> None: - self.arm = C_PiperInterface_V2() + self.arm = C_PiperInterface_V2() # noqa: F405 self.arm.ConnectPort() self.resetArm() time.sleep(0.5) diff --git a/dimos/mapping/osm/demo_osm.py b/dimos/mapping/osm/demo_osm.py index cf907378f3..20d9e40e74 100644 --- a/dimos/mapping/osm/demo_osm.py +++ b/dimos/mapping/osm/demo_osm.py @@ -14,37 +14,17 @@ # limitations under the License. from dotenv import load_dotenv -from reactivex import interval from dimos.agents2.agent import llm_agent from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.demo_robot import demo_robot from dimos.agents2.skills.osm import osm_skill from dimos.agents2.system_prompt import get_system_prompt from dimos.core.blueprints import autoconnect -from dimos.core.module import Module -from dimos.core.stream import Out -from dimos.mapping.types import LatLon load_dotenv() -class DemoRobot(Module): - gps_location: Out[LatLon] = None - - def start(self) -> None: - super().start() - self._disposables.add(interval(1.0).subscribe(lambda _: self._publish_gps_location())) - - def stop(self) -> None: - super().stop() - - def _publish_gps_location(self) -> None: - self.gps_location.publish(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) - - -demo_robot = DemoRobot.blueprint - - demo_osm = autoconnect( demo_robot(), osm_skill(), diff --git a/dimos/models/depth/metric3d.py b/dimos/models/depth/metric3d.py index e22c546dc3..0c10f31e63 100644 --- a/dimos/models/depth/metric3d.py +++ b/dimos/models/depth/metric3d.py @@ -13,8 +13,6 @@ # limitations under the License. import cv2 -import numpy as np -from PIL import Image import torch # May need to add this back for import to work diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index 99a8d8fd15..f7c790ffbf 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod import time -from typing import TYPE_CHECKING, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np import torch diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py index 4241392d8e..fe173dc017 100644 --- a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py +++ b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py @@ -6,9 +6,7 @@ from contact_graspnet_pytorch.checkpoints import CheckpointIO from contact_graspnet_pytorch.contact_grasp_estimator import GraspEstimator from contact_graspnet_pytorch.data import load_available_input_data -from contact_graspnet_pytorch.visualization_utils_o3d import show_image, visualize_grasps import numpy as np -import torch from dimos.utils.data import get_data diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py b/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py index b006c98603..7964a24954 100644 --- a/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py +++ b/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py @@ -1,7 +1,5 @@ import glob -import importlib.util import os -import sys import numpy as np import pytest diff --git a/dimos/models/qwen/video_query.py b/dimos/models/qwen/video_query.py index 0f8a3b8f9c..80bb078bac 100644 --- a/dimos/models/qwen/video_query.py +++ b/dimos/models/qwen/video_query.py @@ -2,7 +2,6 @@ import json import os -from typing import Optional, Tuple import numpy as np from openai import OpenAI diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index ce63c70238..781f1adbf1 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -1,5 +1,4 @@ from functools import cached_property -from typing import Optional import warnings import numpy as np diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index c302d12c22..773fcc35ad 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,6 +1,5 @@ from functools import cached_property import os -from typing import Optional import numpy as np from openai import OpenAI diff --git a/dimos/msgs/foxglove_msgs/__init__.py b/dimos/msgs/foxglove_msgs/__init__.py index 36698f5484..945ebf94c9 100644 --- a/dimos/msgs/foxglove_msgs/__init__.py +++ b/dimos/msgs/foxglove_msgs/__init__.py @@ -1 +1,3 @@ from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations + +__all__ = ["ImageAnnotations"] diff --git a/dimos/msgs/geometry_msgs/__init__.py b/dimos/msgs/geometry_msgs/__init__.py index de46a0a079..683aa2e37c 100644 --- a/dimos/msgs/geometry_msgs/__init__.py +++ b/dimos/msgs/geometry_msgs/__init__.py @@ -9,3 +9,20 @@ from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import TwistWithCovarianceStamped from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike + +__all__ = [ + "Pose", + "PoseLike", + "PoseStamped", + "PoseWithCovariance", + "PoseWithCovarianceStamped", + "Quaternion", + "Transform", + "Twist", + "TwistStamped", + "TwistWithCovariance", + "TwistWithCovarianceStamped", + "Vector3", + "VectorLike", + "to_pose", +] diff --git a/dimos/msgs/geometry_msgs/test_publish.py b/dimos/msgs/geometry_msgs/test_publish.py index 50578346ae..464966d5b7 100644 --- a/dimos/msgs/geometry_msgs/test_publish.py +++ b/dimos/msgs/geometry_msgs/test_publish.py @@ -47,7 +47,7 @@ def _loop() -> None: lc.handle() # loop 10000 times for _ in range(10000000): - 3 + 3 + 3 + 3 # noqa: B018 except Exception as e: print(f"Error in LCM handling: {e}") diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 051169d6a9..b9ffc6a65b 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -430,7 +430,7 @@ def lcm_decode(cls, data: bytes, **kwargs) -> Image: ) ) - def lcm_jpeg_encode(self, quality: int = 75, frame_id: Optional[str] = None) -> bytes: + def lcm_jpeg_encode(self, quality: int = 75, frame_id: str | None = None) -> bytes: """Convert to LCM Image message with JPEG-compressed data. Args: diff --git a/dimos/msgs/sensor_msgs/__init__.py b/dimos/msgs/sensor_msgs/__init__.py index 56574e448d..130df72964 100644 --- a/dimos/msgs/sensor_msgs/__init__.py +++ b/dimos/msgs/sensor_msgs/__init__.py @@ -2,3 +2,5 @@ from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.Joy import Joy from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +__all__ = ["CameraInfo", "Image", "ImageFormat", "Joy", "PointCloud2"] diff --git a/dimos/navigation/base.py b/dimos/navigation/base.py new file mode 100644 index 0000000000..bc60551f54 --- /dev/null +++ b/dimos/navigation/base.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from enum import Enum + +from dimos.msgs.geometry_msgs import PoseStamped + + +class NavigationState(Enum): + IDLE = "idle" + FOLLOWING_PATH = "following_path" + RECOVERY = "recovery" + + +class NavigationInterface(ABC): + @abstractmethod + def set_goal(self, goal: PoseStamped) -> bool: + """ + Set a new navigation goal (non-blocking). + + Args: + goal: Target pose to navigate to + + Returns: + True if goal was accepted, False otherwise + """ + pass + + @abstractmethod + def get_state(self) -> NavigationState: + """ + Get the current state of the navigator. + + Returns: + Current navigation state + """ + pass + + @abstractmethod + def is_goal_reached(self) -> bool: + """ + Check if the current goal has been reached. + + Returns: + True if goal was reached, False otherwise + """ + pass + + @abstractmethod + def cancel_goal(self) -> bool: + """ + Cancel the current navigation goal. + + Returns: + True if goal was cancelled, False if no goal was active + """ + pass + + +__all__ = ["NavigationInterface", "NavigationState"] diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py index ec4fbe7ce9..aa5c1458a1 100644 --- a/dimos/navigation/bt_navigator/navigator.py +++ b/dimos/navigation/bt_navigator/navigator.py @@ -19,7 +19,6 @@ """ from collections.abc import Callable -from enum import Enum import threading import time @@ -30,6 +29,7 @@ from dimos.core.rpc_client import RpcCall from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.navigation.base import NavigationInterface, NavigationState from dimos.navigation.bt_navigator.goal_validator import find_safe_goal from dimos.navigation.bt_navigator.recovery_server import RecoveryServer from dimos.protocol.tf import TF @@ -39,15 +39,7 @@ logger = setup_logger(__file__) -class NavigatorState(Enum): - """Navigator state machine states.""" - - IDLE = "idle" - FOLLOWING_PATH = "following_path" - RECOVERY = "recovery" - - -class BehaviorTreeNavigator(Module): +class BehaviorTreeNavigator(Module, NavigationInterface): """ Navigator module for coordinating navigation tasks. @@ -91,7 +83,7 @@ def __init__( self.publishing_period = 1.0 / publishing_frequency # State machine - self.state = NavigatorState.IDLE + self.state = NavigationState.IDLE self.state_lock = threading.Lock() # Current goal @@ -200,12 +192,12 @@ def set_goal(self, goal: PoseStamped) -> bool: self._goal_reached = False with self.state_lock: - self.state = NavigatorState.FOLLOWING_PATH + self.state = NavigationState.FOLLOWING_PATH return True @rpc - def get_state(self) -> NavigatorState: + def get_state(self) -> NavigationState: """Get the current state of the navigator.""" return self.state @@ -213,7 +205,7 @@ def _on_odom(self, msg: PoseStamped) -> None: """Handle incoming odometry messages.""" self.latest_odom = msg - if self.state == NavigatorState.FOLLOWING_PATH: + if self.state == NavigationState.FOLLOWING_PATH: self.recovery_server.update_odom(msg) def _on_goal_request(self, msg: PoseStamped) -> None: @@ -281,7 +273,7 @@ def _control_loop(self) -> None: current_state = self.state self.navigation_state.publish(String(data=current_state.value)) - if current_state == NavigatorState.FOLLOWING_PATH: + if current_state == NavigationState.FOLLOWING_PATH: with self.goal_lock: goal = self.current_goal original_goal = self.original_goal @@ -328,9 +320,9 @@ def _control_loop(self) -> None: self._goal_reached = True logger.info("Goal reached, resetting local planner") - elif current_state == NavigatorState.RECOVERY: + elif current_state == NavigationState.RECOVERY: with self.state_lock: - self.state = NavigatorState.IDLE + self.state = NavigationState.IDLE time.sleep(self.publishing_period) @@ -351,7 +343,7 @@ def stop_navigation(self) -> None: self._goal_reached = False with self.state_lock: - self.state = NavigatorState.IDLE + self.state = NavigationState.IDLE self.reset_local_planner() self.recovery_server.reset() # Reset recovery server when stopping diff --git a/dimos/navigation/frontier_exploration/__init__.py b/dimos/navigation/frontier_exploration/__init__.py index 7236788842..24ce957ccf 100644 --- a/dimos/navigation/frontier_exploration/__init__.py +++ b/dimos/navigation/frontier_exploration/__init__.py @@ -1 +1,3 @@ from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer, wavefront_frontier_explorer + +__all__ = ["WavefrontFrontierExplorer", "wavefront_frontier_explorer"] diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 1c906be45f..bd67fbd532 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -53,6 +53,7 @@ from dimos.msgs.sensor_msgs import PointCloud2 from dimos.msgs.std_msgs import Bool from dimos.msgs.tf2_msgs.TFMessage import TFMessage +from dimos.navigation.base import NavigationInterface, NavigationState from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion @@ -68,7 +69,9 @@ class Config(ModuleConfig): ) -class ROSNav(Module, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlanner): +class ROSNav( + Module, NavigationInterface, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlanner +): config: Config default_config = Config @@ -89,6 +92,13 @@ class ROSNav(Module, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlan _spin_thread: threading.Thread | None = None _goal_reach: bool | None = None + # Navigation state tracking for NavigationInterface + _navigation_state: NavigationState = NavigationState.IDLE + _state_lock: threading.Lock + _navigation_thread: threading.Thread | None = None + _current_goal: PoseStamped | None = None + _goal_reached: bool = False + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -96,6 +106,11 @@ def __init__(self, *args, **kwargs) -> None: self._local_pointcloud_subject = Subject() self._global_pointcloud_subject = Subject() + # Initialize state tracking + self._state_lock = threading.Lock() + self._navigation_state = NavigationState.IDLE + self._goal_reached = False + if not rclpy.ok(): rclpy.init() @@ -173,6 +188,10 @@ def _spin_node(self) -> None: def _on_ros_goal_reached(self, msg: ROSBool) -> None: self._goal_reach = msg.data + if msg.data: + with self._state_lock: + self._goal_reached = True + self._navigation_state = NavigationState.IDLE def _on_ros_goal_waypoint(self, msg: ROSPointStamped) -> None: dimos_pose = PoseStamped( @@ -360,8 +379,74 @@ def stop_navigation(self) -> bool: soft_stop_msg.data = 2 self.soft_stop_pub.publish(soft_stop_msg) + with self._state_lock: + self._navigation_state = NavigationState.IDLE + self._current_goal = None + self._goal_reached = False + return True + @rpc + def set_goal(self, goal: PoseStamped) -> bool: + """Set a new navigation goal (non-blocking).""" + with self._state_lock: + self._current_goal = goal + self._goal_reached = False + self._navigation_state = NavigationState.FOLLOWING_PATH + + # Start navigation in a separate thread to make it non-blocking + if self._navigation_thread and self._navigation_thread.is_alive(): + logger.warning("Previous navigation still running, cancelling") + self.stop_navigation() + self._navigation_thread.join(timeout=1.0) + + self._navigation_thread = threading.Thread( + target=self._navigate_to_goal_async, + args=(goal,), + daemon=True, + name="ROSNavNavigationThread", + ) + self._navigation_thread.start() + + return True + + def _navigate_to_goal_async(self, goal: PoseStamped) -> None: + """Internal method to handle navigation in a separate thread.""" + try: + result = self.navigate_to(goal, timeout=60.0) + with self._state_lock: + self._goal_reached = result + self._navigation_state = NavigationState.IDLE + except Exception as e: + logger.error(f"Navigation failed: {e}") + with self._state_lock: + self._goal_reached = False + self._navigation_state = NavigationState.IDLE + + @rpc + def get_state(self) -> NavigationState: + """Get the current state of the navigator.""" + with self._state_lock: + return self._navigation_state + + @rpc + def is_goal_reached(self) -> bool: + """Check if the current goal has been reached.""" + with self._state_lock: + return self._goal_reached + + @rpc + def cancel_goal(self) -> bool: + """Cancel the current navigation goal.""" + + with self._state_lock: + had_goal = self._current_goal is not None + + if had_goal: + self.stop_navigation() + + return had_goal + @rpc def stop(self) -> None: """Stop the navigation module and clean up resources.""" @@ -384,7 +469,7 @@ def stop(self) -> None: super().stop() -navigation_module = ROSNav.blueprint +ros_nav = ROSNav.blueprint def deploy(dimos: DimosCluster): @@ -401,4 +486,4 @@ def deploy(dimos: DimosCluster): return nav -__all__ = ["ROSNav", "deploy", "navigation_module"] +__all__ = ["ROSNav", "deploy", "ros_nav"] diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 9da25f98e8..3986c0f7c7 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -24,7 +24,7 @@ import cv2 import numpy as np -from reactivex import Observable, disposable, interval, operators as ops +from reactivex import Observable, interval, operators as ops from reactivex.disposable import Disposable from dimos import spec @@ -38,7 +38,7 @@ from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import PoseStamped, Vector3 + from dimos.msgs.geometry_msgs import Vector3 _OUTPUT_DIR = DIMOS_PROJECT_ROOT / "assets" / "output" _MEMORY_DIR = _OUTPUT_DIR / "memory" diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index ef158ffb30..c348b1dda7 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -121,7 +121,7 @@ def decode(self, msg: bytes, topic: Topic) -> LCMMsg: class JpegSharedMemoryEncoderMixin(PubSubEncoderMixin[str, Image]): - def __init__(self, quality: int = 75, **kwargs): + def __init__(self, quality: int = 75, **kwargs) -> None: super().__init__(**kwargs) self.jpeg = TurboJPEG() self.quality = quality diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 388b8cf1dd..f6ea5deda9 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -23,12 +23,17 @@ "unitree-go2-jpeglcm": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard_with_jpeglcm", "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", "unitree-g1": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard", - "unitree-g1-basic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic", + "unitree-g1-bt-nav": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_bt_nav", + "unitree-g1-basic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_ros", + "unitree-g1-basic-bt-nav": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_bt_nav", "unitree-g1-shm": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_with_shm", "unitree-g1-agentic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:agentic", + "unitree-g1-agentic-bt-nav": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:agentic_bt_nav", "unitree-g1-joystick": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:with_joystick", "unitree-g1-full": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:full_featured", "demo-osm": "dimos.mapping.osm.demo_osm:demo_osm", + "demo-gps-nav": "dimos.agents2.skills.demo_gps_nav:demo_gps_nav_skill", + "demo-google-maps-skill": "dimos.agents2.skills.demo_google_maps_skill:demo_google_maps_skill", "demo-remapping": "dimos.robot.unitree_webrtc.demo_remapping:remapping", "demo-remapping-transport": "dimos.robot.unitree_webrtc.demo_remapping:remapping_and_transport", } @@ -37,18 +42,26 @@ all_modules = { "astar_planner": "dimos.navigation.global_planner.planner", "behavior_tree_navigator": "dimos.navigation.bt_navigator.navigator", + "camera_module": "dimos.hardware.camera.module", "connection": "dimos.robot.unitree_webrtc.unitree_go2", "depth_module": "dimos.robot.unitree_webrtc.depth_module", "detection_2d": "dimos.perception.detection2d.module2D", "foxglove_bridge": "dimos.robot.foxglove_bridge", + "g1_connection": "dimos.robot.unitree_webrtc.unitree_g1", + "g1_joystick": "dimos.robot.unitree_webrtc.g1_joystick_module", + "g1_skills": "dimos.robot.unitree_webrtc.unitree_g1_skill_container", + "google_maps_skill": "dimos.agents2.skills.google_maps_skill_container", + "gps_nav_skill": "dimos.agents2.skills.gps_nav_skill", "holonomic_local_planner": "dimos.navigation.local_planner.holonomic_local_planner", "human_input": "dimos.agents2.cli.human", "llm_agent": "dimos.agents2.agent", "mapper": "dimos.robot.unitree_webrtc.type.map", "navigation_skill": "dimos.agents2.skills.navigation", "object_tracking": "dimos.perception.object_tracker", - "osm_skill": "dimos.agents2.skills.osm.py", + "osm_skill": "dimos.agents2.skills.osm", + "ros_nav": "dimos.navigation.rosnav", "spatial_memory": "dimos.perception.spatial_perception", + "unitree_skills": "dimos.robot.unitree_webrtc.unitree_skill_container", "utilization": "dimos.utils.monitoring", "wavefront_frontier_explorer": "dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector", "websocket_vis": "dimos.web.websocket_vis.websocket_vis_module", diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py index 002dcb4710..1b0ca7041f 100644 --- a/dimos/robot/robot.py +++ b/dimos/robot/robot.py @@ -16,11 +16,6 @@ from abc import ABC, abstractmethod -from reactivex import Observable - -from dimos.mapping.types import LatLon -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.perception.spatial_perception import SpatialMemory from dimos.types.robot_capabilities import RobotCapability @@ -56,38 +51,10 @@ def get_skills(self): """ return self.skill_library + @abstractmethod def cleanup(self) -> None: """Clean up robot resources. Override this method to provide cleanup logic. """ - pass - - -# TODO: Delete -class UnitreeRobot(Robot): - @abstractmethod - def get_odom(self) -> PoseStamped: ... - - @abstractmethod - def explore(self) -> bool: ... - - @abstractmethod - def stop_exploration(self) -> bool: ... - - @abstractmethod - def is_exploration_active(self) -> bool: ... - - @property - @abstractmethod - def spatial_memory(self) -> SpatialMemory | None: ... - - -# TODO: Delete -class GpsRobot(ABC): - @property - @abstractmethod - def gps_position_stream(self) -> Observable[LatLon]: ... - - @abstractmethod - def set_gps_travel_goal_points(self, points: list[LatLon]) -> None: ... + ... diff --git a/dimos/robot/unitree_webrtc/mujoco_connection.py b/dimos/robot/unitree_webrtc/mujoco_connection.py index a17fb3c897..06c119e109 100644 --- a/dimos/robot/unitree_webrtc/mujoco_connection.py +++ b/dimos/robot/unitree_webrtc/mujoco_connection.py @@ -20,6 +20,7 @@ import logging import threading import time +from typing import Any from reactivex import Observable @@ -233,5 +234,5 @@ def move(self, twist: Twist, duration: float = 0.0) -> None: if not self._is_cleaned_up: self.mujoco_thread.move(twist, duration) - def publish_request(self, topic: str, data: dict) -> None: - pass + def publish_request(self, topic: str, data: dict[str, Any]) -> None: + print(f"publishing request, topic={topic}, data={data}") diff --git a/dimos/robot/unitree_webrtc/unitree_g1.py b/dimos/robot/unitree_webrtc/unitree_g1.py index d0e9d46acc..55d7537ef0 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1.py +++ b/dimos/robot/unitree_webrtc/unitree_g1.py @@ -71,7 +71,6 @@ from dimos.skills.skills import SkillLibrary from dimos.types.robot_capabilities import RobotCapability from dimos.utils.logging_config import setup_logger -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule logger = setup_logger("dimos.robot.unitree_webrtc.unitree_g1", level=logging.INFO) @@ -201,6 +200,7 @@ def move(self, twist: Twist, duration: float = 0.0) -> None: @rpc def publish_request(self, topic: str, data: dict): """Forward WebRTC publish requests to connection.""" + logger.info(f"Publishing request to topic: {topic} with data: {data}") return self.connection.publish_request(topic, data) diff --git a/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py b/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py index 6deaebb9a9..975e951e40 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py @@ -24,7 +24,7 @@ from dimos.agents2.agent import llm_agent from dimos.agents2.cli.human import human_input from dimos.agents2.skills.navigation import navigation_skill -from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE, DEFAULT_CAPACITY_DEPTH_IMAGE +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport, pSHMTransport from dimos.hardware.camera import zed @@ -35,7 +35,6 @@ Quaternion, Transform, Twist, - TwistStamped, Vector3, ) from dimos.msgs.nav_msgs import Odometry, Path @@ -51,11 +50,10 @@ from dimos.navigation.local_planner.holonomic_local_planner import ( holonomic_local_planner, ) -from dimos.navigation.rosnav import navigation_module +from dimos.navigation.rosnav import ros_nav from dimos.perception.object_tracker import object_tracking from dimos.perception.spatial_perception import spatial_memory from dimos.robot.foxglove_bridge import foxglove_bridge -from dimos.robot.unitree_webrtc.depth_module import depth_module from dimos.robot.unitree_webrtc.g1_joystick_module import g1_joystick from dimos.robot.unitree_webrtc.type.map import mapper from dimos.robot.unitree_webrtc.unitree_g1 import g1_connection @@ -64,7 +62,7 @@ from dimos.web.websocket_vis.websocket_vis_module import websocket_vis # Basic configuration with navigation and visualization -basic = ( +_basic_no_nav = ( autoconnect( # Core connection module for G1 g1_connection(), @@ -88,9 +86,7 @@ # Navigation stack astar_planner(), holonomic_local_planner(), - behavior_tree_navigator(), wavefront_frontier_explorer(), - navigation_module(), # G1-specific ROS navigation # Visualization websocket_vis(), foxglove_bridge(), @@ -126,12 +122,30 @@ ) ) -# Standard configuration with perception and memory -standard = autoconnect( - basic, +basic_ros = autoconnect( + _basic_no_nav, + ros_nav(), +) + +basic_bt_nav = autoconnect( + _basic_no_nav, + behavior_tree_navigator(), +) + +_perception_and_memory = autoconnect( spatial_memory(), object_tracking(frame_id="camera_link"), utilization(), +) + +standard = autoconnect( + basic_ros, + _perception_and_memory, +).global_config(n_dask_workers=8) + +standard_bt_nav = autoconnect( + basic_bt_nav, + _perception_and_memory, ).global_config(n_dask_workers=8) # Optimized configuration using shared memory for images @@ -150,27 +164,33 @@ ), ) -# Full agentic configuration with LLM and skills -agentic = autoconnect( - standard, +_agentic_skills = autoconnect( llm_agent(), human_input(), navigation_skill(), - g1_skills(), # G1-specific arm and movement mode skills + g1_skills(), +) + +# Full agentic configuration with LLM and skills +agentic = autoconnect( + standard, + _agentic_skills, +) + +agentic_bt_nav = autoconnect( + standard_bt_nav, + _agentic_skills, ) # Configuration with joystick control for teleoperation with_joystick = autoconnect( - basic, + basic_ros, g1_joystick(), # Pygame-based joystick control ) # Full featured configuration with everything full_featured = autoconnect( standard_with_shm, - llm_agent(), - human_input(), - navigation_skill(), - g1_skills(), + _agentic_skills, g1_joystick(), ) diff --git a/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py b/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py index afca3339f5..12635f02bc 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py +++ b/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py @@ -17,20 +17,14 @@ Dynamically generates skills for G1 humanoid robot including arm controls and movement modes. """ -from __future__ import annotations - -from typing import TYPE_CHECKING +import difflib from dimos.core.core import rpc -from dimos.msgs.geometry_msgs import TwistStamped, Vector3 +from dimos.core.skill_module import SkillModule +from dimos.msgs.geometry_msgs import Twist, Vector3 from dimos.protocol.skill.skill import skill -from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.robot.unitree_webrtc.unitree_g1 import UnitreeG1 - from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 - logger = setup_logger("dimos.robot.unitree_webrtc.unitree_g1_skill_container") # G1 Arm Actions - all use api_id 7106 on topic "rt/api/arm/request" @@ -58,25 +52,20 @@ ("RunMode", 801, "Switch to running mode."), ] +_ARM_COMMANDS: dict[str, tuple[int, str]] = { + name: (id_, description) for name, id_, description in G1_ARM_CONTROLS +} -class UnitreeG1SkillContainer(UnitreeSkillContainer): - """Container for Unitree G1 humanoid robot skills. +_MODE_COMMANDS: dict[str, tuple[int, str]] = { + name: (id_, description) for name, id_, description in G1_MODE_CONTROLS +} - Inherits all Go2 skills and adds G1-specific arm controls and movement modes. - """ - def __init__(self, robot: UnitreeG1 | UnitreeGo2 | None = None) -> None: - """Initialize the skill container with robot reference. - - Args: - robot: The UnitreeG1 or UnitreeGo2 robot instance - """ - # Initialize parent class to get all base Unitree skills - super().__init__(robot) - - # Add G1-specific skills on top - self._generate_arm_skills() - self._generate_mode_skills() +class UnitreeG1SkillContainer(SkillModule): + rpc_calls: list[str] = [ + "G1ConnectionModule.move", + "G1ConnectionModule.publish_request", + ] @rpc def start(self) -> None: @@ -86,150 +75,87 @@ def start(self) -> None: def stop(self) -> None: super().stop() - def _generate_arm_skills(self) -> None: - """Dynamically generate arm control skills from G1_ARM_CONTROLS list.""" - logger.info(f"Generating {len(G1_ARM_CONTROLS)} G1 arm control skills") - - for name, data_value, description in G1_ARM_CONTROLS: - skill_name = self._convert_to_snake_case(name) - self._create_arm_skill(skill_name, data_value, description, name) - - def _generate_mode_skills(self) -> None: - """Dynamically generate movement mode skills from G1_MODE_CONTROLS list.""" - logger.info(f"Generating {len(G1_MODE_CONTROLS)} G1 movement mode skills") - - for name, data_value, description in G1_MODE_CONTROLS: - skill_name = self._convert_to_snake_case(name) - self._create_mode_skill(skill_name, data_value, description, name) + @skill() + def move(self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0) -> str: + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions. - def _create_arm_skill( - self, skill_name: str, data_value: int, description: str, original_name: str - ) -> None: - """Create a dynamic arm control skill method with the @skill decorator. + Example call: + args = { "x": 0.5, "y": 0.0, "yaw": 0.0, "duration": 2.0 } + move(**args) Args: - skill_name: Snake_case name for the method - data_value: The arm action data value - description: Human-readable description - original_name: Original CamelCase name for display + x: Forward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds) """ - def dynamic_skill_func(self) -> str: - """Dynamic arm skill function.""" - return self._execute_arm_command(data_value, original_name) - - # Set the function's metadata - dynamic_skill_func.__name__ = skill_name - dynamic_skill_func.__doc__ = description - - # Apply the @skill decorator - decorated_skill = skill()(dynamic_skill_func) - - # Bind the method to the instance - bound_method = decorated_skill.__get__(self, self.__class__) - - # Add it as an attribute - setattr(self, skill_name, bound_method) - - logger.debug(f"Generated arm skill: {skill_name} (data={data_value})") - - def _create_mode_skill( - self, skill_name: str, data_value: int, description: str, original_name: str - ) -> None: - """Create a dynamic movement mode skill method with the @skill decorator. + move_rpc = self.get_rpc_calls("G1ConnectionModule.move") + twist = Twist(linear=Vector3(x, y, 0), angular=Vector3(0, 0, yaw)) + move_rpc(twist, duration=duration) + return f"Started moving with velocity=({x}, {y}, {yaw}) for {duration} seconds" - Args: - skill_name: Snake_case name for the method - data_value: The mode data value - description: Human-readable description - original_name: Original CamelCase name for display - """ + @skill() + def execute_arm_command(self, command_name: str) -> str: + return self._execute_g1_command(_ARM_COMMANDS, 7106, command_name) - def dynamic_skill_func(self) -> str: - """Dynamic mode skill function.""" - return self._execute_mode_command(data_value, original_name) + @skill() + def execute_mode_command(self, command_name: str) -> str: + return self._execute_g1_command(_MODE_COMMANDS, 7101, command_name) - # Set the function's metadata - dynamic_skill_func.__name__ = skill_name - dynamic_skill_func.__doc__ = description + def _execute_g1_command( + self, command_dict: dict[str, tuple[int, str]], api_id: int, command_name: str + ) -> str: + publish_request_rpc = self.get_rpc_calls("G1ConnectionModule.publish_request") - # Apply the @skill decorator - decorated_skill = skill()(dynamic_skill_func) + if command_name not in command_dict: + suggestions = difflib.get_close_matches( + command_name, command_dict.keys(), n=3, cutoff=0.6 + ) + return f"There's no '{command_name}' command. Did you mean: {suggestions}" - # Bind the method to the instance - bound_method = decorated_skill.__get__(self, self.__class__) + id_, _ = command_dict[command_name] - # Add it as an attribute - setattr(self, skill_name, bound_method) + try: + publish_request_rpc( + "rt/api/sport/request", {"api_id": api_id, "parameter": {"data": id_}} + ) + return f"'{command_name}' command executed successfully." + except Exception as e: + logger.error(f"Failed to execute {command_name}: {e}") + return "Failed to execute the command." - logger.debug(f"Generated mode skill: {skill_name} (data={data_value})") - # ========== Override Skills for G1 ========== +_arm_commands = "\n".join( + [f'- "{name}": {description}' for name, (_, description) in _ARM_COMMANDS.items()] +) - @skill() - def move(self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0) -> str: - """Move the robot using direct velocity commands (G1 version with TwistStamped). +UnitreeG1SkillContainer.execute_arm_command.__doc__ = f"""Execute a Unitree G1 arm command. - Args: - x: Forward velocity (m/s) - y: Left/right velocity (m/s) - yaw: Rotational velocity (rad/s) - duration: How long to move (seconds) - """ - if self._robot is None: - return "Error: Robot not connected" +Example usage: - # G1 uses TwistStamped instead of Twist - twist_stamped = TwistStamped(linear=Vector3(x, y, 0), angular=Vector3(0, 0, yaw)) - self._robot.move(twist_stamped, duration=duration) - return f"Started moving with velocity=({x}, {y}, {yaw}) for {duration} seconds" + execute_arm_command("ArmHeart") - # ========== Helper Methods ========== +Here are all the command names and what they do. - def _execute_arm_command(self, data_value: int, name: str) -> str: - """Execute an arm command through WebRTC interface. +{_arm_commands} +""" - Args: - data_value: The arm action data value - name: Human-readable name of the command - """ - if self._robot is None: - return f"Error: Robot not connected (cannot execute {name})" +_mode_commands = "\n".join( + [f'- "{name}": {description}' for name, (_, description) in _MODE_COMMANDS.items()] +) - try: - self._robot.connection.publish_request( - "rt/api/arm/request", {"api_id": 7106, "parameter": {"data": data_value}} - ) - message = f"G1 arm action {name} executed successfully (data={data_value})" - logger.info(message) - return message - except Exception as e: - error_msg = f"Failed to execute G1 arm action {name}: {e}" - logger.error(error_msg) - return error_msg +UnitreeG1SkillContainer.execute_mode_command.__doc__ = f"""Execute a Unitree G1 mode command. - def _execute_mode_command(self, data_value: int, name: str) -> str: - """Execute a movement mode command through WebRTC interface. +Example usage: - Args: - data_value: The mode data value - name: Human-readable name of the command - """ - if self._robot is None: - return f"Error: Robot not connected (cannot execute {name})" + execute_mode_command("RunMode") - try: - self._robot.connection.publish_request( - "rt/api/sport/request", {"api_id": 7101, "parameter": {"data": data_value}} - ) - message = f"G1 mode {name} activated successfully (data={data_value})" - logger.info(message) - return message - except Exception as e: - error_msg = f"Failed to execute G1 mode {name}: {e}" - logger.error(error_msg) - return error_msg +Here are all the command names and what they do. +{_mode_commands} +""" -# Create blueprint function for easy instantiation g1_skills = UnitreeG1SkillContainer.blueprint + +__all__ = ["UnitreeG1SkillContainer", "g1_skills"] diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 62f7b74da3..5f3be25863 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -38,8 +38,9 @@ from dimos.msgs.sensor_msgs import Image from dimos.msgs.std_msgs import Header from dimos.msgs.vision_msgs import Detection2DArray +from dimos.navigation.base import NavigationState from dimos.navigation.bbox_navigation import BBoxNavigationModule -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer from dimos.navigation.global_planner import AstarPlanner from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner @@ -54,7 +55,6 @@ from dimos.protocol.pubsub.lcmpubsub import LCM from dimos.protocol.tf import TF from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.robot import UnitreeRobot from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.map import Map @@ -349,7 +349,7 @@ def publish_request(self, topic: str, data: dict): connection = ConnectionModule.blueprint -class UnitreeGo2(UnitreeRobot, Resource): +class UnitreeGo2(Resource): """Full Unitree Go2 robot with navigation and perception capabilities.""" _dimos: ModuleCoordinator @@ -642,7 +642,7 @@ def navigate_to(self, pose: PoseStamped, blocking: bool = True) -> bool: time.sleep(1.0) if blocking: - while self.navigator.get_state() == NavigatorState.FOLLOWING_PATH: + while self.navigator.get_state() == NavigationState.FOLLOWING_PATH: time.sleep(0.25) time.sleep(1.0) diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index fcb4f162cd..8973f6cd68 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -19,7 +19,7 @@ from dimos.agents2.agent import llm_agent from dimos.agents2.cli.human import human_input from dimos.agents2.skills.navigation import navigation_skill -from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE, DEFAULT_CAPACITY_DEPTH_IMAGE +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.blueprints import autoconnect from dimos.core.transport import JpegLcmTransport, JpegShmTransport, LCMTransport, pSHMTransport from dimos.msgs.geometry_msgs import PoseStamped @@ -37,9 +37,9 @@ from dimos.perception.object_tracker import object_tracking from dimos.perception.spatial_perception import spatial_memory from dimos.robot.foxglove_bridge import foxglove_bridge -from dimos.robot.unitree_webrtc.depth_module import depth_module from dimos.robot.unitree_webrtc.type.map import mapper from dimos.robot.unitree_webrtc.unitree_go2 import connection +from dimos.robot.unitree_webrtc.unitree_skill_container import unitree_skills from dimos.utils.monitoring import utilization from dimos.web.websocket_vis.websocket_vis_module import websocket_vis @@ -112,4 +112,5 @@ llm_agent(), human_input(), navigation_skill(), + unitree_skills(), ) diff --git a/dimos/robot/unitree_webrtc/unitree_skill_container.py b/dimos/robot/unitree_webrtc/unitree_skill_container.py index e6179adcbb..8fca216d04 100644 --- a/dimos/robot/unitree_webrtc/unitree_skill_container.py +++ b/dimos/robot/unitree_webrtc/unitree_skill_container.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Unitree skill container for the new agents2 framework. -Dynamically generates skills from UNITREE_WEBRTC_CONTROLS list. -""" - from __future__ import annotations import datetime +import difflib import time from typing import TYPE_CHECKING from go2_webrtc_driver.constants import RTC_TOPIC -from dimos.core import Module from dimos.core.core import rpc +from dimos.core.skill_module import SkillModule from dimos.msgs.geometry_msgs import Twist, Vector3 from dimos.protocol.skill.skill import skill from dimos.protocol.skill.type import Reducer, Stream @@ -34,25 +30,23 @@ from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 + from dimos.core.rpc_client import RpcCall logger = setup_logger("dimos.robot.unitree_webrtc.unitree_skill_container") -class UnitreeSkillContainer(Module): - """Container for Unitree Go2 robot skills using the new framework.""" +_UNITREE_COMMANDS = { + name: (id_, description) + for name, id_, description in UNITREE_WEBRTC_CONTROLS + if name not in ["Reverse", "Spin"] +} - def __init__(self, robot: UnitreeGo2 | None = None) -> None: - """Initialize the skill container with robot reference. - Args: - robot: The UnitreeGo2 robot instance - """ - super().__init__() - self._robot = robot +class UnitreeSkillContainer(SkillModule): + """Container for Unitree Go2 robot skills using the new framework.""" - # Dynamically generate skills from UNITREE_WEBRTC_CONTROLS - self._generate_unitree_skills() + _move: RpcCall | None = None + _publish_request: RpcCall | None = None @rpc def start(self) -> None: @@ -60,67 +54,17 @@ def start(self) -> None: @rpc def stop(self) -> None: - # TODO: Do I need to clean up dynamic skills? super().stop() - def _generate_unitree_skills(self) -> None: - """Dynamically generate skills from the UNITREE_WEBRTC_CONTROLS list.""" - logger.info(f"Generating {len(UNITREE_WEBRTC_CONTROLS)} dynamic Unitree skills") - - for name, api_id, description in UNITREE_WEBRTC_CONTROLS: - if name not in ["Reverse", "Spin"]: # Exclude reverse and spin as in original - # Convert CamelCase to snake_case for method name - skill_name = self._convert_to_snake_case(name) - self._create_dynamic_skill(skill_name, api_id, description, name) - - def _convert_to_snake_case(self, name: str) -> str: - """Convert CamelCase to snake_case. - - Examples: - StandUp -> stand_up - RecoveryStand -> recovery_stand - FrontFlip -> front_flip - """ - result = [] - for i, char in enumerate(name): - if i > 0 and char.isupper(): - result.append("_") - result.append(char.lower()) - return "".join(result) - - def _create_dynamic_skill( - self, skill_name: str, api_id: int, description: str, original_name: str - ) -> None: - """Create a dynamic skill method with the @skill decorator. - - Args: - skill_name: Snake_case name for the method - api_id: The API command ID - description: Human-readable description - original_name: Original CamelCase name for display - """ - - # Define the skill function - def dynamic_skill_func(self) -> str: - """Dynamic skill function.""" - return self._execute_sport_command(api_id, original_name) - - # Set the function's metadata - dynamic_skill_func.__name__ = skill_name - dynamic_skill_func.__doc__ = description - - # Apply the @skill decorator - decorated_skill = skill()(dynamic_skill_func) - - # Bind the method to the instance - bound_method = decorated_skill.__get__(self, self.__class__) - - # Add it as an attribute - setattr(self, skill_name, bound_method) - - logger.debug(f"Generated skill: {skill_name} (API ID: {api_id})") + @rpc + def set_ConnectionModule_move(self, callable: RpcCall) -> None: + self._move = callable + self._move.set_rpc(self.rpc) - # ========== Explicit Skills ========== + @rpc + def set_ConnectionModule_publish_request(self, callable: RpcCall) -> None: + self._publish_request = callable + self._publish_request.set_rpc(self.rpc) @skill() def move(self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0) -> str: @@ -136,11 +80,11 @@ def move(self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0 yaw: Rotational velocity (rad/s) duration: How long to move (seconds) """ - if self._robot is None: + if self._move is None: return "Error: Robot not connected" twist = Twist(linear=Vector3(x, y, 0), angular=Vector3(0, 0, yaw)) - self._robot.move(twist, duration=duration) + self._move(twist, duration=duration) return f"Started moving with velocity=({x}, {y}, {yaw}) for {duration} seconds" @skill() @@ -153,7 +97,7 @@ def wait(self, seconds: float) -> str: time.sleep(seconds) return f"Wait completed with length={seconds}s" - @skill(stream=Stream.passive, reducer=Reducer.latest) + @skill(stream=Stream.passive, reducer=Reducer.latest, hide_skill=True) def current_time(self): """Provides current time implicitly, don't call this skill directly.""" print("Starting current_time skill") @@ -166,24 +110,43 @@ def speak(self, text: str) -> str: """Speak text out loud through the robot's speakers.""" return f"This is being said aloud: {text}" - # ========== Helper Methods ========== + @skill() + def execute_sport_command(self, command_name: str) -> str: + if self._publish_request is None: + return f"Error: Robot not connected (cannot execute {command_name})" - def _execute_sport_command(self, api_id: int, name: str) -> str: - """Execute a sport command through WebRTC interface. + if command_name not in _UNITREE_COMMANDS: + suggestions = difflib.get_close_matches( + command_name, _UNITREE_COMMANDS.keys(), n=3, cutoff=0.6 + ) + return f"There's no '{command_name}' command. Did you mean: {suggestions}" - Args: - api_id: The API command ID - name: Human-readable name of the command - """ - if self._robot is None: - return f"Error: Robot not connected (cannot execute {name})" + id_, _ = _UNITREE_COMMANDS[command_name] try: - self._robot.connection.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": api_id}) - message = f"{name} command executed successfully (id={api_id})" - logger.info(message) - return message + self._publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": id_}) + return f"'{command_name}' command executed successfully." except Exception as e: - error_msg = f"Failed to execute {name}: {e}" - logger.error(error_msg) - return error_msg + logger.error(f"Failed to execute {command_name}: {e}") + return "Failed to execute the command." + + +_commands = "\n".join( + [f'- "{name}": {description}' for name, (_, description) in _UNITREE_COMMANDS.items()] +) + +UnitreeSkillContainer.execute_sport_command.__doc__ = f"""Execute a Unitree sport command. + +Example usage: + + execute_sport_command("FrontPounce") + +Here are all the command names and what they do. + +{_commands} +""" + + +unitree_skills = UnitreeSkillContainer.blueprint + +__all__ = ["UnitreeSkillContainer", "unitree_skills"] diff --git a/dimos/simulation/mujoco/mujoco.py b/dimos/simulation/mujoco/mujoco.py index dc90cce076..36cbf3d1ad 100644 --- a/dimos/simulation/mujoco/mujoco.py +++ b/dimos/simulation/mujoco/mujoco.py @@ -19,6 +19,7 @@ import logging import threading import time +from typing import Any import mujoco from mujoco import viewer @@ -42,34 +43,38 @@ class MujocoThread(threading.Thread): - def __init__(self, global_config: GlobalConfig): + def __init__(self, global_config: GlobalConfig) -> None: super().__init__(daemon=True) self.global_config = global_config self.shared_pixels = None self.pixels_lock = threading.RLock() self.shared_depth_front = None - self.shared_depth_front_pose = None + self.shared_depth_front_pose: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None = ( + None + ) self.depth_lock_front = threading.RLock() self.shared_depth_left = None - self.shared_depth_left_pose = None + self.shared_depth_left_pose: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None = None self.depth_left_lock = threading.RLock() self.shared_depth_right = None - self.shared_depth_right_pose = None + self.shared_depth_right_pose: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None = ( + None + ) self.depth_right_lock = threading.RLock() - self.odom_data = None + self.odom_data: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None = None self.odom_lock = threading.RLock() self.lidar_lock = threading.RLock() - self.model = None - self.data = None + self.model: mujoco.MjModel | None = None + self.data: mujoco.MjData | None = None self._command = np.zeros(3, dtype=np.float32) self._command_lock = threading.RLock() self._is_running = True self._stop_timer: threading.Timer | None = None self._viewer = None - self._rgb_renderer = None - self._depth_renderer = None - self._depth_left_renderer = None - self._depth_right_renderer = None + self._rgb_renderer: mujoco.Renderer | None = None + self._depth_renderer: mujoco.Renderer | None = None + self._depth_left_renderer: mujoco.Renderer | None = None + self._depth_right_renderer: mujoco.Renderer | None = None self._cleanup_registered = False # Store initial reference pose for stable point cloud generation @@ -79,7 +84,7 @@ def __init__(self, global_config: GlobalConfig): # Register cleanup on exit atexit.register(self.cleanup) - def run(self): + def run(self) -> None: try: self.run_simulation() except Exception as e: @@ -87,7 +92,7 @@ def run(self): finally: self._cleanup_resources() - def run_simulation(self): + def run_simulation(self) -> None: # Go2 isn't in the MuJoCo models yet, so use Go1 as a substitute robot_name = self.global_config.robot_model or "unitree_go1" if robot_name == "unitree_go2": @@ -97,6 +102,9 @@ def run_simulation(self): self.model, self.data = load_model(self, robot=robot_name, scene=scene_name) + if self.model is None or self.data is None: + raise ValueError("Model or data failed to load.") + # Set initial robot position match robot_name: case "unitree_go1": @@ -153,8 +161,8 @@ def run_simulation(self): scene_option = mujoco.MjvOption() # Timing control variables - last_video_time = 0 - last_lidar_time = 0 + last_video_time = 0.0 + last_lidar_time = 0.0 video_interval = 1.0 / VIDEO_FPS lidar_interval = 1.0 / LIDAR_FPS @@ -242,7 +250,12 @@ def run_simulation(self): if time_until_next_step > 0: time.sleep(time_until_next_step) - def _process_depth_camera(self, depth_data, depth_lock, pose_data) -> np.ndarray | None: + def _process_depth_camera( + self, + depth_data: np.ndarray[Any, Any] | None, + depth_lock: threading.RLock, + pose_data: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None, + ) -> np.ndarray[Any, Any] | None: """Process a single depth camera and return point cloud points.""" with depth_lock: if depth_data is None or pose_data is None: @@ -331,12 +344,12 @@ def get_odom_message(self) -> Odometry | None: ) return odom_to_publish - def _stop_move(self): + def _stop_move(self) -> None: with self._command_lock: self._command = np.zeros(3, dtype=np.float32) self._stop_timer = None - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist: Twist, duration: float = 0.0) -> None: if self._stop_timer: self._stop_timer.cancel() @@ -356,7 +369,7 @@ def get_command(self) -> np.ndarray: with self._command_lock: return self._command.copy() - def stop(self): + def stop(self) -> None: """Stop the simulation thread gracefully.""" self._is_running = False @@ -371,7 +384,7 @@ def stop(self): if self.is_alive(): logger.warning("MuJoCo thread did not stop gracefully within timeout") - def cleanup(self): + def cleanup(self) -> None: """Clean up all resources. Can be called multiple times safely.""" if self._cleanup_registered: return @@ -381,7 +394,7 @@ def cleanup(self): self.stop() self._cleanup_resources() - def _cleanup_resources(self): + def _cleanup_resources(self) -> None: """Internal method to clean up MuJoCo-specific resources.""" try: # Cancel any timers @@ -454,7 +467,7 @@ def _cleanup_resources(self): except Exception as e: logger.error(f"Error during resource cleanup: {e}") - def __del__(self): + def __del__(self) -> None: """Destructor to ensure cleanup on object deletion.""" try: self.cleanup() diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 20bfad2a2f..abe1f0f8f3 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Any import mujoco import numpy as np @@ -29,13 +29,13 @@ class OnnxController(ABC): def __init__( self, policy_path: str, - default_angles: np.ndarray, + default_angles: np.ndarray[Any, Any], n_substeps: int, action_scale: float, input_controller: InputController, ctrl_dt: float | None = None, drift_compensation: list[float] | None = None, - ): + ) -> None: self._output_names = ["continuous_actions"] self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"]) @@ -50,7 +50,7 @@ def __init__( self._drift_compensation = np.array(drift_compensation or [0.0, 0.0, 0.0], dtype=np.float32) @abstractmethod - def get_obs(self, model, data) -> np.ndarray: + def get_obs(self, model: mujoco.MjModel, data: mujoco.MjData) -> np.ndarray[Any, Any]: pass def get_control(self, model: mujoco.MjModel, data: mujoco.MjData) -> None: @@ -63,12 +63,12 @@ def get_control(self, model: mujoco.MjModel, data: mujoco.MjData) -> None: data.ctrl[:] = onnx_pred * self._action_scale + self._default_angles self._post_control_update() - def _post_control_update(self) -> None: + def _post_control_update(self) -> None: # noqa: B027 pass class Go1OnnxController(OnnxController): - def get_obs(self, model, data) -> np.ndarray: + def get_obs(self, model: mujoco.MjModel, data: mujoco.MjData) -> np.ndarray[Any, Any]: linvel = data.sensor("local_linvel").data gyro = data.sensor("gyro").data imu_xmat = data.site_xmat[model.site("imu").id].reshape(3, 3) @@ -93,13 +93,13 @@ class G1OnnxController(OnnxController): def __init__( self, policy_path: str, - default_angles: np.ndarray, + default_angles: np.ndarray[Any, Any], ctrl_dt: float, n_substeps: int, action_scale: float, input_controller: InputController, drift_compensation: list[float] | None = None, - ): + ) -> None: super().__init__( policy_path, default_angles, @@ -114,7 +114,7 @@ def __init__( self._gait_freq = 1.5 self._phase_dt = 2 * np.pi * self._gait_freq * ctrl_dt - def get_obs(self, model, data) -> np.ndarray: + def get_obs(self, model: mujoco.MjModel, data: mujoco.MjData) -> np.ndarray[Any, Any]: linvel = data.sensor("local_linvel_pelvis").data gyro = data.sensor("gyro_pelvis").data imu_xmat = data.site_xmat[model.site("imu_in_pelvis").id].reshape(3, 3) diff --git a/dimos/types/sample.py b/dimos/types/sample.py index 6d84942c55..fdb29cf174 100644 --- a/dimos/types/sample.py +++ b/dimos/types/sample.py @@ -30,6 +30,7 @@ from pydantic import BaseModel, ConfigDict, ValidationError from pydantic.fields import FieldInfo from pydantic_core import from_json +import torch Flattenable = Annotated[Literal["dict", "np", "pt", "list"], "Numpy, PyTorch, list, or dict"] @@ -165,7 +166,7 @@ def flatten( self, output_type: Flattenable = "dict", non_numerical: Literal["ignore", "forbid", "allow"] = "allow", - ) -> builtins.dict[str, Any] | np.ndarray | "torch.Tensor" | list: + ) -> builtins.dict[str, Any] | np.ndarray | torch.Tensor | list: accumulator = {} if output_type == "dict" else [] def flatten_recursive(obj, path: str = "") -> None: diff --git a/dimos/utils/demo_image_encoding.py b/dimos/utils/demo_image_encoding.py index a98924260c..00df5c1a62 100644 --- a/dimos/utils/demo_image_encoding.py +++ b/dimos/utils/demo_image_encoding.py @@ -45,19 +45,19 @@ class EmitterModule(Module): _thread: threading.Thread | None = None _stop_event: threading.Event | None = None - def start(self): + def start(self) -> None: super().start() self._stop_event = threading.Event() self._thread = threading.Thread(target=self._publish_image, daemon=True) self._thread.start() - def stop(self): + def stop(self) -> None: if self._thread: self._stop_event.set() self._thread.join(timeout=2) super().stop() - def _publish_image(self): + def _publish_image(self) -> None: open_file = open("/tmp/emitter-times", "w") while not self._stop_event.is_set(): start = time.time() @@ -74,21 +74,21 @@ class ReceiverModule(Module): _open_file = None - def start(self): + def start(self) -> None: super().start() self._disposables.add(Disposable(self.image.subscribe(self._on_image))) self._open_file = open("/tmp/receiver-times", "w") - def stop(self): + def stop(self) -> None: self._open_file.close() super().stop() - def _on_image(self, image: Image): + def _on_image(self, image: Image) -> None: self._open_file.write(str(time.time()) + "\n") print("image") -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Demo image encoding with transport options") parser.add_argument( "--use-jpeg", diff --git a/dimos/utils/fast_image_generator.py b/dimos/utils/fast_image_generator.py index f8e02cb71b..6063f1f4b9 100644 --- a/dimos/utils/fast_image_generator.py +++ b/dimos/utils/fast_image_generator.py @@ -14,7 +14,38 @@ """Fast stateful image generator with visual features for encoding tests.""" +from typing import Literal, TypedDict, Union + import numpy as np +from numpy.typing import NDArray + + +class CircleObject(TypedDict): + """Type definition for circle objects.""" + + type: Literal["circle"] + x: float + y: float + vx: float + vy: float + radius: int + color: NDArray[np.float32] + + +class RectObject(TypedDict): + """Type definition for rectangle objects.""" + + type: Literal["rect"] + x: float + y: float + vx: float + vy: float + width: int + height: int + color: NDArray[np.float32] + + +Object = Union[CircleObject, RectObject] class FastImageGenerator: @@ -31,11 +62,12 @@ class FastImageGenerator: - High contrast boundaries (tests blocking artifacts) """ - def __init__(self, width: int = 1280, height: int = 720): + def __init__(self, width: int = 1280, height: int = 720) -> None: """Initialize the generator with pre-computed elements.""" self.width = width self.height = height self.frame_count = 0 + self.objects: list[Object] = [] # Pre-allocate the main canvas self.canvas = np.zeros((height, width, 3), dtype=np.float32) @@ -57,7 +89,7 @@ def __init__(self, width: int = 1280, height: int = 720): # Pre-allocate shape masks for reuse self._init_shape_masks() - def _init_gradients(self): + def _init_gradients(self) -> None: """Pre-compute gradient patterns.""" # Diagonal gradient self.diag_gradient = (self.x_grid + self.y_grid) * 0.5 @@ -71,7 +103,7 @@ def _init_gradients(self): self.h_gradient = self.x_grid self.v_gradient = self.y_grid - def _init_moving_objects(self): + def _init_moving_objects(self) -> None: """Initialize properties of moving objects.""" self.objects = [ { @@ -104,7 +136,7 @@ def _init_moving_objects(self): }, ] - def _init_texture(self): + def _init_texture(self) -> None: """Pre-compute a texture pattern.""" # Create a simple checkerboard pattern at lower resolution checker_size = 20 @@ -118,7 +150,7 @@ def _init_texture(self): self.texture = np.repeat(np.repeat(checker, checker_size, axis=0), checker_size, axis=1) self.texture = self.texture[: self.height, : self.width].astype(np.float32) * 30 - def _init_shape_masks(self): + def _init_shape_masks(self) -> None: """Pre-allocate reusable masks for shapes.""" # Pre-allocate a mask array self.temp_mask = np.zeros((self.height, self.width), dtype=np.float32) @@ -126,7 +158,7 @@ def _init_shape_masks(self): # Pre-compute indices for the entire image self.y_indices, self.x_indices = np.indices((self.height, self.width)) - def _draw_circle_fast(self, cx: int, cy: int, radius: int, color: np.ndarray): + def _draw_circle_fast(self, cx: int, cy: int, radius: int, color: NDArray[np.float32]) -> None: """Draw a circle using vectorized operations - optimized version without anti-aliasing.""" # Compute bounding box to minimize calculations y1 = max(0, cy - radius - 1) @@ -141,7 +173,7 @@ def _draw_circle_fast(self, cx: int, cy: int, radius: int, color: np.ndarray): mask = dist_sq <= radius**2 self.canvas[y1:y2, x1:x2][mask] = color - def _draw_rect_fast(self, x: int, y: int, w: int, h: int, color: np.ndarray): + def _draw_rect_fast(self, x: int, y: int, w: int, h: int, color: NDArray[np.float32]) -> None: """Draw a rectangle using slicing.""" # Clip to canvas boundaries x1 = max(0, x) @@ -152,7 +184,7 @@ def _draw_rect_fast(self, x: int, y: int, w: int, h: int, color: np.ndarray): if x1 < x2 and y1 < y2: self.canvas[y1:y2, x1:x2] = color - def _update_objects(self): + def _update_objects(self) -> None: """Update positions of moving objects.""" for obj in self.objects: # Update position @@ -182,7 +214,7 @@ def _update_objects(self): obj["vy"] *= -1 obj["y"] = np.clip(obj["y"], 0, 1 - h) - def generate_frame(self) -> np.ndarray: + def generate_frame(self) -> NDArray[np.uint8]: """ Generate a single frame with visual features - optimized for 30+ FPS. @@ -242,17 +274,17 @@ def generate_frame(self) -> np.ndarray: # Direct conversion to uint8 (already in valid range) return self.canvas.astype(np.uint8) - def reset(self): + def reset(self) -> None: """Reset the generator to initial state.""" self.frame_count = 0 self._init_moving_objects() # Convenience function for backward compatibility -_generator = None +_generator: FastImageGenerator | None = None -def random_image(width: int, height: int) -> np.ndarray: +def random_image(width: int, height: int) -> NDArray[np.uint8]: """ Generate an image with visual features suitable for encoding tests. Maintains state for efficient stream generation. diff --git a/dimos/web/dimos_interface/api/server.py b/dimos/web/dimos_interface/api/server.py index 4f9979c085..bddb01495e 100644 --- a/dimos/web/dimos_interface/api/server.py +++ b/dimos/web/dimos_interface/api/server.py @@ -306,7 +306,7 @@ async def upload_audio(file: UploadFile = File(...)): # Push to reactive stream self.audio_subject.on_next(event) - print(f"Received audio – {event.data.shape[0] / sr:.2f} s, {sr} Hz") + print(f"Received audio - {event.data.shape[0] / sr:.2f} s, {sr} Hz") return {"success": True} except Exception as e: print(f"Failed to process uploaded audio: {e}") diff --git a/pyproject.toml b/pyproject.toml index 4a7cec2c6c..d098fad514 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -179,7 +179,7 @@ cuda = [ ] dev = [ - "ruff==0.11.10", + "ruff==0.14.3", "mypy==1.18.2", "pre_commit==4.2.0", "pytest==8.3.5", @@ -250,7 +250,10 @@ exclude = [ [tool.ruff.lint] extend-select = ["E", "W", "F", "B", "UP", "N", "I", "C90", "A", "RUF", "TCH"] # TODO: All of these should be fixed, but it's easier commit autofixes first -ignore = ["A001", "A002", "A004", "B008", "B017", "B018", "B019", "B023", "B024", "B026", "B027", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F401", "F403", "F405", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N812", "N813", "N813", "N816", "N817", "N999", "RUF001", "RUF002", "RUF003", "RUF006", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "TC010", "UP007", "UP035"] +ignore = ["A001", "A002", "B008", "B017", "B019", "B023", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F401", "F403", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N812", "N813", "N813", "N816", "N817", "N999", "RUF002", "RUF003", "RUF006", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007"] + +[tool.ruff.lint.per-file-ignores] +"dimos/models/Detic/*" = ["ALL"] [tool.ruff.lint.isort] known-first-party = ["dimos"] diff --git a/tests/agent_manip_flow_flask_test.py b/tests/agent_manip_flow_flask_test.py index e96c6f2d20..7f7887004b 100644 --- a/tests/agent_manip_flow_flask_test.py +++ b/tests/agent_manip_flow_flask_test.py @@ -26,7 +26,7 @@ # Third-party imports from flask import Flask -from reactivex import interval, operators as ops, zip +from reactivex import interval, operators as ops, zip as rx_zip from reactivex.disposable import CompositeDisposable from reactivex.scheduler import ThreadPoolScheduler @@ -157,14 +157,14 @@ def main(): ai_2_repeat_obs = ai_2_obs.pipe(ops.repeat()) - # Combine emissions using zip - ai_1_secondly_repeating_obs = zip(secondly_emission, ai_1_repeat_obs).pipe( + # Combine emissions using rx_zip + ai_1_secondly_repeating_obs = rx_zip(secondly_emission, ai_1_repeat_obs).pipe( # ops.do_action(lambda s: print(f"AI 1 - Emission Count: {s[0]}")), ops.map(lambda r: r[1]), ) - # Combine emissions using zip - ai_2_secondly_repeating_obs = zip(secondly_emission, ai_2_repeat_obs).pipe( + # Combine emissions using rx_zip + ai_2_secondly_repeating_obs = rx_zip(secondly_emission, ai_2_repeat_obs).pipe( # ops.do_action(lambda s: print(f"AI 2 - Emission Count: {s[0]}")), ops.map(lambda r: r[1]), )