diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index 94f418acc2..1f03a52c5a 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -29,6 +29,7 @@ ) from dimos.agents2.spec import AgentSpec +from dimos.agents2.system_prompt import get_system_prompt from dimos.core import rpc from dimos.msgs.sensor_msgs import Image from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateDict @@ -178,6 +179,8 @@ def __init__( else: self.config.system_prompt.content += SYSTEM_MSG_APPEND self.system_message = self.config.system_prompt + else: + self.system_message = SystemMessage(get_system_prompt() + SYSTEM_MSG_APPEND) self.publish(self.system_message) @@ -263,6 +266,7 @@ def _get_state() -> str: # we are getting tools from the coordinator on each turn # since this allows for skillcontainers to dynamically provide new skills tools = self.get_tools() + print("Available tools:", [tool.name for tool in tools]) self._llm = self._llm.bind_tools(tools) # publish to /agent topic for observability @@ -331,8 +335,14 @@ def query(self, query: str): async def query_async(self, query: str): return await self.agent_loop(query) - def register_skills(self, container): - return self.coordinator.register_skills(container) + @rpc + def register_skills(self, container, run_implicit_name: str | None = None): + ret = self.coordinator.register_skills(container) + + if run_implicit_name: + self.run_implicit_skill(run_implicit_name) + + return ret def get_tools(self): return self.coordinator.get_tools() @@ -346,3 +356,20 @@ def _write_debug_history_file(self): with open(file_path, "w") as f: json.dump(history, f, default=lambda x: repr(x), indent=2) + + +class LlmAgent(Agent): + @rpc + def start(self) -> None: + super().start() + self.loop_thread() + + @rpc + def stop(self) -> None: + super().stop() + + +llm_agent = LlmAgent.blueprint + + +__all__ = ["Agent", "llm_agent"] diff --git a/dimos/agents2/cli/human.py b/dimos/agents2/cli/human.py index 5a20abb388..ca3e503bc1 100644 --- a/dimos/agents2/cli/human.py +++ b/dimos/agents2/cli/human.py @@ -15,14 +15,17 @@ import queue from dimos.agents2 import Output, Reducer, Stream, skill -from dimos.core import Module, pLCMTransport, rpc +from dimos.core import pLCMTransport, rpc from reactivex.disposable import Disposable +from dimos.core.module import Module +from dimos.core.rpc_client import RpcCall + class HumanInput(Module): running: bool = False - @skill(stream=Stream.call_agent, reducer=Reducer.string, output=Output.human) + @skill(stream=Stream.call_agent, reducer=Reducer.string, output=Output.human, hide_skill=True) def human(self): """receives human input, no need to run this, it's running implicitly""" if self.running: @@ -43,3 +46,13 @@ def start(self) -> None: @rpc def stop(self) -> None: super().stop() + + @rpc + def set_LlmAgent_register_skills(self, callable: RpcCall) -> None: + callable.set_rpc(self.rpc) + callable(self, run_implicit_name="human") + + +human_input = HumanInput.blueprint + +__all__ = ["HumanInput", "human_input"] diff --git a/dimos/agents2/skills/conftest.py b/dimos/agents2/skills/conftest.py index 7ea89e320a..5a706985f7 100644 --- a/dimos/agents2/skills/conftest.py +++ b/dimos/agents2/skills/conftest.py @@ -20,13 +20,16 @@ from dimos.agents2.skills.gps_nav_skill import GpsNavSkillContainer from dimos.agents2.skills.navigation import NavigationSkillContainer from dimos.agents2.skills.google_maps_skill_container import GoogleMapsSkillContainer +from dimos.agents2.system_prompt import get_system_prompt from dimos.mapping.types import LatLon from dimos.robot.robot import GpsRobot -from dimos.robot.unitree_webrtc.run_agents2 import SYSTEM_PROMPT from dimos.utils.data import get_data from dimos.msgs.sensor_msgs import Image +system_prompt = get_system_prompt() + + @pytest.fixture(autouse=True) def cleanup_threadpool_scheduler(monkeypatch): # TODO: get rid of this global threadpool @@ -42,11 +45,13 @@ def cleanup_threadpool_scheduler(monkeypatch): threadpool.scheduler = ThreadPoolScheduler(max_workers=threadpool.get_max_workers()) +# TODO: Delete @pytest.fixture def fake_robot(mocker): return mocker.MagicMock() +# TODO: Delete @pytest.fixture def fake_gps_robot(mocker): return mocker.Mock(spec=GpsRobot) @@ -59,14 +64,17 @@ def fake_video_stream(): return rx.of(image) +# TODO: Delete @pytest.fixture def fake_gps_position_stream(): return rx.of(LatLon(lat=37.783, lon=-122.413)) @pytest.fixture -def navigation_skill_container(fake_robot, fake_video_stream): - container = NavigationSkillContainer(fake_robot, fake_video_stream) +def navigation_skill_container(mocker): + container = NavigationSkillContainer() + container.color_image.connection = mocker.MagicMock() + container.odom.connection = mocker.MagicMock() container.start() yield container container.stop() @@ -93,7 +101,7 @@ def google_maps_skill_container(fake_gps_robot, fake_gps_position_stream, mocker def create_navigation_agent(navigation_skill_container, create_fake_agent): return partial( create_fake_agent, - system_prompt=SYSTEM_PROMPT, + system_prompt=system_prompt, skill_containers=[navigation_skill_container], ) @@ -101,7 +109,7 @@ def create_navigation_agent(navigation_skill_container, create_fake_agent): @pytest.fixture def create_gps_nav_agent(gps_nav_skill_container, create_fake_agent): return partial( - create_fake_agent, system_prompt=SYSTEM_PROMPT, skill_containers=[gps_nav_skill_container] + create_fake_agent, system_prompt=system_prompt, skill_containers=[gps_nav_skill_container] ) @@ -111,6 +119,6 @@ def create_google_maps_agent( ): return partial( create_fake_agent, - system_prompt=SYSTEM_PROMPT, + system_prompt=system_prompt, skill_containers=[gps_nav_skill_container, google_maps_skill_container], ) diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 18558515e6..bde328faa6 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -12,21 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import time from typing import Any, Optional -from reactivex import Observable -from reactivex.disposable import CompositeDisposable, Disposable -from dimos.core.resource import Resource +from dimos.core.core import rpc +from dimos.core.rpc_client import RpcCall +from dimos.core.skill_module import SkillModule +from dimos.core.stream import In from dimos.models.qwen.video_query import BBox from dimos.models.vl.qwen import QwenVlModel from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.msgs.sensor_msgs import Image from dimos.navigation.visual.query import get_object_bbox_from_image -from dimos.protocol.skill.skill import SkillContainer, skill -from dimos.robot.robot import UnitreeRobot +from dimos.protocol.skill.skill import skill from dimos.types.robot_location import RobotLocation from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler @@ -35,35 +36,116 @@ logger = setup_logger(__file__) -class NavigationSkillContainer(SkillContainer, Resource): - _robot: UnitreeRobot - _disposables: CompositeDisposable - _latest_image: Optional[Image] - _video_stream: Observable[Image] - _started: bool - - def __init__(self, robot: UnitreeRobot, video_stream: Observable[Image]): +class NavigationSkillContainer(SkillModule): + _latest_image: Optional[Image] = None + _latest_odom: Optional[PoseStamped] = None + _skill_started: bool = False + _similarity_threshold: float = 0.23 + + _tag_location: RpcCall | None = None + _query_tagged_location: RpcCall | None = None + _query_by_text: RpcCall | None = None + _set_goal: RpcCall | None = None + _get_state: RpcCall | None = None + _is_goal_reached: RpcCall | None = None + _cancel_goal: RpcCall | None = None + _track: RpcCall | None = None + _stop_track: RpcCall | None = None + _is_tracking: RpcCall | None = None + _stop_exploration: RpcCall | None = None + _explore: RpcCall | None = None + _is_exploration_active: RpcCall | None = None + + color_image: In[Image] = None + odom: In[PoseStamped] = None + + def __init__(self): super().__init__() - self._robot = robot - self._disposables = CompositeDisposable() - self._latest_image = None - self._video_stream = video_stream - self._similarity_threshold = 0.23 - self._started = False + self._skill_started = False self._vl_model = QwenVlModel() + @rpc def start(self) -> None: - unsub = self._video_stream.subscribe(self._on_video) - self._disposables.add(Disposable(unsub) if callable(unsub) else unsub) - self._started = True + self._disposables.add(self.color_image.subscribe(self._on_color_image)) + self._disposables.add(self.odom.subscribe(self._on_odom)) + self._skill_started = True + @rpc def stop(self) -> None: - self._disposables.dispose() super().stop() - def _on_video(self, image: Image) -> None: + def _on_color_image(self, image: Image) -> None: self._latest_image = image + def _on_odom(self, odom: PoseStamped) -> None: + self._latest_odom = odom + + # TODO: This is quite repetitive, maybe I should automate this somehow + @rpc + def set_SpatialMemory_tag_location(self, callable: RpcCall) -> None: + self._tag_location = callable + self._tag_location.set_rpc(self.rpc) + + @rpc + def set_SpatialMemory_query_tagged_location(self, callable: RpcCall) -> None: + self._query_tagged_location = callable + self._query_tagged_location.set_rpc(self.rpc) + + @rpc + def set_SpatialMemory_query_by_text(self, callable: RpcCall) -> None: + self._query_by_text = callable + self._query_by_text.set_rpc(self.rpc) + + @rpc + def set_BehaviorTreeNavigator_set_goal(self, callable: RpcCall) -> None: + self._set_goal = callable + self._set_goal.set_rpc(self.rpc) + + @rpc + def set_BehaviorTreeNavigator_get_state(self, callable: RpcCall) -> None: + self._get_state = callable + self._get_state.set_rpc(self.rpc) + + @rpc + def set_BehaviorTreeNavigator_is_goal_reached(self, callable: RpcCall) -> None: + self._is_goal_reached = callable + self._is_goal_reached.set_rpc(self.rpc) + + @rpc + def set_BehaviorTreeNavigator_cancel_goal(self, callable: RpcCall) -> None: + self._cancel_goal = callable + self._cancel_goal.set_rpc(self.rpc) + + @rpc + def set_ObjectTracking_track(self, callable: RpcCall) -> None: + self._track = callable + self._track.set_rpc(self.rpc) + + @rpc + def set_ObjectTracking_stop_track(self, callable: RpcCall) -> None: + self._stop_track = callable + self._stop_track.set_rpc(self.rpc) + + @rpc + def set_ObjectTracking_is_tracking(self, callable: RpcCall) -> None: + self._is_tracking = callable + self._is_tracking.set_rpc(self.rpc) + + @rpc + def set_WavefrontFrontierExplorer_stop_exploration(self, callable: RpcCall) -> None: + self._stop_exploration = callable + self._stop_exploration.set_rpc(self.rpc) + + @rpc + def set_WavefrontFrontierExplorer_explore(self, callable: RpcCall) -> None: + self._explore = callable + self._explore.set_rpc(self.rpc) + + @rpc + def set_WavefrontFrontierExplorer_is_exploration_active(self, callable: RpcCall) -> None: + self._is_exploration_active = callable + self._is_exploration_active.set_rpc(self.rpc) + @skill() def tag_location_in_spatial_memory(self, location_name: str) -> str: """Tag this location in the spatial memory with a name. @@ -77,12 +159,17 @@ def tag_location_in_spatial_memory(self, location_name: str) -> str: str: the outcome """ - if not self._started: + if not self._skill_started: raise ValueError(f"{self} has not been started.") - pose_data = self._robot.get_odom() - position = pose_data.position - rotation = quaternion_to_euler(pose_data.orientation) + if not self._latest_odom: + return "Error: No odometry data available to tag the location." + + if not self._tag_location: + return "Error: The SpatialMemory module is not connected." + + position = self._latest_odom.position + rotation = quaternion_to_euler(self._latest_odom.orientation) location = RobotLocation( name=location_name, @@ -90,8 +177,8 @@ def tag_location_in_spatial_memory(self, location_name: str) -> str: rotation=(rotation.x, rotation.y, rotation.z), ) - if not self._robot.spatial_memory.tag_location(location): - return f"Failed to store '{location_name}' in the spatial memory" + if not self._tag_location(location): + return f"Error: Failed to store '{location_name}' in the spatial memory" logger.info(f"Tagged {location}") return f"The current location has been tagged as '{location_name}'." @@ -109,7 +196,7 @@ def navigate_with_text(self, query: str) -> str: query: Text query to search for in the semantic map """ - if not self._started: + if not self._skill_started: raise ValueError(f"{self} has not been started.") success_msg = self._navigate_by_tagged_location(query) @@ -131,7 +218,11 @@ def navigate_with_text(self, query: str) -> str: return f"No tagged location called '{query}'. No object in view matching '{query}'. No matching location found in semantic map for '{query}'." def _navigate_by_tagged_location(self, query: str) -> Optional[str]: - robot_location = self._robot.spatial_memory.query_tagged_location(query) + if not self._query_tagged_location: + logger.warning("SpatialMemory module not connected, cannot query tagged locations") + return None + + robot_location = self._query_tagged_location(query) if not robot_location: return None @@ -142,14 +233,36 @@ def _navigate_by_tagged_location(self, query: str) -> Optional[str]: frame_id="world", ) - result = self._robot.navigate_to(goal_pose, blocking=True) + result = self._navigate_to(goal_pose) if not result: - return None + return "Error: Faild to reach the tagged location." return ( f"Successfuly arrived at location tagged '{robot_location.name}' from query '{query}'." ) + def _navigate_to(self, pose: PoseStamped) -> bool: + if not self._set_goal or not self._get_state or not self._is_goal_reached: + logger.error("BehaviorTreeNavigator module not connected properly") + return False + + logger.info( + f"Navigating to pose: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" + ) + self._set_goal(pose) + time.sleep(1.0) + + while self._get_state() == NavigatorState.FOLLOWING_PATH: + time.sleep(0.25) + + time.sleep(1.0) + if not self._is_goal_reached(): + logger.info("Navigation was cancelled or failed") + return False + else: + logger.info("Navigation goal reached") + return True + def _navigate_to_object(self, query: str) -> Optional[str]: try: bbox = self._get_bbox_for_current_frame(query) @@ -160,10 +273,18 @@ def _navigate_to_object(self, query: str) -> Optional[str]: if bbox is None: return None + if not self._track or not self._stop_track or not self._is_tracking: + logger.error("ObjectTracking module not connected properly") + return None + + if not self._get_state or not self._is_goal_reached: + logger.error("BehaviorTreeNavigator module not connected properly") + return None + logger.info(f"Found {query} at {bbox}") # Start tracking - BBoxNavigationModule automatically generates goals - self._robot.object_tracker.track(bbox) + self._track(bbox) start_time = time.time() timeout = 30.0 @@ -171,31 +292,31 @@ def _navigate_to_object(self, query: str) -> Optional[str]: while time.time() - start_time < timeout: # Check if navigator finished - if self._robot.navigator.get_state() == NavigatorState.IDLE and goal_set: + if self._get_state() == NavigatorState.IDLE and goal_set: logger.info("Waiting for goal result") time.sleep(1.0) - if not self._robot.navigator.is_goal_reached(): + if not self._is_goal_reached(): logger.info(f"Goal cancelled, tracking '{query}' failed") - self._robot.object_tracker.stop_track() + self._stop_track() return None else: logger.info(f"Reached '{query}'") - self._robot.object_tracker.stop_track() + self._stop_track() return f"Successfully arrived at '{query}'" # If goal set and tracking lost, just continue (tracker will resume or timeout) - if goal_set and not self._robot.object_tracker.is_tracking(): + if goal_set and not self._is_tracking(): continue # BBoxNavigationModule automatically sends goals when tracker publishes # Just check if we have any detections to mark goal_set - if self._robot.object_tracker.is_tracking(): + if self._is_tracking(): goal_set = True time.sleep(0.25) logger.warning(f"Navigation to '{query}' timed out after {timeout}s") - self._robot.object_tracker.stop_track() + self._stop_track() return None def _get_bbox_for_current_frame(self, query: str) -> Optional[BBox]: @@ -205,7 +326,10 @@ def _get_bbox_for_current_frame(self, query: str) -> Optional[BBox]: return get_object_bbox_from_image(self._vl_model, self._latest_image, query) def _navigate_using_semantic_map(self, query: str) -> str: - results = self._robot.spatial_memory.query_by_text(query) + if not self._query_by_text: + return "Error: The SpatialMemory module is not connected." + + results = self._query_by_text(query) if not results: return f"No matching location found in semantic map for '{query}'" @@ -217,7 +341,7 @@ def _navigate_using_semantic_map(self, query: str) -> str: if not goal_pose: return f"Found a result for '{query}' but it didn't have a valid position." - result = self._robot.navigate_to(goal_pose, blocking=True) + result = self._navigate_to(goal_pose) if not result: return f"Failed to navigate for '{query}'" @@ -233,13 +357,25 @@ def follow_human(self, person: str) -> str: def stop_movement(self) -> str: """Immediatly stop moving.""" - if not self._started: + if not self._skill_started: raise ValueError(f"{self} has not been started.") - self._robot.stop_exploration() + self._cancel_goal_and_stop() return "Stopped" + def _cancel_goal_and_stop(self) -> None: + if not self._cancel_goal: + logger.warning("BehaviorTreeNavigator module not connected, cannot cancel goal") + return + + if not self._stop_exploration: + logger.warning("FrontierExplorer module not connected, cannot stop exploration") + return + + self._cancel_goal() + return self._stop_exploration() + @skill() def start_exploration(self, timeout: float = 240.0) -> str: """A skill that performs autonomous frontier exploration. @@ -253,24 +389,27 @@ def start_exploration(self, timeout: float = 240.0) -> str: timeout (float, optional): Maximum time (in seconds) allowed for exploration """ - if not self._started: + if not self._skill_started: raise ValueError(f"{self} has not been started.") try: return self._start_exploration(timeout) finally: - self._robot.stop_exploration() + self._cancel_goal_and_stop() def _start_exploration(self, timeout: float) -> str: + if not self._explore or not self._is_exploration_active: + return "Error: The WavefrontFrontierExplorer module is not connected." + logger.info("Starting autonomous frontier exploration") start_time = time.time() - has_started = self._robot.explore() + has_started = self._explore() if not has_started: - return "Could not start exploration." + return "Error: Could not start exploration." - while time.time() - start_time < timeout and self._robot.is_exploration_active(): + while time.time() - start_time < timeout and self._is_exploration_active(): time.sleep(0.5) return "Exploration completed successfuly" @@ -297,3 +436,8 @@ def _get_goal_pose_from_result(self, result: dict[str, Any]) -> Optional[PoseSta orientation=euler_to_quaternion(make_vector3(0, 0, theta)), frame_id="world", ) + + +navigation_skill = NavigationSkillContainer.blueprint + +__all__ = ["NavigationSkillContainer", "navigation_skill"] diff --git a/dimos/agents2/skills/osm.py b/dimos/agents2/skills/osm.py index 6c609e87f4..eaaef41858 100644 --- a/dimos/agents2/skills/osm.py +++ b/dimos/agents2/skills/osm.py @@ -13,45 +13,42 @@ # limitations under the License. from typing import Optional -from reactivex import Observable +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.rpc_client import RPCClient, RpcCall +from dimos.core.skill_module import SkillModule +from dimos.core.stream import In from dimos.mapping.osm.current_location_map import CurrentLocationMap from dimos.mapping.utils.distance import distance_in_meters from dimos.mapping.types import LatLon from dimos.models.vl.qwen import QwenVlModel -from dimos.protocol.skill.skill import SkillContainer, skill -from dimos.robot.robot import Robot +from dimos.protocol.skill.skill import skill from dimos.utils.logging_config import setup_logger -from dimos.core.resource import Resource -from reactivex.disposable import CompositeDisposable logger = setup_logger(__file__) -class OsmSkillContainer(SkillContainer, Resource): - _robot: Robot - _disposables: CompositeDisposable +class OsmSkill(SkillModule): _latest_location: Optional[LatLon] - _position_stream: Observable[LatLon] _current_location_map: CurrentLocationMap - _started: bool + _skill_started: bool - def __init__(self, robot: Robot, position_stream: Observable[LatLon]): + gps_location: In[LatLon] = None + + def __init__(self): super().__init__() - self._robot = robot - self._disposables = CompositeDisposable() self._latest_location = None - self._position_stream = position_stream self._current_location_map = CurrentLocationMap(QwenVlModel()) - self._started = False + self._skill_started = False def start(self): - self._started = True - self._disposables.add(self._position_stream.subscribe(self._on_gps_location)) + super().start() + self._skill_started = True + self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) def stop(self): - self._disposables.dispose() super().stop() def _on_gps_location(self, location: LatLon) -> None: @@ -71,7 +68,7 @@ def street_map_query(self, query_sentence: str) -> str: query_sentence (str): The query sentence. """ - if not self._started: + if not self._skill_started: raise ValueError(f"{self} has not been started.") self._current_location_map.update_position(self._latest_location) @@ -86,3 +83,8 @@ def street_map_query(self, query_sentence: str) -> str: distance = int(distance_in_meters(latlon, self._latest_location)) return f"{context}. It's at position latitude={latlon.lat}, longitude={latlon.lon}. It is {distance} meters away." + + +osm_skill = OsmSkill.blueprint + +__all__ = ["OsmSkill", "osm_skill"] diff --git a/dimos/agents2/skills/test_navigation.py b/dimos/agents2/skills/test_navigation.py index f90f8a2d19..5d70fa2bc5 100644 --- a/dimos/agents2/skills/test_navigation.py +++ b/dimos/agents2/skills/test_navigation.py @@ -17,97 +17,64 @@ from dimos.utils.transform_utils import euler_to_quaternion -def test_stop_movement(fake_robot, create_navigation_agent): +def test_stop_movement(create_navigation_agent, navigation_skill_container, mocker): + navigation_skill_container._cancel_goal = mocker.Mock() + navigation_skill_container._stop_exploration = mocker.Mock() agent = create_navigation_agent(fixture="test_stop_movement.json") + agent.query("stop") - fake_robot.stop_exploration.assert_called_once_with() + navigation_skill_container._cancel_goal.assert_called_once_with() + navigation_skill_container._stop_exploration.assert_called_once_with() -def test_take_a_look_around(fake_robot, create_navigation_agent, mocker): - fake_robot.explore.return_value = True - fake_robot.is_exploration_active.side_effect = [True, False] +def test_take_a_look_around(create_navigation_agent, navigation_skill_container, mocker): + navigation_skill_container._explore = mocker.Mock() + navigation_skill_container._is_exploration_active = mocker.Mock() mocker.patch("dimos.agents2.skills.navigation.time.sleep") agent = create_navigation_agent(fixture="test_take_a_look_around.json") agent.query("take a look around for 10 seconds") - fake_robot.explore.assert_called_once_with() - - -def test_go_to_object(fake_robot, create_navigation_agent, mocker): - fake_robot.object_tracker = mocker.MagicMock() - fake_robot.object_tracker.is_tracking.side_effect = [True, True, True, True] # Tracking active - fake_robot.navigator = mocker.MagicMock() + navigation_skill_container._explore.assert_called_once_with() - # Simulate navigation states: FOLLOWING_PATH -> IDLE (goal reached) - from dimos.navigation.bt_navigator.navigator import NavigatorState - - fake_robot.navigator.get_state.side_effect = [ - NavigatorState.FOLLOWING_PATH, - NavigatorState.FOLLOWING_PATH, - NavigatorState.IDLE, - ] - fake_robot.navigator.is_goal_reached.return_value = True +def test_go_to_semantic_location(create_navigation_agent, navigation_skill_container, mocker): mocker.patch( "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_by_tagged_location", return_value=None, ) mocker.patch( - "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_using_semantic_map", + "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_to_object", return_value=None, ) - mocker.patch("dimos.agents2.skills.navigation.time.sleep") - - agent = create_navigation_agent(fixture="test_go_to_object.json") - - agent.query("go to the chair") - - fake_robot.object_tracker.track.assert_called_once() - actual_bbox = fake_robot.object_tracker.track.call_args[0][0] - expected_bbox = (82, 51, 163, 159) - - for actual_val, expected_val in zip(actual_bbox, expected_bbox): - assert abs(actual_val - expected_val) <= 5, ( - f"BBox {actual_bbox} not within ±5 of {expected_bbox}" - ) - - fake_robot.object_tracker.stop_track.assert_called_once() - - -def test_go_to_semantic_location(fake_robot, create_navigation_agent, mocker): mocker.patch( - "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_by_tagged_location", - return_value=None, + "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_to", + return_value=True, ) - mocker.patch( - "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_to_object", - return_value=None, + navigation_skill_container._query_by_text = mocker.Mock( + return_value=[ + { + "distance": 0.5, + "metadata": [ + { + "pos_x": 1, + "pos_y": 2, + "rot_z": 3, + } + ], + } + ] ) - fake_robot.spatial_memory = mocker.Mock() - fake_robot.spatial_memory.query_by_text.return_value = [ - { - "distance": 0.5, - "metadata": [ - { - "pos_x": 1, - "pos_y": 2, - "rot_z": 3, - } - ], - } - ] agent = create_navigation_agent(fixture="test_go_to_semantic_location.json") agent.query("go to the bookshelf") - fake_robot.spatial_memory.query_by_text.assert_called_once_with("bookshelf") - fake_robot.navigate_to.assert_called_once_with( + navigation_skill_container._query_by_text.assert_called_once_with("bookshelf") + navigation_skill_container._navigate_to.assert_called_once_with( PoseStamped( position=Vector3(1, 2, 0), orientation=euler_to_quaternion(Vector3(0, 0, 3)), frame_id="world", ), - blocking=True, ) diff --git a/dimos/agents2/system_prompt.py b/dimos/agents2/system_prompt.py new file mode 100644 index 0000000000..5168ed96d0 --- /dev/null +++ b/dimos/agents2/system_prompt.py @@ -0,0 +1,25 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.agents2.constants import AGENT_SYSTEM_PROMPT_PATH + +_SYSTEM_PROMPT = None + + +def get_system_prompt() -> str: + global _SYSTEM_PROMPT + if _SYSTEM_PROMPT is None: + with open(AGENT_SYSTEM_PROMPT_PATH, "r") as f: + _SYSTEM_PROMPT = f.read() + return _SYSTEM_PROMPT diff --git a/dimos/conftest.py b/dimos/conftest.py index 495afa8a24..e1d0c96e42 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -70,7 +70,6 @@ def monitor_threads(request): yield - # Only check for threads created BY THIS TEST, not existing ones with _seen_threads_lock: before = _before_test_threads.get(test_name, set()) current = {t.ident for t in threading.enumerate() if t.ident is not None} @@ -86,6 +85,15 @@ def monitor_threads(request): t for t in threading.enumerate() if t.ident in new_thread_ids and t.name != "MainThread" ] + # Filter out expected persistent threads from Dask that are shared globally + # These threads are intentionally left running and cleaned up on process exit + expected_persistent_thread_prefixes = ["Dask-Offload"] + new_threads = [ + t + for t in new_threads + if not any(t.name.startswith(prefix) for prefix in expected_persistent_thread_prefixes) + ] + # Filter out threads we've already seen (from previous tests) truly_new = [t for t in new_threads if t.ident not in _seen_threads] diff --git a/dimos/constants.py b/dimos/constants.py index 86b3a39aa1..17273b6dd3 100644 --- a/dimos/constants.py +++ b/dimos/constants.py @@ -27,3 +27,6 @@ DEFAULT_CAPACITY_COLOR_IMAGE = 1920 * 1080 * 3 # Default depth image size: 1280x720 frame * 4 (float32 size) DEFAULT_CAPACITY_DEPTH_IMAGE = 1280 * 720 * 4 + +# From https://github.com/lcm-proj/lcm.git +LCM_MAX_CHANNEL_NAME_LENGTH = 63 diff --git a/dimos/core/README_BLUEPRINTS.md b/dimos/core/README_BLUEPRINTS.md new file mode 100644 index 0000000000..d54000cc6a --- /dev/null +++ b/dimos/core/README_BLUEPRINTS.md @@ -0,0 +1,219 @@ +# Blueprints + +Blueprints (`ModuleBlueprint`) are instructions for how to initialize a `Module`. + +You don't typically want to run a single module, so multiple blueprints are handled together in `ModuleBlueprintSet`. + +You create a `ModuleBlueprintSet` from a single module (say `ConnectionModule`) with: + +```python +blueprint = create_module_blueprint(ConnectionModule, 'arg1', 'arg2', kwarg='value') +``` + +But the same thing can be acomplished more succinctly as: + +```python +connection = ConnectionModule.blueprint +``` + +Now you can create the blueprint with: + +```python +blueprint = connection('arg1', 'arg2', kwarg='value') +``` + +## Linking blueprints + +You can link multiple blueprints together with `autoconnect`: + +```python +blueprint = autoconnect( + module1(), + module2(), + module3(), +) +``` + +`blueprint` itself is a `ModuleBlueprintSet` so you can link it with other modules: + +```python +expanded_blueprint = autoconnect( + blueprint, + module4(), + module5(), +) +``` + +Blueprints are frozen data classes, and `autoconnect()` always constructs an expanded blueprint so you never have to worry about changes in one affecting the other. + +### Duplicate module handling + +If the same module appears multiple times in `autoconnect`, the **later blueprint wins** and overrides earlier ones: + +```python +blueprint = autoconnect( + module_a(arg1=1), + module_b(), + module_a(arg1=2), # This one is used, the first is discarded +) +``` + +This is so you can "inherit" from one blueprint but override something you need to change. + +## How transports are linked + +Imagine you have this code: + +```python +class ModuleA(Module): + image: Out[Image] = None + start_explore: Out[Bool] = None + +class ModuleB(Module): + image: In[Image] = None + begin_explore: In[Bool] = None + +module_a = partial(create_module_blueprint, ModuleA) +module_b = partial(create_module_blueprint, ModuleB) + +autoconnect(module_a(), module_b()) +``` + +Connections are linked based on `(property_name, object_type)`. In this case `('image', Image)` will be connected between the two modules, but `begin_explore` will not be linked to `start_explore`. + +## Topic names + +By default, the name of the property is used to generate the topic name. So for `image`, the topic will be `/image`. + +The property name is used only if it's unique. If two modules have the same property name with different types, then both get a random topic such as `/SGVsbG8sIFdvcmxkI`. + +If you don't like the name you can always override it like in the next section. + +## Which transport is used? + +By default `LCMTransport` is used if the object supports `lcm_encode`. If it doesn't `pLCMTransport` is used (meaning "pickled LCM"). + +You can override transports with the `with_transports` method. It returns a new blueprint in which the override is set. + +```python +blueprint = autoconnect(...) +expanded_blueprint = autoconnect(blueprint, ...) +blueprint = blueprint.with_transports({ + ("image", Image): pSHMTransport( + "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), + ("start_explore", Bool): pLCMTransport(), +}) +``` + +Note: `expanded_blueprint` does not get the transport overrides because it's created from the initial value of `blueprint`, not the second. + +## Overriding global configuration. + +Each module can optionally take a `global_config` option in `__init__`. E.g.: + +```python +class ModuleA(Module): + + def __init__(self, global_config: GlobalConfig | None = None): + ... +``` + +The config is normally taken from .env or from environment variables. But you can specifically override the values for a specific blueprint: + +```python +blueprint = blueprint.with_global_config(n_dask_workers=8) +``` + +## Calling the methods of other modules + +Imagine you have this code: + +```python +class ModuleA(Module): + + @rpc + def get_time(self) -> str: + ... + +class ModuleB(Module): + def request_the_time(self) -> None: + ... +``` + +And you want to call `ModuleA.get_time` in `ModuleB.request_the_time`. + +You can do so by defining a method like `set__`. It will be called with an `RpcCall` that will call the original `ModuleA.get_time`. So you can write this: + +```python +class ModuleA(Module): + + @rpc + def get_time(self) -> str: + ... + +class ModuleB(Module): + @rpc # Note that it has to be an rpc method. + def set_ModuleA_get_time(self, rpc_call: RpcCall) -> None: + self._get_time = rpc_call + self._get_time.set_rpc(self.rpc) + + def request_the_time(self) -> None: + print(self._get_time()) +``` + +Note that `RpcCall.rpc` does not serialize, so you have to set it to the one from the module with `rpc_call.set_rpc(self.rpc)` + +## Defining skills + +Skills have to be registered with `LlmAgent.register_skills(self)`. + +```python +class SomeSkill(Module): + + @skill + def some_skill(self) -> None: + ... + + @rpc + def set_LlmAgent_register_skills(self, register_skills: RpcCall) -> None: + register_skills.set_rpc(self.rpc) + register_skills(RPCClient(self, self.__class__)) + + # The agent is just interested in the `@skill` methods, so you'll need this if your class + # has things that cannot be pickled. + def __getstate__(self): + pass + def __setstate__(self, _state): + pass +``` + +Or, you can avoid all of this by inheriting from `SkillModule` which does the above automatically: + +```python +class SomeSkill(SkillModule): + + @skill + def some_skill(self) -> None: + ... +``` + +## Building + +All you have to do to build a blueprint is call: + +```python +module_coordinator = blueprint.build(global_config=config) +``` + +This returns a `ModuleCoordinator` instance that manages all deployed modules. + +### Running and shutting down + +You can block the thread until it exits with: + +```python +module_coordinator.wait_until_shutdown() +``` + +This will wait for Ctrl+C and then automatically stop all modules and clean up resources. \ No newline at end of file diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 0bd3603126..747a25e498 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -6,9 +6,11 @@ from dask.distributed import Client, LocalCluster from rich.console import Console +import signal import dimos.core.colors as colors from dimos.core.core import rpc from dimos.core.module import Module, ModuleBase, ModuleConfig +from dimos.core.rpc_client import RPCClient from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.utils.actor_registry import ActorRegistry from dimos.core.transport import ( @@ -79,81 +81,6 @@ def teardown(self, worker): def patch_actor(actor, cls): ... -class RPCClient: - def __init__(self, actor_instance, actor_class): - self.rpc = LCMRPC() - self.actor_class = actor_class - self.remote_name = actor_class.__name__ - self.actor_instance = actor_instance - self.rpcs = actor_class.rpcs.keys() - self.rpc.start() - self._unsub_fns = [] - - def stop_client(self): - for unsub in self._unsub_fns: - try: - unsub() - except Exception: - pass - - self._unsub_fns = [] - - if self.rpc: - self.rpc.stop() - self.rpc = None - - def __reduce__(self): - # Return the class and the arguments needed to reconstruct the object - return ( - self.__class__, - (self.actor_instance, self.actor_class), - ) - - # passthrough - def __getattr__(self, name: str): - # Check if accessing a known safe attribute to avoid recursion - if name in { - "__class__", - "__init__", - "__dict__", - "__getattr__", - "rpcs", - "remote_name", - "remote_instance", - "actor_instance", - }: - raise AttributeError(f"{name} is not found.") - - if name in self.rpcs: - # Get the original method to preserve its docstring - original_method = getattr(self.actor_class, name, None) - - def rpc_call(*args, **kwargs): - # For stop/close/shutdown, use call_nowait to avoid deadlock - # (the remote side stops its RPC service before responding) - if name in ("stop", "close", "shutdown"): - if self.rpc: - self.rpc.call_nowait(f"{self.remote_name}/{name}", (args, kwargs)) - self.stop_client() - return None - - result, unsub_fn = self.rpc.call_sync(f"{self.remote_name}/{name}", (args, kwargs)) - self._unsub_fns.append(unsub_fn) - return result - - # Copy docstring and other attributes from original method - if original_method: - rpc_call.__doc__ = original_method.__doc__ - rpc_call.__name__ = original_method.__name__ - rpc_call.__qualname__ = f"{self.__class__.__name__}.{original_method.__name__}" - - return rpc_call - - # return super().__getattr__(name) - # Try to avoid recursion by directly accessing attributes that are known - return self.actor_instance.__getattr__(name) - - DimosCluster = Client @@ -173,7 +100,7 @@ def deploy( ).result() worker = actor.set_ref(actor).result() - print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) + print((f"deployed: {colors.blue(actor)} @ {colors.orange('worker ' + str(worker))}")) # Register actor deployment in shared memory ActorRegistry.update(str(actor), str(worker)) @@ -280,14 +207,10 @@ def close_all(): except Exception: pass - # Shutdown the Dask offload thread pool - try: - from distributed.utils import _offload_executor - - if _offload_executor: - _offload_executor.shutdown(wait=False) - except Exception: - pass + # Note: We do NOT shutdown the _offload_executor here because it's a global + # module-level ThreadPoolExecutor shared across all Dask clients in the process. + # Shutting it down here would break subsequent Dask client usage (e.g., in tests). + # The executor will be cleaned up when the Python process exits. # Give threads time to clean up # Dask's IO loop and Profile threads are daemon threads @@ -309,8 +232,6 @@ def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client: n: Number of workers (defaults to CPU count) memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default) """ - import signal - import atexit console = Console() if not n: diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py new file mode 100644 index 0000000000..53f20a0bfb --- /dev/null +++ b/dimos/core/blueprints.py @@ -0,0 +1,187 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from collections import defaultdict +from functools import cached_property +import inspect +from types import MappingProxyType +from typing import Any, Literal, Mapping, get_origin, get_args + +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.global_config import GlobalConfig +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.core.transport import LCMTransport, pLCMTransport +from dimos.utils.generic import short_id + + +@dataclass(frozen=True) +class ModuleConnection: + name: str + type: type + direction: Literal["in", "out"] + + +@dataclass(frozen=True) +class ModuleBlueprint: + module: type[Module] + connections: tuple[ModuleConnection, ...] + args: tuple[Any] + kwargs: dict[str, Any] + + +@dataclass(frozen=True) +class ModuleBlueprintSet: + blueprints: tuple[ModuleBlueprint, ...] + # TODO: Replace Any + transports: Mapping[tuple[str, type], Any] = field(default_factory=lambda: MappingProxyType({})) + global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) + + def with_transports(self, transports: dict[tuple[str, type], Any]) -> "ModuleBlueprintSet": + return ModuleBlueprintSet( + blueprints=self.blueprints, + transports=MappingProxyType({**self.transports, **transports}), + global_config_overrides=self.global_config_overrides, + ) + + def with_global_config(self, **kwargs: Any) -> "ModuleBlueprintSet": + return ModuleBlueprintSet( + blueprints=self.blueprints, + transports=self.transports, + global_config_overrides=MappingProxyType({**self.global_config_overrides, **kwargs}), + ) + + def _get_transport_for(self, name: str, type: type) -> Any: + transport = self.transports.get((name, type), None) + if transport: + return transport + + use_pickled = getattr(type, "lcm_encode", None) is None + topic = f"/{name}" if self._is_name_unique(name) else f"/{short_id()}" + transport = pLCMTransport(topic) if use_pickled else LCMTransport(topic, type) + + return transport + + @cached_property + def _all_name_types(self) -> set[tuple[str, type]]: + return { + (conn.name, conn.type) + for blueprint in self.blueprints + for conn in blueprint.connections + } + + def _is_name_unique(self, name: str) -> bool: + return sum(1 for n, _ in self._all_name_types if n == name) == 1 + + def build(self, global_config: GlobalConfig) -> ModuleCoordinator: + global_config = global_config.model_copy(update=self.global_config_overrides) + + module_coordinator = ModuleCoordinator(global_config=global_config) + + module_coordinator.start() + + # Deploy all modules. + for blueprint in self.blueprints: + kwargs = {**blueprint.kwargs} + sig = inspect.signature(blueprint.module.__init__) + if "global_config" in sig.parameters: + kwargs["global_config"] = global_config + module_coordinator.deploy(blueprint.module, *blueprint.args, **kwargs) + + # Gather all the In/Out connections. + connections = defaultdict(list) + for blueprint in self.blueprints: + for conn in blueprint.connections: + connections[conn.name, conn.type].append(blueprint.module) + + # Connect all In/Out connections by name and type. + for name, type in connections.keys(): + transport = self._get_transport_for(name, type) + for module in connections[(name, type)]: + instance = module_coordinator.get_instance(module) + getattr(instance, name).transport = transport + + # Gather all RPC methods. + rpc_methods = {} + for blueprint in self.blueprints: + for method_name in blueprint.module.rpcs.keys(): + method = getattr(module_coordinator.get_instance(blueprint.module), method_name) + rpc_methods[f"{blueprint.module.__name__}_{method_name}"] = method + + # Fulfil method requests (so modules can call each other). + for blueprint in self.blueprints: + for method_name, method in blueprint.module.rpcs.items(): + if not method_name.startswith("set_"): + continue + linked_name = method_name.removeprefix("set_") + if linked_name not in rpc_methods: + continue + instance = module_coordinator.get_instance(blueprint.module) + getattr(instance, method_name)(rpc_methods[linked_name]) + + module_coordinator.start_all_modules() + + return module_coordinator + + +def _make_module_blueprint( + module: type[Module], args: tuple[Any], kwargs: dict[str, Any] +) -> ModuleBlueprint: + connections: list[ModuleConnection] = [] + + all_annotations = {} + for base_class in reversed(module.__mro__): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) + + for name, annotation in all_annotations.items(): + origin = get_origin(annotation) + if origin not in (In, Out): + continue + direction = "in" if origin == In else "out" + type_ = get_args(annotation)[0] + connections.append(ModuleConnection(name=name, type=type_, direction=direction)) + + return ModuleBlueprint(module=module, connections=tuple(connections), args=args, kwargs=kwargs) + + +def create_module_blueprint(module: type[Module], *args: Any, **kwargs: Any) -> ModuleBlueprintSet: + blueprint = _make_module_blueprint(module, args, kwargs) + return ModuleBlueprintSet(blueprints=(blueprint,)) + + +def autoconnect(*blueprints: ModuleBlueprintSet) -> ModuleBlueprintSet: + all_blueprints = tuple(_eliminate_duplicates([bp for bs in blueprints for bp in bs.blueprints])) + all_transports = dict(sum([list(x.transports.items()) for x in blueprints], [])) + all_config_overrides = dict( + sum([list(x.global_config_overrides.items()) for x in blueprints], []) + ) + + return ModuleBlueprintSet( + blueprints=all_blueprints, + transports=MappingProxyType(all_transports), + global_config_overrides=MappingProxyType(all_config_overrides), + ) + + +def _eliminate_duplicates(blueprints: list[ModuleBlueprint]) -> list[ModuleBlueprint]: + # The duplicates are eliminated in reverse so that newer blueprints override older ones. + seen = set() + unique_blueprints = [] + for bp in reversed(blueprints): + if bp.module not in seen: + seen.add(bp.module) + unique_blueprints.append(bp) + return list(reversed(unique_blueprints)) diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py new file mode 100644 index 0000000000..e25184c351 --- /dev/null +++ b/dimos/core/global_config.py @@ -0,0 +1,39 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class GlobalConfig(BaseSettings): + robot_ip: str | None = None + use_simulation: bool = False + use_replay: bool = False + n_dask_workers: int = 2 + + model_config = SettingsConfigDict( + env_prefix="DIMOS_", + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + frozen=True, + ) + + @cached_property + def unitree_connection_type(self) -> str: + if self.use_replay: + return "fake" + if self.use_simulation: + return "mujoco" + return "webrtc" diff --git a/dimos/core/module.py b/dimos/core/module.py index 5cea554072..aa65c1479f 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +from functools import partial import inspect import threading from dataclasses import dataclass @@ -29,12 +30,14 @@ from dimos.core import colors from dimos.core.core import T, rpc +from dimos.core.global_config import GlobalConfig from dimos.core.resource import Resource from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec from dimos.protocol.service import Configurable from dimos.protocol.skill.skill import SkillContainer from dimos.protocol.tf import LCMTF, TFSpec +from dimos.utils.generic import classproperty def get_loop() -> tuple[asyncio.AbstractEventLoop, Optional[threading.Thread]]: @@ -124,6 +127,27 @@ def _close_rpc(self): self.rpc.stop() self.rpc = None + def __getstate__(self): + """Exclude unpicklable runtime attributes when serializing.""" + state = self.__dict__.copy() + # Remove unpicklable attributes + state.pop("_disposables", None) + state.pop("_loop", None) + state.pop("_loop_thread", None) + state.pop("_rpc", None) + state.pop("_tf", None) + return state + + def __setstate__(self, state): + """Restore object from pickled state.""" + self.__dict__.update(state) + # Reinitialize runtime attributes + self._disposables = CompositeDisposable() + self._loop = None + self._loop_thread = None + self._rpc = None + self._tf = None + @property def tf(self): if self._tf is None: @@ -216,6 +240,13 @@ def repr_rpc(fn: Callable) -> str: return "\n".join(ret) + @classproperty + def blueprint(cls): + # Here to prevent circular imports. + from dimos.core.blueprints import create_module_blueprint + + return partial(create_module_blueprint, cls) + class DaskModule(ModuleBase): ref: Actor diff --git a/dimos/core/dimos.py b/dimos/core/module_coordinator.py similarity index 75% rename from dimos/core/dimos.py rename to dimos/core/module_coordinator.py index d286284fec..6eb916fda3 100644 --- a/dimos/core/dimos.py +++ b/dimos/core/module_coordinator.py @@ -12,23 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time from typing import Optional, Type, TypeVar from dimos import core from dimos.core import DimosCluster, Module +from dimos.core.global_config import GlobalConfig from dimos.core.resource import Resource T = TypeVar("T", bound="Module") -class Dimos(Resource): +class ModuleCoordinator(Resource): _client: Optional[DimosCluster] = None _n: Optional[int] = None _memory_limit: str = "auto" _deployed_modules: dict[Type[Module], Module] = {} - def __init__(self, n: Optional[int] = None, memory_limit: str = "auto"): - self._n = n + def __init__( + self, + n: Optional[int] = None, + memory_limit: str = "auto", + global_config: GlobalConfig | None = None, + ): + cfg = global_config or GlobalConfig() + self._n = n if n is not None else cfg.n_dask_workers self._memory_limit = memory_limit def start(self) -> None: @@ -54,3 +62,12 @@ def start_all_modules(self) -> None: def get_instance(self, module: Type[T]) -> T | None: return self._deployed_modules.get(module) + + def wait_until_shutdown(self) -> None: + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + pass + finally: + self.stop() diff --git a/dimos/core/rpc_client.py b/dimos/core/rpc_client.py new file mode 100644 index 0000000000..dce1d704af --- /dev/null +++ b/dimos/core/rpc_client.py @@ -0,0 +1,142 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable + + +from dimos.protocol.rpc.lcmrpc import LCMRPC +from dimos.utils.logging_config import setup_logger + + +logger = setup_logger(__file__) + + +class RpcCall: + _original_method: Callable[..., Any] | None + _rpc: LCMRPC | None + _name: str + _remote_name: str + _unsub_fns: list + _stop_rpc_client: Callable[[], None] | None = None + + def __init__( + self, + original_method: Callable[..., Any] | None, + rpc: LCMRPC, + name: str, + remote_name: str, + unsub_fns: list, + stop_client: Callable[[], None] | None = None, + ): + self._original_method = original_method + self._rpc = rpc + self._name = name + self._remote_name = remote_name + self._unsub_fns = unsub_fns + self._stop_rpc_client = stop_client + + if original_method: + self.__doc__ = original_method.__doc__ + self.__name__ = original_method.__name__ + self.__qualname__ = f"{self.__class__.__name__}.{original_method.__name__}" + + def set_rpc(self, rpc: LCMRPC): + self._rpc = rpc + + def __call__(self, *args, **kwargs): + if not self._rpc: + logger.warning("RPC client not initialized") + return None + + # For stop, use call_nowait to avoid deadlock + # (the remote side stops its RPC service before responding) + if self._name == "stop": + self._rpc.call_nowait(f"{self._remote_name}/{self._name}", (args, kwargs)) + if self._stop_rpc_client: + self._stop_rpc_client() + return None + + result, unsub_fn = self._rpc.call_sync(f"{self._remote_name}/{self._name}", (args, kwargs)) + self._unsub_fns.append(unsub_fn) + return result + + def __getstate__(self): + return (self._original_method, self._name, self._remote_name) + + def __setstate__(self, state): + self._original_method, self._name, self._remote_name = state + self._unsub_fns = [] + self._rpc = None + self._stop_rpc_client = None + + +class RPCClient: + def __init__(self, actor_instance, actor_class): + self.rpc = LCMRPC() + self.actor_class = actor_class + self.remote_name = actor_class.__name__ + self.actor_instance = actor_instance + self.rpcs = actor_class.rpcs.keys() + self.rpc.start() + self._unsub_fns = [] + + def stop_rpc_client(self): + for unsub in self._unsub_fns: + try: + unsub() + except Exception: + pass + + self._unsub_fns = [] + + if self.rpc: + self.rpc.stop() + self.rpc = None + + def __reduce__(self): + # Return the class and the arguments needed to reconstruct the object + return ( + self.__class__, + (self.actor_instance, self.actor_class), + ) + + # passthrough + def __getattr__(self, name: str): + # Check if accessing a known safe attribute to avoid recursion + if name in { + "__class__", + "__init__", + "__dict__", + "__getattr__", + "rpcs", + "remote_name", + "remote_instance", + "actor_instance", + }: + raise AttributeError(f"{name} is not found.") + + if name in self.rpcs: + original_method = getattr(self.actor_class, name, None) + return RpcCall( + original_method, + self.rpc, + name, + self.remote_name, + self._unsub_fns, + self.stop_rpc_client, + ) + + # return super().__getattr__(name) + # Try to avoid recursion by directly accessing attributes that are known + return self.actor_instance.__getattr__(name) diff --git a/dimos/core/skill_module.py b/dimos/core/skill_module.py new file mode 100644 index 0000000000..f432b48861 --- /dev/null +++ b/dimos/core/skill_module.py @@ -0,0 +1,32 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.module import Module +from dimos.core.rpc_client import RPCClient, RpcCall +from dimos.protocol.skill.skill import rpc + + +class SkillModule(Module): + """Use this module if you want to auto-register skills to an LlmAgent.""" + + @rpc + def set_LlmAgent_register_skills(self, callable: RpcCall) -> None: + callable.set_rpc(self.rpc) + callable(RPCClient(self, self.__class__)) + + def __getstate__(self): + pass + + def __setstate__(self, _state): + pass diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py new file mode 100644 index 0000000000..edce54f2e1 --- /dev/null +++ b/dimos/core/test_blueprints.py @@ -0,0 +1,185 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.blueprints import ( + ModuleBlueprint, + ModuleBlueprintSet, + ModuleConnection, + _make_module_blueprint, +) +from dimos.core.blueprints import autoconnect +from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig +from dimos.core.module import Module +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.rpc_client import RpcCall +from dimos.core.stream import In, Out +from dimos.core.transport import LCMTransport +from dimos.protocol import pubsub + + +class Scratch: + pass + + +class Petting: + pass + + +class CatModule(Module): + pet_cat: In[Petting] + scratches: Out[Scratch] + + +class Data1: + pass + + +class Data2: + pass + + +class Data3: + pass + + +class ModuleA(Module): + data1: Out[Data1] = None + data2: Out[Data2] = None + + @rpc + def get_name(self) -> str: + return "A, Module A" + + +class ModuleB(Module): + data1: In[Data1] = None + data2: In[Data2] = None + data3: Out[Data3] = None + + _module_a_get_name: callable = None + + @rpc + def set_ModuleA_get_name(self, callable: RpcCall) -> None: + self._module_a_get_name = callable + self._module_a_get_name.set_rpc(self.rpc) + + @rpc + def what_is_as_name(self) -> str: + if self._module_a_get_name is None: + return "ModuleA.get_name not set" + return self._module_a_get_name() + + +class ModuleC(Module): + data3: In[Data3] = None + + +module_a = ModuleA.blueprint +module_b = ModuleB.blueprint +module_c = ModuleC.blueprint + + +def test_get_connection_set(): + assert _make_module_blueprint(CatModule, args=("arg1"), kwargs={"k": "v"}) == ModuleBlueprint( + module=CatModule, + connections=( + ModuleConnection(name="pet_cat", type=Petting, direction="in"), + ModuleConnection(name="scratches", type=Scratch, direction="out"), + ), + args=("arg1"), + kwargs={"k": "v"}, + ) + + +def test_autoconnect(): + blueprint_set = autoconnect(module_a(), module_b()) + + assert blueprint_set == ModuleBlueprintSet( + blueprints=( + ModuleBlueprint( + module=ModuleA, + connections=( + ModuleConnection(name="data1", type=Data1, direction="out"), + ModuleConnection(name="data2", type=Data2, direction="out"), + ), + args=(), + kwargs={}, + ), + ModuleBlueprint( + module=ModuleB, + connections=( + ModuleConnection(name="data1", type=Data1, direction="in"), + ModuleConnection(name="data2", type=Data2, direction="in"), + ModuleConnection(name="data3", type=Data3, direction="out"), + ), + args=(), + kwargs={}, + ), + ) + ) + + +def test_with_transports(): + custom_transport = LCMTransport("/custom_topic", Data1) + blueprint_set = autoconnect(module_a(), module_b()).with_transports( + {("data1", Data1): custom_transport} + ) + + assert ("data1", Data1) in blueprint_set.transports + assert blueprint_set.transports[("data1", Data1)] == custom_transport + + +def test_with_global_config(): + blueprint_set = autoconnect(module_a(), module_b()).with_global_config(option1=True, option2=42) + + assert "option1" in blueprint_set.global_config_overrides + assert blueprint_set.global_config_overrides["option1"] is True + assert "option2" in blueprint_set.global_config_overrides + assert blueprint_set.global_config_overrides["option2"] == 42 + + +def test_build_happy_path(): + pubsub.lcm.autoconf() + + blueprint_set = autoconnect(module_a(), module_b(), module_c()) + + coordinator = blueprint_set.build(GlobalConfig()) + + try: + assert isinstance(coordinator, ModuleCoordinator) + + module_a_instance = coordinator.get_instance(ModuleA) + module_b_instance = coordinator.get_instance(ModuleB) + module_c_instance = coordinator.get_instance(ModuleC) + + assert module_a_instance is not None + assert module_b_instance is not None + assert module_c_instance is not None + + assert module_a_instance.data1.transport is not None + assert module_a_instance.data2.transport is not None + assert module_b_instance.data1.transport is not None + assert module_b_instance.data2.transport is not None + assert module_b_instance.data3.transport is not None + assert module_c_instance.data3.transport is not None + + assert module_a_instance.data1.transport.topic == module_b_instance.data1.transport.topic + assert module_a_instance.data2.transport.topic == module_b_instance.data2.transport.topic + assert module_b_instance.data3.transport.topic == module_c_instance.data3.transport.topic + + assert module_b_instance.what_is_as_name() == "A, Module A" + + finally: + coordinator.stop() diff --git a/dimos/core/test_modules.py b/dimos/core/test_modules.py index 27474adc7f..42112f2415 100644 --- a/dimos/core/test_modules.py +++ b/dimos/core/test_modules.py @@ -254,10 +254,6 @@ def build_class_hierarchy(root_path: Path) -> Dict[str, List[str]]: # Skip files that can't be parsed continue - from pprint import pprint - - pprint(hierarchy) - return hierarchy @@ -292,7 +288,7 @@ def get_all_module_subclasses(): filtered_results = [] for class_name, filepath, has_start, has_stop, forbidden_methods in results: # Skip base module classes themselves - if class_name in ("Module", "ModuleBase", "DaskModule"): + if class_name in ("Module", "ModuleBase", "DaskModule", "SkillModule"): continue # Skip test-only modules (those defined in test_ files) diff --git a/dimos/mapping/osm/demo_osm.py b/dimos/mapping/osm/demo_osm.py index 7617a48b9f..46f6298591 100644 --- a/dimos/mapping/osm/demo_osm.py +++ b/dimos/mapping/osm/demo_osm.py @@ -13,79 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time -import reactivex as rx from dotenv import load_dotenv -from reactivex import Observable - -from dimos.agents2 import Agent -from dimos.agents2.cli.human import HumanInput -from dimos.agents2.constants import AGENT_SYSTEM_PROMPT_PATH -from dimos.agents2.skills.osm import OsmSkillContainer -from dimos.core.resource import Resource +from reactivex import interval + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.osm import osm_skill +from dimos.agents2.system_prompt import get_system_prompt +from dimos.core.blueprints import autoconnect +from dimos.core.module import Module +from dimos.core.stream import Out from dimos.mapping.types import LatLon -from dimos.robot.robot import Robot -from dimos.robot.utils.robot_debugger import RobotDebugger -from dimos.utils.logging_config import setup_logger - -logger = setup_logger(__file__) load_dotenv() -with open(AGENT_SYSTEM_PROMPT_PATH, "r") as f: - SYSTEM_PROMPT = f.read() - - -class FakeRobot(Robot): - pass - - -class UnitreeAgents2Runner(Resource): - def __init__(self): - self._robot = None - self._agent = None - self._robot_debugger = None - self._osm_skill_container = None - - def start(self) -> None: - self._robot = FakeRobot() - self._agent = Agent(system_prompt=SYSTEM_PROMPT) - self._osm_skill_container = OsmSkillContainer(self._robot, _get_fake_location()) - self._osm_skill_container.start() - self._agent.register_skills(self._osm_skill_container) - self._agent.register_skills(HumanInput()) - self._agent.run_implicit_skill("human") - self._agent.start() - self._agent.loop_thread() - self._robot_debugger = RobotDebugger(self._robot) - self._robot_debugger.start() - def stop(self) -> None: - if self._robot_debugger: - self._robot_debugger.stop() - if self._osm_skill_container: - self._osm_skill_container.stop() - if self._agent: - self._agent.stop() +class DemoRobot(Module): + gps_location: Out[LatLon] = None - def run(self): - while True: - try: - time.sleep(1) - except KeyboardInterrupt: - return + def start(self): + super().start() + self._disposables.add(interval(1.0).subscribe(lambda _: self._publish_gps_location())) + def stop(self): + super().stop() -def main(): - runner = UnitreeAgents2Runner() - runner.start() - runner.run() - runner.stop() + def _publish_gps_location(self): + self.gps_location.publish(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) -def _get_fake_location() -> Observable[LatLon]: - return rx.of(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) +demo_robot = DemoRobot.blueprint -if __name__ == "__main__": - main() +demo_osm = autoconnect( + demo_robot(), + osm_skill(), + human_input(), + llm_agent(system_prompt=get_system_prompt()), +) diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py index 33d516106f..df1d50cbf2 100644 --- a/dimos/navigation/bt_navigator/navigator.py +++ b/dimos/navigation/bt_navigator/navigator.py @@ -24,6 +24,7 @@ from typing import Callable, Optional from dimos.core import Module, In, Out, rpc +from dimos.core.rpc_client import RpcCall from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.nav_msgs import OccupancyGrid from dimos_lcm.std_msgs import String @@ -121,6 +122,16 @@ def __init__( logger.info("Navigator initialized with stuck detection") + @rpc + def set_HolonomicLocalPlanner_reset(self, callable: RpcCall) -> None: + self.reset_local_planner = callable + self.reset_local_planner.set_rpc(self.rpc) + + @rpc + def set_HolonomicLocalPlanner_is_goal_reached(self, callable: RpcCall) -> None: + self.check_goal_reached = callable + self.check_goal_reached.set_rpc(self.rpc) + @rpc def start(self): super().start() @@ -342,3 +353,8 @@ def stop_navigation(self) -> None: self.recovery_server.reset() # Reset recovery server when stopping logger.info("Navigator stopped") + + +behavior_tree_navigator = BehaviorTreeNavigator.blueprint + +__all__ = ["BehaviorTreeNavigator", "behavior_tree_navigator"] diff --git a/dimos/navigation/frontier_exploration/__init__.py b/dimos/navigation/frontier_exploration/__init__.py index 388a5bfe6f..7236788842 100644 --- a/dimos/navigation/frontier_exploration/__init__.py +++ b/dimos/navigation/frontier_exploration/__init__.py @@ -1 +1 @@ -from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer +from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer, wavefront_frontier_explorer diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 5acbf7b5bf..a1ce4e8075 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -731,7 +731,12 @@ def stop_exploration(self) -> bool: self.no_gain_counter = 0 # Reset counter when exploration stops self.stop_event.set() - if self.exploration_thread and self.exploration_thread.is_alive(): + # Only join if we're NOT being called from the exploration thread itself + if ( + self.exploration_thread + and self.exploration_thread.is_alive() + and threading.current_thread() != self.exploration_thread + ): self.exploration_thread.join(timeout=2.0) logger.info("Stopped autonomous frontier exploration") @@ -810,3 +815,8 @@ def _exploration_loop(self): f"No frontier found (attempt {consecutive_failures}/{max_consecutive_failures}). Retrying in 2 seconds..." ) threading.Event().wait(2.0) + + +wavefront_frontier_explorer = WavefrontFrontierExplorer.blueprint + +__all__ = ["WavefrontFrontierExplorer", "wavefront_frontier_explorer"] diff --git a/dimos/navigation/global_planner/__init__.py b/dimos/navigation/global_planner/__init__.py index 0496f586b9..9aaf52e11e 100644 --- a/dimos/navigation/global_planner/__init__.py +++ b/dimos/navigation/global_planner/__init__.py @@ -1,2 +1,4 @@ -from dimos.navigation.global_planner.planner import AstarPlanner +from dimos.navigation.global_planner.planner import AstarPlanner, astar_planner from dimos.navigation.global_planner.algo import astar + +__all__ = ["AstarPlanner", "astar_planner", "astar"] diff --git a/dimos/navigation/global_planner/planner.py b/dimos/navigation/global_planner/planner.py index 08a00596aa..f9df988cfe 100644 --- a/dimos/navigation/global_planner/planner.py +++ b/dimos/navigation/global_planner/planner.py @@ -216,3 +216,8 @@ def plan(self, goal: Pose) -> Optional[Path]: logger.warning("No path found to the goal.") return None + + +astar_planner = AstarPlanner.blueprint + +__all__ = ["AstarPlanner", "astar_planner"] diff --git a/dimos/navigation/local_planner/holonomic_local_planner.py b/dimos/navigation/local_planner/holonomic_local_planner.py index d74e272724..94624fc65e 100644 --- a/dimos/navigation/local_planner/holonomic_local_planner.py +++ b/dimos/navigation/local_planner/holonomic_local_planner.py @@ -260,3 +260,8 @@ def _find_lookahead_point(self, path: np.ndarray, start_idx: int) -> np.ndarray: def _clip(self, v: np.ndarray) -> np.ndarray: """Instance method to clip velocity with access to v_max.""" return np.clip(v, -self.v_max, self.v_max) + + +holonomic_local_planner = HolonomicLocalPlanner.blueprint + +__all__ = ["HolonomicLocalPlanner", "holonomic_local_planner"] diff --git a/dimos/perception/detection/type/detection3d/test_pointcloud.py b/dimos/perception/detection/type/detection3d/test_pointcloud.py index 308839f8bf..edeeaacb4b 100644 --- a/dimos/perception/detection/type/detection3d/test_pointcloud.py +++ b/dimos/perception/detection/type/detection3d/test_pointcloud.py @@ -28,7 +28,7 @@ def test_detection3dpc(detection3dpc): assert obb.center[2] == pytest.approx(0.220184, abs=0.1) # Verify OBB extent values - assert obb.extent[0] == pytest.approx(0.531275, abs=0.1) + assert obb.extent[0] == pytest.approx(0.531275, abs=0.12) assert obb.extent[1] == pytest.approx(0.461054, abs=0.1) assert obb.extent[2] == pytest.approx(0.155, abs=0.1) diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index d59165cb06..497b6933b3 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -622,3 +622,8 @@ def _get_depth_from_bbox(self, bbox: List[int], depth_frame: np.ndarray) -> Opti return depth_25th_percentile return None + + +object_tracking = ObjectTracking.blueprint + +__all__ = ["ObjectTracking", "object_tracking"] diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 7d93e2e174..a6f4169b3e 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -28,6 +28,7 @@ from datetime import datetime from reactivex.disposable import Disposable +from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core import In, Module, rpc from dimos.msgs.sensor_msgs import Image from dimos.msgs.geometry_msgs import Vector3, Pose, PoseStamped @@ -38,6 +39,13 @@ from dimos.types.vector import Vector from dimos.types.robot_location import RobotLocation +_OUTPUT_DIR = DIMOS_PROJECT_ROOT / "assets" / "output" +_MEMORY_DIR = _OUTPUT_DIR / "memory" +_SPATIAL_MEMORY_DIR = _MEMORY_DIR / "spatial_memory" +_DB_PATH = _SPATIAL_MEMORY_DIR / "chromadb_data" +_VISUAL_MEMORY_PATH = _SPATIAL_MEMORY_DIR / "visual_memory.pkl" + + logger = setup_logger(__file__) @@ -62,10 +70,14 @@ def __init__( embedding_dimensions: int = 512, min_distance_threshold: float = 0.01, # Min distance in meters to store a new frame min_time_threshold: float = 1.0, # Min time in seconds to record a new frame - db_path: Optional[str] = None, # Path for ChromaDB persistence - visual_memory_path: Optional[str] = None, # Path for saving/loading visual memory + db_path: Optional[str] = str(_DB_PATH), # Path for ChromaDB persistence + visual_memory_path: Optional[str] = str( + _VISUAL_MEMORY_PATH + ), # Path for saving/loading visual memory new_memory: bool = True, # Whether to create a new memory from scratch - output_dir: Optional[str] = None, # Directory for storing visual memory data + output_dir: Optional[str] = str( + _SPATIAL_MEMORY_DIR + ), # Directory for storing visual memory data chroma_client: Any = None, # Optional ChromaDB client for persistence visual_memory: Optional[ "VisualMemory" @@ -649,3 +661,8 @@ def query_tagged_location(self, query: str) -> Optional[RobotLocation]: if semantic_distance < 0.3: return location return None + + +spatial_memory = SpatialMemory.blueprint + +__all__ = ["SpatialMemory", "spatial_memory"] diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index 238c1f6545..5fda6dbb83 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -96,49 +96,6 @@ def unsubscribe(): return unsubscribe - @deprecated("Listen for the lastest message directly") - def wait_for_message(self, topic: Topic, timeout: float = 1.0) -> Any: - """Wait for a single message on the specified topic. - - Args: - topic: The topic to listen on - timeout: Maximum time to wait for a message in seconds - - Returns: - The received message or None if timeout occurred - """ - - if self.l is None: - logger.error("Tried to wait for message after LCM was closed") - return None - - received_message = None - message_event = threading.Event() - - def message_handler(channel, data): - nonlocal received_message - try: - # Decode the message if type is specified - if hasattr(self, "decode") and topic.lcm_type is not None: - received_message = self.decode(data, topic) - else: - received_message = data - message_event.set() - except Exception as e: - print(f"Error decoding message: {e}") - message_event.set() - - # Subscribe to the topic - subscription = self.l.subscribe(str(topic), message_handler) - - try: - # Wait for message or timeout - message_event.wait(timeout) - return received_message - finally: - # Clean up subscription - self.l.unsubscribe(subscription) - class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any]): def encode(self, msg: LCMMsg, _: Topic) -> bytes: diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index 8bcf87828c..2d643a32d8 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -206,32 +206,6 @@ def _unsub(): return _unsub - # Optional utility like in LCMPubSubBase - def wait_for_message(self, topic: str, timeout: float = 1.0) -> Any: - """Wait once; if an encoder mixin is present, returned value is decoded.""" - received: Any = None - evt = threading.Event() - - def _handler(msg: bytes, _topic: str): - nonlocal received - try: - if hasattr(self, "decode"): # provided by encoder mixin - received = self.decode(msg, topic) # type: ignore[misc] - else: - received = msg - finally: - evt.set() - - unsub = self.subscribe(topic, _handler) - try: - evt.wait(timeout) - return received - finally: - try: - unsub() - except Exception: - pass - # ----- Capacity mgmt ---------------------------------------------------- def reconfigure(self, topic: str, *, capacity: int) -> dict: diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 54a45b5cc5..d8a39248bb 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -193,191 +193,3 @@ def callback(msg, topic): assert received_topic == topic print(test_message, topic) - - -def test_wait_for_message_basic(lcm): - """Test basic wait_for_message functionality - message arrives before timeout.""" - topic = Topic(topic="/test_wait", lcm_type=MockLCMMessage) - test_message = MockLCMMessage("wait_test_data") - - # Publish message after a short delay in another thread - def publish_delayed(): - time.sleep(0.1) - lcm.publish(topic, test_message) - - publisher_thread = threading.Thread(target=publish_delayed) - publisher_thread.start() - - # Wait for message with 1 second timeout - start_time = time.time() - received_msg = lcm.wait_for_message(topic, timeout=1.0) - elapsed_time = time.time() - start_time - - publisher_thread.join() - - # Check that we received the message - assert received_msg is not None - assert isinstance(received_msg, MockLCMMessage) - assert received_msg.data == "wait_test_data" - - # Check that we didn't wait the full timeout - assert elapsed_time < 0.5 # Should receive message in ~0.1 seconds - - -def test_wait_for_message_timeout(lcm): - """Test wait_for_message timeout - no message published.""" - topic = Topic(topic="/test_timeout", lcm_type=MockLCMMessage) - - # Wait for message that will never come - start_time = time.time() - received_msg = lcm.wait_for_message(topic, timeout=0.5) - elapsed_time = time.time() - start_time - - # Check that we got None (timeout) - assert received_msg is None - - # Check that we waited approximately the timeout duration - assert 0.4 < elapsed_time < 0.7 # Allow some tolerance - - -def test_wait_for_message_immediate(lcm): - """Test wait_for_message with message published immediately after subscription.""" - topic = Topic(topic="/test_immediate", lcm_type=MockLCMMessage) - test_message = MockLCMMessage("immediate_data") - - # Start waiting in a thread - received_msg = None - - def wait_for_msg(): - nonlocal received_msg - received_msg = lcm.wait_for_message(topic, timeout=1.0) - - wait_thread = threading.Thread(target=wait_for_msg) - wait_thread.start() - - # Give a tiny bit of time for subscription to be established - time.sleep(0.01) - - # Now publish the message - start_time = time.time() - lcm.publish(topic, test_message) - - # Wait for the thread to complete - wait_thread.join() - elapsed_time = time.time() - start_time - - # Check that we received the message quickly - assert received_msg is not None - assert isinstance(received_msg, MockLCMMessage) - assert received_msg.data == "immediate_data" - assert elapsed_time < 0.2 # Should be nearly immediate - - -def test_wait_for_message_multiple_sequential(lcm): - """Test multiple sequential wait_for_message calls.""" - topic = Topic(topic="/test_sequential", lcm_type=MockLCMMessage) - - # Test multiple messages in sequence - messages = ["msg1", "msg2", "msg3"] - - for msg_data in messages: - test_message = MockLCMMessage(msg_data) - - # Publish in background - def publish_delayed(msg=test_message): - time.sleep(0.05) - lcm.publish(topic, msg) - - publisher_thread = threading.Thread(target=publish_delayed) - publisher_thread.start() - - # Wait and verify - received_msg = lcm.wait_for_message(topic, timeout=1.0) - assert received_msg is not None - assert received_msg.data == msg_data - - publisher_thread.join() - - -def test_wait_for_message_concurrent(lcm): - """Test concurrent wait_for_message calls on different topics.""" - topic1 = Topic(topic="/test_concurrent1", lcm_type=MockLCMMessage) - topic2 = Topic(topic="/test_concurrent2", lcm_type=MockLCMMessage) - - message1 = MockLCMMessage("concurrent1") - message2 = MockLCMMessage("concurrent2") - - received_messages = {} - - def wait_for_topic(topic_name, topic): - msg = lcm.wait_for_message(topic, timeout=2.0) - received_messages[topic_name] = msg - - # Start waiting on both topics - thread1 = threading.Thread(target=wait_for_topic, args=("topic1", topic1)) - thread2 = threading.Thread(target=wait_for_topic, args=("topic2", topic2)) - - thread1.start() - thread2.start() - - # Publish to both topics after a delay - time.sleep(0.1) - lcm.publish(topic1, message1) - lcm.publish(topic2, message2) - - # Wait for both threads to complete - thread1.join(timeout=3.0) - thread2.join(timeout=3.0) - - # Verify both messages were received - assert "topic1" in received_messages - assert "topic2" in received_messages - assert received_messages["topic1"].data == "concurrent1" - assert received_messages["topic2"].data == "concurrent2" - - -def test_wait_for_message_wrong_topic(lcm): - """Test wait_for_message doesn't receive messages from wrong topic.""" - topic_correct = Topic(topic="/test_correct", lcm_type=MockLCMMessage) - topic_wrong = Topic(topic="/test_wrong", lcm_type=MockLCMMessage) - - message = MockLCMMessage("wrong_topic_data") - - # Publish to wrong topic - lcm.publish(topic_wrong, message) - - # Wait on correct topic - received_msg = lcm.wait_for_message(topic_correct, timeout=0.3) - - # Should timeout and return None - assert received_msg is None - - -def test_wait_for_message_pickle(pickle_lcm): - """Test wait_for_message with PickleLCM.""" - lcm = pickle_lcm - topic = Topic(topic="/test_pickle") - test_obj = {"key": "value", "number": 42} - - # Publish after delay - def publish_delayed(): - time.sleep(0.1) - lcm.publish(topic, test_obj) - - publisher_thread = threading.Thread(target=publish_delayed) - publisher_thread.start() - - # Wait for message - received_msg = lcm.wait_for_message(topic, timeout=1.0) - - publisher_thread.join() - - # Verify received object - assert received_msg is not None - # PickleLCM's wait_for_message returns the pickled bytes, need to decode - import pickle - - decoded_msg = pickle.loads(received_msg) - assert decoded_msg == test_obj - assert decoded_msg["key"] == "value" - assert decoded_msg["number"] == 42 diff --git a/dimos/protocol/rpc/lcmrpc.py b/dimos/protocol/rpc/lcmrpc.py index 7c6ed43c59..7ff98b1338 100644 --- a/dimos/protocol/rpc/lcmrpc.py +++ b/dimos/protocol/rpc/lcmrpc.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH from dimos.protocol.pubsub.lcmpubsub import PickleLCM, Topic from dimos.protocol.rpc.pubsubrpc import PassThroughPubSubRPC +from dimos.utils.generic import short_id class LCMRPC(PassThroughPubSubRPC, PickleLCM): def topicgen(self, name: str, req_or_res: bool) -> Topic: - return Topic(topic=f"/rpc/{name}/{'res' if req_or_res else 'req'}") + suffix = "res" if req_or_res else "req" + topic = f"/rpc/{name}/{suffix}" + if len(topic) > LCM_MAX_CHANNEL_NAME_LENGTH: + topic = f"/rpc/{short_id(name)}/{suffix}" + return Topic(topic=topic) diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index 1730b27175..ef4fb25aa4 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -14,30 +14,24 @@ from __future__ import annotations -import pickle -import subprocess -import sys -import threading import time -import traceback from abc import abstractmethod -from dataclasses import dataclass from types import FunctionType from typing import ( Any, Callable, Generic, Optional, - Protocol, - Sequence, TypedDict, TypeVar, - runtime_checkable, ) -from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub -from dimos.protocol.rpc.spec import Args, RPCClient, RPCInspectable, RPCServer, RPCSpec -from dimos.protocol.service.spec import Service +from dimos.protocol.pubsub.spec import PubSub +from dimos.protocol.rpc.spec import Args, RPCSpec +from dimos.utils.logging_config import setup_logger + + +logger = setup_logger(__file__) MsgT = TypeVar("MsgT") TopicT = TypeVar("TopicT") @@ -121,11 +115,23 @@ def receive_call(msg: MsgT, _: TopicT) -> None: args = req.get("args") if args is None: return - response = f(*args[0], **args[1]) - req_id = req.get("id") - if req_id is not None: - self.publish(topic_res, self._encodeRPCRes({"id": req_id, "res": response})) + # Execute RPC handler in a separate thread to avoid deadlock when + # the handler makes nested RPC calls. + def execute_and_respond(): + try: + response = f(*args[0], **args[1]) + req_id = req.get("id") + if req_id is not None: + self.publish(topic_res, self._encodeRPCRes({"id": req_id, "res": response})) + except Exception as e: + logger.exception(f"Exception in RPC handler for {name}: {e}", exc_info=e) + + get_thread_pool = getattr(self, "_get_call_thread_pool", None) + if get_thread_pool: + get_thread_pool().submit(execute_and_respond) + else: + execute_and_respond() return self.subscribe(topic_req, receive_call) diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 82115c6eec..461d60f8ae 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -46,7 +46,7 @@ def call( # we expect to crash if we don't get a return value after 10 seconds # but callers can override this timeout for extra long functions def call_sync( - self, name: str, arguments: Args, rpc_timeout: Optional[float] = 120.0 + self, name: str, arguments: Args, rpc_timeout: Optional[float] = 30.0 ) -> Tuple[Any, Callable[[], None]]: event = threading.Event() diff --git a/dimos/protocol/rpc/test_lcmrpc.py b/dimos/protocol/rpc/test_lcmrpc.py new file mode 100644 index 0000000000..02fe0a2d3a --- /dev/null +++ b/dimos/protocol/rpc/test_lcmrpc.py @@ -0,0 +1,43 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from collections.abc import Generator +from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH +from dimos.protocol.rpc.lcmrpc import LCMRPC + + +@pytest.fixture +def lcmrpc() -> Generator[LCMRPC, None, None]: + ret = LCMRPC() + ret.start() + yield ret + ret.stop() + + +def test_short_name(lcmrpc) -> None: + actual = lcmrpc.topicgen("Hello/say", req_or_res=True) + assert actual.topic == "/rpc/Hello/say/res" + + +def test_long_name(lcmrpc) -> None: + long = "GreatyLongComplexExampleClassNameForTestingStuff/create" + long_topic = lcmrpc.topicgen(long, req_or_res=True).topic + assert long_topic == "/rpc/2cudPuFGMJdWxM5KZb/res" + + less_long = long[:-1] + less_long_topic = lcmrpc.topicgen(less_long, req_or_res=True).topic + assert less_long_topic == "/rpc/GreatyLongComplexExampleClassNameForTestingStuff/creat/res" + + assert len(less_long_topic) == LCM_MAX_CHANNEL_NAME_LENGTH diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index 2228a671fc..f1cabbba3d 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -14,6 +14,7 @@ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor import os import subprocess import sys @@ -225,6 +226,8 @@ class LCMService(Service[LCMConfig]): _stop_event: threading.Event _l_lock: threading.Lock _thread: Optional[threading.Thread] + _call_thread_pool: ThreadPoolExecutor | None = None + _call_thread_pool_lock: threading.RLock = threading.RLock() def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -241,7 +244,37 @@ def __init__(self, **kwargs) -> None: self._stop_event = threading.Event() self._thread = None + def __getstate__(self): + """Exclude unpicklable runtime attributes when serializing.""" + state = self.__dict__.copy() + # Remove unpicklable attributes + state.pop("l", None) + state.pop("_stop_event", None) + state.pop("_thread", None) + state.pop("_l_lock", None) + state.pop("_call_thread_pool", None) + state.pop("_call_thread_pool_lock", None) + return state + + def __setstate__(self, state): + """Restore object from pickled state.""" + self.__dict__.update(state) + # Reinitialize runtime attributes + self.l = None + self._stop_event = threading.Event() + self._thread = None + self._l_lock = threading.Lock() + self._call_thread_pool = None + self._call_thread_pool_lock = threading.RLock() + def start(self): + # Reinitialize LCM if it's None (e.g., after unpickling) + if self.l is None: + if self.config.lcm: + self.l = self.config.lcm + else: + self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + if self.config.autoconf: autoconf() else: @@ -283,3 +316,28 @@ def stop(self): if self.l is not None: del self.l self.l = None + + with self._call_thread_pool_lock: + if self._call_thread_pool: + # Check if we're being called from within the thread pool + # If so, we can't wait for shutdown (would cause "cannot join current thread") + current_thread = threading.current_thread() + is_pool_thread = False + + # Check if current thread is one of the pool's threads + # ThreadPoolExecutor threads have names like "ThreadPoolExecutor-N_M" + if hasattr(self._call_thread_pool, "_threads"): + is_pool_thread = current_thread in self._call_thread_pool._threads + elif "ThreadPoolExecutor" in current_thread.name: + # Fallback: check thread name pattern + is_pool_thread = True + + # Don't wait if we're in a pool thread to avoid deadlock + self._call_thread_pool.shutdown(wait=not is_pool_thread) + self._call_thread_pool = None + + def _get_call_thread_pool(self) -> ThreadPoolExecutor: + with self._call_thread_pool_lock: + if self._call_thread_pool is None: + self._call_thread_pool = ThreadPoolExecutor(max_workers=4) + return self._call_thread_pool diff --git a/dimos/protocol/skill/coordinator.py b/dimos/protocol/skill/coordinator.py index 23d9025a1a..e9c8680864 100644 --- a/dimos/protocol/skill/coordinator.py +++ b/dimos/protocol/skill/coordinator.py @@ -349,14 +349,11 @@ def __len__(self) -> int: # and langchain takes this output as well # just faster for now def get_tools(self) -> list[dict]: - # return [skill.schema for skill in self.skills().values()] - - ret = [] - for name, skill_config in self.skills().items(): - # print(f"Tool {name} config: {skill_config}, {skill_config.f}") - ret.append(langchain_tool(skill_config.f)) - - return ret + return [ + langchain_tool(skill_config.f) + for skill_config in self.skills().values() + if not skill_config.hide_skill + ] # internal skill call def call_skill( diff --git a/dimos/protocol/skill/skill.py b/dimos/protocol/skill/skill.py index 6a7d35bcb9..5008232554 100644 --- a/dimos/protocol/skill/skill.py +++ b/dimos/protocol/skill/skill.py @@ -67,6 +67,7 @@ def skill( stream: Stream = Stream.none, ret: Return = Return.call_agent, output: Output = Output.standard, + hide_skill: bool = False, ) -> Callable: def decorator(f: Callable[..., Any]) -> Any: def wrapper(self, *args, **kwargs): @@ -100,6 +101,7 @@ def wrapper(self, *args, **kwargs): ret=ret.passive if stream == Stream.passive else ret, output=output, schema=function_to_schema(f), + hide_skill=hide_skill, ) wrapper.__rpc__ = True # type: ignore[attr-defined] diff --git a/dimos/protocol/skill/type.py b/dimos/protocol/skill/type.py index 25b83661f1..7ffbe13798 100644 --- a/dimos/protocol/skill/type.py +++ b/dimos/protocol/skill/type.py @@ -61,6 +61,7 @@ class SkillConfig: schema: dict[str, Any] f: Callable | None = None autostart: bool = False + hide_skill: bool = False def bind(self, f: Callable) -> "SkillConfig": self.f = f diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py new file mode 100644 index 0000000000..2eef48855f --- /dev/null +++ b/dimos/robot/all_blueprints.py @@ -0,0 +1,61 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.blueprints import ModuleBlueprintSet + + +# The blueprints are defined as import strings so as not to trigger unnecessary imports. +all_blueprints = { + "unitree-go2": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard", + "unitree-go2-basic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:basic", + "unitree-go2-shm": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard_with_shm", + "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", + "demo-osm": "dimos.mapping.osm.demo_osm:demo_osm", +} + + +all_modules = { + "astar_planner": "dimos.navigation.global_planner.planner", + "behavior_tree_navigator": "dimos.navigation.bt_navigator.navigator", + "connection": "dimos.robot.unitree_webrtc.unitree_go2", + "depth_module": "dimos.robot.unitree_webrtc.depth_module", + "detection_2d": "dimos.perception.detection2d.module2D", + "foxglove_bridge": "dimos.robot.foxglove_bridge", + "holonomic_local_planner": "dimos.navigation.local_planner.holonomic_local_planner", + "human_input": "dimos.agents2.cli.human", + "llm_agent": "dimos.agents2.agent", + "mapper": "dimos.robot.unitree_webrtc.type.map", + "navigation_skill": "dimos.agents2.skills.navigation", + "object_tracking": "dimos.perception.object_tracker", + "osm_skill": "dimos.agents2.skills.osm.py", + "spatial_memory": "dimos.perception.spatial_perception", + "utilization": "dimos.utils.monitoring", + "wavefront_frontier_explorer": "dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector", + "websocket_vis": "dimos.web.websocket_vis.websocket_vis_module", +} + + +def get_blueprint_by_name(name: str) -> ModuleBlueprintSet: + if name not in all_blueprints: + raise ValueError(f"Unknown blueprint set name: {name}") + module_path, attr = all_blueprints[name].split(":") + module = __import__(module_path, fromlist=[attr]) + return getattr(module, attr) + + +def get_module_by_name(name: str) -> ModuleBlueprintSet: + if name not in all_modules: + raise ValueError(f"Unknown module name: {name}") + python_module = __import__(all_modules[name], fromlist=[name]) + return getattr(python_module, name)() diff --git a/dimos/robot/cli/README.md b/dimos/robot/cli/README.md new file mode 100644 index 0000000000..164fc8538c --- /dev/null +++ b/dimos/robot/cli/README.md @@ -0,0 +1,65 @@ +# Robot CLI + +To avoid having so many runfiles, I created a common script to run any blueprint. + +For example, to run the standard Unitree Go2 blueprint run: + +```bash +dimos-robot run unitree-go2 +``` + +For the one with agents run: + +```bash +dimos-robot run unitree-go2-agentic +``` + +You can dynamically connect additional modules. For example: + +```bash +dimos-robot run unitree-go2 --extra-module llm_agent --extra-module human_input --extra-module navigation_skill +``` + +## Definitions + +Blueprints can be defined anywhere, but they're all linked together in `dimos/robot/all_blueprints.py`. E.g.: + +```python +all_blueprints = { + "unitree-go2": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard", + "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", + ... +} +``` + +(They are defined as imports to avoid triggering unrelated imports.) + +## `GlobalConfig` + +This tool also initializes the global config and passes it to the blueprint. + +`GlobalConfig` contains configuration options that are useful across many modules. For example: + +```python +class GlobalConfig(BaseSettings): + robot_ip: str | None = None + use_simulation: bool = False + use_replay: bool = False + n_dask_workers: int = 2 +``` + +Configuration values can be set from multiple places in order of precedence (later entries override earlier ones): + +- Default value defined on GlobalConfig. (`use_simulation = False`) +- Value defined in `.env` (`DIMOS_USE_SIMULATION=true`) +- Value in the environment variable (`DIMOS_USE_SIMULATION=true`) +- Value coming from the CLI (`--use-simulation` or `--no-use-simulation`) +- Value defined on the blueprint (`blueprint.with_global_config(use_simulation=True)`) + +For environment variables/`.env` values, you have to prefix the name with `DIMOS_`. + +For the command line, you call it like this: + +```bash +dimos-robot --use-simulation run unitree-go2 +``` \ No newline at end of file diff --git a/dimos/robot/cli/dimos_robot.py b/dimos/robot/cli/dimos_robot.py new file mode 100644 index 0000000000..5b589b3d69 --- /dev/null +++ b/dimos/robot/cli/dimos_robot.py @@ -0,0 +1,129 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from enum import Enum +from typing import Optional, get_args, get_origin + +import typer + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import GlobalConfig +from dimos.robot.all_blueprints import all_blueprints, get_blueprint_by_name, get_module_by_name +from dimos.protocol import pubsub + + +RobotType = Enum("RobotType", {key.replace("-", "_").upper(): key for key in all_blueprints.keys()}) + +main = typer.Typer() + + +def create_dynamic_callback(): + fields = GlobalConfig.model_fields + + # Build the function signature dynamically + params = [ + inspect.Parameter("ctx", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typer.Context), + ] + + # Create parameters for each field in GlobalConfig + for field_name, field_info in fields.items(): + field_type = field_info.annotation + + # Handle Optional types + if get_origin(field_type) is type(Optional[str]): # Check for Optional/Union with None + inner_types = get_args(field_type) + if len(inner_types) == 2 and type(None) in inner_types: + # It's Optional[T], get the actual type T + actual_type = next(t for t in inner_types if t != type(None)) + else: + actual_type = field_type + else: + actual_type = field_type + + # Convert field name from snake_case to kebab-case for CLI + cli_option_name = field_name.replace("_", "-") + + # Special handling for boolean fields + if actual_type is bool: + # For boolean fields, create --flag/--no-flag pattern + param = inspect.Parameter( + field_name, + inspect.Parameter.KEYWORD_ONLY, + default=typer.Option( + None, # None means use the model's default if not provided + f"--{cli_option_name}/--no-{cli_option_name}", + help=f"Override {field_name} in GlobalConfig", + ), + annotation=Optional[bool], + ) + else: + # For non-boolean fields, use regular option + param = inspect.Parameter( + field_name, + inspect.Parameter.KEYWORD_ONLY, + default=typer.Option( + None, # None means use the model's default if not provided + f"--{cli_option_name}", + help=f"Override {field_name} in GlobalConfig", + ), + annotation=Optional[actual_type], + ) + params.append(param) + + def callback(**kwargs): + ctx = kwargs.pop("ctx") + overrides = {k: v for k, v in kwargs.items() if v is not None} + ctx.obj = GlobalConfig().model_copy(update=overrides) + + callback.__signature__ = inspect.Signature(params) + + return callback + + +main.callback()(create_dynamic_callback()) + + +@main.command() +def run( + ctx: typer.Context, + robot_type: RobotType = typer.Argument(..., help="Type of robot to run"), + extra_modules: list[str] = typer.Option( + [], "--extra-module", help="Extra modules to add to the blueprint" + ), +): + """Run the robot with the specified configuration.""" + config: GlobalConfig = ctx.obj + pubsub.lcm.autoconf() + blueprint = get_blueprint_by_name(robot_type.value) + + if extra_modules: + loaded_modules = [get_module_by_name(mod_name) for mod_name in extra_modules] + blueprint = autoconnect(blueprint, *loaded_modules) + + dimos = blueprint.build(global_config=config) + dimos.wait_until_shutdown() + + +@main.command() +def show_config(ctx: typer.Context): + """Show current configuration status.""" + config: GlobalConfig = ctx.obj + + for field_name, value in config.model_dump().items(): + typer.echo(f"{field_name}: {value}") + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 18211f65c2..0e9b967757 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -14,6 +14,7 @@ import asyncio import threading +import logging # this is missing, I'm just trying to import lcm_foxglove_bridge.py from dimos_lcm from dimos_lcm.foxglove_bridge import FoxgloveBridge as LCMFoxgloveBridge @@ -37,6 +38,12 @@ def run_bridge(): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) try: + for logger in ["lcm_foxglove_bridge", "FoxgloveServer"]: + logger = logging.getLogger(logger) + logger.setLevel(logging.ERROR) + for handler in logger.handlers: + handler.setLevel(logging.ERROR) + bridge = LCMFoxgloveBridge( host="0.0.0.0", port=8765, @@ -58,3 +65,9 @@ def stop(self): self._thread.join(timeout=2) super().stop() + + +foxglove_bridge = FoxgloveBridge.blueprint + + +__all__ = ["FoxgloveBridge", "foxglove_bridge"] diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py index 7cdd50cf0b..7a0bd27867 100644 --- a/dimos/robot/robot.py +++ b/dimos/robot/robot.py @@ -25,6 +25,7 @@ from dimos.types.robot_capabilities import RobotCapability +# TODO: Delete class Robot(ABC): """Minimal abstract base class for all DIMOS robots. @@ -64,6 +65,7 @@ def cleanup(self): pass +# TODO: Delete class UnitreeRobot(Robot): @abstractmethod def get_odom(self) -> PoseStamped: ... @@ -82,6 +84,7 @@ def is_exploration_active(self) -> bool: ... def spatial_memory(self) -> Optional[SpatialMemory]: ... +# TODO: Delete class GpsRobot(ABC): @property @abstractmethod diff --git a/dimos/robot/unitree_webrtc/depth_module.py b/dimos/robot/unitree_webrtc/depth_module.py index b5b3b12738..2e0bd77ee2 100644 --- a/dimos/robot/unitree_webrtc/depth_module.py +++ b/dimos/robot/unitree_webrtc/depth_module.py @@ -21,6 +21,7 @@ import numpy as np from dimos.core import Module, In, Out, rpc +from dimos.core.global_config import GlobalConfig from dimos.msgs.sensor_msgs import Image, ImageFormat from dimos_lcm.sensor_msgs import CameraInfo from dimos.utils.logging_config import setup_logger @@ -49,7 +50,8 @@ class DepthModule(Module): def __init__( self, - gt_depth_scale: float = 1.0, + gt_depth_scale: float = 0.5, + global_config: GlobalConfig | None = None, **kwargs, ): """ @@ -77,7 +79,9 @@ def __init__( self._processing_thread: Optional[threading.Thread] = None self._stop_processing = threading.Event() - logger.info(f"DepthModule initialized") + if global_config: + if global_config.use_simulation: + self.gt_depth_scale = 1.0 @rpc def start(self): @@ -232,3 +236,9 @@ def _publish_depth(self): except Exception as e: logger.error(f"Error publishing depth data: {e}", exc_info=True) + + +depth_module = DepthModule.blueprint + + +__all__ = ["DepthModule", "depth_module"] diff --git a/dimos/robot/unitree_webrtc/run.py b/dimos/robot/unitree_webrtc/run.py deleted file mode 100644 index ee4c21b51a..0000000000 --- a/dimos/robot/unitree_webrtc/run.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Run script for Unitree Go2 robot with Claude agent integration. -Provides navigation and interaction capabilities with natural language interface. -""" - -import os -import sys -import time -from dotenv import load_dotenv - -from reactivex.subject import Subject -import reactivex.operators as ops - -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.agents.claude_agent import ClaudeAgent -from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal, Explore -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.stream.audio.pipelines import tts -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.unitree_webrtc.run") - -# Load environment variables -load_dotenv() - -# System prompt - loaded from prompt.txt -SYSTEM_PROMPT_PATH = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), - "assets/agent/prompt.txt", -) - - -def main(): - """Main entry point.""" - print("\n" + "=" * 60) - print("Unitree Go2 Robot with Claude Agent") - print("=" * 60) - print("\nThis system integrates:") - print(" - Unitree Go2 quadruped robot") - print(" - WebRTC communication interface") - print(" - Claude AI for natural language understanding") - print(" - Spatial memory and navigation") - print(" - Web interface with text and voice input") - print("\nStarting system...\n") - - # Check for API key - if not os.getenv("ANTHROPIC_API_KEY"): - print("WARNING: ANTHROPIC_API_KEY not found in environment") - print("Please set your API key in .env file or environment") - sys.exit(1) - - # Load system prompt - try: - with open(SYSTEM_PROMPT_PATH, "r") as f: - system_prompt = f.read() - except FileNotFoundError: - logger.error(f"System prompt file not found at {SYSTEM_PROMPT_PATH}") - sys.exit(1) - - logger.info("Starting Unitree Go2 Robot with Agent") - - # Create robot instance - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - connection_type=os.getenv("CONNECTION_TYPE", "webrtc"), - ) - - robot.start() - time.sleep(3) - - try: - logger.info("Robot initialized successfully") - - # Set up skill library - skills = robot.get_skills() - skills.add(KillSkill) - skills.add(NavigateWithText) - skills.add(GetPose) - skills.add(NavigateToGoal) - skills.add(Explore) - - # Create skill instances - skills.create_instance("KillSkill", robot=robot, skill_library=skills) - skills.create_instance("NavigateWithText", robot=robot) - skills.create_instance("GetPose", robot=robot) - skills.create_instance("NavigateToGoal", robot=robot) - skills.create_instance("Explore", robot=robot) - - logger.info(f"Skills registered: {[skill.__name__ for skill in skills.get_class_skills()]}") - - # Set up streams for agent and web interface - agent_response_subject = Subject() - agent_response_stream = agent_response_subject.pipe(ops.share()) - audio_subject = Subject() - - # Set up streams for web interface - streams = {} - - text_streams = { - "agent_responses": agent_response_stream, - } - - # Create web interface first (needed for agent) - try: - web_interface = RobotWebInterface( - port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams - ) - logger.info("Web interface created successfully") - except Exception as e: - logger.error(f"Failed to create web interface: {e}") - raise - - # Set up speech-to-text - # stt_node = stt() - # stt_node.consume_audio(audio_subject.pipe(ops.share())) - - # Create Claude agent - agent = ClaudeAgent( - dev_name="unitree_go2_agent", - input_query_stream=web_interface.query_stream, # Use text input from web interface - # input_query_stream=stt_node.emit_text(), # Uncomment to use voice input - skills=skills, - system_query=system_prompt, - model_name="claude-3-5-haiku-latest", - thinking_budget_tokens=0, - max_output_tokens_per_request=8192, - ) - - # Subscribe to agent responses - agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - - # Set up text-to-speech for agent responses - tts_node = tts() - tts_node.consume_text(agent.get_response_observable()) - - # Create skill instances that need agent reference - - logger.info("=" * 60) - logger.info("Unitree Go2 Agent Ready!") - logger.info("Web interface available at: http://localhost:5555") - logger.info("You can:") - logger.info(" - Type commands in the web interface") - logger.info(" - Use voice commands") - logger.info(" - Ask the robot to navigate to locations") - logger.info(" - Ask the robot to observe and describe its surroundings") - logger.info(" - Ask the robot to follow people or explore areas") - logger.info("=" * 60) - - # Run web interface (this blocks) - web_interface.run() - - except KeyboardInterrupt: - logger.info("Keyboard interrupt received") - except Exception as e: - logger.error(f"Error running robot: {e}") - import traceback - - traceback.print_exc() - finally: - logger.info("Shutting down...") - # WebRTC robot doesn't have a stop method, just log shutdown - logger.info("Shutdown complete") - - -if __name__ == "__main__": - main() diff --git a/dimos/robot/unitree_webrtc/run_agents2.py b/dimos/robot/unitree_webrtc/run_agents2.py deleted file mode 100755 index e779c26bb6..0000000000 --- a/dimos/robot/unitree_webrtc/run_agents2.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -from typing import Optional -from dotenv import load_dotenv - -from dimos.agents2 import Agent -from dimos.agents2.cli.human import HumanInput -from dimos.agents2.constants import AGENT_SYSTEM_PROMPT_PATH -from dimos.core.resource import Resource -from dimos.robot.robot import UnitreeRobot -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer -from dimos.agents2.skills.navigation import NavigationSkillContainer -from dimos.robot.utils.robot_debugger import RobotDebugger -from dimos.utils.logging_config import setup_logger - -logger = setup_logger(__file__) - -load_dotenv() - -with open(AGENT_SYSTEM_PROMPT_PATH, "r") as f: - SYSTEM_PROMPT = f.read() - - -class UnitreeAgents2Runner(Resource): - _robot: Optional[UnitreeRobot] - _agent: Optional[Agent] - _robot_debugger: Optional[RobotDebugger] - _navigation_skill: Optional[NavigationSkillContainer] - - def __init__(self): - self._robot: UnitreeRobot = None - self._agent = None - self._robot_debugger = None - self._navigation_skill = None - - def start(self) -> None: - self._robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - connection_type=os.getenv("CONNECTION_TYPE", "webrtc"), - ) - - time.sleep(3) - - logger.info("Robot initialized successfully") - - self.setup_agent() - - self._robot_debugger = RobotDebugger(self._robot) - self._robot_debugger.start() - - def stop(self) -> None: - if self._navigation_skill: - self._navigation_skill.stop() - if self._robot_debugger: - self._robot_debugger.stop() - if self._agent: - self._agent.stop() - if self._robot: - self._robot.stop() - - def setup_agent(self) -> None: - if not self._robot: - raise ValueError("robot not set") - - logger.info("Setting up agent with skills...") - - self._agent = Agent(system_prompt=SYSTEM_PROMPT) - self._navigation_skill = NavigationSkillContainer( - robot=self._robot, - video_stream=self._robot.connection.video, - ) - self._navigation_skill.start() - - skill_containers = [ - UnitreeSkillContainer(robot=self._robot), - self._navigation_skill, - HumanInput(), - ] - - for container in skill_containers: - logger.info(f"Registering skills from container: {container}") - self._agent.register_skills(container) - - self._agent.run_implicit_skill("human") - - self._agent.start() - - # Log available skills - tools = self._agent.get_tools() - names = ", ".join([tool.name for tool in tools]) - logger.info(f"Agent configured with {len(tools)} skills: {names}") - - # Start the agent loop thread - self._agent.loop_thread() - - def run(self): - while True: - try: - time.sleep(1) - except KeyboardInterrupt: - return - - -def main(): - runner = UnitreeAgents2Runner() - runner.start() - runner.run() - runner.stop() - - -if __name__ == "__main__": - main() diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index 52e2c62260..068048fb8b 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -21,6 +21,7 @@ from reactivex.disposable import Disposable from dimos.core import In, Module, Out, rpc +from dimos.core.global_config import GlobalConfig from dimos.msgs.nav_msgs import OccupancyGrid from dimos.msgs.sensor_msgs import PointCloud2 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -41,6 +42,7 @@ def __init__( global_publish_interval: Optional[float] = None, min_height: float = 0.15, max_height: float = 0.6, + global_config: GlobalConfig | None = None, **kwargs, ): self.voxel_size = voxel_size @@ -48,6 +50,11 @@ def __init__( self.global_publish_interval = global_publish_interval self.min_height = min_height self.max_height = max_height + + if global_config: + if global_config.use_simulation: + self.min_height = 0.3 + super().__init__(**kwargs) @rpc @@ -159,3 +166,9 @@ def splice_cylinder( survivors = map_pcd.select_by_index(victims, invert=True) return survivors + patch_pcd + + +mapper = Map.blueprint + + +__all__ = ["Map", "mapper"] diff --git a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py index 78d22c37e3..5501557820 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py @@ -25,7 +25,7 @@ from typing import Optional from dimos import core -from dimos.core.dimos import Dimos +from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.resource import Resource from dimos.msgs.geometry_msgs import TwistStamped, PoseStamped from dimos.msgs.nav_msgs.Odometry import Odometry @@ -99,7 +99,7 @@ def __init__( self.connection = None self.joystick = None self.ros_bridge = None - self._dimos = Dimos(n=2) + self._dimos = ModuleCoordinator(n=2) os.makedirs(self.output_dir, exist_ok=True) logger.info(f"Robot outputs will be saved to: {self.output_dir}") diff --git a/dimos/robot/unitree_webrtc/unitree_g1.py b/dimos/robot/unitree_webrtc/unitree_g1.py index 26b310629f..d8f6975d27 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1.py +++ b/dimos/robot/unitree_webrtc/unitree_g1.py @@ -38,7 +38,7 @@ from dimos.agents2.skills.ros_navigation import RosNavigation from dimos.agents2.spec import Model, Provider from dimos.core import In, Module, Out, rpc -from dimos.core.dimos import Dimos +from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.resource import Resource from dimos.hardware.camera import zed from dimos.hardware.camera.module import CameraModule @@ -195,7 +195,7 @@ def __init__( self.capabilities = [RobotCapability.LOCOMOTION] # Module references - self._dimos = Dimos(n=4) + self._dimos = ModuleCoordinator(n=4) self.connection = None self.websocket_vis = None self.foxglove_bridge = None diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index a3109e24f3..7bb544f52c 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -28,7 +28,8 @@ from dimos import core from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core import In, Module, Out, rpc -from dimos.core.dimos import Dimos +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.global_config import GlobalConfig from dimos.core.resource import Resource from dimos.mapping.types import LatLon from dimos.msgs.std_msgs import Header @@ -150,15 +151,17 @@ class ConnectionModule(Module): def __init__( self, - ip: str = None, - connection_type: str = "webrtc", + ip: str | None = None, + connection_type: str | None = None, rectify_image: bool = True, + global_config: GlobalConfig | None = None, *args, **kwargs, ): - self.ip = ip - self.connection_type = connection_type - self.rectify_image = rectify_image + cfg = global_config or GlobalConfig() + self.ip = ip if ip is not None else cfg.robot_ip + self.connection_type = connection_type or cfg.unitree_connection_type + self.rectify_image = not cfg.use_simulation self.tf = TF() self.connection = None @@ -325,10 +328,13 @@ def publish_request(self, topic: str, data: dict): return self.connection.publish_request(topic, data) +connection = ConnectionModule.blueprint + + class UnitreeGo2(UnitreeRobot, Resource): """Full Unitree Go2 robot with navigation and perception capabilities.""" - _dimos: Dimos + _dimos: ModuleCoordinator _disposables: CompositeDisposable = CompositeDisposable() def __init__( @@ -349,7 +355,7 @@ def __init__( connection_type: webrtc, replay, or mujoco """ super().__init__() - self._dimos = Dimos(n=8, memory_limit="8GiB") + self._dimos = ModuleCoordinator(n=8, memory_limit="8GiB") self.ip = ip self.connection_type = connection_type or "webrtc" if ip is None and self.connection_type == "webrtc": @@ -590,10 +596,6 @@ def _start_modules(self): self.skill_library.init() self.skill_library.initialize_skills() - def get_single_rgb_frame(self, timeout: float = 2.0) -> Image: - topic = Topic("/go2/color_image", Image) - return self.lcm.wait_for_message(topic, timeout=timeout) - def move(self, twist: Twist, duration: float = 0.0): """Send movement command to robot.""" self.connection.move(twist, duration) @@ -701,3 +703,6 @@ def main(): if __name__ == "__main__": main() + + +__all__ = ["ConnectionModule", "connection", "UnitreeGo2", "ReplayRTC"] diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py new file mode 100644 index 0000000000..af13dc20bc --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE, DEFAULT_CAPACITY_DEPTH_IMAGE +from dimos.core.blueprints import autoconnect +from dimos.core.transport import LCMTransport, pSHMTransport +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Image +from dimos_lcm.sensor_msgs import CameraInfo +from dimos.perception.spatial_perception import spatial_memory +from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.robot.unitree_webrtc.unitree_go2 import connection +from dimos.utils.monitoring import utilization +from dimos.web.websocket_vis.websocket_vis_module import websocket_vis +from dimos.navigation.global_planner import astar_planner +from dimos.navigation.local_planner.holonomic_local_planner import ( + holonomic_local_planner, +) +from dimos.navigation.bt_navigator.navigator import ( + behavior_tree_navigator, +) +from dimos.navigation.frontier_exploration import ( + wavefront_frontier_explorer, +) +from dimos.robot.unitree_webrtc.type.map import mapper +from dimos.robot.unitree_webrtc.depth_module import depth_module +from dimos.perception.object_tracker import object_tracking +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.navigation import navigation_skill + + +basic = ( + autoconnect( + connection(), + mapper(voxel_size=0.5, global_publish_interval=2.5), + astar_planner(), + holonomic_local_planner(), + behavior_tree_navigator(), + wavefront_frontier_explorer(), + websocket_vis(), + foxglove_bridge(), + ) + .with_global_config(n_dask_workers=4) + .with_transports( + # These are kept the same so that we don't have to change foxglove configs. + # Although we probably should. + { + ("color_image", Image): LCMTransport("/go2/color_image", Image), + ("camera_pose", PoseStamped): LCMTransport("/go2/camera_pose", PoseStamped), + ("camera_info", CameraInfo): LCMTransport("/go2/camera_info", CameraInfo), + } + ) +) + +standard = ( + autoconnect( + basic, + spatial_memory(), + object_tracking(frame_id="camera_link"), + depth_module(), + utilization(), + ) + .with_global_config(n_dask_workers=8) + .with_transports( + { + ("depth_image", Image): LCMTransport("/go2/depth_image", Image), + } + ) +) + +standard_with_shm = autoconnect( + standard.with_transports( + { + ("color_image", Image): pSHMTransport( + "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), + ("depth_image", Image): pSHMTransport( + "/go2/depth_image", default_capacity=DEFAULT_CAPACITY_DEPTH_IMAGE + ), + } + ), + foxglove_bridge( + shm_channels=[ + "/go2/color_image#sensor_msgs.Image", + "/go2/depth_image#sensor_msgs.Image", + ] + ), +) + +agentic = autoconnect( + standard, + llm_agent(), + human_input(), + navigation_skill(), +) diff --git a/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py b/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py deleted file mode 100644 index cf2136dde6..0000000000 --- a/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py +++ /dev/null @@ -1,528 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# $$$$$$$$\ $$$$$$\ $$$$$$$\ $$$$$$\ -# \__$$ __|$$ __$$\ $$ __$$\ $$ __$$\ -# $$ | $$ / $$ |$$ | $$ |$$ / $$ | -# $$ | $$ | $$ |$$ | $$ |$$ | $$ | -# $$ | $$ | $$ |$$ | $$ |$$ | $$ | -# $$ | $$ | $$ |$$ | $$ |$$ | $$ | -# $$ | $$$$$$ |$$$$$$$ | $$$$$$ | -# \__| \______/ \_______/ \______/ -# DOES anyone use this? The imports are broken which tells me it's unused. - -import functools -import logging -import os -import time -import warnings -from typing import Optional - -from dimos_lcm.std_msgs import Bool, String - -from dimos import core -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid, Path -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState -from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer -from dimos.navigation.global_planner import AstarPlanner -from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner - -from dimos.perception.common.utils import load_camera_info, load_camera_info_opencv, rectify_image -from dimos.protocol import pubsub -from dimos.protocol.pubsub.lcmpubsub import LCM -from dimos.protocol.tf import TF -from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.robot import Robot -from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.robot_capabilities import RobotCapability - -from dimos.utils.data import get_data -from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import TimedSensorReplay - -logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2_nav_only", level=logging.INFO) - -# Suppress verbose loggers -logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) -logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) -logging.getLogger("websockets.server").setLevel(logging.ERROR) -logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) -logging.getLogger("asyncio").setLevel(logging.ERROR) -logging.getLogger("root").setLevel(logging.WARNING) - -# Suppress warnings -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") - - -class FakeRTC: - """Fake WebRTC connection for testing with recorded data.""" - - def __init__(self, *args, **kwargs): - get_data("unitree_office_walk") # Preload data for testing - - def connect(self): - pass - - def standup(self): - print("standup suppressed") - - def liedown(self): - print("liedown suppressed") - - @functools.cache - def lidar_stream(self): - print("lidar stream start") - lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) - return lidar_store.stream() - - @functools.cache - def odom_stream(self): - print("odom stream start") - odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) - return odom_store.stream() - - @functools.cache - def video_stream(self): - print("video stream start") - video_store = TimedSensorReplay( - "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() - ) - return video_store.stream() - - def move(self, twist: Twist, duration: float = 0.0): - pass - - def publish_request(self, topic: str, data: dict): - """Fake publish request for testing.""" - return {"status": "ok", "message": "Fake publish"} - - -class ConnectionModule(Module): - """Module that handles robot sensor data, movement commands, and camera information.""" - - movecmd: In[Twist] = None - odom: Out[PoseStamped] = None - lidar: Out[LidarMessage] = None - video: Out[Image] = None - camera_info: Out[CameraInfo] = None - camera_pose: Out[PoseStamped] = None - ip: str - connection_type: str = "webrtc" - - _odom: PoseStamped = None - _lidar: LidarMessage = None - _last_image: Image = None - - def __init__( - self, - ip: str = None, - connection_type: str = "webrtc", - rectify_image: bool = True, - *args, - **kwargs, - ): - self.ip = ip - self.connection_type = connection_type - self.rectify_image = rectify_image - self.tf = TF() - self.connection = None - - # Load camera parameters from YAML - base_dir = os.path.dirname(os.path.abspath(__file__)) - - # Use sim camera parameters for mujoco, real camera for others - if connection_type == "mujoco": - camera_params_path = os.path.join(base_dir, "params", "sim_camera.yaml") - else: - camera_params_path = os.path.join(base_dir, "params", "front_camera_720.yaml") - - self.lcm_camera_info = load_camera_info(camera_params_path, frame_id="camera_link") - - # Load OpenCV matrices for rectification if enabled - if rectify_image: - self.camera_matrix, self.dist_coeffs = load_camera_info_opencv(camera_params_path) - self.lcm_camera_info.D = [0.0] * len( - self.lcm_camera_info.D - ) # zero out distortion coefficients for rectification - else: - self.camera_matrix = None - self.dist_coeffs = None - - Module.__init__(self, *args, **kwargs) - - @rpc - def start(self): - super().start() - """Start the connection and subscribe to sensor streams.""" - match self.connection_type: - case "webrtc": - self.connection = UnitreeWebRTCConnection(self.ip) - case "fake": - self.connection = FakeRTC(self.ip) - case "mujoco": - from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection - - self.connection = MujocoConnection() - self.connection.start() - case _: - raise ValueError(f"Unknown connection type: {self.connection_type}") - - # Connect sensor streams to outputs - unsub = self.connection.lidar_stream().subscribe(self.lidar.publish) - self._disposables.add(unsub) - - unsub = self.connection.odom_stream().subscribe(self._publish_tf) - self._disposables.add(unsub) - - unsub = self.connection.video_stream().subscribe(self._on_video) - self._disposables.add(unsub) - - unsub = self.movecmd.subscribe(self.move) - self._disposables.add(unsub) - - @rpc - def stop(self) -> None: - if self.connection: - self.connection.stop() - super().stop() - - def _on_video(self, msg: Image): - """Handle incoming video frames and publish synchronized camera data.""" - # Apply rectification if enabled - if self.rectify_image: - rectified_msg = rectify_image(msg, self.camera_matrix, self.dist_coeffs) - self._last_image = rectified_msg - self.video.publish(rectified_msg) - else: - self._last_image = msg - self.video.publish(msg) - - # Publish camera info and pose synchronized with video - timestamp = msg.ts if msg.ts else time.time() - self._publish_camera_info(timestamp) - self._publish_camera_pose(timestamp) - - def _publish_tf(self, msg): - self._odom = msg - self.odom.publish(msg) - self.tf.publish(Transform.from_pose("base_link", msg)) - camera_link = Transform( - translation=Vector3(0.3, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", - ts=time.time(), - ) - self.tf.publish(camera_link) - - def _publish_camera_info(self, timestamp: float): - header = Header(timestamp, "camera_link") - self.lcm_camera_info.header = header - self.camera_info.publish(self.lcm_camera_info) - - def _publish_camera_pose(self, timestamp: float): - """Publish camera pose from TF lookup.""" - try: - # Look up transform from world to camera_link - transform = self.tf.get( - parent_frame="world", - child_frame="camera_link", - time_point=timestamp, - time_tolerance=1.0, - ) - - if transform: - pose_msg = PoseStamped( - ts=timestamp, - frame_id="camera_link", - position=transform.translation, - orientation=transform.rotation, - ) - self.camera_pose.publish(pose_msg) - else: - logger.debug("Could not find transform from world to camera_link") - - except Exception as e: - logger.error(f"Error publishing camera pose: {e}") - - @rpc - def get_odom(self) -> Optional[PoseStamped]: - """Get the robot's odometry. - - Returns: - The robot's odometry - """ - return self._odom - - @rpc - def move(self, twist: Twist, duration: float = 0.0): - """Send movement command to robot.""" - self.connection.move(twist, duration) - - @rpc - def standup(self): - """Make the robot stand up.""" - return self.connection.standup() - - @rpc - def liedown(self): - """Make the robot lie down.""" - return self.connection.liedown() - - @rpc - def publish_request(self, topic: str, data: dict): - """Publish a request to the WebRTC connection. - Args: - topic: The RTC topic to publish to - data: The data dictionary to publish - Returns: - The result of the publish request - """ - return self.connection.publish_request(topic, data) - - -class UnitreeGo2NavOnly(Robot): - """Minimal Unitree Go2 robot with only navigation and visualization capabilities.""" - - def __init__( - self, - ip: str, - websocket_port: int = 7779, - connection_type: Optional[str] = "webrtc", - ): - """Initialize the navigation-only robot system. - - Args: - ip: Robot IP address (or None for fake connection) - websocket_port: Port for web visualization - connection_type: webrtc, fake, or mujoco - """ - super().__init__() - self.ip = ip - self.connection_type = connection_type or "webrtc" - if ip is None and self.connection_type == "webrtc": - self.connection_type = "fake" # Auto-enable playback if no IP provided - self.websocket_port = websocket_port - self.lcm = LCM() - - # Set capabilities - navigation only - self.capabilities = [RobotCapability.LOCOMOTION] - - self.dimos = None - self.connection = None - self.mapper = None - self.global_planner = None - self.local_planner = None - self.navigator = None - self.frontier_explorer = None - self.websocket_vis = None - self.foxglove_bridge = None - - def start(self): - """Start the robot system with navigation modules only.""" - self.dimos = core.start(8) - - self._deploy_connection() - self._deploy_mapping() - self._deploy_navigation() - - self.foxglove_bridge = self.dimos.deploy(FoxgloveBridge) - - self._start_modules() - - self.lcm.start() - - logger.info("UnitreeGo2NavOnly initialized and started") - logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") - - def _deploy_connection(self): - """Deploy and configure the connection module.""" - self.connection = self.dimos.deploy( - ConnectionModule, self.ip, connection_type=self.connection_type - ) - - self.connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) - self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) - self.connection.video.transport = core.LCMTransport("/go2/color_image", Image) - self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", Twist) - self.connection.camera_info.transport = core.LCMTransport("/go2/camera_info", CameraInfo) - self.connection.camera_pose.transport = core.LCMTransport("/go2/camera_pose", PoseStamped) - - def _deploy_mapping(self): - """Deploy and configure the mapping module.""" - min_height = 0.3 if self.connection_type == "mujoco" else 0.15 - self.mapper = self.dimos.deploy( - Map, voxel_size=0.5, global_publish_interval=2.5, min_height=min_height - ) - - self.mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) - self.mapper.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) - self.mapper.local_costmap.transport = core.LCMTransport("/local_costmap", OccupancyGrid) - - self.mapper.lidar.connect(self.connection.lidar) - - def _deploy_navigation(self): - """Deploy and configure navigation modules.""" - self.global_planner = self.dimos.deploy(AstarPlanner) - self.local_planner = self.dimos.deploy(HolonomicLocalPlanner) - self.navigator = self.dimos.deploy( - BehaviorTreeNavigator, - reset_local_planner=self.local_planner.reset, - check_goal_reached=self.local_planner.is_goal_reached, - ) - self.frontier_explorer = self.dimos.deploy(WavefrontFrontierExplorer) - - self.navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) - self.navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) - self.navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) - self.navigator.navigation_state.transport = core.LCMTransport("/navigation_state", String) - self.navigator.global_costmap.transport = core.LCMTransport( - "/global_costmap", OccupancyGrid - ) - self.global_planner.path.transport = core.LCMTransport("/global_path", Path) - self.local_planner.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) - self.frontier_explorer.goal_request.transport = core.LCMTransport( - "/goal_request", PoseStamped - ) - self.frontier_explorer.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) - self.frontier_explorer.explore_cmd.transport = core.LCMTransport("/explore_cmd", Bool) - self.frontier_explorer.stop_explore_cmd.transport = core.LCMTransport( - "/stop_explore_cmd", Bool - ) - - self.global_planner.target.connect(self.navigator.goal) - - self.global_planner.global_costmap.connect(self.mapper.global_costmap) - self.global_planner.odom.connect(self.connection.odom) - - self.local_planner.path.connect(self.global_planner.path) - self.local_planner.local_costmap.connect(self.mapper.local_costmap) - self.local_planner.odom.connect(self.connection.odom) - - self.connection.movecmd.connect(self.local_planner.cmd_vel) - - self.navigator.odom.connect(self.connection.odom) - - self.frontier_explorer.costmap.connect(self.mapper.global_costmap) - self.frontier_explorer.odometry.connect(self.connection.odom) - - def _start_modules(self): - """Start all deployed modules in the correct order.""" - self.connection.start() - self.mapper.start() - self.global_planner.start() - self.local_planner.start() - self.navigator.start() - self.frontier_explorer.start() - self.foxglove_bridge.start() - - def move(self, twist: Twist, duration: float = 0.0): - """Send movement command to robot.""" - self.connection.move(twist, duration) - - def explore(self) -> bool: - """Start autonomous frontier exploration. - - Returns: - True if exploration started successfully - """ - return self.frontier_explorer.explore() - - def navigate_to(self, pose: PoseStamped, blocking: bool = True): - """Navigate to a target pose. - - Args: - pose: Target pose to navigate to - blocking: If True, block until goal is reached. If False, return immediately. - - Returns: - If blocking=True: True if navigation was successful, False otherwise - If blocking=False: True if goal was accepted, False otherwise - """ - - logger.info( - f"Navigating to pose: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" - ) - self.navigator.set_goal(pose) - time.sleep(1.0) - - if blocking: - while self.navigator.get_state() == NavigatorState.FOLLOWING_PATH: - time.sleep(0.25) - - time.sleep(1.0) - if not self.navigator.is_goal_reached(): - logger.info("Navigation was cancelled or failed") - return False - else: - logger.info("Navigation goal reached") - return True - - return True - - def stop_exploration(self) -> bool: - """Stop autonomous exploration. - - Returns: - True if exploration was stopped - """ - self.navigator.cancel_goal() - return self.frontier_explorer.stop_exploration() - - def cancel_navigation(self) -> bool: - """Cancel the current navigation goal. - - Returns: - True if goal was cancelled - """ - return self.navigator.cancel_goal() - - def get_odom(self) -> PoseStamped: - """Get the robot's odometry. - - Returns: - The robot's odometry - """ - return self.connection.get_odom() - - -def main(): - """Main entry point.""" - ip = os.getenv("ROBOT_IP") - connection_type = os.getenv("CONNECTION_TYPE", "webrtc") - - pubsub.lcm.autoconf() - - robot = UnitreeGo2NavOnly(ip=ip, websocket_port=7779, connection_type=connection_type) - robot.start() - - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - logger.info("Shutting down...") - - -if __name__ == "__main__": - main() diff --git a/dimos/skills/navigation.py b/dimos/skills/navigation.py deleted file mode 100644 index 7a6e1af4d9..0000000000 --- a/dimos/skills/navigation.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Semantic map skills for building and navigating spatial memory maps. - -This module provides two skills: -1. BuildSemanticMap - Builds a semantic map by recording video frames at different locations -2. Navigate - Queries an existing semantic map using natural language -""" - -import os -import time -from typing import Optional, Tuple -import cv2 -from pydantic import Field - -from dimos.skills.skills import AbstractRobotSkill -from dimos.types.robot_location import RobotLocation -from dimos.utils.logging_config import setup_logger -from dimos.models.qwen.video_query import get_bbox_from_qwen_frame -from dimos.msgs.geometry_msgs import PoseStamped, Vector3 -from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler - -logger = setup_logger(__file__) - - -class NavigateWithText(AbstractRobotSkill): - """ - A skill that queries an existing semantic map using natural language or tries to navigate to an object in view. - - This skill first attempts to locate an object in the robot's camera view using vision. - If the object is found, it navigates to it. If not, it falls back to querying the - semantic map for a location matching the description. For example, "Find the Teddy Bear" - will first look for a Teddy Bear in view, then check the semantic map coordinates where - a Teddy Bear was previously observed. - - CALL THIS SKILL FOR ONE SUBJECT AT A TIME. For example: "Go to the person wearing a blue shirt in the living room", - you should call this skill twice, once for the person wearing a blue shirt and once for the living room. - - If skip_visual_search is True, this skill will skip the visual search for the object in view. - This is useful if you want to navigate to a general location such as a kitchen or office. - For example, "Go to the kitchen" will not look for a kitchen in view, but will check the semantic map coordinates where - a kitchen was previously observed. - """ - - query: str = Field("", description="Text query to search for in the semantic map") - - limit: int = Field(1, description="Maximum number of results to return") - distance: float = Field(0.3, description="Desired distance to maintain from object in meters") - skip_visual_search: bool = Field(False, description="Skip visual search for object in view") - timeout: float = Field(40.0, description="Maximum time to spend navigating in seconds") - - def __init__(self, robot=None, **data): - """ - Initialize the Navigate skill. - - Args: - robot: The robot instance - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - self._spatial_memory = None - self._similarity_threshold = 0.23 - - def _navigate_to_object(self): - """ - Helper method that attempts to navigate to an object visible in the camera view. - - Returns: - dict: Result dictionary with success status and details - """ - logger.info( - f"Attempting to navigate to visible object: {self.query} with desired distance {self.distance}m, timeout {self.timeout} seconds..." - ) - - # Try to get a bounding box from Qwen - bbox = None - try: - # Get a single frame from the robot's camera - frame = self._robot.get_single_rgb_frame().data - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - if frame is None: - logger.error("Failed to get camera frame") - return { - "success": False, - "failure_reason": "Perception", - "error": "Could not get camera frame", - } - bbox = get_bbox_from_qwen_frame(frame, object_name=self.query) - except Exception as e: - logger.error(f"Error getting frame or bbox: {e}") - return { - "success": False, - "failure_reason": "Perception", - "error": f"Error getting frame or bbox: {e}", - } - if bbox is None: - logger.error(f"Failed to get bounding box for {self.query}") - return { - "success": False, - "failure_reason": "Perception", - "error": f"Could not find {self.query} in view", - } - - logger.info(f"Found {self.query} at {bbox}") - - # Use the robot's navigate_to_object method - success = self._robot.navigate_to_object(bbox, self.distance, self.timeout) - - if success: - logger.info(f"Successfully navigated to {self.query}") - return { - "success": True, - "failure_reason": None, - "query": self.query, - "message": f"Successfully navigated to {self.query} in view", - } - else: - logger.warning(f"Failed to reach {self.query} within timeout") - return { - "success": False, - "failure_reason": "Navigation", - "error": f"Failed to reach {self.query} within timeout", - } - - def _navigate_using_semantic_map(self): - """ - Helper method that attempts to navigate using the semantic map query. - - Returns: - dict: Result dictionary with success status and details - """ - logger.info(f"Querying semantic map for: '{self.query}'") - - try: - self._spatial_memory = self._robot.spatial_memory - - # Run the query - results = self._spatial_memory.query_by_text(self.query, self.limit) - - if not results: - logger.warning(f"No results found for query: '{self.query}'") - return { - "success": False, - "query": self.query, - "error": "No matching location found in semantic map", - } - - # Get the best match - best_match = results[0] - metadata = best_match.get("metadata", {}) - - if isinstance(metadata, list) and metadata: - metadata = metadata[0] - - # Extract coordinates from metadata - if ( - isinstance(metadata, dict) - and "pos_x" in metadata - and "pos_y" in metadata - and "rot_z" in metadata - ): - pos_x = metadata.get("pos_x", 0) - pos_y = metadata.get("pos_y", 0) - theta = metadata.get("rot_z", 0) - - # Calculate similarity score (distance is inverse of similarity) - similarity = 1.0 - ( - best_match.get("distance", 0) if best_match.get("distance") is not None else 0 - ) - - logger.info( - f"Found match for '{self.query}' at ({pos_x:.2f}, {pos_y:.2f}, rotation {theta:.2f}) with similarity: {similarity:.4f}" - ) - - # Check if similarity is below the threshold - if similarity < self._similarity_threshold: - logger.warning( - f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})" - ) - return { - "success": False, - "query": self.query, - "position": (pos_x, pos_y), - "rotation": theta, - "similarity": similarity, - "error": f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})", - } - - # Create a PoseStamped for navigation - goal_pose = PoseStamped( - position=Vector3(pos_x, pos_y, 0), - orientation=euler_to_quaternion(Vector3(0, 0, theta)), - frame_id="world", - ) - - logger.info( - f"Starting navigation to ({pos_x:.2f}, {pos_y:.2f}) with rotation {theta:.2f}" - ) - - # Use the robot's navigate_to method - result = self._robot.navigate_to(goal_pose, blocking=True) - - if result: - logger.info("Navigation completed successfully") - return { - "success": True, - "query": self.query, - "position": (pos_x, pos_y), - "rotation": theta, - "similarity": similarity, - "metadata": metadata, - } - else: - logger.error("Navigation did not complete successfully") - return { - "success": False, - "query": self.query, - "position": (pos_x, pos_y), - "rotation": theta, - "similarity": similarity, - "error": "Navigation did not complete successfully", - } - else: - logger.warning(f"No valid position data found for query: '{self.query}'") - return { - "success": False, - "query": self.query, - "error": "No valid position data found in semantic map", - } - - except Exception as e: - logger.error(f"Error in semantic map navigation: {e}") - return {"success": False, "error": f"Semantic map error: {e}"} - - def __call__(self): - """ - First attempts to navigate to an object in view, then falls back to querying the semantic map. - - Returns: - A dictionary with the result of the navigation attempt - """ - super().__call__() - - if not self.query: - error_msg = "No query provided to Navigate skill" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - # First, try to find and navigate to the object in camera view - logger.info(f"First attempting to find and navigate to visible object: '{self.query}'") - - if not self.skip_visual_search: - object_result = self._navigate_to_object() - - if object_result and object_result["success"]: - logger.info(f"Successfully navigated to {self.query} in view") - return object_result - - elif object_result and object_result["failure_reason"] == "Navigation": - logger.info( - f"Failed to navigate to {self.query} in view: {object_result.get('error', 'Unknown error')}" - ) - return object_result - - # If object navigation failed, fall back to semantic map - logger.info( - f"Object not found in view. Falling back to semantic map query for: '{self.query}'" - ) - - return self._navigate_using_semantic_map() - - def stop(self): - """ - Stop the navigation skill and clean up resources. - - Returns: - A message indicating whether the navigation was stopped successfully - """ - logger.info("Stopping Navigate skill") - - # Cancel navigation - self._robot.cancel_navigation() - - skill_library = self._robot.get_skills() - self.unregister_as_running("Navigate", skill_library) - - return "Navigate skill stopped successfully." - - -class GetPose(AbstractRobotSkill): - """ - A skill that returns the current position and orientation of the robot. - - This skill is useful for getting the current pose of the robot in the map frame. You call this skill - if you want to remember a location, for example, "remember this is where my favorite chair is" and then - call this skill to get the position and rotation of approximately where the chair is. You can then use - the position to navigate to the chair. - - When location_name is provided, this skill will also remember the current location with that name, - allowing you to navigate back to it later using the Navigate skill. - """ - - location_name: str = Field( - "", description="Optional name to assign to this location (e.g., 'kitchen', 'office')" - ) - - def __init__(self, robot=None, **data): - """ - Initialize the GetPose skill. - - Args: - robot: The robot instance - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - - def __call__(self): - """ - Get the current pose of the robot. - - Returns: - A dictionary containing the position and rotation of the robot - """ - super().__call__() - - if self._robot is None: - error_msg = "No robot instance provided to GetPose skill" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - try: - # Get the current pose using the robot's get_pose method - pose_data = self._robot.get_odom() - - # Extract position and rotation from the new dictionary format - position = pose_data.position - rotation = quaternion_to_euler(pose_data.orientation) - - # Format the response - result = { - "success": True, - "position": { - "x": position.x, - "y": position.y, - "z": position.z, - }, - "rotation": {"roll": rotation.x, "pitch": rotation.y, "yaw": rotation.z}, - } - - # If location_name is provided, remember this location - if self.location_name: - # Get the spatial memory instance - spatial_memory = self._robot.spatial_memory - - # Create a RobotLocation object - location = RobotLocation( - name=self.location_name, - position=(position.x, position.y, position.z), - rotation=(rotation.x, rotation.y, rotation.z), - ) - - # Add to spatial memory - if spatial_memory.add_robot_location(location): - result["location_saved"] = True - result["location_name"] = self.location_name - logger.info(f"Location '{self.location_name}' saved at {position}") - else: - result["location_saved"] = False - logger.error(f"Failed to save location '{self.location_name}'") - - return result - except Exception as e: - error_msg = f"Error getting robot pose: {e}" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - -class NavigateToGoal(AbstractRobotSkill): - """ - A skill that navigates the robot to a specified position and orientation. - - This skill uses the global planner to generate a path to the target position - and then uses navigate_path_local to follow that path, achieving the desired - orientation at the goal position. - """ - - position: Tuple[float, float] = Field( - (0.0, 0.0), description="Target position (x, y) in map frame" - ) - rotation: Optional[float] = Field(None, description="Target orientation (yaw) in radians") - frame: str = Field("map", description="Reference frame for the position and rotation") - timeout: float = Field(120.0, description="Maximum time (in seconds) allowed for navigation") - - def __init__(self, robot=None, **data): - """ - Initialize the NavigateToGoal skill. - - Args: - robot: The robot instance - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - - def __call__(self): - """ - Navigate to the specified goal position and orientation. - - Returns: - A dictionary containing the result of the navigation attempt - """ - super().__call__() - - if self._robot is None: - error_msg = "No robot instance provided to NavigateToGoal skill" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - skill_library = self._robot.get_skills() - self.register_as_running("NavigateToGoal", skill_library) - - logger.info( - f"Starting navigation to position=({self.position[0]:.2f}, {self.position[1]:.2f}) " - f"with rotation={self.rotation if self.rotation is not None else 'None'} " - f"in frame={self.frame}" - ) - - try: - # Create a PoseStamped for navigation - goal_pose = PoseStamped( - position=Vector3(self.position[0], self.position[1], 0), - orientation=euler_to_quaternion(Vector3(0, 0, self.rotation or 0)), - ) - - # Use the robot's navigate_to method - result = self._robot.navigate_to(goal_pose, blocking=True) - - if result: - logger.info("Navigation completed successfully") - return { - "success": True, - "position": self.position, - "rotation": self.rotation, - "message": "Goal reached successfully", - } - else: - logger.warning("Navigation did not complete successfully") - return { - "success": False, - "position": self.position, - "rotation": self.rotation, - "message": "Goal could not be reached", - } - - except Exception as e: - error_msg = f"Error during navigation: {e}" - logger.error(error_msg) - return { - "success": False, - "position": self.position, - "rotation": self.rotation, - "error": error_msg, - } - finally: - self.stop() - - def stop(self): - """ - Stop the navigation. - - Returns: - A message indicating that the navigation was stopped - """ - logger.info("Stopping NavigateToGoal") - skill_library = self._robot.get_skills() - self.unregister_as_running("NavigateToGoal", skill_library) - self._robot.cancel_navigation() - return "Navigation stopped" - - -class Explore(AbstractRobotSkill): - """ - A skill that performs autonomous frontier exploration. - - This skill continuously finds and navigates to unknown frontiers in the environment - until no more frontiers are found or the exploration is stopped. - - Don't save GetPose locations when frontier exploring. Don't call any other skills except stop skill when needed. - """ - - timeout: float = Field(240.0, description="Maximum time (in seconds) allowed for exploration") - - def __init__(self, robot=None, **data): - """ - Initialize the Explore skill. - - Args: - robot: The robot instance - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - - def __call__(self): - """ - Start autonomous frontier exploration. - - Returns: - A dictionary containing the result of the exploration - """ - super().__call__() - - if self._robot is None: - error_msg = "No robot instance provided to Explore skill" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - skill_library = self._robot.get_skills() - self.register_as_running("Explore", skill_library) - - logger.info("Starting autonomous frontier exploration") - - try: - # Start exploration using the robot's explore method - result = self._robot.explore() - - if result: - logger.info("Exploration started successfully") - - # Wait for exploration to complete or timeout - start_time = time.time() - while time.time() - start_time < self.timeout: - time.sleep(0.5) - - # Timeout reached, stop exploration - logger.info(f"Exploration timeout reached after {self.timeout} seconds") - self._robot.stop_exploration() - return { - "success": True, - "message": f"Exploration ran for {self.timeout} seconds", - } - else: - logger.warning("Failed to start exploration") - return { - "success": False, - "message": "Failed to start exploration", - } - - except Exception as e: - error_msg = f"Error during exploration: {e}" - logger.error(error_msg) - return { - "success": False, - "error": error_msg, - } - finally: - self.stop() - - def stop(self): - """ - Stop the exploration. - - Returns: - A message indicating that the exploration was stopped - """ - logger.info("Stopping Explore") - skill_library = self._robot.get_skills() - self.unregister_as_running("Explore", skill_library) - - # Stop the robot's exploration if it's running - try: - self._robot.stop_exploration() - except Exception as e: - logger.error(f"Error stopping exploration: {e}") - - return "Exploration stopped" diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index d5b9bd4364..7c776e984e 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -63,9 +63,16 @@ def short_id(from_string: str | None = None) -> str: hash_bytes = hashlib.sha1(from_string.encode()).digest()[:16] num = int.from_bytes(hash_bytes, "big") - chars = [] - while num: + min_chars = 18 + + chars: list[str] = [] + while num > 0 or len(chars) < min_chars: num, rem = divmod(num, base) chars.append(alphabet[rem]) - return "".join(reversed(chars))[:18] + return "".join(reversed(chars))[:min_chars] + + +class classproperty(property): + def __get__(self, obj, cls): + return self.fget(cls) diff --git a/dimos/utils/monitoring.py b/dimos/utils/monitoring.py index c13c274cac..abadbe591c 100644 --- a/dimos/utils/monitoring.py +++ b/dimos/utils/monitoring.py @@ -23,7 +23,7 @@ import re import os import shutil -from functools import lru_cache +from functools import lru_cache, partial from typing import Optional from distributed.client import Client @@ -185,6 +185,12 @@ def stop(self): super().stop() +utilization = UtilizationModule.blueprint + + +__all__ = ["UtilizationModule", "utilization"] + + def _can_use_py_spy(): try: with open("/proc/sys/kernel/yama/ptrace_scope") as f: diff --git a/dimos/utils/test_generic.py b/dimos/utils/test_generic.py new file mode 100644 index 0000000000..f85201d9bf --- /dev/null +++ b/dimos/utils/test_generic.py @@ -0,0 +1,30 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID +from dimos.utils.generic import short_id + + +def test_short_id_hello_world() -> None: + assert short_id("HelloWorld") == "6GgJmzi1KYf4iaHVxk" + + +def test_short_id_uuid_one(mocker) -> None: + mocker.patch("uuid.uuid4", return_value=UUID("11111111-1111-1111-1111-111111111111")) + assert short_id() == "wcFtOGNXQnQFZ8QRh1" + + +def test_short_id_uuid_zero(mocker) -> None: + mocker.patch("uuid.uuid4", return_value=UUID("00000000-0000-0000-0000-000000000000")) + assert short_id() == "000000000000000000" diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 004853a2d6..b33b874ecc 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -124,21 +124,17 @@ def start(self): self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) self._uvicorn_server_thread.start() - if self.odom.connection is not None: - unsub = self.odom.subscribe(self._on_robot_pose) - self._disposables.add(Disposable(unsub)) + unsub = self.odom.subscribe(self._on_robot_pose) + self._disposables.add(Disposable(unsub)) - if self.gps_location.connection is not None: - unsub = self.gps_location.subscribe(self._on_gps_location) - self._disposables.add(Disposable(unsub)) + unsub = self.gps_location.subscribe(self._on_gps_location) + self._disposables.add(Disposable(unsub)) - if self.path.connection is not None: - unsub = self.path.subscribe(self._on_path) - self._disposables.add(Disposable(unsub)) + unsub = self.path.subscribe(self._on_path) + self._disposables.add(Disposable(unsub)) - if self.global_costmap.connection is not None: - unsub = self.global_costmap.subscribe(self._on_global_costmap) - self._disposables.add(Disposable(unsub)) + unsub = self.global_costmap.subscribe(self._on_global_costmap) + self._disposables.add(Disposable(unsub)) @rpc def stop(self): @@ -291,3 +287,8 @@ def _process_costmap(self, costmap: OccupancyGrid) -> Dict[str, Any]: def _emit(self, event: str, data: Any): if self._broadcast_loop and not self._broadcast_loop.is_closed(): asyncio.run_coroutine_threadsafe(self.sio.emit(event, data), self._broadcast_loop) + + +websocket_vis = WebsocketVisModule.blueprint + +__all__ = ["WebsocketVisModule", "websocket_vis"] diff --git a/pyproject.toml b/pyproject.toml index 7e978fe907..2d3804c1fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,7 +104,11 @@ dependencies = [ "dask[complete]==2025.5.1", # LCM / DimOS utilities - "dimos-lcm @ git+https://github.com/dimensionalOS/dimos-lcm.git@03e320b325edf3ead9b74746baea318d431030bc" + "dimos-lcm @ git+https://github.com/dimensionalOS/dimos-lcm.git@03e320b325edf3ead9b74746baea318d431030bc", + + # CLI + "pydantic-settings>=2.11.0,<3", + "typer>=0.19.2,<1", ] [project.scripts] @@ -113,6 +117,7 @@ foxglove-bridge = "dimos.utils.cli.foxglove_bridge.run_foxglove_bridge:main" skillspy = "dimos.utils.cli.skillspy.skillspy:main" agentspy = "dimos.utils.cli.agentspy.agentspy:main" human-cli = "dimos.utils.cli.human.humancli:main" +dimos-robot = "dimos.robot.cli.dimos_robot:main" [project.optional-dependencies] manipulation = [ diff --git a/tests/test_object_tracking_module.py b/tests/test_object_tracking_module.py index 2fd1038c89..0b4b1f1364 100755 --- a/tests/test_object_tracking_module.py +++ b/tests/test_object_tracking_module.py @@ -223,7 +223,7 @@ async def test_object_tracking_module(): # Start Foxglove bridge for visualization foxglove_bridge = FoxgloveBridge() - foxglove_bridge.start() + foxglove_bridge.acquire() # Give modules time to initialize await asyncio.sleep(1) @@ -280,7 +280,7 @@ async def test_object_tracking_module(): if zed: zed.stop() if foxglove_bridge: - foxglove_bridge.stop() + foxglove_bridge.release() dimos.close() logger.info("Test completed")