diff --git a/.gitignore b/.gitignore index 48717f7e6a..26a19fc463 100644 --- a/.gitignore +++ b/.gitignore @@ -17,8 +17,10 @@ __pycache__ # Ignore default runtime output folder /assets/output/ +/assets/rgbd_data/ +/assets/saved_maps/ /assets/model-cache/ -assets/agent/memory.txt +/assets/agent/memory.txt .bash_history @@ -33,3 +35,8 @@ package-lock.json # Ignore build artifacts dist/ + +# Ignore data and modelfiles +data/ +FastSAM-x.pt +yolo11n.pt diff --git a/dimos/perception/detection2d/utils.py b/dimos/perception/detection2d/utils.py index ebcfa75ec5..dbe19baf30 100644 --- a/dimos/perception/detection2d/utils.py +++ b/dimos/perception/detection2d/utils.py @@ -14,8 +14,8 @@ import numpy as np import cv2 - -from dimos.utils.ros_utils import distance_angle_to_goal_xy +from dimos.types.vector import Vector +from dimos.utils.transform_utils import distance_angle_to_goal_xy def filter_detections( @@ -206,24 +206,19 @@ def plot_results(image, bboxes, track_ids, class_ids, confidences, names, alpha= return vis_img -def calculate_depth_from_bbox(depth_model, frame, bbox): +def calculate_depth_from_bbox(depth_map, bbox): """ Calculate the average depth of an object within a bounding box. Uses the 25th to 75th percentile range to filter outliers. Args: - depth_model: Depth model - frame: The image frame + depth_map: The depth map bbox: Bounding box in format [x1, y1, x2, y2] Returns: float: Average depth in meters, or None if depth estimation fails """ try: - # Get depth map for the entire frame - depth_map = depth_model.infer_depth(frame) - depth_map = np.array(depth_map) - # Extract region of interest from the depth map x1, y1, x2, y2 = map(int, bbox) roi_depth = depth_map[y1:y2, x1:x2] @@ -323,7 +318,8 @@ def calculate_position_rotation_from_bbox(bbox, depth, camera_intrinsics): camera_intrinsics: List [fx, fy, cx, cy] with camera parameters Returns: - Tuple of (position_dict, rotation_dict) + Vector: position + Vector: rotation """ # Calculate distance and angle to object distance, angle = calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics) @@ -336,11 +332,7 @@ def calculate_position_rotation_from_bbox(bbox, depth, camera_intrinsics): # For now, rotation is only in yaw (around z-axis) # We can use the negative of the angle as an estimate of the object's yaw # assuming objects tend to face the camera - position = {"x": x, "y": y, "z": 0.0} # z=0 assuming objects are on the ground - rotation = { - "roll": 0.0, - "pitch": 0.0, - "yaw": -angle, - } # Only yaw is meaningful with monocular camera + position = Vector([x, y, 0.0]) + rotation = Vector([0.0, 0.0, -angle]) return position, rotation diff --git a/dimos/perception/object_detection_stream.py b/dimos/perception/object_detection_stream.py index 8376944c87..0c6514ea36 100644 --- a/dimos/perception/object_detection_stream.py +++ b/dimos/perception/object_detection_stream.py @@ -13,6 +13,7 @@ # limitations under the License. import cv2 +import time import numpy as np from reactivex import Observable from reactivex import operators as ops @@ -26,8 +27,9 @@ calculate_position_rotation_from_bbox, ) from dimos.types.vector import Vector -from typing import Optional, Union +from typing import Optional, Union, Callable from dimos.types.manipulation import ObjectData +from dimos.utils.transform_utils import transform_robot_to_map from dimos.utils.logging_config import setup_logger @@ -54,7 +56,7 @@ def __init__( gt_depth_scale=1000.0, min_confidence=0.7, class_filter=None, # Optional list of class names to filter (e.g., ["person", "car"]) - transform_to_map=None, # Optional function to transform coordinates to map frame + get_pose: Callable = None, # Optional function to transform coordinates to map frame detector: Optional[Union[Detic2DDetector, Yolo2DDetector]] = None, video_stream: Observable = None, disable_depth: bool = False, # Flag to disable monocular Metric3D depth estimation @@ -69,7 +71,7 @@ def __init__( gt_depth_scale: Ground truth depth scale for Metric3D min_confidence: Minimum confidence for detections class_filter: Optional list of class names to filter - transform_to_map: Optional function to transform pose to map coordinates + get_pose: Optional function to transform pose to map coordinates detector: Optional detector instance (Detic or Yolo) video_stream: Observable of video frames to process (if provided, returns a stream immediately) disable_depth: Flag to disable monocular Metric3D depth estimation @@ -77,7 +79,7 @@ def __init__( """ self.min_confidence = min_confidence self.class_filter = class_filter - self.transform_to_map = transform_to_map + self.get_pose = get_pose self.disable_depth = disable_depth self.draw_masks = draw_masks # Initialize object detector @@ -131,6 +133,11 @@ def process_frame(frame): # Process detections objects = [] + if not self.disable_depth: + depth_map = self.depth_model.infer_depth(frame) + depth_map = np.array(depth_map) + else: + depth_map = None for i, bbox in enumerate(bboxes): # Skip if confidence is too low @@ -142,9 +149,9 @@ def process_frame(frame): if self.class_filter and class_name not in self.class_filter: continue - if not self.disable_depth: + if not self.disable_depth and depth_map is not None: # Get depth for this object - depth = calculate_depth_from_bbox(self.depth_model, frame, bbox) + depth = calculate_depth_from_bbox(depth_map, bbox) if depth is None: # Skip objects with invalid depth continue @@ -159,13 +166,11 @@ def process_frame(frame): # Transform to map frame if a transform function is provided try: - if self.transform_to_map: - position = Vector([position["x"], position["y"], position["z"]]) - rotation = Vector( - [rotation["roll"], rotation["pitch"], rotation["yaw"]] - ) - position, rotation = self.transform_to_map( - position, rotation, source_frame="base_link" + if self.get_pose: + # position and rotation are already Vector objects, no need to convert + robot_pose = self.get_pose() + position, rotation = transform_robot_to_map( + robot_pose, position, rotation ) except Exception as e: logger.error(f"Error transforming to map frame: {e}") diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index 8bdd7c05b7..010dbb9f3e 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -102,7 +102,9 @@ def track(self, bbox, frame=None, distance=None, size=None): # Calculate depth only if distance and size not provided if frame is not None and distance is None and size is None: - depth_estimate = calculate_depth_from_bbox(self.depth_model, frame, bbox) + depth_map = self.depth_model.infer_depth(frame) + depth_map = np.array(depth_map) + depth_estimate = calculate_depth_from_bbox(depth_map, bbox) if depth_estimate is not None: print(f"Estimated depth for object: {depth_estimate:.2f}m") diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 6a9ee553b3..b994b52bc4 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -61,9 +61,7 @@ def __init__( "VisualMemory" ] = None, # Optional VisualMemory instance for storing images video_stream: Optional[Observable] = None, # Video stream to process - transform_provider: Optional[ - callable - ] = None, # Function that returns position and rotation + get_pose: Optional[callable] = None, # Function that returns position and rotation ): """ Initialize the spatial perception system. @@ -162,8 +160,8 @@ def __init__( logger.info(f"SpatialMemory initialized with model {embedding_model}") # Start processing video stream if provided - if video_stream is not None and transform_provider is not None: - self.start_continuous_processing(video_stream, transform_provider) + if video_stream is not None and get_pose is not None: + self.start_continuous_processing(video_stream, get_pose) def query_by_location( self, x: float, y: float, radius: float = 2.0, limit: int = 5 @@ -183,14 +181,14 @@ def query_by_location( return self.vector_db.query_by_location(x, y, radius, limit) def start_continuous_processing( - self, video_stream: Observable, transform_provider: callable + self, video_stream: Observable, get_pose: callable ) -> disposable.Disposable: """ Start continuous processing of video frames from an Observable stream. Args: video_stream: Observable of video frames - transform_provider: Callable that returns position and rotation for each frame + get_pose: Callable that returns position and rotation for each frame Returns: Disposable subscription that can be used to stop processing @@ -200,7 +198,7 @@ def start_continuous_processing( # Map each video frame to include transform data combined_stream = video_stream.pipe( - ops.map(lambda video_frame: {"frame": video_frame, **transform_provider()}), + ops.map(lambda video_frame: {"frame": video_frame, **get_pose()}), # Filter out bad transforms ops.filter( lambda data: data.get("position") is not None and data.get("rotation") is not None diff --git a/dimos/robot/abstract_robot.py b/dimos/robot/abstract_robot.py deleted file mode 100644 index 50502ab988..0000000000 --- a/dimos/robot/abstract_robot.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Abstract base class for all DIMOS robot implementations. - -This module defines the AbstractRobot class which serves as the foundation for -all robot implementations in DIMOS, establishing a common interface regardless -of the underlying hardware or communication protocol (ROS, WebRTC, etc). -""" - -from abc import ABC, abstractmethod -from reactivex.observable import Observable - - -class AbstractRobot(ABC): - """Abstract base class for all robot implementations. - - This class defines the minimal interface that all robot implementations - must provide, regardless of whether they use ROS, WebRTC, or other - communication protocols. - """ - - @abstractmethod - def connect(self) -> bool: - """Establish a connection to the robot. - - This method should handle all necessary setup to establish - communication with the robot hardware. - - Returns: - bool: True if connection was successful, False otherwise. - """ - pass - - @abstractmethod - def move(self, *args, **kwargs) -> bool: - """Move the robot. - - This is a generic movement interface that should be implemented - by all robot classes. The exact parameters will depend on the - specific robot implementation. - - Returns: - bool: True if movement command was successfully sent. - """ - pass - - @abstractmethod - def get_video_stream(self, fps: int = 30) -> Observable: - """Get a video stream from the robot's camera. - - Args: - fps: Frames per second for the video stream. Defaults to 30. - - Returns: - Observable: An observable stream of video frames. - """ - pass - - @abstractmethod - def stop(self) -> None: - """Clean up resources and stop the robot. - - This method should handle all necessary cleanup when shutting down - the robot connection, including stopping any ongoing movements. - """ - pass diff --git a/dimos/robot/connection_interface.py b/dimos/robot/connection_interface.py new file mode 100644 index 0000000000..1f327a7939 --- /dev/null +++ b/dimos/robot/connection_interface.py @@ -0,0 +1,70 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional +from reactivex.observable import Observable +from dimos.types.vector import Vector + +__all__ = ["ConnectionInterface"] + + +class ConnectionInterface(ABC): + """Abstract base class for robot connection interfaces. + + This class defines the minimal interface that all connection types (ROS, WebRTC, etc.) + must implement to provide robot control and data streaming capabilities. + """ + + @abstractmethod + def move(self, velocity: Vector, duration: float = 0.0) -> bool: + """Send movement command to the robot using velocity commands. + + Args: + velocity: Velocity vector [x, y, yaw] where: + x: Forward/backward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds). If 0, command is continuous + + Returns: + bool: True if command was sent successfully + """ + pass + + @abstractmethod + def get_video_stream(self, fps: int = 30) -> Optional[Observable]: + """Get the video stream from the robot's camera. + + Args: + fps: Frames per second for the video stream + + Returns: + Observable: An observable stream of video frames or None if not available + """ + pass + + @abstractmethod + def stop(self) -> bool: + """Stop the robot's movement. + + Returns: + bool: True if stop command was sent successfully + """ + pass + + @abstractmethod + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + pass diff --git a/dimos/robot/frontier_exploration/__init__.py b/dimos/robot/frontier_exploration/__init__.py new file mode 100644 index 0000000000..2b69011a9f --- /dev/null +++ b/dimos/robot/frontier_exploration/__init__.py @@ -0,0 +1 @@ +from utils import * diff --git a/dimos/robot/frontier_exploration/qwen_frontier_predictor.py b/dimos/robot/frontier_exploration/qwen_frontier_predictor.py new file mode 100644 index 0000000000..10a1d8a265 --- /dev/null +++ b/dimos/robot/frontier_exploration/qwen_frontier_predictor.py @@ -0,0 +1,368 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Qwen-based frontier exploration goal predictor using vision language model. + +This module provides a frontier goal detector that uses Qwen's vision capabilities +to analyze costmap images and predict optimal exploration goals. +""" + +import os +import glob +import json +import re +from typing import Optional, List, Tuple + +import numpy as np +from PIL import Image, ImageDraw + +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector +from dimos.models.qwen.video_query import query_single_frame +from dimos.robot.frontier_exploration.utils import ( + costmap_to_pil_image, + smooth_costmap_for_frontiers, +) + + +class QwenFrontierPredictor: + """ + Qwen-based frontier exploration goal predictor. + + Uses Qwen's vision capabilities to analyze costmap images and predict + optimal exploration goals based on visual understanding of the map structure. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "qwen2.5-vl-72b-instruct", + use_smoothed_costmap: bool = True, + image_scale_factor: int = 4, + ): + """ + Initialize the Qwen frontier predictor. + + Args: + api_key: Alibaba API key for Qwen access + model_name: Qwen model to use for predictions + image_scale_factor: Scale factor for image processing + """ + self.api_key = api_key or os.getenv("ALIBABA_API_KEY") + if not self.api_key: + raise ValueError( + "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" + ) + + self.model_name = model_name + self.image_scale_factor = image_scale_factor + self.use_smoothed_costmap = use_smoothed_costmap + + # Storage for previously explored goals + self.explored_goals: List[Vector] = [] + + def _world_to_image_coords(self, world_pos: Vector, costmap: Costmap) -> Tuple[int, int]: + """Convert world coordinates to image pixel coordinates.""" + grid_pos = costmap.world_to_grid(world_pos) + img_x = int(grid_pos.x * self.image_scale_factor) + img_y = int((costmap.height - grid_pos.y) * self.image_scale_factor) # Flip Y + return img_x, img_y + + def _image_to_world_coords(self, img_x: int, img_y: int, costmap: Costmap) -> Vector: + """Convert image pixel coordinates to world coordinates.""" + # Unscale and flip Y coordinate + grid_x = img_x / self.image_scale_factor + grid_y = costmap.height - (img_y / self.image_scale_factor) + + # Convert grid to world coordinates + world_pos = costmap.grid_to_world(Vector([grid_x, grid_y])) + return world_pos + + def _draw_goals_on_image( + self, + image: Image.Image, + robot_pose: Vector, + costmap: Costmap, + latest_goal: Optional[Vector] = None, + ) -> Image.Image: + """ + Draw explored goals and robot position on the costmap image. + + Args: + image: PIL Image to draw on + robot_pose: Current robot position + costmap: Costmap for coordinate conversion + latest_goal: Latest predicted goal to highlight in red + + Returns: + PIL Image with goals drawn + """ + img_copy = image.copy() + draw = ImageDraw.Draw(img_copy) + + # Draw previously explored goals as green dots + for explored_goal in self.explored_goals: + x, y = self._world_to_image_coords(explored_goal, costmap) + radius = 8 + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(0, 255, 0), + outline=(0, 128, 0), + width=2, + ) + + # Draw robot position as blue dot + robot_x, robot_y = self._world_to_image_coords(robot_pose, costmap) + robot_radius = 10 + draw.ellipse( + [ + robot_x - robot_radius, + robot_y - robot_radius, + robot_x + robot_radius, + robot_y + robot_radius, + ], + fill=(0, 0, 255), + outline=(0, 0, 128), + width=3, + ) + + # Draw latest predicted goal as red dot + if latest_goal: + goal_x, goal_y = self._world_to_image_coords(latest_goal, costmap) + goal_radius = 12 + draw.ellipse( + [ + goal_x - goal_radius, + goal_y - goal_radius, + goal_x + goal_radius, + goal_y + goal_radius, + ], + fill=(255, 0, 0), + outline=(128, 0, 0), + width=3, + ) + + return img_copy + + def _create_vision_prompt(self) -> str: + """Create the vision prompt for Qwen model.""" + prompt = """You are an expert robot navigation system analyzing a costmap for frontier exploration. + +COSTMAP LEGEND: +- Light gray pixels (205,205,205): FREE SPACE - areas the robot can navigate +- Dark gray pixels (128,128,128): UNKNOWN SPACE - unexplored areas that need exploration +- Black pixels (0,0,0): OBSTACLES - walls, furniture, blocked areas +- Blue dot: CURRENT ROBOT POSITION +- Green dots: PREVIOUSLY EXPLORED GOALS - avoid these areas + +TASK: Find the best frontier exploration goal by identifying the optimal point where: +1. It's at the boundary between FREE SPACE (light gray) and UNKNOWN SPACE (dark gray) (HIGHEST Priority) +2. It's reasonably far from the robot position (blue dot) (MEDIUM Priority) +3. It's reasonably far from previously explored goals (green dots) (MEDIUM Priority) +4. It leads to a large area of unknown space to explore (HIGH Priority) +5. It's accessible from the robot's current position through free space (MEDIUM Priority) +6. It's not near or on obstacles (HIGHEST Priority) + +RESPONSE FORMAT: Return ONLY the pixel coordinates as a JSON object: +{"x": pixel_x_coordinate, "y": pixel_y_coordinate, "reasoning": "brief explanation"} + +Example: {"x": 245, "y": 187, "reasoning": "Large unknown area to the north, good distance from robot and previous goals"} + +Analyze the image and identify the single best frontier exploration goal.""" + + return prompt + + def _parse_prediction_response(self, response: str) -> Optional[Tuple[int, int, str]]: + """ + Parse the model's response to extract coordinates and reasoning. + + Args: + response: Raw response from Qwen model + + Returns: + Tuple of (x, y, reasoning) or None if parsing failed + """ + try: + # Try to find JSON object in response + json_match = re.search(r"\{[^}]*\}", response) + if json_match: + json_str = json_match.group() + data = json.loads(json_str) + + if "x" in data and "y" in data: + x = int(data["x"]) + y = int(data["y"]) + reasoning = data.get("reasoning", "No reasoning provided") + return (x, y, reasoning) + + # Fallback: try to extract coordinates with regex + coord_match = re.search(r"[^\d]*(\d+)[^\d]+(\d+)", response) + if coord_match: + x = int(coord_match.group(1)) + y = int(coord_match.group(2)) + return (x, y, "Coordinates extracted from response") + + except (json.JSONDecodeError, ValueError, KeyError) as e: + print(f"DEBUG: Failed to parse prediction response: {e}") + + return None + + def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional[Vector]: + """ + Get the best exploration goal using Qwen vision analysis. + + Args: + robot_pose: Current robot position in world coordinates + costmap: Current costmap for analysis + + Returns: + Single best frontier goal in world coordinates, or None if no suitable goal found + """ + print( + f"DEBUG: Qwen frontier prediction starting with {len(self.explored_goals)} explored goals" + ) + + # Create costmap image + if self.use_smoothed_costmap: + costmap = smooth_costmap_for_frontiers(costmap, alpha=4.0) + + base_image = costmap_to_pil_image(costmap, self.image_scale_factor) + + # Draw goals on image (without latest goal initially) + annotated_image = self._draw_goals_on_image(base_image, robot_pose, costmap) + + # Query Qwen model for frontier prediction + try: + prompt = self._create_vision_prompt() + response = query_single_frame( + annotated_image, prompt, api_key=self.api_key, model_name=self.model_name + ) + + print(f"DEBUG: Qwen response: {response}") + + # Parse response to get coordinates + parsed_result = self._parse_prediction_response(response) + if not parsed_result: + print("DEBUG: Failed to parse Qwen response") + return None + + img_x, img_y, reasoning = parsed_result + print(f"DEBUG: Parsed coordinates: ({img_x}, {img_y}), Reasoning: {reasoning}") + + # Convert image coordinates to world coordinates + predicted_goal = self._image_to_world_coords(img_x, img_y, costmap) + print( + f"DEBUG: Predicted goal in world coordinates: ({predicted_goal.x:.2f}, {predicted_goal.y:.2f})" + ) + + # Store the goal in explored_goals for future reference + self.explored_goals.append(predicted_goal) + + print(f"DEBUG: Successfully predicted frontier goal: {predicted_goal}") + return predicted_goal + + except Exception as e: + print(f"DEBUG: Error during Qwen prediction: {e}") + return None + + +def test_qwen_frontier_detection(): + """ + Visual test for Qwen frontier detection using saved costmaps. + Shows frontier detection results with Qwen predictions. + """ + + # Path to saved costmaps + saved_maps_dir = os.path.join(os.getcwd(), "assets", "saved_maps") + + if not os.path.exists(saved_maps_dir): + print(f"Error: Saved maps directory not found: {saved_maps_dir}") + return + + # Get all pickle files + pickle_files = sorted(glob.glob(os.path.join(saved_maps_dir, "*.pickle"))) + + if not pickle_files: + print(f"No pickle files found in {saved_maps_dir}") + return + + print(f"Found {len(pickle_files)} costmap files for Qwen testing") + + # Initialize Qwen frontier predictor + predictor = QwenFrontierPredictor(image_scale_factor=4, use_smoothed_costmap=False) + + # Track the robot pose across iterations + robot_pose = None + + # Process each costmap + for i, pickle_file in enumerate(pickle_files): + print( + f"\n--- Processing costmap {i + 1}/{len(pickle_files)}: {os.path.basename(pickle_file)} ---" + ) + + try: + # Load the costmap + costmap = Costmap.from_pickle(pickle_file) + print( + f"Loaded costmap: {costmap.width}x{costmap.height}, resolution: {costmap.resolution}" + ) + + # Set robot pose: first iteration uses center, subsequent use last predicted goal + if robot_pose is None: + # First iteration: use center of costmap as robot position + center_world = costmap.grid_to_world( + Vector([costmap.width / 2, costmap.height / 2]) + ) + robot_pose = Vector([center_world.x, center_world.y]) + # else: robot_pose remains the last predicted goal + + print(f"Using robot position: {robot_pose}") + + # Get frontier prediction from Qwen + print("Getting Qwen frontier prediction...") + predicted_goal = predictor.get_exploration_goal(robot_pose, costmap) + + if predicted_goal: + distance = np.sqrt( + (predicted_goal.x - robot_pose.x) ** 2 + (predicted_goal.y - robot_pose.y) ** 2 + ) + print(f"Predicted goal: {predicted_goal}, Distance: {distance:.2f}m") + + # Show the final visualization + base_image = costmap_to_pil_image(costmap, predictor.image_scale_factor) + final_image = predictor._draw_goals_on_image( + base_image, robot_pose, costmap, predicted_goal + ) + + # Display image + title = f"Qwen Frontier Prediction {i + 1:04d}" + final_image.show(title=title) + + # Update robot pose for next iteration + robot_pose = predicted_goal + + else: + print("No suitable frontier goal predicted by Qwen") + + except Exception as e: + print(f"Error processing {pickle_file}: {e}") + continue + + print(f"\n=== Qwen Frontier Detection Test Complete ===") + print(f"Final explored goals count: {len(predictor.explored_goals)}") + + +if __name__ == "__main__": + test_qwen_frontier_detection() diff --git a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py new file mode 100644 index 0000000000..8273b21a52 --- /dev/null +++ b/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -0,0 +1,296 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np +from typing import List, Optional +from PIL import Image, ImageDraw + +from dimos.utils.testing import SensorReplay +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.types.vector import Vector +from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( + WavefrontFrontierExplorer, +) +from dimos.robot.frontier_exploration.utils import costmap_to_pil_image +from reactivex import operators as ops + + +def get_office_lidar_costmap(take_frames: int = 1, voxel_size: float = 0.5) -> tuple: + """ + Get a costmap from office_lidar data using SensorReplay. + + Args: + take_frames: Number of lidar frames to take (default 1) + voxel_size: Voxel size for map construction + + Returns: + Tuple of (costmap, first_lidar_message) for testing + """ + # Load office lidar data using SensorReplay as documented + lidar_stream = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + # Create map with specified voxel size + map_obj = Map(voxel_size=voxel_size) + + # Take only the specified number of frames and build map + limited_stream = lidar_stream.stream().pipe(ops.take(take_frames)) + + # Store the first lidar message for reference + first_lidar = None + + def capture_first_and_add(lidar_msg): + nonlocal first_lidar + if first_lidar is None: + first_lidar = lidar_msg + return map_obj.add_frame(lidar_msg) + + # Process the stream + limited_stream.pipe(ops.map(capture_first_and_add)).run() + + # Get the resulting costmap + costmap = map_obj.costmap + + return costmap, first_lidar + + +def test_frontier_detection_with_office_lidar(): + """Test frontier detection using a single frame from office_lidar data.""" + # Get costmap from office lidar data + costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + + # Verify we have a valid costmap + assert costmap is not None, "Costmap should not be None" + assert costmap.width > 0 and costmap.height > 0, "Costmap should have valid dimensions" + + print(f"Costmap dimensions: {costmap.width}x{costmap.height}") + print(f"Costmap resolution: {costmap.resolution}") + print(f"Unknown percent: {costmap.unknown_percent:.1f}%") + print(f"Free percent: {costmap.free_percent:.1f}%") + print(f"Occupied percent: {costmap.occupied_percent:.1f}%") + + # Initialize frontier explorer with default parameters + explorer = WavefrontFrontierExplorer() + + # Set robot pose near the center of free space in the costmap + # We'll use the lidar origin as a reasonable robot position + robot_pose = first_lidar.origin + print(f"Robot pose: {robot_pose}") + + # Detect frontiers + frontiers = explorer.detect_frontiers(robot_pose, costmap) + + # Verify frontier detection results + assert isinstance(frontiers, list), "Frontiers should be returned as a list" + print(f"Detected {len(frontiers)} frontiers") + + # Test that we get some frontiers (office environment should have unexplored areas) + if len(frontiers) > 0: + print("Frontier detection successful - found unexplored areas") + + # Verify frontiers are Vector objects with valid coordinates + for i, frontier in enumerate(frontiers[:5]): # Check first 5 + assert isinstance(frontier, Vector), f"Frontier {i} should be a Vector" + assert hasattr(frontier, "x") and hasattr(frontier, "y"), ( + f"Frontier {i} should have x,y coordinates" + ) + print(f" Frontier {i}: ({frontier.x:.2f}, {frontier.y:.2f})") + else: + print("No frontiers detected - map may be fully explored or parameters too restrictive") + + +def test_exploration_goal_selection(): + """Test the complete exploration goal selection pipeline.""" + # Get costmap from office lidar data + costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + + # Initialize frontier explorer with default parameters + explorer = WavefrontFrontierExplorer() + + # Use lidar origin as robot position + robot_pose = first_lidar.origin + + # Get exploration goal + goal = explorer.get_exploration_goal(robot_pose, costmap) + + if goal is not None: + assert isinstance(goal, Vector), "Goal should be a Vector" + print(f"Selected exploration goal: ({goal.x:.2f}, {goal.y:.2f})") + + # Verify goal is at reasonable distance from robot + distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) + print(f"Goal distance from robot: {distance:.2f}m") + assert distance >= explorer.min_distance_from_robot, ( + "Goal should respect minimum distance from robot" + ) + + # Test that goal gets marked as explored + assert len(explorer.explored_goals) == 1, "Goal should be marked as explored" + assert explorer.explored_goals[0] == goal, "Explored goal should match selected goal" + + else: + print("No exploration goal selected - map may be fully explored") + + +def test_exploration_session_reset(): + """Test exploration session reset functionality.""" + # Get costmap + costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + + # Initialize explorer and select a goal + explorer = WavefrontFrontierExplorer() + robot_pose = first_lidar.origin + + # Select a goal to populate exploration state + goal = explorer.get_exploration_goal(robot_pose, costmap) + + # Verify state is populated + initial_explored_count = len(explorer.explored_goals) + initial_direction = explorer.exploration_direction + + # Reset exploration session + explorer.reset_exploration_session() + + # Verify state is cleared + assert len(explorer.explored_goals) == 0, "Explored goals should be cleared after reset" + assert explorer.exploration_direction.x == 0.0 and explorer.exploration_direction.y == 0.0, ( + "Exploration direction should be reset" + ) + assert explorer.last_costmap is None, "Last costmap should be cleared" + assert explorer.num_no_gain_attempts == 0, "No-gain attempts should be reset" + + print("Exploration session reset successfully") + + +@pytest.mark.vis +def test_frontier_detection_visualization(): + """Test frontier detection with visualization (marked with @pytest.mark.vis).""" + # Get costmap from office lidar data + costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.2) + + # Initialize frontier explorer with default parameters + explorer = WavefrontFrontierExplorer() + + # Use lidar origin as robot position + robot_pose = first_lidar.origin + + # Detect all frontiers for visualization + all_frontiers = explorer.detect_frontiers(robot_pose, costmap) + + # Get selected goal + selected_goal = explorer.get_exploration_goal(robot_pose, costmap) + + print(f"Visualizing {len(all_frontiers)} frontier candidates") + if selected_goal: + print(f"Selected goal: ({selected_goal.x:.2f}, {selected_goal.y:.2f})") + + # Create visualization + image_scale_factor = 4 + base_image = costmap_to_pil_image(costmap, image_scale_factor) + + # Helper function to convert world coordinates to image coordinates + def world_to_image_coords(world_pos: Vector) -> tuple[int, int]: + grid_pos = costmap.world_to_grid(world_pos) + img_x = int(grid_pos.x * image_scale_factor) + img_y = int((costmap.height - grid_pos.y) * image_scale_factor) # Flip Y + return img_x, img_y + + # Draw visualization + draw = ImageDraw.Draw(base_image) + + # Draw frontier candidates as gray dots + for frontier in all_frontiers[:20]: # Limit to top 20 + x, y = world_to_image_coords(frontier) + radius = 6 + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(128, 128, 128), # Gray + outline=(64, 64, 64), + width=1, + ) + + # Draw robot position as blue dot + robot_x, robot_y = world_to_image_coords(robot_pose) + robot_radius = 10 + draw.ellipse( + [ + robot_x - robot_radius, + robot_y - robot_radius, + robot_x + robot_radius, + robot_y + robot_radius, + ], + fill=(0, 0, 255), # Blue + outline=(0, 0, 128), + width=3, + ) + + # Draw selected goal as red dot + if selected_goal: + goal_x, goal_y = world_to_image_coords(selected_goal) + goal_radius = 12 + draw.ellipse( + [ + goal_x - goal_radius, + goal_y - goal_radius, + goal_x + goal_radius, + goal_y + goal_radius, + ], + fill=(255, 0, 0), # Red + outline=(128, 0, 0), + width=3, + ) + + # Display the image + base_image.show(title="Frontier Detection - Office Lidar") + + print("Visualization displayed. Close the image window to continue.") + + +def test_multi_frame_exploration(): + """Tool test for multi-frame exploration analysis.""" + print("=== Multi-Frame Exploration Analysis ===") + + # Test with different numbers of frames + frame_counts = [1, 3, 5] + + for frame_count in frame_counts: + print(f"\n--- Testing with {frame_count} lidar frame(s) ---") + + # Get costmap with multiple frames + costmap, first_lidar = get_office_lidar_costmap(take_frames=frame_count, voxel_size=0.3) + + print( + f"Costmap: {costmap.width}x{costmap.height}, " + f"unknown: {costmap.unknown_percent:.1f}%, " + f"free: {costmap.free_percent:.1f}%, " + f"occupied: {costmap.occupied_percent:.1f}%" + ) + + # Initialize explorer with default parameters + explorer = WavefrontFrontierExplorer() + + # Detect frontiers + robot_pose = first_lidar.origin + frontiers = explorer.detect_frontiers(robot_pose, costmap) + + print(f"Detected {len(frontiers)} frontiers") + + # Get exploration goal + goal = explorer.get_exploration_goal(robot_pose, costmap) + if goal: + distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) + print(f"Selected goal at distance {distance:.2f}m") + else: + print("No exploration goal selected") diff --git a/dimos/robot/frontier_exploration/utils.py b/dimos/robot/frontier_exploration/utils.py new file mode 100644 index 0000000000..746f72e2f5 --- /dev/null +++ b/dimos/robot/frontier_exploration/utils.py @@ -0,0 +1,188 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for frontier exploration visualization and testing. +""" + +import numpy as np +from PIL import Image, ImageDraw +from typing import List, Tuple +from dimos.types.costmap import Costmap, CostValues +from dimos.types.vector import Vector +import os +import pickle +import cv2 + + +def costmap_to_pil_image(costmap: Costmap, scale_factor: int = 2) -> Image.Image: + """ + Convert costmap to PIL Image with ROS-style coloring and optional scaling. + + Args: + costmap: Costmap to convert + scale_factor: Factor to scale up the image for better visibility + + Returns: + PIL Image with ROS-style colors + """ + # Create image array (height, width, 3 for RGB) + img_array = np.zeros((costmap.height, costmap.width, 3), dtype=np.uint8) + + # Apply ROS-style coloring based on costmap values + for i in range(costmap.height): + for j in range(costmap.width): + value = costmap.grid[i, j] + if value == CostValues.FREE: # Free space = light grey + img_array[i, j] = [205, 205, 205] + elif value == CostValues.UNKNOWN: # Unknown = dark gray + img_array[i, j] = [128, 128, 128] + elif value >= CostValues.OCCUPIED: # Occupied/obstacles = black + img_array[i, j] = [0, 0, 0] + else: # Any other values (low cost) = light grey + img_array[i, j] = [205, 205, 205] + + # Flip vertically to match ROS convention (origin at bottom-left) + img_array = np.flipud(img_array) + + # Create PIL image + img = Image.fromarray(img_array, "RGB") + + # Scale up if requested + if scale_factor > 1: + new_size = (img.width * scale_factor, img.height * scale_factor) + img = img.resize(new_size, Image.NEAREST) # Use NEAREST to keep sharp pixels + + return img + + +def draw_frontiers_on_image( + image: Image.Image, + costmap: Costmap, + frontiers: List[Vector], + scale_factor: int = 2, + unfiltered_frontiers: List[Vector] = None, +) -> Image.Image: + """ + Draw frontier points on the costmap image. + + Args: + image: PIL Image to draw on + costmap: Original costmap for coordinate conversion + frontiers: List of frontier centroids (top 5) + scale_factor: Scaling factor used for the image + unfiltered_frontiers: All unfiltered frontier results (light green) + + Returns: + PIL Image with frontiers drawn + """ + img_copy = image.copy() + draw = ImageDraw.Draw(img_copy) + + def world_to_image_coords(world_pos: Vector) -> Tuple[int, int]: + """Convert world coordinates to image pixel coordinates.""" + grid_pos = costmap.world_to_grid(world_pos) + # Flip Y coordinate and apply scaling + img_x = int(grid_pos.x * scale_factor) + img_y = int((costmap.height - grid_pos.y) * scale_factor) # Flip Y + return img_x, img_y + + # Draw all unfiltered frontiers as light green circles + if unfiltered_frontiers: + for frontier in unfiltered_frontiers: + x, y = world_to_image_coords(frontier) + radius = 3 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(144, 238, 144), + outline=(144, 238, 144), + ) # Light green + + # Draw top 5 frontiers as green circles + for i, frontier in enumerate(frontiers[1:]): # Skip the best one for now + x, y = world_to_image_coords(frontier) + radius = 4 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(0, 255, 0), + outline=(0, 128, 0), + width=2, + ) # Green + + # Add number label + draw.text((x + radius + 2, y - radius), str(i + 2), fill=(0, 255, 0)) + + # Draw best frontier as red circle + if frontiers: + best_frontier = frontiers[0] + x, y = world_to_image_coords(best_frontier) + radius = 6 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(255, 0, 0), + outline=(128, 0, 0), + width=3, + ) # Red + + # Add "BEST" label + draw.text((x + radius + 2, y - radius), "BEST", fill=(255, 0, 0)) + + return img_copy + + +def smooth_costmap_for_frontiers( + costmap: Costmap, +) -> Costmap: + """ + Smooth a costmap using morphological operations for frontier exploration. + + This function applies OpenCV morphological operations to smooth free space + areas and improve connectivity for better frontier detection. It's designed + specifically for frontier exploration. + + Args: + costmap: Input Costmap object + + Returns: + Smoothed Costmap object with enhanced free space connectivity + """ + # Extract grid data and metadata from costmap + grid = costmap.grid + resolution = costmap.resolution + + # Work with a copy to avoid modifying input + filtered_grid = grid.copy() + + # 1. Create binary mask for free space + free_mask = (grid == CostValues.FREE).astype(np.uint8) * 255 + + # 2. Apply morphological operations for smoothing + kernel_size = 7 + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + + # Dilate free space to connect nearby areas + dilated = cv2.dilate(free_mask, kernel, iterations=1) + + # Morphological closing to fill small gaps + closed = cv2.morphologyEx(dilated, cv2.MORPH_CLOSE, kernel, iterations=1) + + eroded = cv2.erode(closed, kernel, iterations=1) + + # Apply the smoothed free space back to costmap + # Only change unknown areas to free, don't override obstacles + smoothed_free = eroded == 255 + unknown_mask = grid == CostValues.UNKNOWN + filtered_grid[smoothed_free & unknown_mask] = CostValues.FREE + + return Costmap(grid=filtered_grid, origin=costmap.origin, resolution=resolution) diff --git a/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py new file mode 100644 index 0000000000..5f9032aa28 --- /dev/null +++ b/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py @@ -0,0 +1,665 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple wavefront frontier exploration algorithm implementation using dimos types. + +This module provides frontier detection and exploration goal selection +for autonomous navigation using the dimos Costmap and Vector types. +""" + +from typing import List, Tuple, Optional, Callable +from collections import deque +import numpy as np +from dataclasses import dataclass +from enum import IntFlag +import threading +from dimos.utils.logging_config import setup_logger + +from dimos.types.costmap import Costmap, CostValues +from dimos.types.vector import Vector +from dimos.robot.frontier_exploration.utils import smooth_costmap_for_frontiers + +logger = setup_logger("dimos.robot.unitree.frontier_exploration") + + +class PointClassification(IntFlag): + """Point classification flags for frontier detection algorithm.""" + + NoInformation = 0 + MapOpen = 1 + MapClosed = 2 + FrontierOpen = 4 + FrontierClosed = 8 + + +@dataclass +class GridPoint: + """Represents a point in the grid map with classification.""" + + x: int + y: int + classification: int = PointClassification.NoInformation + + +class FrontierCache: + """Cache for grid points to avoid duplicate point creation.""" + + def __init__(self): + self.points = {} + + def get_point(self, x: int, y: int) -> GridPoint: + """Get or create a grid point at the given coordinates.""" + key = (x, y) + if key not in self.points: + self.points[key] = GridPoint(x, y) + return self.points[key] + + def clear(self): + """Clear the point cache.""" + self.points.clear() + + +class WavefrontFrontierExplorer: + """ + Wavefront frontier exploration algorithm implementation. + + This class encapsulates the frontier detection and exploration goal selection + functionality using the wavefront algorithm with BFS exploration. + """ + + def __init__( + self, + min_frontier_size: int = 10, + occupancy_threshold: int = 65, + subsample_resolution: int = 3, + min_distance_from_robot: float = 0.5, + explored_area_buffer: float = 0.5, + min_distance_from_obstacles: float = 0.6, + info_gain_threshold: float = 0.03, + num_no_gain_attempts: int = 4, + set_goal: Optional[Callable] = None, + get_costmap: Optional[Callable] = None, + get_robot_pos: Optional[Callable] = None, + ): + """ + Initialize the frontier explorer. + + Args: + min_frontier_size: Minimum number of points to consider a valid frontier + occupancy_threshold: Cost threshold above which a cell is considered occupied (0-255) + subsample_resolution: Factor by which to subsample the costmap for faster processing (1=no subsampling, 2=half resolution, 4=quarter resolution) + min_distance_from_robot: Minimum distance frontier must be from robot (meters) + explored_area_buffer: Buffer distance around free areas to consider as explored (meters) + min_distance_from_obstacles: Minimum distance frontier must be from obstacles (meters) + info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) + num_no_gain_attempts: Maximum number of consecutive attempts with no information gain + set_goal: Callable to set navigation goal, signature: (goal: Vector, stop_event: Optional[threading.Event]) -> bool + get_costmap: Callable to get current costmap, signature: () -> Costmap + get_robot_pos: Callable to get current robot position, signature: () -> Vector + """ + self.min_frontier_size = min_frontier_size + self.occupancy_threshold = occupancy_threshold + self.subsample_resolution = subsample_resolution + self.min_distance_from_robot = min_distance_from_robot + self.explored_area_buffer = explored_area_buffer + self.min_distance_from_obstacles = min_distance_from_obstacles + self.info_gain_threshold = info_gain_threshold + self.num_no_gain_attempts = num_no_gain_attempts + self.set_goal = set_goal + self.get_costmap = get_costmap + self.get_robot_pos = get_robot_pos + self._cache = FrontierCache() + self.explored_goals = [] # list of explored goals + self.exploration_direction = Vector([0.0, 0.0]) # current exploration direction + self.last_costmap = None # store last costmap for information comparison + + def _count_costmap_information(self, costmap: Costmap) -> int: + """ + Count the amount of information in a costmap (free space + obstacles). + + Args: + costmap: Costmap to analyze + + Returns: + Number of cells that are free space or obstacles (not unknown) + """ + free_count = np.sum(costmap.grid == CostValues.FREE) + obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) + return int(free_count + obstacle_count) + + def _get_neighbors(self, point: GridPoint, costmap: Costmap) -> List[GridPoint]: + """Get valid neighboring points for a given grid point.""" + neighbors = [] + + # 8-connected neighbors + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx == 0 and dy == 0: + continue + + nx, ny = point.x + dx, point.y + dy + + # Check bounds + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + neighbors.append(self._cache.get_point(nx, ny)) + + return neighbors + + def _is_frontier_point(self, point: GridPoint, costmap: Costmap) -> bool: + """ + Check if a point is a frontier point. + A frontier point is an unknown cell adjacent to at least one free cell + and not adjacent to any occupied cells. + """ + # Point must be unknown + world_pos = costmap.grid_to_world(Vector([float(point.x), float(point.y)])) + cost = costmap.get_value(world_pos) + if cost != CostValues.UNKNOWN: + return False + + has_free = False + + for neighbor in self._get_neighbors(point, costmap): + neighbor_world = costmap.grid_to_world(Vector([float(neighbor.x), float(neighbor.y)])) + neighbor_cost = costmap.get_value(neighbor_world) + + # If adjacent to occupied space, not a frontier + if neighbor_cost and neighbor_cost > self.occupancy_threshold: + return False + + # Check if adjacent to free space + if neighbor_cost == CostValues.FREE: + has_free = True + + return has_free + + def _find_free_space(self, start_x: int, start_y: int, costmap: Costmap) -> Tuple[int, int]: + """ + Find the nearest free space point using BFS from the starting position. + """ + queue = deque([self._cache.get_point(start_x, start_y)]) + visited = set() + + while queue: + point = queue.popleft() + + if (point.x, point.y) in visited: + continue + visited.add((point.x, point.y)) + + # Check if this point is free space + world_pos = costmap.grid_to_world(Vector([float(point.x), float(point.y)])) + if costmap.get_value(world_pos) == CostValues.FREE: + return (point.x, point.y) + + # Add neighbors to search + for neighbor in self._get_neighbors(point, costmap): + if (neighbor.x, neighbor.y) not in visited: + queue.append(neighbor) + + # If no free space found, return original position + return (start_x, start_y) + + def _compute_centroid(self, frontier_points: List[Vector]) -> Vector: + """Compute the centroid of a list of frontier points.""" + if not frontier_points: + return Vector([0.0, 0.0]) + + # Vectorized approach using numpy + points_array = np.array([[point.x, point.y] for point in frontier_points]) + centroid = np.mean(points_array, axis=0) + + return Vector([centroid[0], centroid[1]]) + + def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector]: + """ + Main frontier detection algorithm using wavefront exploration. + + Args: + robot_pose: Current robot position in world coordinates (Vector with x, y) + costmap: Costmap for additional analysis + + Returns: + List of frontier centroids in world coordinates + """ + self._cache.clear() + + # Apply filtered costmap (now default) + working_costmap = smooth_costmap_for_frontiers(costmap) + + # Subsample the costmap for faster processing + if self.subsample_resolution > 1: + subsampled_costmap = working_costmap.subsample(self.subsample_resolution) + else: + subsampled_costmap = working_costmap + + # Convert robot pose to subsampled grid coordinates + subsampled_grid_pos = subsampled_costmap.world_to_grid(robot_pose) + grid_x, grid_y = int(subsampled_grid_pos.x), int(subsampled_grid_pos.y) + + # Find nearest free space to start exploration + free_x, free_y = self._find_free_space(grid_x, grid_y, subsampled_costmap) + start_point = self._cache.get_point(free_x, free_y) + start_point.classification = PointClassification.MapOpen + + # Main exploration queue - explore ALL reachable free space + map_queue = deque([start_point]) + frontiers = [] + frontier_sizes = [] + + points_checked = 0 + frontier_candidates = 0 + + while map_queue: + current_point = map_queue.popleft() + points_checked += 1 + + # Skip if already processed + if current_point.classification & PointClassification.MapClosed: + continue + + # Mark as processed + current_point.classification |= PointClassification.MapClosed + + # Check if this point starts a new frontier + if self._is_frontier_point(current_point, subsampled_costmap): + frontier_candidates += 1 + current_point.classification |= PointClassification.FrontierOpen + frontier_queue = deque([current_point]) + new_frontier = [] + + # Explore this frontier region using BFS + while frontier_queue: + frontier_point = frontier_queue.popleft() + + # Skip if already processed + if frontier_point.classification & PointClassification.FrontierClosed: + continue + + # If this is still a frontier point, add to current frontier + if self._is_frontier_point(frontier_point, subsampled_costmap): + new_frontier.append(frontier_point) + + # Add neighbors to frontier queue + for neighbor in self._get_neighbors(frontier_point, subsampled_costmap): + if not ( + neighbor.classification + & ( + PointClassification.FrontierOpen + | PointClassification.FrontierClosed + ) + ): + neighbor.classification |= PointClassification.FrontierOpen + frontier_queue.append(neighbor) + + frontier_point.classification |= PointClassification.FrontierClosed + + # Check if we found a large enough frontier + if len(new_frontier) >= self.min_frontier_size: + world_points = [] + for point in new_frontier: + world_pos = subsampled_costmap.grid_to_world( + Vector([float(point.x), float(point.y)]) + ) + world_points.append(world_pos) + + # Compute centroid in world coordinates (already correctly scaled) + centroid = self._compute_centroid(world_points) + frontiers.append(centroid) # Store centroid + frontier_sizes.append(len(new_frontier)) # Store frontier size + + # Add ALL neighbors to main exploration queue to explore entire free space + for neighbor in self._get_neighbors(current_point, subsampled_costmap): + if not ( + neighbor.classification + & (PointClassification.MapOpen | PointClassification.MapClosed) + ): + # Check if neighbor is free space or unknown (explorable) + neighbor_world = subsampled_costmap.grid_to_world( + Vector([float(neighbor.x), float(neighbor.y)]) + ) + neighbor_cost = subsampled_costmap.get_value(neighbor_world) + + # Add free space and unknown space to exploration queue + if neighbor_cost is not None and ( + neighbor_cost == CostValues.FREE or neighbor_cost == CostValues.UNKNOWN + ): + neighbor.classification |= PointClassification.MapOpen + map_queue.append(neighbor) + + # Extract just the centroids for ranking + frontier_centroids = frontiers + + if not frontier_centroids: + return [] + + # Rank frontiers using original costmap for proper filtering + ranked_frontiers = self._rank_frontiers( + frontier_centroids, frontier_sizes, robot_pose, costmap + ) + + return ranked_frontiers + + def _update_exploration_direction(self, robot_pose: Vector, goal_pose: Optional[Vector] = None): + """Update the current exploration direction based on robot movement or selected goal.""" + if goal_pose is not None: + # Calculate direction from robot to goal + direction = Vector([goal_pose.x - robot_pose.x, goal_pose.y - robot_pose.y]) + magnitude = np.sqrt(direction.x**2 + direction.y**2) + if magnitude > 0.1: # Avoid division by zero for very close goals + self.exploration_direction = Vector( + [direction.x / magnitude, direction.y / magnitude] + ) + + def _compute_direction_momentum_score(self, frontier: Vector, robot_pose: Vector) -> float: + """Compute direction momentum score for a frontier.""" + if self.exploration_direction.x == 0 and self.exploration_direction.y == 0: + return 0.0 # No momentum if no previous direction + + # Calculate direction from robot to frontier + frontier_direction = Vector([frontier.x - robot_pose.x, frontier.y - robot_pose.y]) + magnitude = np.sqrt(frontier_direction.x**2 + frontier_direction.y**2) + + if magnitude < 0.1: + return 0.0 # Too close to calculate meaningful direction + + # Normalize frontier direction + frontier_direction = Vector( + [frontier_direction.x / magnitude, frontier_direction.y / magnitude] + ) + + # Calculate dot product for directional alignment + dot_product = ( + self.exploration_direction.x * frontier_direction.x + + self.exploration_direction.y * frontier_direction.y + ) + + # Return momentum score (higher for same direction, lower for opposite) + return max(0.0, dot_product) # Only positive momentum, no penalty for different directions + + def _compute_distance_to_explored_goals(self, frontier: Vector) -> float: + """Compute distance from frontier to the nearest explored goal.""" + if not self.explored_goals: + return 5.0 # Default consistent value when no explored goals + # Calculate distance to nearest explored goal + min_distance = float("inf") + for goal in self.explored_goals: + distance = np.sqrt((frontier.x - goal.x) ** 2 + (frontier.y - goal.y) ** 2) + min_distance = min(min_distance, distance) + + return min_distance + + def _compute_distance_to_obstacles(self, frontier: Vector, costmap: Costmap) -> float: + """ + Compute the minimum distance from a frontier point to the nearest obstacle. + + Args: + frontier: Frontier point in world coordinates + costmap: Costmap to check for obstacles + + Returns: + Minimum distance to nearest obstacle in meters + """ + # Convert frontier to grid coordinates + grid_pos = costmap.world_to_grid(frontier) + grid_x, grid_y = int(grid_pos.x), int(grid_pos.y) + + # Check if frontier is within costmap bounds + if grid_x < 0 or grid_x >= costmap.width or grid_y < 0 or grid_y >= costmap.height: + return 0.0 # Consider out-of-bounds as obstacle + + min_distance = float("inf") + search_radius = ( + int(self.min_distance_from_obstacles / costmap.resolution) + 5 + ) # Search a bit beyond minimum + + # Search in a square around the frontier point + for dy in range(-search_radius, search_radius + 1): + for dx in range(-search_radius, search_radius + 1): + check_x = grid_x + dx + check_y = grid_y + dy + + # Skip if out of bounds + if ( + check_x < 0 + or check_x >= costmap.width + or check_y < 0 + or check_y >= costmap.height + ): + continue + + # Check if this cell is an obstacle + if costmap.grid[check_y, check_x] >= self.occupancy_threshold: + # Calculate distance in meters + distance = np.sqrt(dx**2 + dy**2) * costmap.resolution + min_distance = min(min_distance, distance) + + return min_distance if min_distance != float("inf") else float("inf") + + def _compute_comprehensive_frontier_score( + self, frontier: Vector, frontier_size: int, robot_pose: Vector, costmap: Costmap + ) -> float: + """Compute comprehensive score considering multiple criteria.""" + + # 1. Distance from robot (preference for moderate distances) + robot_distance = np.sqrt( + (frontier.x - robot_pose.x) ** 2 + (frontier.y - robot_pose.y) ** 2 + ) + + # Distance score: prefer moderate distances (not too close, not too far) + optimal_distance = 4.0 # meters + distance_score = 1.0 / (1.0 + abs(robot_distance - optimal_distance)) + + # 2. Information gain (frontier size) + info_gain_score = frontier_size + + # 3. Distance to explored goals (bonus for being far from explored areas) + explored_goals_distance = self._compute_distance_to_explored_goals(frontier) + explored_goals_score = explored_goals_distance + + # 4. Distance to obstacles (penalty for being too close) + obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) + obstacles_score = obstacles_distance + + # 5. Direction momentum (if we have a current direction) + momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) + + # Combine scores with consistent scaling (no arbitrary multipliers) + total_score = ( + 0.3 * info_gain_score # 30% information gain + + 0.3 * explored_goals_score # 30% distance from explored goals + + 0.2 * distance_score # 20% distance optimization + + 0.15 * obstacles_score # 15% distance from obstacles + + 0.05 * momentum_score # 5% direction momentum + ) + + return total_score + + def _rank_frontiers( + self, + frontier_centroids: List[Vector], + frontier_sizes: List[int], + robot_pose: Vector, + costmap: Costmap, + ) -> List[Vector]: + """ + Find the single best frontier using comprehensive scoring and filtering. + + Args: + frontier_centroids: List of frontier centroids + frontier_sizes: List of frontier sizes + robot_pose: Current robot position + costmap: Costmap for additional analysis + + Returns: + List containing single best frontier, or empty list if none suitable + """ + if not frontier_centroids: + return [] + + valid_frontiers = [] + + for i, frontier in enumerate(frontier_centroids): + robot_distance = np.sqrt( + (frontier.x - robot_pose.x) ** 2 + (frontier.y - robot_pose.y) ** 2 + ) + + # Filter 1: Skip frontiers too close to robot + if robot_distance < self.min_distance_from_robot: + continue + + # Filter 2: Skip frontiers too close to obstacles + obstacle_distance = self._compute_distance_to_obstacles(frontier, costmap) + if obstacle_distance < self.min_distance_from_obstacles: + continue + + # Compute comprehensive score + frontier_size = frontier_sizes[i] if i < len(frontier_sizes) else 1 + score = self._compute_comprehensive_frontier_score( + frontier, frontier_size, robot_pose, costmap + ) + + valid_frontiers.append((frontier, score)) + + logger.info(f"Valid frontiers: {len(valid_frontiers)}") + + if not valid_frontiers: + return [] + + # Sort by score and return all valid frontiers (highest scores first) + valid_frontiers.sort(key=lambda x: x[1], reverse=True) + + # Extract just the frontiers (remove scores) and return as list + return [frontier for frontier, _ in valid_frontiers] + + def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional[Vector]: + """ + Get the single best exploration goal using comprehensive frontier scoring. + + Args: + robot_pose: Current robot position in world coordinates (Vector with x, y) + costmap: Costmap for additional analysis + + Returns: + Single best frontier goal in world coordinates, or None if no suitable frontiers found + """ + # Check if we should compare costmaps for information gain + if len(self.explored_goals) > 5 and self.last_costmap is not None: + current_info = self._count_costmap_information(costmap) + last_info = self._count_costmap_information(self.last_costmap) + + # Check if information increase meets minimum percentage threshold + if last_info > 0: # Avoid division by zero + info_increase_percent = (current_info - last_info) / last_info + if info_increase_percent < self.info_gain_threshold: + logger.info( + f"Information increase ({info_increase_percent:.2f}) below threshold ({self.info_gain_threshold:.2f})" + ) + logger.info( + f"Current information: {current_info}, Last information: {last_info}" + ) + self.num_no_gain_attempts += 1 + if self.num_no_gain_attempts >= self.num_no_gain_attempts: + logger.info( + "No information gain for {} consecutive attempts, skipping frontier selection".format( + self.num_no_gain_attempts + ) + ) + self.reset_exploration_session() + return None + + # Always detect new frontiers to get most up-to-date information + # The new algorithm filters out explored areas and returns only the best frontier + frontiers = self.detect_frontiers(robot_pose, costmap) + + if not frontiers: + # Store current costmap before returning + self.last_costmap = costmap + self.reset_exploration_session() + return None + + # Update exploration direction based on best goal selection + if frontiers: + self._update_exploration_direction(robot_pose, frontiers[0]) + + # Store the selected goal as explored + selected_goal = frontiers[0] + self.mark_explored_goal(selected_goal) + + # Store current costmap for next comparison + self.last_costmap = costmap + + return selected_goal + + # Store current costmap before returning + self.last_costmap = costmap + return None + + def mark_explored_goal(self, goal: Vector): + """Mark a goal as explored.""" + self.explored_goals.append(goal) + + def reset_exploration_session(self): + """ + Reset all exploration state variables for a new exploration session. + + Call this method when starting a new exploration or when the robot + needs to forget its previous exploration history. + """ + self.explored_goals.clear() # Clear all previously explored goals + self.exploration_direction = Vector([0.0, 0.0]) # Reset exploration direction + self.last_costmap = None # Clear last costmap comparison + self.num_no_gain_attempts = 0 # Reset no-gain attempt counter + self._cache.clear() # Clear frontier point cache + + logger.info("Exploration session reset - all state variables cleared") + + def explore(self, stop_event: Optional[threading.Event] = None) -> bool: + """ + Perform autonomous frontier exploration by continuously finding and navigating to frontiers. + + Args: + stop_event: Optional threading.Event to signal when exploration should stop + + Returns: + bool: True if exploration completed successfully, False if stopped or failed + """ + + logger.info("Starting autonomous frontier exploration") + + while True: + # Check if stop event is set + if stop_event and stop_event.is_set(): + logger.info("Exploration stopped by stop event") + return False + + # Get fresh robot position and costmap data + robot_pose = self.get_robot_pos() + costmap = self.get_costmap() + + # Get the next frontier goal + next_goal = self.get_exploration_goal(robot_pose, costmap) + if not next_goal: + logger.info("No more frontiers found, exploration complete") + return True + + # Navigate to the frontier + logger.info(f"Navigating to frontier at {next_goal}") + navigation_successful = self.set_goal(next_goal, stop_event=stop_event) + + if not navigation_successful: + logger.warning("Failed to navigate to frontier, continuing exploration") + # Continue to try other frontiers instead of stopping + continue diff --git a/dimos/robot/local_planner/local_planner.py b/dimos/robot/local_planner/local_planner.py index 2c559afae6..286ee94f2b 100644 --- a/dimos/robot/local_planner/local_planner.py +++ b/dimos/robot/local_planner/local_planner.py @@ -16,7 +16,7 @@ import math import numpy as np -from typing import Dict, Tuple, Optional, Callable +from typing import Dict, Tuple, Optional, Callable, Any from abc import ABC, abstractmethod import cv2 from reactivex import Observable @@ -26,7 +26,7 @@ import logging from collections import deque from dimos.utils.logging_config import setup_logger -from dimos.utils.ros_utils import normalize_angle, distance_angle_to_goal_xy +from dimos.utils.transform_utils import normalize_angle, distance_angle_to_goal_xy from dimos.types.vector import VectorLike, Vector, to_tuple from dimos.types.path import Path @@ -41,48 +41,51 @@ class BaseLocalPlanner(ABC): This class defines the common interface and shared functionality that all local planners must implement, regardless of the specific algorithm used. + + Args: + get_costmap: Function to get the latest local costmap + get_robot_pose: Function to get the latest robot pose (returning odom object) + move: Function to send velocity commands + safety_threshold: Distance to maintain from obstacles (meters) + max_linear_vel: Maximum linear velocity (m/s) + max_angular_vel: Maximum angular velocity (rad/s) + lookahead_distance: Lookahead distance for path following (meters) + goal_tolerance: Distance at which the goal is considered reached (meters) + angle_tolerance: Angle at which the goal orientation is considered reached (radians) + robot_width: Width of the robot for visualization (meters) + robot_length: Length of the robot for visualization (meters) + visualization_size: Size of the visualization image in pixels + control_frequency: Frequency at which the planner is called (Hz) + safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) + max_recovery_attempts: Maximum number of recovery attempts before failing navigation. + If the robot gets stuck and cannot recover within this many attempts, navigation will fail. + global_planner_plan: Optional callable to plan a global path to the goal. + If provided, this will be used to generate a path to the goal before local planning. """ def __init__( self, get_costmap: Callable[[], Optional[Costmap]], - transform: object, - move_vel_control: Callable[[float, float, float], None], + get_robot_pose: Callable[[], Any], + move: Callable[[Vector], None], safety_threshold: float = 0.5, max_linear_vel: float = 0.8, max_angular_vel: float = 1.0, lookahead_distance: float = 1.0, goal_tolerance: float = 0.75, - angle_tolerance: float = 0.15, + angle_tolerance: float = 0.5, robot_width: float = 0.5, robot_length: float = 0.7, visualization_size: int = 400, control_frequency: float = 10.0, safe_goal_distance: float = 1.5, + max_recovery_attempts: int = 4, + global_planner_plan: Optional[Callable[[VectorLike], Optional[Any]]] = None, ): # Control frequency in Hz - """ - Initialize the base local planner. - - Args: - get_costmap: Function to get the latest local costmap - transform: Object with transform methods (transform_point, transform_rot, etc.) - move_vel_control: Function to send velocity commands - safety_threshold: Distance to maintain from obstacles (meters) - max_linear_vel: Maximum linear velocity (m/s) - max_angular_vel: Maximum angular velocity (rad/s) - lookahead_distance: Lookahead distance for path following (meters) - goal_tolerance: Distance at which the goal is considered reached (meters) - angle_tolerance: Angle at which the goal orientation is considered reached (radians) - robot_width: Width of the robot for visualization (meters) - robot_length: Length of the robot for visualization (meters) - visualization_size: Size of the visualization image in pixels - control_frequency: Frequency at which the planner is called (Hz) - safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) - """ # Store callables for robot interactions self.get_costmap = get_costmap - self.transform = transform - self.move_vel_control = move_vel_control + self.get_robot_pose = get_robot_pose + self.move = move # Store parameters self.safety_threshold = safety_threshold @@ -98,32 +101,66 @@ def __init__( self.control_period = 1.0 / control_frequency # Period in seconds self.safe_goal_distance = safe_goal_distance # Distance to ignore obstacles at goal self.ignore_obstacles = False # Flag for derived classes to check + self.max_recovery_attempts = max_recovery_attempts # Maximum recovery attempts + self.recovery_attempts = 0 # Current number of recovery attempts + self.global_planner_plan = global_planner_plan # Global planner function for replanning # Goal and Waypoint Tracking self.goal_xy: Optional[Tuple[float, float]] = None # Current target for planning - self.goal_theta: Optional[float] = None # Goal orientation in odom frame - self.position_reached: bool = False # Flag indicating if position goal is reached - self.waypoints: Optional[Path] = None # Full path if following waypoints - self.waypoints_in_odom: Optional[Path] = None # Full path in odom frame - self.waypoint_frame: Optional[str] = None # Frame of the waypoints + self.goal_theta: Optional[float] = None # Goal orientation (radians) + self.waypoints: Optional[Path] = None # List of waypoints to follow + self.waypoints_in_absolute: Optional[Path] = None # Full path in absolute frame + self.waypoint_is_relative: bool = False # Whether waypoints are in relative frame self.current_waypoint_index: int = 0 # Index of the next waypoint to reach self.final_goal_reached: bool = False # Flag indicating if the final waypoint is reached + self.position_reached: bool = False # Flag indicating if position goal is reached # Stuck detection - self.stuck_detection_window_seconds = 8.0 # Time window for stuck detection (seconds) + self.stuck_detection_window_seconds = 4.0 # Time window for stuck detection (seconds) self.position_history_size = int(self.stuck_detection_window_seconds * control_frequency) self.position_history = deque( maxlen=self.position_history_size ) # History of recent positions - self.stuck_distance_threshold = 0.1 # Distance threshold for stuck detection (meters) - self.unstuck_distance_threshold = 0.5 # Distance threshold for unstuck detection (meters) - self.stuck_time_threshold = 4.0 # Time threshold for stuck detection (seconds) + self.stuck_distance_threshold = 0.15 # Distance threshold for stuck detection (meters) + self.unstuck_distance_threshold = ( + 0.5 # Distance threshold for unstuck detection (meters) - increased hysteresis + ) + self.stuck_time_threshold = 3.0 # Time threshold for stuck detection (seconds) - increased self.is_recovery_active = False # Whether recovery behavior is active self.recovery_start_time = 0.0 # When recovery behavior started - self.recovery_duration = 8.0 # How long to run recovery before giving up (seconds) + self.recovery_duration = ( + 10.0 # How long to run recovery before giving up (seconds) - increased + ) self.last_update_time = time.time() # Last time position was updated self.navigation_failed = False # Flag indicating if navigation should be terminated + # Recovery improvements + self.recovery_cooldown_time = ( + 3.0 # Seconds to wait after recovery before checking stuck again + ) + self.last_recovery_end_time = 0.0 # When the last recovery ended + self.pre_recovery_position = ( + None # Position when recovery started (for better stuck detection) + ) + self.backup_duration = 4.0 # How long to backup when stuck (seconds) + + # Cached data updated periodically for consistent plan() execution time + self._robot_pose = None + self._costmap = None + self._update_frequency = 10.0 # Hz - how often to update cached data + self._update_timer = None + self._start_periodic_updates() + + def _start_periodic_updates(self): + self._update_timer = threading.Thread(target=self._periodic_update, daemon=True) + self._update_timer.start() + + def _periodic_update(self): + while True: + self._robot_pose = self.get_robot_pose() + self._costmap = self.get_costmap() + time.sleep(1.0 / self._update_frequency) + def reset(self): """ Reset all navigation and state tracking variables. @@ -141,37 +178,89 @@ def reset(self): self.final_goal_reached = False self.ignore_obstacles = False + # Reset recovery improvements + self.last_recovery_end_time = 0.0 + self.pre_recovery_position = None + + # Reset recovery attempts + self.recovery_attempts = 0 + + # Clear waypoint following state + self.waypoints = None + self.current_waypoint_index = 0 + self.goal_xy = None # Clear previous goal + self.goal_theta = None # Clear previous goal orientation + logger.info("Local planner state has been reset") + def _get_robot_pose(self) -> Tuple[Tuple[float, float], float]: + """ + Get the current robot position and orientation. + + Returns: + Tuple containing: + - position as (x, y) tuple + - orientation (theta) in radians + """ + if self._robot_pose is None: + return ((0.0, 0.0), 0.0) # Fallback if not yet initialized + pos, rot = self._robot_pose.pos, self._robot_pose.rot + return (pos.x, pos.y), rot.z + + def _get_costmap(self): + """Get cached costmap data.""" + return self._costmap + + def clear_cache(self): + """Clear all cached data to force fresh retrieval on next access.""" + self._robot_pose = None + self._costmap = None + def set_goal( - self, goal_xy: VectorLike, frame: str = "odom", goal_theta: Optional[float] = None + self, goal_xy: VectorLike, is_relative: bool = False, goal_theta: Optional[float] = None ): - """Set a single goal position, converting to odom frame if necessary. + """Set a single goal position, converting to absolute frame if necessary. This clears any existing waypoints being followed. Args: goal_xy: The goal position to set. - frame: The frame of the goal position. - goal_theta: Optional goal orientation in radians (in the specified frame) + is_relative: Whether the goal is in the robot's relative frame. + goal_theta: Optional goal orientation in radians """ # Reset all state variables self.reset() - # Clear waypoint following state - self.waypoints = None - self.current_waypoint_index = 0 - self.goal_xy = None # Clear previous goal - self.goal_theta = None # Clear previous goal orientation - target_goal_xy: Optional[Tuple[float, float]] = None - target_goal_xy = self.transform.transform_point( - goal_xy, source_frame=frame, target_frame="odom" - ).to_tuple() - - logger.info( - f"Goal set directly in odom frame: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" - ) + # Transform goal to absolute frame if it's relative + if is_relative: + # Get current robot pose + odom = self._robot_pose + if odom is None: + logger.warning("Robot pose not yet available, cannot set relative goal") + return + robot_pos, robot_rot = odom.pos, odom.rot + + # Extract current position and orientation + robot_x, robot_y = robot_pos.x, robot_pos.y + robot_theta = robot_rot.z # Assuming rotation is euler angles + + # Transform the relative goal into absolute coordinates + goal_x, goal_y = to_tuple(goal_xy) + # Rotate + abs_x = goal_x * math.cos(robot_theta) - goal_y * math.sin(robot_theta) + abs_y = goal_x * math.sin(robot_theta) + goal_y * math.cos(robot_theta) + # Translate + target_goal_xy = (robot_x + abs_x, robot_y + abs_y) + + logger.info( + f"Goal set in relative frame, converted to absolute: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" + ) + else: + target_goal_xy = to_tuple(goal_xy) + logger.info( + f"Goal set directly in absolute frame: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" + ) # Check if goal is valid (in bounds and not colliding) if not self.is_goal_in_costmap_bounds(target_goal_xy) or self.check_goal_collision( @@ -186,20 +275,25 @@ def set_goal( # Set goal orientation if provided if goal_theta is not None: - transformed_rot = self.transform.transform_rot( - Vector(0.0, 0.0, goal_theta), source_frame=frame, target_frame="odom" - ) - self.goal_theta = transformed_rot[2] + if is_relative: + # Transform the orientation to absolute frame + odom = self._robot_pose + if odom is None: + logger.warning( + "Robot pose not yet available, cannot set relative goal orientation" + ) + return + robot_theta = odom.rot.z + self.goal_theta = normalize_angle(goal_theta + robot_theta) + else: + self.goal_theta = goal_theta - def set_goal_waypoints( - self, waypoints: Path, frame: str = "map", goal_theta: Optional[float] = None - ): + def set_goal_waypoints(self, waypoints: Path, goal_theta: Optional[float] = None): """Sets a path of waypoints for the robot to follow. Args: - waypoints: A list of waypoints to follow. Each waypoint is a tuple of (x, y) coordinates in odom frame. - frame: The frame of the waypoints. - goal_theta: Optional final orientation in radians (in the specified frame) + waypoints: A list of waypoints to follow. Each waypoint is a tuple of (x, y) coordinates in absolute frame. + goal_theta: Optional final orientation in radians """ # Reset all state variables self.reset() @@ -207,7 +301,7 @@ def set_goal_waypoints( if not isinstance(waypoints, Path) or len(waypoints) == 0: logger.warning("Invalid or empty path provided to set_goal_waypoints. Ignoring.") self.waypoints = None - self.waypoint_frame = None + self.waypoint_is_relative = False self.goal_xy = None self.goal_theta = None self.current_waypoint_index = 0 @@ -215,16 +309,14 @@ def set_goal_waypoints( logger.info(f"Setting goal waypoints with {len(waypoints)} points.") self.waypoints = waypoints - self.waypoint_frame = frame + self.waypoint_is_relative = False self.current_waypoint_index = 0 - # Transform waypoints to odom frame - self.waypoints_in_odom = self.transform.transform_path( - self.waypoints, source_frame=frame, target_frame="odom" - ) + # Waypoints are always in absolute frame + self.waypoints_in_absolute = waypoints # Set the initial target to the first waypoint, adjusting if necessary - first_waypoint = self.waypoints_in_odom[0] + first_waypoint = self.waypoints_in_absolute[0] if not self.is_goal_in_costmap_bounds(first_waypoint) or self.check_goal_collision( first_waypoint ): @@ -235,22 +327,7 @@ def set_goal_waypoints( # Set goal orientation if provided if goal_theta is not None: - transformed_rot = self.transform.transform_rot( - Vector(0.0, 0.0, goal_theta), source_frame=frame, target_frame="odom" - ) - self.goal_theta = transformed_rot[2] - - def _get_robot_pose(self) -> Tuple[Tuple[float, float], float]: - """ - Get the current robot position and orientation. - - Returns: - Tuple containing: - - position as (x, y) tuple - - orientation (theta) in radians - """ - [pos, rot] = self.transform.transform_euler("base_link", "odom") - return (pos[0], pos[1]), rot[2] + self.goal_theta = goal_theta def _get_final_goal_position(self) -> Optional[Tuple[float, float]]: """ @@ -259,8 +336,8 @@ def _get_final_goal_position(self) -> Optional[Tuple[float, float]]: Returns: Tuple (x, y) of the final goal, or None if no goal is set """ - if self.waypoints_in_odom is not None and len(self.waypoints_in_odom) > 0: - return to_tuple(self.waypoints_in_odom[-1]) + if self.waypoints_in_absolute is not None and len(self.waypoints_in_absolute) > 0: + return to_tuple(self.waypoints_in_absolute[-1]) elif self.goal_xy is not None: return self.goal_xy return None @@ -295,8 +372,11 @@ def plan(self) -> Dict[str, float]: and self.goal_theta is not None and not self._is_goal_orientation_reached() ): - logger.info("Position goal reached. Rotating to target orientation.") return self._rotate_to_goal_orientation() + elif self.position_reached and self.goal_theta is None: + self.final_goal_reached = True + logger.info("Position goal reached. Stopping.") + return {"x_vel": 0.0, "angular_vel": 0.0} # Check if the robot is stuck and handle accordingly if self.check_if_stuck() and not self.position_reached: @@ -315,6 +395,9 @@ def plan(self) -> Dict[str, float]: self.position_reached = True return {"x_vel": 0.0, "angular_vel": 0.0} + if self.navigation_failed: + return {"x_vel": 0.0, "angular_vel": 0.0} + # Otherwise, execute normal recovery behavior logger.warning("Robot is stuck - executing recovery behavior") return self.execute_recovery_behavior() @@ -325,7 +408,6 @@ def plan(self) -> Dict[str, float]: # --- Waypoint Following Mode --- if self.waypoints is not None: if self.final_goal_reached: - logger.info("Final waypoint reached. Stopping.") return {"x_vel": 0.0, "angular_vel": 0.0} # Get current robot pose @@ -333,8 +415,8 @@ def plan(self) -> Dict[str, float]: robot_pos_np = np.array(robot_pos) # Check if close to final waypoint - if self.waypoints_in_odom is not None and len(self.waypoints_in_odom) > 0: - final_waypoint = self.waypoints_in_odom[-1] + if self.waypoints_in_absolute is not None and len(self.waypoints_in_absolute) > 0: + final_waypoint = self.waypoints_in_absolute[-1] dist_to_final = np.linalg.norm(robot_pos_np - final_waypoint) # If we're close to the final waypoint, adjust it and ignore obstacles @@ -342,11 +424,10 @@ def plan(self) -> Dict[str, float]: final_wp_tuple = to_tuple(final_waypoint) adjusted_goal = self.adjust_goal_to_valid_position(final_wp_tuple) # Create a new Path with the adjusted final waypoint - new_waypoints = self.waypoints_in_odom[:-1] # Get all but the last waypoint + new_waypoints = self.waypoints_in_absolute[:-1] # Get all but the last waypoint new_waypoints.append(adjusted_goal) # Append the adjusted goal - self.waypoints_in_odom = new_waypoints + self.waypoints_in_absolute = new_waypoints self.ignore_obstacles = True - logger.debug("Within safe distance of final waypoint. Ignoring obstacles.") # Update the target goal based on waypoint progression just_reached_final = self._update_waypoint_target(robot_pos_np) @@ -361,7 +442,7 @@ def plan(self) -> Dict[str, float]: return {"x_vel": 0.0, "angular_vel": 0.0} # Get necessary data for planning - costmap = self.get_costmap() + costmap = self._get_costmap() if costmap is None: logger.warning("Local costmap is None. Cannot plan.") return {"x_vel": 0.0, "angular_vel": 0.0} @@ -375,7 +456,6 @@ def plan(self) -> Dict[str, float]: if goal_distance < self.safe_goal_distance: self.goal_xy = self.adjust_goal_to_valid_position(self.goal_xy) self.ignore_obstacles = True - logger.debug("Within safe distance of goal. Ignoring obstacles.") # First check position if goal_distance < self.goal_tolerance or self.position_reached: @@ -418,9 +498,8 @@ def _rotate_to_goal_orientation(self) -> Dict[str, float]: # Calculate rotation speed - proportional to the angle difference # but capped at max_angular_vel direction = 1.0 if angle_diff > 0 else -1.0 - angular_vel = direction * min(abs(angle_diff) * 2.0, self.max_angular_vel) + angular_vel = direction * min(abs(angle_diff), self.max_angular_vel) - # logger.debug(f"Rotating to goal orientation: angle_diff={angle_diff:.4f}, angular_vel={angular_vel:.4f}") return {"x_vel": 0.0, "angular_vel": angular_vel} def _is_goal_orientation_reached(self) -> bool: @@ -438,9 +517,6 @@ def _is_goal_orientation_reached(self) -> bool: # Calculate the angle difference and normalize angle_diff = abs(normalize_angle(self.goal_theta - robot_theta)) - logger.debug( - f"Orientation error: {angle_diff:.4f} rad, tolerance: {self.angle_tolerance:.4f} rad" - ) return angle_diff <= self.angle_tolerance def _update_waypoint_target(self, robot_pos_np: np.ndarray) -> bool: @@ -455,31 +531,32 @@ def _update_waypoint_target(self, robot_pos_np: np.ndarray) -> bool: if self.waypoints is None or len(self.waypoints) == 0: return False # Not in waypoint mode or empty path - self.waypoints_in_odom = self.transform.transform_path( - self.waypoints, source_frame=self.waypoint_frame, target_frame="odom" - ) + # Waypoints are always in absolute frame + self.waypoints_in_absolute = self.waypoints # Check if final goal is reached - final_waypoint = self.waypoints_in_odom[-1] + final_waypoint = self.waypoints_in_absolute[-1] dist_to_final = np.linalg.norm(robot_pos_np - final_waypoint) - if dist_to_final < self.goal_tolerance: - self.position_reached = True - self.goal_xy = to_tuple(final_waypoint) - - # If goal orientation is not specified or achieved, consider fully reached - if self.goal_theta is None or self._is_goal_orientation_reached(): + if dist_to_final <= self.goal_tolerance: + # Final waypoint position reached + if self.goal_theta is not None: + # Check orientation if specified + if self._is_goal_orientation_reached(): + self.final_goal_reached = True + return True + # Continue rotating + self.position_reached = True + return False + else: + # No orientation goal, mark as reached self.final_goal_reached = True - logger.info("Reached final waypoint with correct orientation.") return True - else: - logger.info("Reached final waypoint position, rotating to target orientation.") - return False # Always find the lookahead point lookahead_point = None - for i in range(self.current_waypoint_index, len(self.waypoints_in_odom)): - wp = self.waypoints_in_odom[i] + for i in range(self.current_waypoint_index, len(self.waypoints_in_absolute)): + wp = self.waypoints_in_absolute[i] dist_to_wp = np.linalg.norm(robot_pos_np - wp) if dist_to_wp >= self.lookahead_distance: lookahead_point = wp @@ -489,14 +566,13 @@ def _update_waypoint_target(self, robot_pos_np: np.ndarray) -> bool: # If no point is far enough, target the final waypoint if lookahead_point is None: - lookahead_point = self.waypoints_in_odom[-1] - self.current_waypoint_index = len(self.waypoints_in_odom) - 1 + lookahead_point = self.waypoints_in_absolute[-1] + self.current_waypoint_index = len(self.waypoints_in_absolute) - 1 # Set the lookahead point as the immediate target, adjusting if needed if not self.is_goal_in_costmap_bounds(lookahead_point) or self.check_goal_collision( lookahead_point ): - logger.debug("Lookahead point is invalid. Adjusting...") adjusted_lookahead = self.adjust_goal_to_valid_position(lookahead_point) # Only update if adjustment didn't fail completely if adjusted_lookahead is not None: @@ -569,12 +645,15 @@ def is_goal_reached(self) -> bool: """Check if the final goal (single or last waypoint) is reached, including orientation.""" if self.waypoints is not None: # Waypoint mode: check if the final waypoint and orientation have been reached - return self.final_goal_reached + return self.final_goal_reached and self._is_goal_orientation_reached() else: # Single goal mode: check distance to the single goal and orientation if self.goal_xy is None: return False # No goal set + if self.goal_theta is None: + return self.position_reached + return self.position_reached and self._is_goal_orientation_reached() def check_goal_collision(self, goal_xy: VectorLike) -> bool: @@ -584,7 +663,7 @@ def check_goal_collision(self, goal_xy: VectorLike) -> bool: bool: True if goal is in collision, False if goal is safe or cannot be checked """ - costmap = self.get_costmap() + costmap = self._get_costmap() if costmap is None: logger.warning("Cannot check collision: No costmap available") return False @@ -604,7 +683,7 @@ def is_goal_in_costmap_bounds(self, goal_xy: VectorLike) -> bool: Returns: bool: True if the goal is within the costmap bounds, False otherwise """ - costmap = self.get_costmap() + costmap = self._get_costmap() if costmap is None: logger.warning("Cannot check bounds: No costmap available") return False @@ -633,7 +712,7 @@ def adjust_goal_to_valid_position( Returns: Tuple[float, float]: A valid goal position, or the original goal if already valid """ - [pos, rot] = self.transform.transform_euler("base_link", "odom") + [pos, rot] = self._get_robot_pose() robot_x, robot_y = pos[0], pos[1] @@ -701,26 +780,17 @@ def adjust_goal_to_valid_position( clearance_x = current_x + dx * clearance clearance_y = current_y + dy * clearance - logger.info( - f"Checking clearance position at ({clearance_x:.2f}, {clearance_y:.2f})" - ) - # Check if the clearance position is also valid if not self.check_goal_collision( (clearance_x, clearance_y) ) and self.is_goal_in_costmap_bounds((clearance_x, clearance_y)): - logger.info( - f"Found valid goal with clearance at ({clearance_x:.2f}, {clearance_y:.2f})" - ) return (clearance_x, clearance_y) # Return the valid position without clearance - logger.info(f"Found valid goal at ({current_x:.2f}, {current_y:.2f})") return (current_x, current_y) # If we found a valid position earlier but couldn't add clearance if valid_found: - logger.info(f"Using valid goal found at ({valid_x:.2f}, {valid_y:.2f})") return (valid_x, valid_y) logger.warning( @@ -731,6 +801,7 @@ def adjust_goal_to_valid_position( def check_if_stuck(self) -> bool: """ Check if the robot is stuck by analyzing movement history. + Includes improvements to prevent oscillation between stuck and recovered states. Returns: bool: True if the robot is determined to be stuck, False otherwise @@ -739,18 +810,60 @@ def check_if_stuck(self) -> bool: current_time = time.time() # Get current robot position - [pos, _] = self.transform.transform_euler("base_link", "odom") + [pos, _] = self._get_robot_pose() current_position = (pos[0], pos[1], current_time) + # If we're already in recovery, don't add movements to history (they're intentional) + # Instead, check if we should continue or end recovery + if self.is_recovery_active: + # Check if we've moved far enough from our pre-recovery position to consider unstuck + if self.pre_recovery_position is not None: + pre_recovery_x, pre_recovery_y = self.pre_recovery_position[:2] + displacement_from_start = np.sqrt( + (pos[0] - pre_recovery_x) ** 2 + (pos[1] - pre_recovery_y) ** 2 + ) + + # If we've moved far enough, we're unstuck + if displacement_from_start > self.unstuck_distance_threshold: + logger.info( + f"Robot has escaped from stuck state (moved {displacement_from_start:.3f}m from start)" + ) + self.is_recovery_active = False + self.last_recovery_end_time = current_time + # Do not reset recovery attempts here - only reset during replanning or goal reaching + # Clear position history to start fresh tracking + self.position_history.clear() + return False + + # Check if we've been trying to recover for too long + recovery_time = current_time - self.recovery_start_time + if recovery_time > self.recovery_duration: + logger.error( + f"Recovery behavior has been active for {self.recovery_duration}s without success" + ) + self.navigation_failed = True + return True + + # Continue recovery + return True + + # Check cooldown period - don't immediately check for stuck after recovery + if current_time - self.last_recovery_end_time < self.recovery_cooldown_time: + # Add position to history but don't check for stuck yet + self.position_history.append(current_position) + return False + # Add current position to history (newest is appended at the end) self.position_history.append(current_position) # Need enough history to make a determination - min_history_size = self.stuck_detection_window_seconds * self.control_frequency + min_history_size = int( + self.stuck_detection_window_seconds * self.control_frequency * 0.6 + ) # 60% of window if len(self.position_history) < min_history_size: return False - # Find positions within our detection window (positions are already in order from oldest to newest) + # Find positions within our detection window window_start_time = current_time - self.stuck_detection_window_seconds window_positions = [] @@ -770,78 +883,109 @@ def check_if_stuck(self) -> bool: oldest_x, oldest_y, oldest_time = window_positions[0] newest_x, newest_y, newest_time = window_positions[-1] - # Calculate time range in the window (should always be positive) + # Calculate time range in the window time_range = newest_time - oldest_time # Calculate displacement from oldest to newest position displacement = np.sqrt((newest_x - oldest_x) ** 2 + (newest_y - oldest_y) ** 2) + # Also check average displacement over multiple sub-windows to avoid false positives + sub_window_size = max(3, len(window_positions) // 3) + avg_displacement = 0.0 + displacement_count = 0 + + for i in range(0, len(window_positions) - sub_window_size, sub_window_size // 2): + start_pos = window_positions[i] + end_pos = window_positions[min(i + sub_window_size, len(window_positions) - 1)] + sub_displacement = np.sqrt( + (end_pos[0] - start_pos[0]) ** 2 + (end_pos[1] - start_pos[1]) ** 2 + ) + avg_displacement += sub_displacement + displacement_count += 1 + + if displacement_count > 0: + avg_displacement /= displacement_count + # Check if we're stuck - moved less than threshold over minimum time - # Only consider it if the time range makes sense (positive and sufficient) is_currently_stuck = ( time_range >= self.stuck_time_threshold and time_range <= self.stuck_detection_window_seconds and displacement < self.stuck_distance_threshold + and avg_displacement < self.stuck_distance_threshold * 1.5 ) if is_currently_stuck: logger.warning( - f"Robot appears to be stuck! Displacement {displacement:.3f}m over {time_range:.1f}s" + f"Robot appears to be stuck! Total displacement: {displacement:.3f}m, " + f"avg displacement: {avg_displacement:.3f}m over {time_range:.1f}s" ) - # Don't trigger recovery if it's already active - if not self.is_recovery_active: - self.is_recovery_active = True - self.recovery_start_time = current_time - return True + # Start recovery behavior + self.is_recovery_active = True + self.recovery_start_time = current_time + self.pre_recovery_position = current_position - # Check if we've been trying to recover for too long - elif current_time - self.recovery_start_time > self.recovery_duration: + # Clear position history to avoid contamination during recovery + self.position_history.clear() + + # Increment recovery attempts + self.recovery_attempts += 1 + logger.warning( + f"Starting recovery attempt {self.recovery_attempts}/{self.max_recovery_attempts}" + ) + + # Check if maximum recovery attempts have been exceeded + if self.recovery_attempts > self.max_recovery_attempts: logger.error( - f"Recovery behavior has been active for {self.recovery_duration}s without success" + f"Maximum recovery attempts ({self.max_recovery_attempts}) exceeded. Navigation failed." ) - # Reset recovery state - maybe a different behavior will work - self.is_recovery_active = False - self.recovery_start_time = current_time + self.navigation_failed = True - # If we've moved enough, we're not stuck anymore - elif self.is_recovery_active and displacement > self.unstuck_distance_threshold: - logger.info(f"Robot has escaped from stuck state (moved {displacement:.3f}m)") - self.is_recovery_active = False + return True - return self.is_recovery_active + return False def execute_recovery_behavior(self) -> Dict[str, float]: """ - Execute a recovery behavior when the robot is stuck. + Execute enhanced recovery behavior when the robot is stuck. + - First attempt: Backup for a set duration + - Second+ attempts: Replan to the original goal using global planner Returns: Dict[str, float]: Velocity commands for the recovery behavior """ - # Calculate how long we've been in recovery - recovery_time = time.time() - self.recovery_start_time - - # Calculate recovery phases based on control frequency - backup_phase_time = 3.0 # seconds - rotate_phase_time = 2.0 # seconds - - # Simple recovery behavior state machine - if recovery_time < backup_phase_time: - # First try backing up - logger.info("Recovery: backing up") - return {"x_vel": -0.2, "angular_vel": 0.0} - elif recovery_time < backup_phase_time + rotate_phase_time: - # Then try rotating - logger.info("Recovery: rotating to find new path") - rotation_direction = 1.0 if np.random.random() > 0.5 else -1.0 - return {"x_vel": 0.0, "angular_vel": rotation_direction * self.max_angular_vel * 0.7} + current_time = time.time() + recovery_time = current_time - self.recovery_start_time + + # First recovery attempt: Simple backup behavior + if self.recovery_attempts % 2 == 0: + if recovery_time < self.backup_duration: + logger.warning(f"Recovery attempt 1: backup for {recovery_time:.1f}s") + return {"x_vel": -0.5, "angular_vel": 0.0} # Backup at moderate speed + else: + logger.info("Recovery attempt 1: backup completed") + self.recovery_attempts += 1 + return {"x_vel": 0.0, "angular_vel": 0.0} + + final_goal = self.waypoints_in_absolute[-1] + logger.info( + f"Recovery attempt {self.recovery_attempts}: replanning to final waypoint {final_goal}" + ) + + new_path = self.global_planner_plan(Vector([final_goal[0], final_goal[1]])) + + if new_path is not None: + logger.info("Replanning successful. Setting new waypoints.") + attempts = self.recovery_attempts + self.set_goal_waypoints(new_path, self.goal_theta) + self.recovery_attempts = attempts + self.is_recovery_active = False + self.last_recovery_end_time = current_time else: - # If we're still stuck after backup and rotation, terminate navigation - logger.error("Recovery failed after backup and rotation. Navigation terminated.") - # Set a flag to indicate navigation should terminate + logger.error("Global planner could not find a path to the goal. Recovery failed.") self.navigation_failed = True - # Stop the robot - return {"x_vel": 0.0, "angular_vel": 0.0} + + return {"x_vel": 0.0, "angular_vel": 0.0} def navigate_to_goal_local( @@ -872,6 +1016,8 @@ def navigate_to_goal_local( f"Starting navigation to local goal {goal_xy_robot} with distance {distance}m and timeout {timeout}s." ) + robot.local_planner.reset() + goal_x, goal_y = goal_xy_robot # Calculate goal orientation to face the target @@ -888,7 +1034,7 @@ def navigate_to_goal_local( goal_x, goal_y = distance_angle_to_goal_xy(goal_distance - distance, goal_theta) # Set the goal in the robot's frame with orientation to face the original target - robot.local_planner.set_goal((goal_x, goal_y), frame="base_link", goal_theta=goal_theta) + robot.local_planner.set_goal((goal_x, goal_y), is_relative=True, goal_theta=goal_theta) # Get control period from robot's local planner for consistent timing control_period = 1.0 / robot.local_planner.control_frequency @@ -916,7 +1062,7 @@ def navigate_to_goal_local( angular_vel = vel_command.get("angular_vel", 0.0) # Send velocity command - robot.local_planner.move_vel_control(x=x_vel, y=0, yaw=angular_vel) + robot.local_planner.move(Vector(x_vel, 0, angular_vel)) # Control loop frequency - use robot's control frequency time.sleep(control_period) @@ -932,7 +1078,7 @@ def navigate_to_goal_local( goal_reached = False # Consider error as failure finally: logger.info("Stopping robot after navigation attempt.") - robot.local_planner.move_vel_control(0, 0, 0) # Stop the robot + robot.local_planner.move(Vector(0, 0, 0)) # Stop the robot return goal_reached @@ -950,7 +1096,7 @@ def navigate_path_local( Args: robot: Robot instance to control - path: Path object containing waypoints in odom/map frame + path: Path object containing waypoints in absolute frame timeout: Maximum time (in seconds) allowed to follow the complete path goal_theta: Optional final orientation in radians stop_event: Optional threading.Event to signal when navigation should stop @@ -962,6 +1108,8 @@ def navigate_path_local( f"Starting navigation along path with {len(path)} waypoints and timeout {timeout}s." ) + robot.local_planner.reset() + # Set the path in the local planner robot.local_planner.set_goal_waypoints(path, goal_theta=goal_theta) @@ -991,7 +1139,7 @@ def navigate_path_local( angular_vel = vel_command.get("angular_vel", 0.0) # Send velocity command - robot.local_planner.move_vel_control(x=x_vel, y=0, yaw=angular_vel) + robot.local_planner.move(Vector(x_vel, 0, angular_vel)) # Control loop frequency - use robot's control frequency time.sleep(control_period) @@ -1009,7 +1157,7 @@ def navigate_path_local( path_completed = False finally: logger.info("Stopping robot after path navigation attempt.") - robot.local_planner.move_vel_control(0, 0, 0) # Stop the robot + robot.local_planner.move(Vector(0, 0, 0)) # Stop the robot return path_completed @@ -1017,7 +1165,7 @@ def navigate_path_local( def visualize_local_planner_state( occupancy_grid: np.ndarray, grid_resolution: float, - grid_origin: Tuple[float, float, float], + grid_origin: Tuple[float, float], robot_pose: Tuple[float, float, float], visualization_size: int = 400, robot_width: float = 0.5, @@ -1036,7 +1184,7 @@ def visualize_local_planner_state( Args: occupancy_grid: 2D numpy array of the occupancy grid grid_resolution: Resolution of the grid in meters/cell - grid_origin: Tuple (x, y, theta) of the grid origin in the odom frame + grid_origin: Tuple (x, y) of the grid origin in the odom frame robot_pose: Tuple (x, y, theta) of the robot pose in the odom frame visualization_size: Size of the visualization image in pixels robot_width: Width of the robot in meters @@ -1051,7 +1199,7 @@ def visualize_local_planner_state( """ robot_x, robot_y, robot_theta = robot_pose - grid_origin_x, grid_origin_y, _ = grid_origin + grid_origin_x, grid_origin_y = grid_origin vis_size = visualization_size scale = vis_size / map_size_meters diff --git a/dimos/robot/local_planner/vfh_local_planner.py b/dimos/robot/local_planner/vfh_local_planner.py index df0ebcaee2..f97701e5a5 100644 --- a/dimos/robot/local_planner/vfh_local_planner.py +++ b/dimos/robot/local_planner/vfh_local_planner.py @@ -15,16 +15,16 @@ # limitations under the License. import numpy as np -from typing import Dict, Tuple, Optional, Callable +from typing import Dict, Tuple, Optional, Callable, Any import cv2 import logging from dimos.utils.logging_config import setup_logger -from dimos.utils.ros_utils import normalize_angle +from dimos.utils.transform_utils import normalize_angle from dimos.robot.local_planner.local_planner import BaseLocalPlanner, visualize_local_planner_state from dimos.types.costmap import Costmap -from nav_msgs.msg import OccupancyGrid +from dimos.types.vector import Vector, VectorLike logger = setup_logger("dimos.robot.unitree.vfh_local_planner", level=logging.DEBUG) @@ -37,29 +37,31 @@ class VFHPurePursuitPlanner(BaseLocalPlanner): def __init__( self, - get_costmap: Callable[[], Optional[OccupancyGrid]], - transform: object, - move_vel_control: Callable[[float, float, float], None], + get_costmap: Callable[[], Optional[Costmap]], + get_robot_pose: Callable[[], Any], + move: Callable[[Vector], None], safety_threshold: float = 0.8, histogram_bins: int = 144, max_linear_vel: float = 0.8, max_angular_vel: float = 1.0, lookahead_distance: float = 1.0, - goal_tolerance: float = 0.2, + goal_tolerance: float = 0.4, angle_tolerance: float = 0.1, # ~5.7 degrees robot_width: float = 0.5, robot_length: float = 0.7, visualization_size: int = 400, control_frequency: float = 10.0, safe_goal_distance: float = 1.0, + max_recovery_attempts: int = 3, + global_planner_plan: Optional[Callable[[VectorLike], Optional[Any]]] = None, ): """ Initialize the VFH + Pure Pursuit planner. Args: get_costmap: Function to get the latest local costmap - transform: Object with transform methods (transform_point, transform_rot, etc.) - move_vel_control: Function to send velocity commands + get_robot_pose: Function to get the latest robot pose (returning odom object) + move: Function to send velocity commands safety_threshold: Distance to maintain from obstacles (meters) histogram_bins: Number of directional bins in the polar histogram max_linear_vel: Maximum linear velocity (m/s) @@ -72,12 +74,14 @@ def __init__( visualization_size: Size of the visualization image in pixels control_frequency: Frequency at which the planner is called (Hz) safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) + max_recovery_attempts: Maximum number of recovery attempts + global_planner_plan: Optional function to get the global plan """ # Initialize base class super().__init__( get_costmap=get_costmap, - transform=transform, - move_vel_control=move_vel_control, + get_robot_pose=get_robot_pose, + move=move, safety_threshold=safety_threshold, max_linear_vel=max_linear_vel, max_angular_vel=max_angular_vel, @@ -89,6 +93,8 @@ def __init__( visualization_size=visualization_size, control_frequency=control_frequency, safe_goal_distance=safe_goal_distance, + max_recovery_attempts=max_recovery_attempts, + global_planner_plan=global_planner_plan, ) # VFH specific parameters @@ -97,10 +103,10 @@ def __init__( self.selected_direction = None # VFH tuning parameters - self.alpha = 0.2 # Histogram smoothing factor - self.obstacle_weight = 10.0 - self.goal_weight = 1.0 - self.prev_direction_weight = 0.5 + self.alpha = 0.25 # Histogram smoothing factor + self.obstacle_weight = 5.0 + self.goal_weight = 2.0 + self.prev_direction_weight = 1.0 self.prev_selected_angle = 0.0 self.prev_linear_vel = 0.0 self.linear_vel_filter_factor = 0.4 @@ -118,13 +124,13 @@ def _compute_velocity_commands(self) -> Dict[str, float]: Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys """ # Get necessary data for planning - costmap = self.get_costmap() + costmap = self._get_costmap() if costmap is None: logger.warning("No costmap available for planning") return {"x_vel": 0.0, "angular_vel": 0.0} - [pos, rot] = self.transform.transform_euler("base_link", "odom") - robot_x, robot_y, robot_theta = pos[0], pos[1], rot[2] + robot_pos, robot_theta = self._get_robot_pose() + robot_x, robot_y = robot_pos robot_pose = (robot_x, robot_y, robot_theta) # Calculate goal-related parameters @@ -175,7 +181,6 @@ def _compute_velocity_commands(self) -> Dict[str, float]: ) if self.check_collision(0.0, safety_threshold=self.safety_threshold): - logger.warning("Collision detected ahead. Stopping.") linear_vel = 0.0 self.prev_linear_vel = linear_vel @@ -337,12 +342,12 @@ def check_collision(self, selected_direction: float, safety_threshold: float = 1 return False # Get the latest costmap and robot pose - costmap = self.get_costmap() + costmap = self._get_costmap() if costmap is None: return False # No costmap available - [pos, rot] = self.transform.transform_euler("base_link", "odom") - robot_x, robot_y, robot_theta = pos[0], pos[1], rot[2] + robot_pos, robot_theta = self._get_robot_pose() + robot_x, robot_y = robot_pos # Direction in world frame direction_world = robot_theta + selected_direction @@ -373,12 +378,12 @@ def check_collision(self, selected_direction: float, safety_threshold: float = 1 def update_visualization(self) -> np.ndarray: """Generate visualization of the planning state.""" try: - costmap = self.get_costmap() + costmap = self._get_costmap() if costmap is None: raise ValueError("Costmap is None") - [pos, rot] = self.transform.transform_euler("base_link", "odom") - robot_x, robot_y, robot_theta = pos[0], pos[1], rot[2] + robot_pos, robot_theta = self._get_robot_pose() + robot_x, robot_y = robot_pos robot_pose = (robot_x, robot_y, robot_theta) goal_xy = self.goal_xy # This could be a lookahead point or final goal @@ -388,9 +393,9 @@ def update_visualization(self) -> np.ndarray: selected_direction = getattr(self, "selected_direction", None) # Get waypoint data if in waypoint mode - waypoints_to_draw = self.waypoints_in_odom + waypoints_to_draw = self.waypoints_in_absolute current_wp_index_to_draw = ( - self.current_waypoint_index if self.waypoints_in_odom is not None else None + self.current_waypoint_index if self.waypoints_in_absolute is not None else None ) # Ensure index is valid before passing if waypoints_to_draw is not None and current_wp_index_to_draw is not None: @@ -400,7 +405,7 @@ def update_visualization(self) -> np.ndarray: return visualize_local_planner_state( occupancy_grid=costmap.grid, grid_resolution=costmap.resolution, - grid_origin=(costmap.origin.x, costmap.origin.y, costmap.origin_theta), + grid_origin=(costmap.origin.x, costmap.origin.y), robot_pose=robot_pose, goal_xy=goal_xy, # Current target (lookahead or final) goal_theta=self.goal_theta, # Pass goal orientation if available diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py index 90022023ca..58526b5f0c 100644 --- a/dimos/robot/robot.py +++ b/dimos/robot/robot.py @@ -19,20 +19,17 @@ and video streaming. """ -from abc import ABC +from abc import ABC, abstractmethod import os -from typing import TYPE_CHECKING, Optional, List, Union, Dict, Any +from typing import Optional, List, Union, Dict, Any from dimos.hardware.interface import HardwareInterface from dimos.perception.spatial_perception import SpatialMemory from dimos.manipulation.manipulation_interface import ManipulationInterface from dimos.types.robot_capabilities import RobotCapability +from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger - -if TYPE_CHECKING: - from dimos.robot.ros_control import ROSControl -else: - ROSControl = "ROSControl" +from dimos.robot.connection_interface import ConnectionInterface from dimos.skills.skills import SkillLibrary from reactivex import Observable, operators as ops @@ -65,7 +62,7 @@ class Robot(ABC): def __init__( self, hardware_interface: HardwareInterface = None, - ros_control: ROSControl = None, + connection_interface: ConnectionInterface = None, output_dir: str = os.path.join(os.getcwd(), "assets", "output"), pool_scheduler: ThreadPoolScheduler = None, skill_library: SkillLibrary = None, @@ -73,24 +70,29 @@ def __init__( new_memory: bool = False, capabilities: List[RobotCapability] = None, video_stream: Optional[Observable] = None, + enable_perception: bool = True, ): """Initialize a Robot instance. Args: hardware_interface: Interface to the robot's hardware. Defaults to None. - ros_control: ROS-based control system. Defaults to None. + connection_interface: Connection interface for robot control and communication. output_dir: Directory for storing output files. Defaults to "./assets/output". pool_scheduler: Thread pool scheduler. If None, one will be created. skill_library: Skill library instance. If None, one will be created. spatial_memory_collection: Name of the collection in the ChromaDB database. new_memory: If True, creates a new spatial memory from scratch. Defaults to False. + capabilities: List of robot capabilities. Defaults to None. + video_stream: Optional video stream. Defaults to None. + enable_perception: If True, enables perception streams and spatial memory. Defaults to True. """ self.hardware_interface = hardware_interface - self.ros_control = ros_control + self.connection_interface = connection_interface self.output_dir = output_dir self.disposables = CompositeDisposable() self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() self.skill_library = skill_library if skill_library else SkillLibrary() + self.enable_perception = enable_perception # Initialize robot capabilities self.capabilities = capabilities or [] @@ -113,35 +115,29 @@ def __init__( os.makedirs(self.spatial_memory_dir, exist_ok=True) os.makedirs(self.db_path, exist_ok=True) - # Initialize spatial memory - this will be handled by SpatialMemory class + # Initialize spatial memory properties self._video_stream = video_stream - transform_provider = None - - # Only create video stream if ROS control is available - if self.ros_control is not None and self.ros_control.video_provider is not None: - # Get video stream - self._video_stream = self.get_ros_video_stream(fps=10) # Lower FPS for processing - - # Define transform provider - def transform_provider(): - position, rotation = self.ros_control.transform_euler("base_link") - if position is None or rotation is None: - return {"position": None, "rotation": None} - return {"position": position, "rotation": rotation} - - # Avoids circular imports - from dimos.perception.spatial_perception import SpatialMemory - - # Create SpatialMemory instance - it will handle all initialization internally - self._spatial_memory = SpatialMemory( - collection_name=self.spatial_memory_collection, - db_path=self.db_path, - visual_memory_path=self.visual_memory_path, - new_memory=new_memory, - output_dir=self.spatial_memory_dir, - video_stream=self._video_stream, - transform_provider=transform_provider, - ) + + # Only create video stream if connection interface is available + if self.connection_interface is not None: + # Get video stream - always create this, regardless of enable_perception + self._video_stream = self.get_video_stream(fps=10) # Lower FPS for processing + + # Create SpatialMemory instance only if perception is enabled + if self.enable_perception: + self._spatial_memory = SpatialMemory( + collection_name=self.spatial_memory_collection, + db_path=self.db_path, + visual_memory_path=self.visual_memory_path, + new_memory=new_memory, + output_dir=self.spatial_memory_dir, + video_stream=self._video_stream, + get_pose=self.get_pose, + ) + logger.info("Spatial memory initialized") + else: + self._spatial_memory = None + logger.info("Spatial memory disabled (enable_perception=False)") # Initialize manipulation interface if the robot has manipulation capability self._manipulation_interface = None @@ -158,8 +154,8 @@ def transform_provider(): ) logger.info("Manipulation interface initialized") - def get_ros_video_stream(self, fps: int = 30) -> Observable: - """Get the ROS video stream with rate limiting and frame processing. + def get_video_stream(self, fps: int = 30) -> Observable: + """Get the video stream with rate limiting and frame processing. Args: fps: Frames per second for the video stream. Defaults to 30. @@ -168,58 +164,39 @@ def get_ros_video_stream(self, fps: int = 30) -> Observable: Observable: An observable stream of video frames. Raises: - RuntimeError: If no ROS video provider is available. + RuntimeError: If no connection interface is available for video streaming. """ - if not self.ros_control or not self.ros_control.video_provider: - raise RuntimeError("No ROS video provider available") + if self.connection_interface is None: + raise RuntimeError("No connection interface available for video streaming") - print(f"Starting ROS video stream at {fps} FPS...") + stream = self.connection_interface.get_video_stream(fps) + if stream is None: + raise RuntimeError("No video stream available from connection interface") - # Get base stream from video provider - video_stream = self.ros_control.video_provider.capture_video_as_observable(fps=fps) - - # Add minimal processing pipeline with proper thread handling - processed_stream = video_stream.pipe( - ops.subscribe_on(self.pool_scheduler), - ops.observe_on(self.pool_scheduler), # Ensure thread safety - ops.share(), # Share the stream + return stream.pipe( + ops.observe_on(self.pool_scheduler), ) - return processed_stream - - def move(self, distance: float, speed: float = 0.5) -> bool: + def move(self, velocity: Vector, duration: float = 0.0) -> bool: """Move the robot using velocity commands. - DEPRECATED: Use move_vel instead for direct velocity control. - Args: - distance: Distance to move forward in meters (must be positive). - speed: Speed to move at in m/s. Defaults to 0.5. + velocity: Velocity vector [x, y, yaw] where: + x: Linear velocity in x direction (m/s) + y: Linear velocity in y direction (m/s) + yaw: Angular velocity (rad/s) + duration: Duration to apply command (seconds). If 0, apply once. Returns: bool: True if movement succeeded. Raises: - RuntimeError: If no ROS control interface is available. + RuntimeError: If no connection interface is available. """ - pass - - def reverse(self, distance: float, speed: float = 0.5) -> bool: - """Move the robot backward by a specified distance. - - DEPRECATED: Use move_vel with negative x value instead for direct velocity control. + if self.connection_interface is None: + raise RuntimeError("No connection interface available for movement") - Args: - distance: Distance to move backward in meters (must be positive). - speed: Speed to move at in m/s. Defaults to 0.5. - - Returns: - bool: True if movement succeeded. - - Raises: - RuntimeError: If no ROS control interface is available. - """ - pass + return self.connection_interface.move(velocity, duration) def spin(self, degrees: float, speed: float = 45.0) -> bool: """Rotate the robot by a specified angle. @@ -233,11 +210,35 @@ def spin(self, degrees: float, speed: float = 45.0) -> bool: bool: True if rotation succeeded. Raises: - RuntimeError: If no ROS control interface is available. + RuntimeError: If no connection interface is available. + """ + if self.connection_interface is None: + raise RuntimeError("No connection interface available for rotation") + + # Convert degrees to radians + import math + + angular_velocity = math.radians(speed) + duration = abs(degrees) / speed if speed > 0 else 0 + + # Set direction based on sign of degrees + if degrees < 0: + angular_velocity = -angular_velocity + + velocity = Vector(0.0, 0.0, angular_velocity) + return self.connection_interface.move(velocity, duration) + + @abstractmethod + def get_pose(self) -> dict: + """ + Get the current pose (position and rotation) of the robot. + + Returns: + Dictionary containing: + - position: Tuple[float, float, float] (x, y, z) + - rotation: Tuple[float, float, float] (roll, pitch, yaw) in radians """ - if self.ros_control is None: - raise RuntimeError("No ROS control interface available for rotation") - return self.ros_control.spin(degrees, speed) + pass def webrtc_req( self, @@ -248,54 +249,40 @@ def webrtc_req( request_id: str = None, data=None, timeout: float = 1000.0, - ) -> bool: + ): """Send a WebRTC request command to the robot. Args: api_id: The API ID for the command. topic: The API topic to publish to. Defaults to ROSControl.webrtc_api_topic. - parameter: Optional parameter string. Defaults to ''. - priority: Priority level as defined by PriorityQueue(). Defaults to 0 (no priority). - data: Optional data dictionary. - timeout: Maximum time to wait for the command to complete. - - Returns: - bool: True if command was sent successfully. - - Raises: - RuntimeError: If no ROS control interface is available. - - """ - if self.ros_control is None: - raise RuntimeError("No ROS control interface available for WebRTC commands") - return self.ros_control.queue_webrtc_req( - api_id=api_id, - topic=topic, - parameter=parameter, - priority=priority, - request_id=request_id, - data=data, - timeout=timeout, - ) - - def move_vel(self, x: float, y: float, yaw: float, duration: float = 0.0) -> bool: - """Move the robot using direct movement commands. - - Args: - x: Forward/backward velocity (m/s) - y: Left/right velocity (m/s) - yaw: Rotational velocity (rad/s) - duration: How long to move (seconds). If 0, command is continuous + parameter: Additional parameter data. Defaults to "". + priority: Priority of the request. Defaults to 0. + request_id: Unique identifier for the request. If None, one will be generated. + data: Additional data to include with the request. Defaults to None. + timeout: Timeout for the request in milliseconds. Defaults to 1000.0. Returns: - bool: True if command was sent successfully + The result of the WebRTC request. Raises: - RuntimeError: If no ROS control interface is available. + RuntimeError: If no connection interface with WebRTC capability is available. """ - if self.ros_control is None: - raise RuntimeError("No ROS control interface available for movement") - return self.ros_control.move_vel(x, y, yaw, duration) + if self.connection_interface is None: + raise RuntimeError("No connection interface available for WebRTC commands") + + # WebRTC requests are only available on ROS control interfaces + if hasattr(self.connection_interface, "queue_webrtc_req"): + return self.connection_interface.queue_webrtc_req( + api_id=api_id, + topic=topic, + parameter=parameter, + priority=priority, + request_id=request_id, + data=data, + timeout=timeout, + ) + else: + raise RuntimeError("WebRTC requests not supported by this connection interface") def pose_command(self, roll: float, pitch: float, yaw: float) -> bool: """Send a pose command to the robot. @@ -309,11 +296,13 @@ def pose_command(self, roll: float, pitch: float, yaw: float) -> bool: bool: True if command was sent successfully. Raises: - RuntimeError: If no ROS control interface is available. + RuntimeError: If no connection interface with pose command capability is available. """ - if self.ros_control is None: - raise RuntimeError("No ROS control interface available for pose commands") - return self.ros_control.pose_command(roll, pitch, yaw) + # Pose commands are only available on ROS control interfaces + if hasattr(self.connection_interface, "pose_command"): + return self.connection_interface.pose_command(roll, pitch, yaw) + else: + raise RuntimeError("Pose commands not supported by this connection interface") def update_hardware_interface(self, new_hardware_interface: HardwareInterface): """Update the hardware interface with a new configuration. @@ -346,11 +335,11 @@ def set_hardware_configuration(self, configuration): self.hardware_interface.set_configuration(configuration) @property - def spatial_memory(self) -> SpatialMemory: + def spatial_memory(self) -> Optional[SpatialMemory]: """Get the robot's spatial memory. Returns: - SpatialMemory: The robot's spatial memory system. + SpatialMemory: The robot's spatial memory system, or None if perception is disabled. """ return self._spatial_memory @@ -392,6 +381,14 @@ def video_stream(self) -> Optional[Observable]: """ return self._video_stream + def get_skills(self): + """Get the robot's skill library. + + Returns: + The robot's skill library for adding/managing skills. + """ + return self.skill_library + def cleanup(self): """Clean up resources used by the robot. @@ -403,8 +400,10 @@ def cleanup(self): if self.disposables: self.disposables.dispose() - if self.ros_control: - self.ros_control.cleanup() + # Clean up connection interface + if self.connection_interface: + self.connection_interface.disconnect() + self.disposables.dispose() diff --git a/dimos/robot/ros_control.py b/dimos/robot/ros_control.py index 454f41c2b6..6aa51fc3a8 100644 --- a/dimos/robot/ros_control.py +++ b/dimos/robot/ros_control.py @@ -44,6 +44,8 @@ import tf2_ros from dimos.robot.ros_transform import ROSTransformAbility from dimos.robot.ros_observable_topic import ROSObservableTopicAbility +from dimos.robot.connection_interface import ConnectionInterface +from dimos.types.vector import Vector from nav_msgs.msg import Odometry @@ -62,7 +64,7 @@ class RobotMode(Enum): ERROR = auto() -class ROSControl(ROSTransformAbility, ROSObservableTopicAbility, ABC): +class ROSControl(ROSTransformAbility, ROSObservableTopicAbility, ConnectionInterface, ABC): """Abstract base class for ROS-controlled robots""" def __init__( @@ -412,6 +414,20 @@ def video_provider(self) -> Optional[ROSVideoProvider]: """Data provider property for streaming data""" return self._video_provider + def get_video_stream(self, fps: int = 30) -> Optional[Observable]: + """Get the video stream from the robot's camera. + + Args: + fps: Frames per second for the video stream + + Returns: + Observable: An observable stream of video frames or None if not available + """ + if not self.video_provider: + return None + + return self.video_provider.get_stream(fps=fps) + def _send_action_client_goal(self, client, goal_msg, description=None, time_allowance=20.0): """ Generic function to send any action client goal and wait for completion. @@ -459,61 +475,46 @@ def _send_action_client_goal(self, client, goal_msg, description=None, time_allo logger.error("Action failed") return False - def move(self, distance: float, speed: float = 0.5, time_allowance: float = 120) -> bool: - """ - Move the robot forward by a specified distance + def move(self, velocity: Vector, duration: float = 0.0) -> bool: + """Send velocity commands to the robot. Args: - distance: Distance to move forward in meters (must be positive) - speed: Speed to move at in m/s (default 0.5) - time_allowance: Maximum time to wait for the request to complete + velocity: Velocity vector [x, y, yaw] where: + x: Linear velocity in x direction (m/s) + y: Linear velocity in y direction (m/s) + yaw: Angular velocity around z axis (rad/s) + duration: Duration to apply command (seconds). If 0, apply once. Returns: - bool: True if movement succeeded + bool: True if command was sent successfully """ - try: - if distance <= 0: - logger.error("Distance must be positive") - return False - - speed = min(abs(speed), self.MAX_LINEAR_VELOCITY) + x, y, yaw = velocity.x, velocity.y, velocity.z - # Define function to execute the move - def execute_move(): - # Create DriveOnHeading goal - goal = DriveOnHeading.Goal() - goal.target.x = distance - goal.target.y = 0.0 - goal.target.z = 0.0 - goal.speed = speed - goal.time_allowance = Duration(sec=time_allowance) - - logger.info(f"Moving forward: distance={distance}m, speed={speed}m/s") + # Clamp velocities to safe limits + x = self._clamp_velocity(x, self.MAX_LINEAR_VELOCITY) + y = self._clamp_velocity(y, self.MAX_LINEAR_VELOCITY) + yaw = self._clamp_velocity(yaw, self.MAX_ANGULAR_VELOCITY) - return self._send_action_client_goal( - self._drive_client, - goal, - f"Sending Action Client goal in ROSControl.execute_move for {distance}m at {speed}m/s", - time_allowance, - ) + # Create and send command + cmd = Twist() + cmd.linear.x = float(x) + cmd.linear.y = float(y) + cmd.angular.z = float(yaw) - # Queue the action - cmd_id = self._command_queue.queue_action_client_request( - action_name="move", - execute_func=execute_move, - priority=0, - timeout=time_allowance, - distance=distance, - speed=speed, - ) - logger.info(f"Queued move command: {cmd_id} - Distance: {distance}m, Speed: {speed}m/s") + try: + if duration > 0: + start_time = time.time() + while time.time() - start_time < duration: + self._move_vel_pub.publish(cmd) + time.sleep(0.1) # 10Hz update rate + # Stop after duration + self.stop() + else: + self._move_vel_pub.publish(cmd) return True except Exception as e: - logger.error(f"Forward movement failed: {e}") - import traceback - - logger.error(traceback.format_exc()) + self._logger.error(f"Failed to send movement command: {e}") return False def reverse(self, distance: float, speed: float = 0.5, time_allowance: float = 120) -> bool: @@ -644,32 +645,6 @@ def execute_spin(): logger.error(traceback.format_exc()) return False - def _goal_response_callback(self, future): - """Handle the goal response.""" - goal_handle = future.result() - if not goal_handle.accepted: - logger.warn("Goal was rejected!") - print("[ROSControl] Goal was REJECTED by the action server") - self._action_success = False - return - - logger.info("Goal accepted") - print("[ROSControl] Goal was ACCEPTED by the action server") - result_future = goal_handle.get_result_async() - result_future.add_done_callback(self._goal_result_callback) - - def _goal_result_callback(self, future): - """Handle the goal result.""" - try: - result = future.result().result - logger.info("Goal completed") - print(f"[ROSControl] Goal COMPLETED with result: {result}") - self._action_success = True - except Exception as e: - logger.error(f"Goal failed with error: {e}") - print(f"[ROSControl] Goal FAILED with error: {e}") - self._action_success = False - def stop(self) -> bool: """Stop all robot movement""" try: @@ -697,6 +672,10 @@ def cleanup(self): self._node.destroy_node() rclpy.shutdown() + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + self.cleanup() + def webrtc_req( self, api_id: int, @@ -787,46 +766,6 @@ def queue_webrtc_req( data=data, ) - def move_vel(self, x: float, y: float, yaw: float, duration: float = 0.0) -> bool: - """ - Send movement command to the robot using velocity commands - - Args: - x: Forward/backward velocity (m/s) - y: Left/right velocity (m/s) - yaw: Rotational velocity (rad/s) - duration: How long to move (seconds). If 0, command is continuous - - Returns: - bool: True if command was sent successfully - """ - # Clamp velocities to safe limits - x = self._clamp_velocity(x, self.MAX_LINEAR_VELOCITY) - y = self._clamp_velocity(y, self.MAX_LINEAR_VELOCITY) - yaw = self._clamp_velocity(yaw, self.MAX_ANGULAR_VELOCITY) - - # Create and send command - cmd = Twist() - cmd.linear.x = float(x) - cmd.linear.y = float(y) - cmd.angular.z = float(yaw) - - try: - if duration > 0: - start_time = time.time() - while time.time() - start_time < duration: - self._move_vel_pub.publish(cmd) - time.sleep(0.1) # 10Hz update rate - # Stop after duration - self.stop() - else: - self._move_vel_pub.publish(cmd) - return True - - except Exception as e: - self._logger.error(f"Failed to send movement command: {e}") - return False - def move_vel_control(self, x: float, y: float, yaw: float) -> bool: """ Send a single velocity command without duration handling. @@ -900,3 +839,29 @@ def get_position_stream(self): ) return position_provider.get_position_stream() + + def _goal_response_callback(self, future): + """Handle the goal response.""" + goal_handle = future.result() + if not goal_handle.accepted: + logger.warn("Goal was rejected!") + print("[ROSControl] Goal was REJECTED by the action server") + self._action_success = False + return + + logger.info("Goal accepted") + print("[ROSControl] Goal was ACCEPTED by the action server") + result_future = goal_handle.get_result_async() + result_future.add_done_callback(self._goal_result_callback) + + def _goal_result_callback(self, future): + """Handle the goal result.""" + try: + result = future.result().result + logger.info("Goal completed") + print(f"[ROSControl] Goal COMPLETED with result: {result}") + self._action_success = True + except Exception as e: + logger.error(f"Goal failed with error: {e}") + print(f"[ROSControl] Goal FAILED with error: {e}") + self._action_success = False diff --git a/dimos/robot/unitree/unitree_go2.py b/dimos/robot/unitree/unitree_go2.py index 8667edfcf8..ca878e7134 100644 --- a/dimos/robot/unitree/unitree_go2.py +++ b/dimos/robot/unitree/unitree_go2.py @@ -13,24 +13,25 @@ # limitations under the License. import multiprocessing -from typing import Optional, Union, Tuple +from typing import Optional, Union, List import numpy as np from dimos.robot.robot import Robot from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary -from dimos.stream.video_providers.unitree import UnitreeVideoProvider from reactivex.disposable import CompositeDisposable import logging -from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod import os from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from reactivex.scheduler import ThreadPoolScheduler from dimos.utils.logging_config import setup_logger from dimos.perception.person_tracker import PersonTrackingStream from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.robot.local_planner import VFHPurePursuitPlanner, navigate_path_local +from dimos.robot.local_planner.local_planner import navigate_path_local +from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner from dimos.robot.global_planner.planner import AstarPlanner from dimos.types.costmap import Costmap +from dimos.types.robot_capabilities import RobotCapability +from dimos.types.vector import Vector # Set up logging logger = setup_logger("dimos.robot.unitree.unitree_go2", level=logging.DEBUG) @@ -41,58 +42,63 @@ class UnitreeGo2(Robot): + """Unitree Go2 robot implementation using ROS2 control interface. + + This class extends the base Robot class to provide specific functionality + for the Unitree Go2 quadruped robot using ROS2 for communication and control. + """ + def __init__( self, - ros_control: Optional[UnitreeROSControl] = None, - ip=None, - connection_method: WebRTCConnectionMethod = WebRTCConnectionMethod.LocalSTA, - serial_number: str = None, + video_provider=None, output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - use_ros: bool = True, - use_webrtc: bool = False, + skill_library: SkillLibrary = None, + robot_capabilities: List[RobotCapability] = None, + spatial_memory_collection: str = "spatial_memory", + new_memory: bool = False, disable_video_stream: bool = False, mock_connection: bool = False, - skills: Optional[Union[MyUnitreeSkills, AbstractSkill]] = None, - new_memory: bool = False, + enable_perception: bool = True, ): - """Initialize the UnitreeGo2 robot. + """Initialize UnitreeGo2 robot with ROS control interface. Args: - ros_control: ROS control interface, if None a new one will be created - ip: IP address of the robot (for LocalSTA connection) - connection_method: WebRTC connection method (LocalSTA or LocalAP) - serial_number: Serial number of the robot (for LocalSTA with serial) + video_provider: Provider for video streams output_dir: Directory for output files - use_ros: Whether to use ROSControl and ROS video provider - use_webrtc: Whether to use WebRTC video provider ONLY - disable_video_stream: Whether to disable the video stream - mock_connection: Whether to mock the connection to the robot - skills: Skills library or custom skill implementation. Default is MyUnitreeSkills() if None. - spatial_memory_dir: Directory for storing spatial memory data. If None, uses output_dir/spatial_memory. - spatial_memory_collection: Name of the collection in the ChromaDB database. - new_memory: If True, creates a new spatial memory from scratch. + skill_library: Library of robot skills + robot_capabilities: List of robot capabilities + spatial_memory_collection: Collection name for spatial memory + new_memory: Whether to create new memory collection + disable_video_stream: Whether to disable video streaming + mock_connection: Whether to use mock connection for testing + enable_perception: Whether to enable perception streams and spatial memory """ - print(f"Initializing UnitreeGo2 with use_ros: {use_ros} and use_webrtc: {use_webrtc}") - if not (use_ros ^ use_webrtc): # XOR operator ensures exactly one is True - raise ValueError("Exactly one video/control provider (ROS or WebRTC) must be enabled") - - # Initialize ros_control if it is not provided and use_ros is True - if ros_control is None and use_ros: - ros_control = UnitreeROSControl( - node_name="unitree_go2", - disable_video_stream=disable_video_stream, - mock_connection=mock_connection, - ) + # Create ROS control interface + ros_control = UnitreeROSControl( + node_name="unitree_go2", + video_provider=video_provider, + disable_video_stream=disable_video_stream, + mock_connection=mock_connection, + ) - # Initialize skill library - if skills is None: - skills = MyUnitreeSkills(robot=self) + # Initialize skill library if not provided + if skill_library is None: + skill_library = MyUnitreeSkills() + # Initialize base robot with connection interface super().__init__( - ros_control=ros_control, + connection_interface=ros_control, output_dir=output_dir, - skill_library=skills, + skill_library=skill_library, + capabilities=robot_capabilities + or [ + RobotCapability.LOCOMOTION, + RobotCapability.VISION, + RobotCapability.AUDIO, + ], + spatial_memory_collection=spatial_memory_collection, new_memory=new_memory, + enable_perception=enable_perception, ) if self.skill_library is not None: @@ -110,7 +116,6 @@ def __init__( self.camera_height = 0.44 # meters # Initialize UnitreeGo2-specific attributes - self.ip = ip self.disposables = CompositeDisposable() self.main_stream_obs = None @@ -118,79 +123,86 @@ def __init__( self.optimal_thread_count = multiprocessing.cpu_count() self.thread_pool_scheduler = ThreadPoolScheduler(self.optimal_thread_count // 2) - if (connection_method == WebRTCConnectionMethod.LocalSTA) and (ip is None): - raise ValueError("IP address is required for LocalSTA connection") - - # Choose data provider based on configuration - if use_ros and not disable_video_stream: - # Use ROS video provider from ROSControl - self.video_stream = self.ros_control.video_provider - elif use_webrtc and not disable_video_stream: - # Use WebRTC ONLY video provider - self.video_stream = UnitreeVideoProvider( - dev_name="UnitreeGo2", - connection_method=connection_method, - serial_number=serial_number, - ip=self.ip if connection_method == WebRTCConnectionMethod.LocalSTA else None, - ) - else: - self.video_stream = None - # Initialize visual servoing if enabled - if self.video_stream is not None: - self.video_stream_ros = self.get_ros_video_stream(fps=8) - self.person_tracker = PersonTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - self.object_tracker = ObjectTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - person_tracking_stream = self.person_tracker.create_stream(self.video_stream_ros) - object_tracking_stream = self.object_tracker.create_stream(self.video_stream_ros) - - self.person_tracking_stream = person_tracking_stream - self.object_tracking_stream = object_tracking_stream + if not disable_video_stream: + self.video_stream_ros = self.get_video_stream(fps=8) + if enable_perception: + self.person_tracker = PersonTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + self.object_tracker = ObjectTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + person_tracking_stream = self.person_tracker.create_stream(self.video_stream_ros) + object_tracking_stream = self.object_tracker.create_stream(self.video_stream_ros) + + self.person_tracking_stream = person_tracking_stream + self.object_tracking_stream = object_tracking_stream + else: + # Video stream is available but perception tracking is disabled + self.person_tracker = None + self.object_tracker = None + self.person_tracking_stream = None + self.object_tracking_stream = None + else: + # Video stream is disabled + self.video_stream_ros = None + self.person_tracker = None + self.object_tracker = None + self.person_tracking_stream = None + self.object_tracking_stream = None # Initialize the local planner and create BEV visualization stream - self.local_planner = VFHPurePursuitPlanner( - get_costmap=self.ros_control.topic_latest("/local_costmap/costmap", Costmap), - transform=self.ros_control, - move_vel_control=self.ros_control.move_vel_control, - robot_width=0.36, # Unitree Go2 width in meters - robot_length=0.6, # Unitree Go2 length in meters - max_linear_vel=0.5, - lookahead_distance=2.0, - visualization_size=500, # 500x500 pixel visualization - ) + # Note: These features require ROS-specific methods that may not be available on all connection interfaces + if hasattr(self.connection_interface, "topic_latest") and hasattr( + self.connection_interface, "transform_euler" + ): + self.local_planner = VFHPurePursuitPlanner( + get_costmap=self.connection_interface.topic_latest( + "/local_costmap/costmap", Costmap + ), + transform=self.connection_interface, + move_vel_control=self.connection_interface.move_vel_control, + robot_width=0.36, # Unitree Go2 width in meters + robot_length=0.6, # Unitree Go2 length in meters + max_linear_vel=0.5, + lookahead_distance=2.0, + visualization_size=500, # 500x500 pixel visualization + ) - self.global_planner = AstarPlanner( - conservativism=20, # how close to obstacles robot is allowed to path plan - set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( - self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event - ), - get_costmap=self.ros_control.topic_latest("map", Costmap), - get_robot_pos=lambda: self.ros_control.transform_euler_pos("base_link"), - ) + self.global_planner = AstarPlanner( + conservativism=20, # how close to obstacles robot is allowed to path plan + set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( + self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event + ), + get_costmap=self.connection_interface.topic_latest("map", Costmap), + get_robot_pos=lambda: self.connection_interface.transform_euler_pos("base_link"), + ) - # Create the visualization stream at 5Hz - self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) + # Create the visualization stream at 5Hz + self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) + else: + self.local_planner = None + self.global_planner = None + self.local_planner_viz_stream = None def get_skills(self) -> Optional[SkillLibrary]: return self.skill_library - def get_pose(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]: + def get_pose(self) -> dict: """ Get the current pose (position and rotation) of the robot in the map frame. Returns: - Tuple containing: - - position: Tuple[float, float, float] (x, y, z) - - rotation: Tuple[float, float, float] (roll, pitch, yaw) in radians + Dictionary containing: + - position: Vector (x, y, z) + - rotation: Vector (roll, pitch, yaw) in radians """ - [position, rotation] = self.ros_control.transform_euler("base_link") - - return position, rotation + position_tuple, orientation_tuple = self.connection_interface.get_pose_odom_transform() + position = Vector(position_tuple[0], position_tuple[1], position_tuple[2]) + rotation = Vector(orientation_tuple[0], orientation_tuple[1], orientation_tuple[2]) + return {"position": position, "rotation": rotation} diff --git a/dimos/robot/unitree/unitree_skills.py b/dimos/robot/unitree/unitree_skills.py index 38adc399c8..d191753294 100644 --- a/dimos/robot/unitree/unitree_skills.py +++ b/dimos/robot/unitree/unitree_skills.py @@ -26,6 +26,7 @@ from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary from dimos.types.constants import Colors +from dimos.types.vector import Vector # Module-level constant for Unitree ROS control definitions UNITREE_ROS_CONTROLS: List[Tuple[str, int, str]] = [ @@ -84,7 +85,7 @@ "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", ), ( - "Hello", + "ShakeHand", 1016, "Performs a greeting action, which could involve a wave or other friendly gesture.", ), @@ -270,7 +271,7 @@ class Move(AbstractRobotSkill): def __call__(self): super().__call__() - return self._robot.move_vel(x=self.x, y=self.y, yaw=self.yaw, duration=self.duration) + return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) class Reverse(AbstractRobotSkill): """Reverse the robot using direct velocity commands. Determine duration required based on user distance instructions.""" @@ -282,8 +283,8 @@ class Reverse(AbstractRobotSkill): def __call__(self): super().__call__() - # Use move_vel with negative x for backward movement - return self._robot.move_vel(x=-self.x, y=self.y, yaw=self.yaw, duration=self.duration) + # Use move with negative x for backward movement + return self._robot.move(Vector(-self.x, self.y, self.yaw), duration=self.duration) class SpinLeft(AbstractRobotSkill): """Spin the robot left using degree commands.""" diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index 16697c4378..a847b7f2df 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -29,13 +29,13 @@ from reactivex import operators as ops from aiortc import MediaStreamTrack from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg -from dimos.robot.abstract_robot import AbstractRobot - +from dimos.robot.connection_interface import ConnectionInterface +import time VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] -class WebRTCRobot(AbstractRobot): +class WebRTCRobot(ConnectionInterface): def __init__(self, ip: str, mode: str = "ai"): self.ip = ip self.mode = mode @@ -74,18 +74,55 @@ def start_background_loop(): self.thread.start() self.connection_ready.wait() - def move(self, vector: Vector): + def move(self, velocity: Vector, duration: float = 0.0) -> bool: + """Send movement command to the robot using velocity commands. + + Args: + velocity: Velocity vector [x, y, yaw] where: + x: Forward/backward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds). If 0, command is continuous + + Returns: + bool: True if command was sent successfully + """ + x, y, yaw = velocity.x, velocity.y, velocity.z + + # WebRTC coordinate mapping: # x - Positive right, negative left # y - positive forward, negative backwards - # z - Positive rotate right, negative rotate left + # yaw - Positive rotate right, negative rotate left async def async_move(): self.conn.datachannel.pub_sub.publish_without_callback( RTC_TOPIC["WIRELESS_CONTROLLER"], - data={"lx": vector.x, "ly": vector.y, "rx": vector.z, "ry": 0}, + data={"lx": y, "ly": x, "rx": -yaw, "ry": 0}, ) - future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) - return future.result() + async def async_move_duration(): + """Send movement commands continuously for the specified duration.""" + start_time = time.time() + sleep_time = 0.01 + + while time.time() - start_time < duration: + await async_move() + await asyncio.sleep(sleep_time) + + try: + if duration > 0: + # Send continuous move commands for the duration + future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) + future.result() + # Stop after duration + self.stop() + else: + # Single command for continuous movement + future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) + future.result() + return True + except Exception as e: + print(f"Failed to send movement command: {e}") + return False # Generic conversion of unitree subscription to Subject (used for all subs) def unitree_sub_stream(self, topic_name: str): @@ -129,7 +166,10 @@ def standup_ai(self): return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) def standup_normal(self): - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + time.sleep(0.5) + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) + return True def standup(self): if self.mode == "ai": @@ -178,7 +218,7 @@ def stop(cb): self.conn.video.track_callbacks.remove(accept_track) self.conn.video.switchVideoChannel(False) - return backpressure(subject.pipe(ops.finally_action(stop))) + return subject.pipe(ops.finally_action(stop)) def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: """Get the video stream from the robot's camera. @@ -202,19 +242,28 @@ def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: print(f"Error getting video stream: {e}") return None - def stop(self): + def stop(self) -> bool: + """Stop the robot's movement. + + Returns: + bool: True if stop command was sent successfully + """ + return self.move(Vector(0.0, 0.0, 0.0)) + + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" if hasattr(self, "task") and self.task: self.task.cancel() if hasattr(self, "conn"): - async def disconnect(): + async def async_disconnect(): try: await self.conn.disconnect() except: pass if hasattr(self, "loop") and self.loop.is_running(): - asyncio.run_coroutine_threadsafe(disconnect(), self.loop) + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) if hasattr(self, "loop") and self.loop.is_running(): self.loop.call_soon_threadsafe(self.loop.stop) diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index 29ccab4555..251dd208db 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -156,9 +156,15 @@ def get_color(color_choice): self.estimate_normals() return self - def costmap(self) -> Costmap: + def costmap(self, voxel_size: float = 0.2) -> Costmap: if not self._costmap: - grid, origin_xy = pointcloud_to_costmap(self.pointcloud, resolution=self.resolution) + down_sampled_pointcloud = self.pointcloud.voxel_down_sample(voxel_size=voxel_size) + inflate_radius_m = 1.0 * voxel_size if voxel_size > self.resolution else 0.0 + grid, origin_xy = pointcloud_to_costmap( + down_sampled_pointcloud, + resolution=self.resolution, + inflate_radius_m=inflate_radius_m, + ) self._costmap = Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.resolution) return self._costmap diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py index a2df73184f..389223e4a5 100644 --- a/dimos/robot/unitree_webrtc/type/odometry.py +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -22,7 +22,9 @@ to_human_readable, ) from dimos.types.position import Position -from dimos.types.vector import VectorLike +from dimos.types.vector import VectorLike, Vector +from dimos.robot.unitree_webrtc.type.timeseries import Timestamped, to_human_readable +from scipy.spatial.transform import Rotation as R raw_odometry_msg_sample = { "type": "msg", @@ -81,15 +83,6 @@ def __init__(self, pos: VectorLike, rot: VectorLike, ts: EpochLike): super().__init__(pos, rot) self.ts = to_datetime(ts) if ts else datetime.now() - @staticmethod - def quaternion_to_yaw(x: float, y: float, z: float, w: float) -> float: - """Convert quaternion to yaw angle (rotation around z-axis) in radians.""" - # Calculate yaw (rotation around z-axis) - siny_cosp = 2 * (w * z + x * y) - cosy_cosp = 1 - 2 * (y * y + z * z) - yaw = math.atan2(siny_cosp, cosy_cosp) - return yaw - @classmethod def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": pose = msg["data"]["pose"] @@ -99,16 +92,20 @@ def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": # Extract position pos = [position.get("x"), position.get("y"), position.get("z")] - # Extract quaternion components - qx = orientation.get("x") - qy = orientation.get("y") - qz = orientation.get("z") - qw = orientation.get("w") + quat = [ + orientation.get("x"), + orientation.get("y"), + orientation.get("z"), + orientation.get("w"), + ] + + # Check if quaternion has zero norm (invalid) + quat_norm = sum(x**2 for x in quat) ** 0.5 + if quat_norm < 1e-8: + quat = [0.0, 0.0, 0.0, 1.0] - # Convert quaternion to yaw angle and store in rot.z - # Keep x,y as quaternion components for now, but z becomes the actual yaw angle - yaw_radians = cls.quaternion_to_yaw(qx, qy, qz, qw) - rot = [qx, qy, yaw_radians] + rotation = R.from_quat(quat) + rot = Vector(rotation.as_euler("xyz", degrees=False)) return cls(pos, rot, msg["data"]["header"]["stamp"]) diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 0c9a585ccb..94676bfffc 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -12,36 +12,132 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.types.vector import Vector -from typing import Union, Optional +from typing import Union, Optional, List +import time +import numpy as np +import os +from dimos.robot.robot import Robot from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.connection import WebRTCRobot from dimos.robot.global_planner.planner import AstarPlanner from dimos.utils.reactive import getter_streaming -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import AbstractSkill, SkillLibrary -import os +from dimos.skills.skills import AbstractRobotSkill, SkillLibrary from go2_webrtc_driver.constants import VUI_COLOR -from dimos.robot.local_planner import navigate_path_local +from go2_webrtc_driver.webrtc_driver import WebRTCConnectionMethod +from dimos.perception.person_tracker import PersonTrackingStream +from dimos.perception.object_tracker import ObjectTrackingStream +from dimos.robot.local_planner.local_planner import navigate_path_local +from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner +from dimos.types.robot_capabilities import RobotCapability +from dimos.types.vector import Vector +from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills +from dimos.robot.frontier_exploration.qwen_frontier_predictor import QwenFrontierPredictor +from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( + WavefrontFrontierExplorer, +) +import threading class Color(VUI_COLOR): ... -class UnitreeGo2(WebRTCRobot): +class UnitreeGo2(Robot): def __init__( self, ip: str, mode: str = "ai", - skills: Optional[Union[MyUnitreeSkills, AbstractSkill]] = None, - skill_library: SkillLibrary = None, output_dir: str = os.path.join(os.getcwd(), "assets", "output"), + skill_library: SkillLibrary = None, + robot_capabilities: List[RobotCapability] = None, + spatial_memory_collection: str = "spatial_memory", + new_memory: bool = True, + enable_perception: bool = True, ): - super().__init__(ip=ip, mode=mode) + """Initialize Unitree Go2 robot with WebRTC control interface. + + Args: + ip: IP address of the robot + mode: Robot mode (ai, etc.) + output_dir: Directory for output files + skill_library: Skill library instance + robot_capabilities: List of robot capabilities + spatial_memory_collection: Collection name for spatial memory + new_memory: Whether to create new spatial memory + enable_perception: Whether to enable perception streams and spatial memory + """ + # Create WebRTC connection interface + self.webrtc_connection = WebRTCRobot( + ip=ip, + mode=mode, + ) + + print("standing up") + self.webrtc_connection.standup() + + # Initialize WebRTC-specific features + self.lidar_stream = self.webrtc_connection.lidar_stream() + self.odom = getter_streaming(self.webrtc_connection.odom_stream()) + self.map = Map(voxel_size=0.2) + self.map_stream = self.map.consume(self.lidar_stream) + self.lidar_message = getter_streaming(self.lidar_stream) + + if skill_library is None: + skill_library = MyUnitreeSkills() + + # Initialize base robot with connection interface + super().__init__( + connection_interface=self.webrtc_connection, + output_dir=output_dir, + skill_library=skill_library, + capabilities=robot_capabilities + or [ + RobotCapability.LOCOMOTION, + RobotCapability.VISION, + RobotCapability.AUDIO, + ], + spatial_memory_collection=spatial_memory_collection, + new_memory=new_memory, + enable_perception=enable_perception, + ) - self.odom = getter_streaming(self.odom_stream()) - self.map = Map() - self.map_stream = self.map.consume(self.lidar_stream()) + if self.skill_library is not None: + for skill in self.skill_library: + if isinstance(skill, AbstractRobotSkill): + self.skill_library.create_instance(skill.__name__, robot=self) + if isinstance(self.skill_library, MyUnitreeSkills): + self.skill_library._robot = self + self.skill_library.init() + self.skill_library.initialize_skills() + + # Camera configuration + self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] + self.camera_pitch = np.deg2rad(0) # negative for downward pitch + self.camera_height = 0.44 # meters + + # Initialize visual servoing using connection interface + video_stream = self.get_video_stream() + if video_stream is not None and enable_perception: + self.person_tracker = PersonTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + self.object_tracker = ObjectTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + person_tracking_stream = self.person_tracker.create_stream(video_stream) + object_tracking_stream = self.object_tracker.create_stream(video_stream) + + self.person_tracking_stream = person_tracking_stream + self.object_tracking_stream = object_tracking_stream + else: + # Video stream not available or perception disabled + self.person_tracker = None + self.object_tracker = None + self.person_tracking_stream = None + self.object_tracking_stream = None self.global_planner = AstarPlanner( set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( @@ -51,65 +147,78 @@ def __init__( get_robot_pos=lambda: self.odom().pos, ) - # # Initialize skills - # if skills is None: - # skills = MyUnitreeSkills(robot=self) - - # self.skill_library = skills if skills else SkillLibrary() - - # if self.skill_library is not None: - # for skill in self.skill_library: - # if isinstance(skill, AbstractRobotSkill): - # self.skill_library.create_instance(skill.__name__, robot=self) - # if isinstance(self.skill_library, MyUnitreeSkills): - # self.skill_library._robot = self - # self.skill_library.init() - # self.skill_library.initialize_skills() - - # # Camera stuff - # self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] - # self.camera_pitch = np.deg2rad(0) # negative for downward pitch - # self.camera_height = 0.44 # meters - - # os.makedirs(self.output_dir, exist_ok=True) - - # # Initialize visual servoing if enabled - # if self.get_video_stream() is not None: - # self.person_tracker = PersonTrackingStream( - # camera_intrinsics=self.camera_intrinsics, - # camera_pitch=self.camera_pitch, - # camera_height=self.camera_height, - # ) - # self.object_tracker = ObjectTrackingStream( - # camera_intrinsics=self.camera_intrinsics, - # camera_pitch=self.camera_pitch, - # camera_height=self.camera_height, - # ) - # person_tracking_stream = self.person_tracker.create_stream(self.get_video_stream()) - # object_tracking_stream = self.object_tracker.create_stream(self.get_video_stream()) - - # self.person_tracking_stream = person_tracking_stream - # self.object_tracking_stream = object_tracking_stream - - # Initialize the local planner and create BEV visualization stream - # self.local_planner = VFHPurePursuitPlanner( - # robot=self, - # robot_width=0.36, # Unitree Go2 width in meters - # robot_length=0.6, # Unitree Go2 length in meters - # max_linear_vel=0.5, - # lookahead_distance=0.6, - # visualization_size=500, # 500x500 pixel visualization - # ) + # Initialize the local planner using WebRTC-specific methods + self.local_planner = VFHPurePursuitPlanner( + get_costmap=lambda: self.lidar_message().costmap(), + get_robot_pose=lambda: self.odom(), + move=self.move, # Use the robot's move method directly + robot_width=0.36, # Unitree Go2 width in meters + robot_length=0.6, # Unitree Go2 length in meters + max_linear_vel=0.7, + max_angular_vel=0.65, + lookahead_distance=1.5, + visualization_size=500, # 500x500 pixel visualization + global_planner_plan=self.global_planner.plan, + ) + + # Initialize frontier exploration + self.frontier_explorer = WavefrontFrontierExplorer( + set_goal=self.global_planner.set_goal, + get_costmap=lambda: self.map.costmap, + get_robot_pos=lambda: self.odom().pos, + ) # Create the visualization stream at 5Hz - # self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) + self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) + + def get_pose(self) -> dict: + """ + Get the current pose (position and rotation) of the robot in the map frame. + + Returns: + Dictionary containing: + - position: Vector (x, y, z) + - rotation: Vector (roll, pitch, yaw) in radians + """ + position = Vector(self.odom().pos.x, self.odom().pos.y, self.odom().pos.z) + orientation = Vector(self.odom().rot.x, self.odom().rot.y, self.odom().rot.z) + return {"position": position, "rotation": orientation} + + def explore(self, stop_event: Optional[threading.Event] = None) -> bool: + """ + Start autonomous frontier exploration. + + Args: + stop_event: Optional threading.Event to signal when exploration should stop + + Returns: + bool: True if exploration completed successfully, False if stopped or failed + """ + return self.frontier_explorer.explore(stop_event=stop_event) + + def odom_stream(self): + """Get the odometry stream from the robot. + + Returns: + Observable stream of robot odometry data containing position and orientation. + """ + return self.webrtc_connection.odom_stream() + + def standup(self): + """Make the robot stand up. + + Uses AI mode standup if robot is in AI mode, otherwise uses normal standup. + """ + return self.webrtc_connection.standup() - def move(self, vector: Vector): - super().move(vector) + def liedown(self): + """Make the robot lie down. - def get_skills(self) -> Optional[SkillLibrary]: - return self.skill_library + Commands the robot to lie down on the ground. + """ + return self.webrtc_connection.liedown() @property def costmap(self): + """Access to the costmap for navigation.""" return self.map.costmap diff --git a/dimos/robot/unitree_webrtc/unitree_skills.py b/dimos/robot/unitree_webrtc/unitree_skills.py new file mode 100644 index 0000000000..63bacb5121 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_skills.py @@ -0,0 +1,285 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union +import time +from pydantic import Field + +if TYPE_CHECKING: + from dimos.robot.robot import Robot, MockRobot +else: + Robot = "Robot" + MockRobot = "MockRobot" + +from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary +from dimos.types.constants import Colors +from dimos.types.vector import Vector +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD + +# Module-level constant for Unitree WebRTC control definitions +UNITREE_WEBRTC_CONTROLS: List[Tuple[str, int, str]] = [ + ("Damp", 1001, "Lowers the robot to the ground fully."), + ( + "BalanceStand", + 1002, + "Activates a mode that maintains the robot in a balanced standing position.", + ), + ( + "StandUp", + 1004, + "Commands the robot to transition from a sitting or prone position to a standing posture.", + ), + ( + "StandDown", + 1005, + "Instructs the robot to move from a standing position to a sitting or prone posture.", + ), + ( + "RecoveryStand", + 1006, + "Recovers the robot to a state from which it can take more commands. Useful to run after multiple dynamic commands like front flips.", + ), + ( + "Euler", + 1007, + "Adjusts the robot's orientation using Euler angles, providing precise control over its rotation.", + ), + # ("Move", 1008, "Move the robot using velocity commands."), # Handled separately + ("Sit", 1009, "Commands the robot to sit down from a standing or moving stance."), + ( + "RiseSit", + 1010, + "Commands the robot to rise back to a standing position from a sitting posture.", + ), + ( + "SwitchGait", + 1011, + "Switches the robot's walking pattern or style dynamically, suitable for different terrains or speeds.", + ), + ("Trigger", 1012, "Triggers a specific action or custom routine programmed into the robot."), + ( + "BodyHeight", + 1013, + "Adjusts the height of the robot's body from the ground, useful for navigating various obstacles.", + ), + ( + "FootRaiseHeight", + 1014, + "Controls how high the robot lifts its feet during movement, which can be adjusted for different surfaces.", + ), + ( + "SpeedLevel", + 1015, + "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", + ), + ( + "Hello", + 1016, + "Performs a greeting action, which could involve a wave or other friendly gesture.", + ), + ("Stretch", 1017, "Engages the robot in a stretching routine."), + ( + "TrajectoryFollow", + 1018, + "Directs the robot to follow a predefined trajectory, which could involve complex paths or maneuvers.", + ), + ( + "ContinuousGait", + 1019, + "Enables a mode for continuous walking or running, ideal for long-distance travel.", + ), + ("Content", 1020, "To display or trigger when the robot is happy."), + ("Wallow", 1021, "The robot falls onto its back and rolls around."), + ( + "Dance1", + 1022, + "Performs a predefined dance routine 1, programmed for entertainment or demonstration.", + ), + ("Dance2", 1023, "Performs another variant of a predefined dance routine 2."), + ("GetBodyHeight", 1024, "Retrieves the current height of the robot's body from the ground."), + ( + "GetFootRaiseHeight", + 1025, + "Retrieves the current height at which the robot's feet are being raised during movement.", + ), + ( + "GetSpeedLevel", + 1026, + "Retrieves the current speed level setting of the robot.", + ), + ( + "SwitchJoystick", + 1027, + "Switches the robot's control mode to respond to joystick input for manual operation.", + ), + ( + "Pose", + 1028, + "Commands the robot to assume a specific pose or posture as predefined in its programming.", + ), + ("Scrape", 1029, "The robot performs a scraping motion."), + ( + "FrontFlip", + 1030, + "Commands the robot to perform a front flip, showcasing its agility and dynamic movement capabilities.", + ), + ( + "FrontJump", + 1031, + "Instructs the robot to jump forward, demonstrating its explosive movement capabilities.", + ), + ( + "FrontPounce", + 1032, + "Commands the robot to perform a pouncing motion forward.", + ), + ( + "WiggleHips", + 1033, + "The robot performs a hip wiggling motion, often used for entertainment or demonstration purposes.", + ), + ( + "GetState", + 1034, + "Retrieves the current operational state of the robot, including its mode, position, and status.", + ), + ( + "EconomicGait", + 1035, + "Engages a more energy-efficient walking or running mode to conserve battery life.", + ), + ("FingerHeart", 1036, "Performs a finger heart gesture while on its hind legs."), + ( + "Handstand", + 1301, + "Commands the robot to perform a handstand, demonstrating balance and control.", + ), + ( + "CrossStep", + 1302, + "Commands the robot to perform cross-step movements.", + ), + ( + "OnesidedStep", + 1303, + "Commands the robot to perform one-sided step movements.", + ), + ("Bound", 1304, "Commands the robot to perform bounding movements."), + ("MoonWalk", 1305, "Commands the robot to perform a moonwalk motion."), + ("LeftFlip", 1042, "Executes a flip towards the left side."), + ("RightFlip", 1043, "Performs a flip towards the right side."), + ("Backflip", 1044, "Executes a backflip, a complex and dynamic maneuver."), +] + +# region MyUnitreeSkills + + +class MyUnitreeSkills(SkillLibrary): + """My Unitree Skills for WebRTC interface.""" + + def __init__(self, robot: Optional[Robot] = None): + super().__init__() + self._robot: Robot = None + + # Add dynamic skills to this class + dynamic_skills = self.create_skills_live() + self.register_skills(dynamic_skills) + + @classmethod + def register_skills(cls, skill_classes: Union["AbstractSkill", list["AbstractSkill"]]): + """Add multiple skill classes as class attributes. + + Args: + skill_classes: List of skill classes to add + """ + if not isinstance(skill_classes, list): + skill_classes = [skill_classes] + + for skill_class in skill_classes: + # Add to the class as a skill + setattr(cls, skill_class.__name__, skill_class) + + def initialize_skills(self): + for skill_class in self.get_class_skills(): + self.create_instance(skill_class.__name__, robot=self._robot) + + # Refresh the class skills + self.refresh_class_skills() + + def create_skills_live(self) -> List[AbstractRobotSkill]: + # ================================================ + # Procedurally created skills + # ================================================ + class BaseUnitreeSkill(AbstractRobotSkill): + """Base skill for dynamic skill creation.""" + + def __call__(self): + string = f"{Colors.GREEN_PRINT_COLOR}This is a base skill, created for the specific skill: {self._app_id}{Colors.RESET_COLOR}" + print(string) + super().__call__() + if self._app_id is None: + raise RuntimeError( + f"{Colors.RED_PRINT_COLOR}" + f"No App ID provided to {self.__class__.__name__} Skill" + f"{Colors.RESET_COLOR}" + ) + else: + # Use WebRTC publish_request interface through the robot's webrtc_connection + result = self._robot.webrtc_connection.publish_request( + RTC_TOPIC["SPORT_MOD"], {"api_id": self._app_id} + ) + string = f"{Colors.GREEN_PRINT_COLOR}{self.__class__.__name__} was successful: id={self._app_id}{Colors.RESET_COLOR}" + print(string) + return string + + skills_classes = [] + for name, app_id, description in UNITREE_WEBRTC_CONTROLS: + if name not in ["Reverse", "Spin"]: # Exclude reverse and spin skills + skill_class = type( + name, # Name of the class + (BaseUnitreeSkill,), # Base classes + {"__doc__": description, "_app_id": app_id}, + ) + skills_classes.append(skill_class) + + return skills_classes + + # region Class-based Skills + + class Move(AbstractRobotSkill): + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" + + x: float = Field(..., description="Forward velocity (m/s).") + y: float = Field(default=0.0, description="Left/right velocity (m/s)") + yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") + duration: float = Field(default=0.0, description="How long to move (seconds).") + + def execute(self): + return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) + + class Wait(AbstractSkill): + """Wait for a specified amount of time.""" + + seconds: float = Field(..., description="Seconds to wait") + + def __call__(self): + time.sleep(self.seconds) + return f"Wait completed with length={self.seconds}s" + + # endregion + + +# endregion diff --git a/dimos/skills/navigation.py b/dimos/skills/navigation.py index adb4d0e980..114e3dacb8 100644 --- a/dimos/skills/navigation.py +++ b/dimos/skills/navigation.py @@ -33,7 +33,7 @@ from dimos.types.robot_location import RobotLocation from dimos.utils.logging_config import setup_logger from dimos.models.qwen.video_query import get_bbox_from_qwen_frame -from dimos.utils.ros_utils import distance_angle_to_goal_xy +from dimos.utils.transform_utils import distance_angle_to_goal_xy from dimos.robot.local_planner.local_planner import navigate_to_goal_local logger = setup_logger("dimos.skills.semantic_map_skills") @@ -73,10 +73,6 @@ class NavigateWithText(AbstractRobotSkill): limit: int = Field(1, description="Maximum number of results to return") distance: float = Field(1.0, description="Desired distance to maintain from object in meters") timeout: float = Field(40.0, description="Maximum time to spend navigating in seconds") - similarity_threshold: float = Field( - 0.25, - description="Minimum similarity score required for semantic map results to be considered valid", - ) def __init__(self, robot=None, **data): """ @@ -92,6 +88,7 @@ def __init__(self, robot=None, **data): self._scheduler = get_scheduler() # Use the shared DiMOS thread pool self._navigation_disposable = None # Disposable returned by scheduler.schedule() self._tracking_subscriber = None # For object tracking + self._similarity_threshold = 0.25 def _navigate_to_object(self): """ @@ -111,8 +108,8 @@ def _navigate_to_object(self): # Try to get a bounding box from Qwen - only try once bbox = None try: - # Capture a single frame from the video stream - frame = self._robot.get_ros_video_stream().pipe(ops.take(1)).run() + # Use the robot's existing video stream instead of creating a new one + frame = self._robot.get_video_stream().pipe(ops.take(1)).run() # Use the frame-based function bbox, object_size = get_bbox_from_qwen_frame(frame, object_name=self.query) except Exception as e: @@ -169,12 +166,10 @@ def _navigate_to_object(self): break else: - logger.warning(f"No valid target tracking data found. target: {target}") + logger.warning("No valid target tracking data found.") else: - logger.warning( - f"No valid target tracking data found. tracking_data: {tracking_data}" - ) + logger.warning("No valid target tracking data found.") time.sleep(0.1) @@ -223,7 +218,6 @@ def _navigate_to_object(self): return {"success": False, "failure_reason": "Code Error", "error": f"Error: {e}"} finally: # Clean up - self._robot.ros_control.stop() self._robot.object_tracker.cleanup() def _navigate_using_semantic_map(self): @@ -277,9 +271,9 @@ def _navigate_using_semantic_map(self): ) # Check if similarity is below the threshold - if similarity < self.similarity_threshold: + if similarity < self._similarity_threshold: logger.warning( - f"Match found but similarity score ({similarity:.4f}) is below threshold ({self.similarity_threshold})" + f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})" ) return { "success": False, @@ -287,7 +281,7 @@ def _navigate_using_semantic_map(self): "position": (pos_x, pos_y), "rotation": theta, "similarity": similarity, - "error": f"Match found but similarity score ({similarity:.4f}) is below threshold ({self.similarity_threshold})", + "error": f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})", } # Reset the stop event before starting navigation @@ -347,6 +341,7 @@ def run_navigation(): "query": self.query, "error": "No valid position data found in semantic map", } + except Exception as e: logger.error(f"Error in semantic map navigation: {e}") return {"success": False, "error": f"Semantic map error: {e}"} @@ -416,9 +411,6 @@ def stop(self): self._spatial_memory.cleanup() self._spatial_memory = None - # Stop robot motion - self._robot.ros_control.stop() - return "Navigate skill stopped successfully." @@ -465,17 +457,21 @@ def __call__(self): try: # Get the current pose using the robot's get_pose method - position, rotation = self._robot.get_pose() + pose_data = self._robot.get_pose() + + # Extract position and rotation from the new dictionary format + position = pose_data["position"] + rotation = pose_data["rotation"] # Format the response result = { "success": True, "position": { - "x": position[0], - "y": position[1], - "z": position[2] if len(position) > 2 else 0.0, + "x": position.x, + "y": position.y, + "z": position.z, }, - "rotation": {"roll": rotation[0], "pitch": rotation[1], "yaw": rotation[2]}, + "rotation": {"roll": rotation.x, "pitch": rotation.y, "yaw": rotation.z}, } # If location_name is provided, remember this location @@ -485,7 +481,9 @@ def __call__(self): # Create a RobotLocation object location = RobotLocation( - name=self.location_name, position=position, rotation=rotation + name=self.location_name, + position=(position.x, position.y, position.z), + rotation=(rotation.x, rotation.y, rotation.z), ) # Add to spatial memory @@ -604,3 +602,96 @@ def stop(self): self.unregister_as_running("NavigateToGoal", skill_library) self._stop_event.set() return "Navigation stopped" + + +class Explore(AbstractRobotSkill): + """ + A skill that performs autonomous frontier exploration. + + This skill continuously finds and navigates to unknown frontiers in the environment + until no more frontiers are found or the exploration is stopped. + + Don't save GetPose locations when frontier exploring. Don't call any other skills except stop skill when needed. + """ + + timeout: float = Field(60.0, description="Maximum time (in seconds) allowed for exploration") + + def __init__(self, robot=None, **data): + """ + Initialize the Explore skill. + + Args: + robot: The robot instance + **data: Additional data for configuration + """ + super().__init__(robot=robot, **data) + self._stop_event = threading.Event() + + def __call__(self): + """ + Start autonomous frontier exploration. + + Returns: + A dictionary containing the result of the exploration + """ + super().__call__() + + if self._robot is None: + error_msg = "No robot instance provided to Explore skill" + logger.error(error_msg) + return {"success": False, "error": error_msg} + + # Reset stop event to make sure we don't immediately abort + self._stop_event.clear() + + skill_library = self._robot.get_skills() + self.register_as_running("Explore", skill_library) + + logger.info("Starting autonomous frontier exploration") + + try: + # Start exploration using the robot's explore method + result = self._robot.explore(stop_event=self._stop_event) + + if result: + logger.info("Exploration completed successfully - no more frontiers found") + return { + "success": True, + "message": "Exploration completed - all accessible areas explored", + } + else: + if self._stop_event.is_set(): + logger.info("Exploration stopped by user") + return { + "success": False, + "message": "Exploration stopped by user", + } + else: + logger.warning("Exploration did not complete successfully") + return { + "success": False, + "message": "Exploration failed or was interrupted", + } + + except Exception as e: + error_msg = f"Error during exploration: {e}" + logger.error(error_msg) + return { + "success": False, + "error": error_msg, + } + finally: + self.stop() + + def stop(self): + """ + Stop the exploration. + + Returns: + A message indicating that the exploration was stopped + """ + logger.info("Stopping Explore") + skill_library = self._robot.get_skills() + self.unregister_as_running("Explore", skill_library) + self._stop_event.set() + return "Exploration stopped" diff --git a/dimos/skills/observe.py b/dimos/skills/observe.py new file mode 100644 index 0000000000..844df11805 --- /dev/null +++ b/dimos/skills/observe.py @@ -0,0 +1,194 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Observer skill for an agent. + +This module provides a skill that sends a single image from any +Robot Data Stream to the Qwen VLM for inference and adds the response +to the agent's conversation history. +""" + +import time +from typing import Optional +import base64 +import cv2 +import numpy as np +import reactivex as rx +from reactivex import operators as ops +from pydantic import Field + +from dimos.skills.skills import AbstractRobotSkill +from dimos.agents.agent import LLMAgent +from dimos.models.qwen.video_query import query_single_frame +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.skills.observe") + + +class Observe(AbstractRobotSkill): + """ + A skill that captures a single frame from a Robot Video Stream, sends it to a VLM, + and adds the response to the agent's conversation history. + + This skill is used for visual reasoning, spatial understanding, or any queries involving visual information that require critical thinking. + """ + + query_text: str = Field( + "What do you see in this image? Describe the environment in detail.", + description="Query text to send to the VLM model with the image", + ) + + def __init__(self, robot=None, agent: Optional[LLMAgent] = None, **data): + """ + Initialize the Observe skill. + + Args: + robot: The robot instance + agent: The agent to store results in + **data: Additional data for configuration + """ + super().__init__(robot=robot, **data) + self._agent = agent + self._model_name = "qwen2.5-vl-72b-instruct" + + # Get the video stream from the robot + self._video_stream = self._robot.video_stream + if self._video_stream is None: + logger.error("Failed to get video stream from robot") + + def __call__(self): + """ + Capture a single frame, process it with Qwen, and add the result to conversation history. + + Returns: + A message indicating the observation result + """ + super().__call__() + + if self._agent is None: + error_msg = "No agent provided to Observe skill" + logger.error(error_msg) + return error_msg + + if self._robot is None: + error_msg = "No robot instance provided to Observe skill" + logger.error(error_msg) + return error_msg + + if self._video_stream is None: + error_msg = "No video stream available" + logger.error(error_msg) + return error_msg + + try: + logger.info("Capturing frame for Qwen observation") + + # Get a single frame from the video stream + frame = self._get_frame_from_stream() + + if frame is None: + error_msg = "Failed to capture frame from video stream" + logger.error(error_msg) + return error_msg + + # Process the frame with Qwen + response = self._process_frame_with_qwen(frame) + + # Add the response to the conversation history + # self._agent.append_to_history( + # f"Observation: {response}", + # ) + response = self._agent.run_observable_query(f"Observation: {response}") + + logger.info(f"Added Qwen observation to conversation history") + return f"Observation complete: {response[:100]}..." + + except Exception as e: + error_msg = f"Error in Observe skill: {e}" + logger.error(error_msg) + return error_msg + + def _get_frame_from_stream(self): + """ + Get a single frame from the video stream. + + Returns: + A single frame from the video stream, or None if no frame is available + """ + frame = None + frame_subject = rx.subject.Subject() + + subscription = self._video_stream.pipe( + ops.take(1) # Take just one frame + ).subscribe( + on_next=lambda x: frame_subject.on_next(x), + on_error=lambda e: logger.error(f"Error getting frame: {e}"), + ) + + # Wait up to 5 seconds for a frame + timeout = 5.0 + start_time = time.time() + + def on_frame(f): + nonlocal frame + frame = f + + frame_subject.subscribe(on_frame) + + while frame is None and time.time() - start_time < timeout: + time.sleep(0.1) + + subscription.dispose() + return frame + + def _process_frame_with_qwen(self, frame): + """ + Process a frame with the Qwen model using query_single_frame. + + Args: + frame: The video frame to process (numpy array) + + Returns: + The response from Qwen + """ + logger.info(f"Processing frame with Qwen model: {self._model_name}") + + try: + # Convert numpy array to PIL Image if needed + from PIL import Image + + if isinstance(frame, np.ndarray): + # OpenCV uses BGR, PIL uses RGB + if frame.shape[-1] == 3: # Check if it has color channels + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + else: + pil_image = Image.fromarray(frame) + else: + pil_image = frame + + # Query Qwen with the frame (direct function call) + response = query_single_frame( + pil_image, + self.query_text, + model_name=self._model_name, + ) + + logger.info(f"Qwen response received: {response[:100]}...") + return response + + except Exception as e: + logger.error(f"Error processing frame with Qwen: {e}") + raise diff --git a/dimos/skills/observe_stream.py b/dimos/skills/observe_stream.py index 4449de2995..7b4e08874e 100644 --- a/dimos/skills/observe_stream.py +++ b/dimos/skills/observe_stream.py @@ -24,12 +24,15 @@ from typing import Optional import base64 import cv2 +import numpy as np import reactivex as rx from reactivex import operators as ops from pydantic import Field +from PIL import Image from dimos.skills.skills import AbstractRobotSkill from dimos.agents.agent import LLMAgent +from dimos.models.qwen.video_query import query_single_frame from dimos.utils.threadpool import get_scheduler from dimos.utils.logging_config import setup_logger @@ -74,7 +77,7 @@ def __init__(self, robot=None, agent: Optional[LLMAgent] = None, video_stream=No # Get the video stream # TODO: Use the video stream provided in the constructor for dynamic video_stream selection by the agent - self._video_stream = self._robot.get_ros_video_stream() + self._video_stream = self._robot.video_stream if self._video_stream is None: logger.error("Failed to get video stream from robot") return @@ -189,32 +192,41 @@ def on_frame(f): def _process_frame(self, frame): """ - Process a frame with the Claude agent. + Process a frame with the Qwen VLM and add the response to conversation history. Args: frame: The video frame to process """ - logger.info("Processing frame with Claude agent") + logger.info("Processing frame with Qwen VLM") try: - _, buffer = cv2.imencode(".jpg", frame) - base64_image = base64.b64encode(buffer).decode("utf-8") - - observable = self._agent.run_observable_query( - f"{self.query_text}\n\nHere is the current camera view from the robot:", - base64_image=base64_image, - ) - - # Simple subscription to make sure the query executes - # The actual response content isn't important - observable.subscribe( - on_next=lambda x: logger.info(f"Got response from _observable_query: {x}"), - on_error=lambda e: logger.error(f"Error: {e}"), - on_completed=lambda: logger.info("ObserveStream query completed"), - ) + # Convert frame to PIL Image format + if isinstance(frame, np.ndarray): + # OpenCV uses BGR, PIL uses RGB + if frame.shape[-1] == 3: # Check if it has color channels + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + else: + pil_image = Image.fromarray(frame) + else: + pil_image = frame + + # Use Qwen to process the frame + model_name = "qwen2.5-vl-72b-instruct" # Using the most capable model + response = query_single_frame(pil_image, self.query_text, model_name=model_name) + + logger.info(f"Qwen response received: {response[:100]}...") + + # Add the response to the conversation history + # self._agent.append_to_history( + # f"Observation: {response}", + # ) + response = self._agent.run_observable_query(f"Observation: {response}") + + logger.info("Added Qwen observation to conversation history") except Exception as e: - logger.error(f"Error processing frame with agent: {e}") + logger.error(f"Error processing frame with Qwen VLM: {e}") def stop(self): """ diff --git a/dimos/skills/visual_navigation_skills.py b/dimos/skills/visual_navigation_skills.py index 72696f0427..96e21eb92d 100644 --- a/dimos/skills/visual_navigation_skills.py +++ b/dimos/skills/visual_navigation_skills.py @@ -28,6 +28,7 @@ from dimos.utils.logging_config import setup_logger from dimos.perception.visual_servoing import VisualServoing from pydantic import Field +from dimos.types.vector import Vector logger = setup_logger("dimos.skills.visual_navigation", level=logging.DEBUG) @@ -105,7 +106,7 @@ def __call__(self): x_vel = output.get("linear_vel") z_vel = output.get("angular_vel") logger.debug(f"Following human: x_vel: {x_vel}, z_vel: {z_vel}") - self._robot.ros_control.move_vel_control(x=x_vel, y=0, yaw=z_vel) + self._robot.move(Vector(x_vel, 0, z_vel)) time.sleep(0.05) # If we completed the full timeout duration, consider it success @@ -127,7 +128,6 @@ def __call__(self): if self._visual_servoing: self._visual_servoing.stop_tracking() self._visual_servoing = None - self._robot.ros_control.stop() def stop(self): """ @@ -144,8 +144,5 @@ def stop(self): self._visual_servoing.stop_tracking() self._visual_servoing = None - # Stop the robot - self._robot.ros_control.stop() - return True return False diff --git a/dimos/types/costmap.py b/dimos/types/costmap.py index 1107abf8bf..2d9b1c433e 100644 --- a/dimos/types/costmap.py +++ b/dimos/types/costmap.py @@ -18,8 +18,12 @@ from typing import Optional from scipy import ndimage from dimos.types.ros_polyfill import OccupancyGrid +from scipy.ndimage import binary_dilation from dimos.types.vector import Vector, VectorLike, x, y, to_vector import open3d as o3d +from matplotlib.path import Path +from PIL import Image +import cv2 DTYPE2STR = { np.float32: "f32", @@ -31,6 +35,14 @@ STR2DTYPE = {v: k for k, v in DTYPE2STR.items()} +class CostValues: + """Standard cost values for occupancy grid cells.""" + + FREE = 0 # Free space + UNKNOWN = -1 # Unknown space + OCCUPIED = 100 # Occupied/lethal space + + def encode_ndarray(arr: np.ndarray, compress: bool = False): arr_c = np.ascontiguousarray(arr) payload = arr_c.tobytes() @@ -118,11 +130,17 @@ def save_pickle(self, pickle_path: str): @classmethod def from_pickle(cls, pickle_path: str) -> "Costmap": - """Load costmap from a pickle file containing a ROS OccupancyGrid message.""" + """Load costmap from a pickle file containing either a Costmap object or constructor arguments.""" with open(pickle_path, "rb") as f: data = pickle.load(f) - costmap = cls(*data) - return costmap + + # Check if data is already a Costmap object + if isinstance(data, cls): + return data + else: + # Assume it's constructor arguments + costmap = cls(*data) + return costmap @classmethod def create_empty( @@ -172,14 +190,14 @@ def get_value(self, point: VectorLike) -> Optional[int]: point = self.world_to_grid(point) if 0 <= point.x < self.width and 0 <= point.y < self.height: - return int(self.grid[point.y, point.x]) + return int(self.grid[int(point.y), int(point.x)]) return None def set_value(self, point: VectorLike, value: int = 0) -> bool: point = self.world_to_grid(point) if 0 <= point.x < self.width and 0 <= point.y < self.height: - self.grid[point.y, point.x] = value + self.grid[int(point.y), int(point.x)] = value return value return False @@ -289,6 +307,56 @@ def smudge( origin_theta=self.origin_theta, ) + def subsample(self, subsample_factor: int = 2) -> "Costmap": + """ + Create a subsampled (lower resolution) version of the costmap. + + Args: + subsample_factor: Factor by which to reduce resolution (e.g., 2 = half resolution, 4 = quarter resolution) + + Returns: + New Costmap instance with reduced resolution + """ + if subsample_factor <= 1: + return self # No subsampling needed + + # Calculate new grid dimensions + new_height = self.height // subsample_factor + new_width = self.width // subsample_factor + + # Create new grid by subsampling + subsampled_grid = np.zeros((new_height, new_width), dtype=self.grid.dtype) + + # Sample every subsample_factor-th point + for i in range(new_height): + for j in range(new_width): + orig_i = i * subsample_factor + orig_j = j * subsample_factor + + # Take a small neighborhood and use the most conservative value + # (prioritize occupied > unknown > free for safety) + neighborhood = self.grid[ + orig_i : min(orig_i + subsample_factor, self.height), + orig_j : min(orig_j + subsample_factor, self.width), + ] + + # Priority: Occupied (100) > Unknown (-1) > Free (0) + if np.any(neighborhood == CostValues.OCCUPIED): + subsampled_grid[i, j] = CostValues.OCCUPIED + elif np.any(neighborhood == CostValues.UNKNOWN): + subsampled_grid[i, j] = CostValues.UNKNOWN + else: + subsampled_grid[i, j] = CostValues.FREE + + # Create new costmap with adjusted resolution and origin + new_resolution = self.resolution * subsample_factor + + return Costmap( + grid=subsampled_grid, + resolution=new_resolution, + origin=self.origin, # Origin stays the same + ) + @property def total_cells(self) -> int: return self.width * self.height @@ -338,6 +406,38 @@ def __str__(self) -> str: return " ".join(cell_info) + def costmap_to_image(self, image_path: str) -> None: + """ + Convert costmap to JPEG image with ROS-style coloring. + Free space: light grey, Obstacles: black, Unknown: dark gray + + Args: + image_path: Path to save the JPEG image + """ + # Create image array (height, width, 3 for RGB) + img_array = np.zeros((self.height, self.width, 3), dtype=np.uint8) + + # Apply ROS-style coloring based on costmap values + for i in range(self.height): + for j in range(self.width): + value = self.grid[i, j] + if value == CostValues.FREE: # Free space = light grey (205, 205, 205) + img_array[i, j] = [205, 205, 205] + elif value == CostValues.UNKNOWN: # Unknown = dark gray (128, 128, 128) + img_array[i, j] = [128, 128, 128] + elif value >= CostValues.OCCUPIED: # Occupied/obstacles = black (0, 0, 0) + img_array[i, j] = [0, 0, 0] + else: # Any other values (low cost) = light grey + img_array[i, j] = [205, 205, 205] + + # Flip vertically to match ROS convention (origin at bottom-left) + img_array = np.flipud(img_array) + + # Create PIL image and save as JPEG + img = Image.fromarray(img_array, "RGB") + img.save(image_path, "JPEG", quality=95) + print(f"Costmap image saved to: {image_path}") + def _inflate_lethal(costmap: np.ndarray, radius: int, lethal_val: int = 100) -> np.ndarray: """Return *costmap* with lethal cells dilated by *radius* grid steps (circular).""" diff --git a/dimos/utils/ros_utils.py b/dimos/utils/ros_utils.py deleted file mode 100644 index f7b47a003a..0000000000 --- a/dimos/utils/ros_utils.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from typing import Tuple -import logging - -logger = logging.getLogger(__name__) - - -def normalize_angle(angle: float) -> float: - """Normalize angle to [-pi, pi] range""" - return np.arctan2(np.sin(angle), np.cos(angle)) - - -def distance_angle_to_goal_xy(distance: float, angle: float) -> Tuple[float, float]: - """Convert distance and angle to goal x, y in robot frame""" - return distance * np.cos(angle), distance * np.sin(angle) diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py new file mode 100644 index 0000000000..31d3840884 --- /dev/null +++ b/dimos/utils/transform_utils.py @@ -0,0 +1,85 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import Tuple, Dict, Any +import logging + +from dimos.types.vector import Vector + +logger = logging.getLogger(__name__) + + +def normalize_angle(angle: float) -> float: + """Normalize angle to [-pi, pi] range""" + return np.arctan2(np.sin(angle), np.cos(angle)) + + +def distance_angle_to_goal_xy(distance: float, angle: float) -> Tuple[float, float]: + """Convert distance and angle to goal x, y in robot frame""" + return distance * np.cos(angle), distance * np.sin(angle) + + +def transform_robot_to_map( + robot_position: Vector, robot_rotation: Vector, position: Vector, rotation: Vector +) -> Tuple[Vector, Vector]: + """Transform position and rotation from robot frame to map frame. + + Args: + robot_position: Current robot position in map frame + robot_rotation: Current robot rotation in map frame + position: Position in robot frame as Vector (x, y, z) + rotation: Rotation in robot frame as Vector (roll, pitch, yaw) in radians + + Returns: + Tuple of (transformed_position, transformed_rotation) where: + - transformed_position: Vector (x, y, z) in map frame + - transformed_rotation: Vector (roll, pitch, yaw) in map frame + + Example: + obj_pos_robot = Vector(1.0, 0.5, 0.0) # 1m forward, 0.5m left of robot + obj_rot_robot = Vector(0.0, 0.0, 0.0) # No rotation relative to robot + + map_pos, map_rot = transform_robot_to_map(robot_position, robot_rotation, obj_pos_robot, obj_rot_robot) + """ + # Extract robot pose components + robot_pos = robot_position + robot_rot = robot_rotation + + # Robot position and orientation in map frame + robot_x, robot_y, robot_z = robot_pos.x, robot_pos.y, robot_pos.z + robot_yaw = robot_rot.z # yaw is rotation around z-axis + + # Position in robot frame + pos_x, pos_y, pos_z = position.x, position.y, position.z + + # Apply 2D transformation (rotation + translation) for x,y coordinates + cos_yaw = np.cos(robot_yaw) + sin_yaw = np.sin(robot_yaw) + + # Transform position from robot frame to map frame + map_x = robot_x + cos_yaw * pos_x - sin_yaw * pos_y + map_y = robot_y + sin_yaw * pos_x + cos_yaw * pos_y + map_z = robot_z + pos_z # Z translation (assume flat ground) + + # Transform rotation from robot frame to map frame + rot_roll, rot_pitch, rot_yaw = rotation.x, rotation.y, rotation.z + map_roll = robot_rot.x + rot_roll # Add robot's roll + map_pitch = robot_rot.y + rot_pitch # Add robot's pitch + map_yaw_rot = normalize_angle(robot_yaw + rot_yaw) # Add robot's yaw and normalize + + transformed_position = Vector(map_x, map_y, map_z) + transformed_rotation = Vector(map_roll, map_pitch, map_yaw_rot) + + return transformed_position, transformed_rotation diff --git a/dimos/web/websocket_vis/server.py b/dimos/web/websocket_vis/server.py index cfb404fadd..a7aca2da2b 100644 --- a/dimos/web/websocket_vis/server.py +++ b/dimos/web/websocket_vis/server.py @@ -25,6 +25,7 @@ from starlette.staticfiles import StaticFiles from dimos.web.websocket_vis.types import Drawable from reactivex import Observable +import concurrent.futures async def serve_index(request): @@ -113,6 +114,7 @@ def __init__(self, port=7779, use_reload=False, msg_handler=None): self.use_reload = use_reload self.main_state = main_state # Reference to global main_state self.msg_handler = msg_handler + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) # Store reference to this instance on the sio object for message handling sio.vis_instance = self @@ -167,7 +169,7 @@ def new_update(data): [name, drawable] = data self.update_state({"draw": {name: self.process_drawable(drawable)}}) - obs.subscribe( + return obs.subscribe( on_next=new_update, on_error=lambda e: print(f"Error in stream: {e}"), on_completed=lambda: print("Stream completed"), @@ -176,30 +178,42 @@ def new_update(data): def stop(self): if self.server_thread and self.server_thread.is_alive(): self.server_thread.join() - self.sio.disconnect() + if hasattr(self, "_executor"): + self._executor.shutdown(wait=True) + if hasattr(self.sio, "disconnect"): + try: + asyncio.run(self.sio.disconnect()) + except: + pass async def update_state_async(self, new_data): """Update main_state and broadcast to all connected clients""" await update_state(new_data) def update_state(self, new_data): - """Synchronous wrapper for update_state""" - - # Get or create an event loop + """Thread-safe wrapper for update_state""" try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Run the coroutine in the loop - if loop.is_running(): - # Create a future and run it in the existing loop + # Try to get the current event loop + loop = asyncio.get_running_loop() + # If we're in an event loop, schedule the coroutine future = asyncio.run_coroutine_threadsafe(update_state(new_data), loop) - return future.result() - else: - # Run the coroutine in a new loop - return loop.run_until_complete(update_state(new_data)) + return future.result(timeout=5.0) # Add timeout to prevent blocking + except RuntimeError: + # No event loop running, this is likely called from a different thread + # Use the executor to run in a controlled manner + def run_async(): + try: + return asyncio.run(update_state(new_data)) + except Exception as e: + print(f"Error updating state: {e}") + return None + + future = self._executor.submit(run_async) + try: + return future.result(timeout=5.0) + except concurrent.futures.TimeoutError: + print("Warning: State update timed out") + return None # Test timer function that updates state with current Unix time diff --git a/tests/run.py b/tests/run.py index 4c4bfc036e..8ddb2261e0 100644 --- a/tests/run.py +++ b/tests/run.py @@ -17,14 +17,18 @@ import time from dotenv import load_dotenv +from dimos.agents.cerebras_agent import CerebrasAgent from dimos.agents.claude_agent import ClaudeAgent from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl + +# from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.websocket_vis.server import WebsocketVis from dimos.skills.observe_stream import ObserveStream +from dimos.skills.observe import Observe from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal +from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal, Explore from dimos.skills.visual_navigation_skills import FollowHuman import reactivex as rx import reactivex.operators as ops @@ -33,9 +37,23 @@ import json from dimos.types.vector import Vector from dimos.skills.speak import Speak + from dimos.perception.object_detection_stream import ObjectDetectionStream from dimos.perception.detection2d.detic_2d_det import Detic2DDetector from dimos.utils.reactive import backpressure +import asyncio +import atexit +import signal +import sys +import warnings +import logging + +# Filter out known WebRTC warnings that don't affect functionality +warnings.filterwarnings("ignore", message="coroutine.*was never awaited") +warnings.filterwarnings("ignore", message=".*RTCSctpTransport.*") + +# Set up logging to reduce asyncio noise +logging.getLogger("asyncio").setLevel(logging.ERROR) # Load API key from environment load_dotenv() @@ -59,16 +77,136 @@ def parse_arguments(): args = parse_arguments() -# Initialize robot with spatial memory parameters +# Initialize robot with spatial memory parameters - using WebRTC mode instead of "ai" robot = UnitreeGo2( ip=os.getenv("ROBOT_IP"), - skills=MyUnitreeSkills(), - mock_connection=False, - spatial_memory_dir=args.spatial_memory_dir, # Will use default if None - new_memory=args.new_memory, # Create a new memory if specified - mode="ai", + mode="normal", ) + +# Add graceful shutdown handling to prevent WebRTC task destruction errors +def cleanup_robot(): + print("Cleaning up robot connection...") + try: + # Make cleanup non-blocking to avoid hangs + def quick_cleanup(): + try: + robot.liedown() + except: + pass + + # Run cleanup in a separate thread with timeout + cleanup_thread = threading.Thread(target=quick_cleanup) + cleanup_thread.daemon = True + cleanup_thread.start() + cleanup_thread.join(timeout=3.0) # Max 3 seconds for cleanup + + # Force stop the robot's WebRTC connection + try: + robot.stop() + except: + pass + + except Exception as e: + print(f"Error during cleanup: {e}") + # Continue anyway + + +atexit.register(cleanup_robot) + + +def signal_handler(signum, frame): + print("Received shutdown signal, cleaning up...") + try: + cleanup_robot() + except: + pass + # Force exit if cleanup hangs + os._exit(0) + + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +# Initialize WebSocket visualization +websocket_vis = WebsocketVis() +websocket_vis.start() +websocket_vis.connect(robot.global_planner.vis_stream()) + + +def msg_handler(msgtype, data): + if msgtype == "click": + print(f"Received click at position: {data['position']}") + + try: + print("Setting goal...") + + # Instead of disabling visualization, make it timeout-safe + original_vis = robot.global_planner.vis + + def safe_vis(name, drawable): + """Visualization wrapper that won't block on timeouts""" + try: + # Use a separate thread for visualization to avoid blocking + def vis_update(): + try: + original_vis(name, drawable) + except Exception as e: + print(f"Visualization update failed (non-critical): {e}") + + vis_thread = threading.Thread(target=vis_update) + vis_thread.daemon = True + vis_thread.start() + # Don't wait for completion - let it run asynchronously + except Exception as e: + print(f"Visualization setup failed (non-critical): {e}") + + robot.global_planner.vis = safe_vis + robot.global_planner.set_goal(Vector(data["position"])) + robot.global_planner.vis = original_vis + + print("Goal set successfully") + except Exception as e: + print(f"Error setting goal: {e}") + import traceback + + traceback.print_exc() + + +def threaded_msg_handler(msgtype, data): + print(f"Processing message: {msgtype}") + + # Create a dedicated event loop for goal setting to avoid conflicts + def run_with_dedicated_loop(): + try: + # Use asyncio.run which creates and manages its own event loop + # This won't conflict with the robot's or websocket's event loops + async def async_msg_handler(): + msg_handler(msgtype, data) + + asyncio.run(async_msg_handler()) + print("Goal setting completed successfully") + except Exception as e: + print(f"Error in goal setting thread: {e}") + import traceback + + traceback.print_exc() + + thread = threading.Thread(target=run_with_dedicated_loop) + thread.daemon = True + thread.start() + + +websocket_vis.msg_handler = threaded_msg_handler + + +def newmap(msg): + return ["costmap", robot.map.costmap.smudge()] + + +websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) +websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) + # Create a subject for agent responses agent_response_subject = rx.subject.Subject() agent_response_stream = agent_response_subject.pipe(ops.share()) @@ -77,39 +215,45 @@ def parse_arguments(): # Initialize object detection stream min_confidence = 0.6 class_filter = None # No class filtering -detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) +min_confidence = 0.99 # temporarily disable detections +# detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) # Create video stream from robot's camera -video_stream = backpressure(robot.get_ros_video_stream()) - -# Initialize ObjectDetectionStream with robot -object_detector = ObjectDetectionStream( - camera_intrinsics=robot.camera_intrinsics, - min_confidence=min_confidence, - class_filter=class_filter, - transform_to_map=robot.ros_control.transform_pose, - detector=detector, - video_stream=video_stream, -) +video_stream = robot.get_video_stream() # WebRTC doesn't use ROS video stream -# Create visualization stream for web interface -viz_stream = backpressure(object_detector.get_stream()).pipe( - ops.share(), - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), -) +# # Initialize ObjectDetectionStream with robot +# object_detector = ObjectDetectionStream( +# camera_intrinsics=robot.camera_intrinsics, +# min_confidence=min_confidence, +# class_filter=class_filter, +# get_pose=robot.get_pose, +# detector=detector, +# video_stream=video_stream, +# ) -# Get the formatted detection stream -formatted_detection_stream = object_detector.get_formatted_stream().pipe( - ops.filter(lambda x: x is not None) -) +# # Create visualization stream for web interface +# viz_stream = backpressure(object_detector.get_stream()).pipe( +# ops.share(), +# ops.map(lambda x: x["viz_frame"] if x is not None else None), +# ops.filter(lambda x: x is not None), +# ) + +# # Get the formatted detection stream +# formatted_detection_stream = object_detector.get_formatted_stream().pipe( +# ops.filter(lambda x: x is not None) +# ) # Create a direct mapping that combines detection data with locations def combine_with_locations(object_detections): # Get locations from spatial memory try: - locations = robot.get_spatial_memory().get_robot_locations() + spatial_memory = robot.get_spatial_memory() + if spatial_memory is None: + # If spatial memory is disabled, just return the object detections + return object_detections + + locations = spatial_memory.get_robot_locations() # Format the locations section locations_text = "\n\nSaved Robot Locations:\n" @@ -128,12 +272,12 @@ def combine_with_locations(object_detections): # Create the combined stream with a simple pipe operation -enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) +# enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) streams = { - "unitree_video": robot.get_ros_video_stream(), + "unitree_video": robot.get_video_stream(), # Changed from get_ros_video_stream to get_video_stream for WebRTC "local_planner_viz": local_planner_viz_stream, - "object_detection": viz_stream, + # "object_detection": viz_stream, # Uncommented object detection } text_streams = { "agent_responses": agent_response_stream, @@ -144,17 +288,18 @@ def combine_with_locations(object_detections): # stt_node = stt() # Read system query from prompt.txt file -# with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'prompt.txt'), 'r') as f: -# system_query = f.read() +with open( + os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets/agent/prompt.txt"), "r" +) as f: + system_query = f.read() # Create a ClaudeAgent instance agent = ClaudeAgent( dev_name="test_agent", # input_query_stream=stt_node.emit_text(), input_query_stream=web_interface.query_stream, - input_data_stream=enhanced_data_stream, # Add the enhanced data stream skills=robot.get_skills(), - system_query="What do you see", + system_query=system_query, model_name="claude-3-7-sonnet-latest", thinking_budget_tokens=0, ) @@ -164,19 +309,24 @@ def combine_with_locations(object_detections): robot_skills = robot.get_skills() robot_skills.add(ObserveStream) +robot_skills.add(Observe) robot_skills.add(KillSkill) robot_skills.add(NavigateWithText) robot_skills.add(FollowHuman) robot_skills.add(GetPose) -robot_skills.add(Speak) -# robot_skills.add(NavigateToGoal) +# robot_skills.add(Speak) +robot_skills.add(NavigateToGoal) +robot_skills.add(Explore) + robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("Observe", robot=robot, agent=agent) robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) robot_skills.create_instance("NavigateWithText", robot=robot) robot_skills.create_instance("FollowHuman", robot=robot) robot_skills.create_instance("GetPose", robot=robot) -# robot_skills.create_instance("NavigateToGoal", robot=robot) -robot_skills.create_instance("Speak", tts_node=tts_node) +robot_skills.create_instance("NavigateToGoal", robot=robot) +robot_skills.create_instance("Explore", robot=robot) +# robot_skills.create_instance("Speak", tts_node=tts_node) # Subscribe to agent responses and send them to the subject agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) @@ -184,4 +334,24 @@ def combine_with_locations(object_detections): print("ObserveStream and Kill skills registered and ready for use") print("Created memory.txt file") -web_interface.run() +# Start web interface in a separate thread to avoid blocking +web_thread = threading.Thread(target=web_interface.run) +web_thread.daemon = True +web_thread.start() + +try: + while True: + # Main loop - can add robot movement or other logic here + time.sleep(0.01) + +except KeyboardInterrupt: + print("Stopping robot") + robot.liedown() +except Exception as e: + print(f"Unexpected error in main loop: {e}") + import traceback + + traceback.print_exc() +finally: + print("Cleaning up...") + cleanup_robot() diff --git a/tests/run_navigation_only.py b/tests/run_navigation_only.py new file mode 100644 index 0000000000..2995750e2b --- /dev/null +++ b/tests/run_navigation_only.py @@ -0,0 +1,191 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dotenv import load_dotenv +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream +from dimos.web.websocket_vis.server import WebsocketVis +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.types.vector import Vector +import reactivex.operators as ops +import time +import threading +import asyncio +import atexit +import signal +import sys +import warnings +import logging +# logging.basicConfig(level=logging.DEBUG) + +# Filter out known WebRTC warnings that don't affect functionality +warnings.filterwarnings("ignore", message="coroutine.*was never awaited") +warnings.filterwarnings("ignore", message=".*RTCSctpTransport.*") + +# Set up logging to reduce asyncio noise +logging.getLogger("asyncio").setLevel(logging.ERROR) + +load_dotenv() +robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="normal", enable_perception=False) + + +# Add graceful shutdown handling to prevent WebRTC task destruction errors +def cleanup_robot(): + print("Cleaning up robot connection...") + try: + # Make cleanup non-blocking to avoid hangs + def quick_cleanup(): + try: + robot.liedown() + except: + pass + + # Run cleanup in a separate thread with timeout + cleanup_thread = threading.Thread(target=quick_cleanup) + cleanup_thread.daemon = True + cleanup_thread.start() + cleanup_thread.join(timeout=3.0) # Max 3 seconds for cleanup + + # Force stop the robot's WebRTC connection + try: + robot.stop() + except: + pass + + except Exception as e: + print(f"Error during cleanup: {e}") + # Continue anyway + + +atexit.register(cleanup_robot) + + +def signal_handler(signum, frame): + print("Received shutdown signal, cleaning up...") + try: + cleanup_robot() + except: + pass + # Force exit if cleanup hangs + os._exit(0) + + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +websocket_vis = WebsocketVis() +websocket_vis.start() +websocket_vis.connect(robot.global_planner.vis_stream()) + + +def msg_handler(msgtype, data): + if msgtype == "click": + print(f"Received click at position: {data['position']}") + + try: + print("Setting goal...") + + # Instead of disabling visualization, make it timeout-safe + original_vis = robot.global_planner.vis + + def safe_vis(name, drawable): + """Visualization wrapper that won't block on timeouts""" + try: + # Use a separate thread for visualization to avoid blocking + def vis_update(): + try: + original_vis(name, drawable) + except Exception as e: + print(f"Visualization update failed (non-critical): {e}") + + vis_thread = threading.Thread(target=vis_update) + vis_thread.daemon = True + vis_thread.start() + # Don't wait for completion - let it run asynchronously + except Exception as e: + print(f"Visualization setup failed (non-critical): {e}") + + robot.global_planner.vis = safe_vis + robot.global_planner.set_goal(Vector(data["position"])) + robot.global_planner.vis = original_vis + + print("Goal set successfully") + except Exception as e: + print(f"Error setting goal: {e}") + import traceback + + traceback.print_exc() + + +def threaded_msg_handler(msgtype, data): + print(f"Processing message: {msgtype}") + + # Create a dedicated event loop for goal setting to avoid conflicts + def run_with_dedicated_loop(): + try: + # Use asyncio.run which creates and manages its own event loop + # This won't conflict with the robot's or websocket's event loops + async def async_msg_handler(): + msg_handler(msgtype, data) + + asyncio.run(async_msg_handler()) + print("Goal setting completed successfully") + except Exception as e: + print(f"Error in goal setting thread: {e}") + import traceback + + traceback.print_exc() + + thread = threading.Thread(target=run_with_dedicated_loop) + thread.daemon = True + thread.start() + + +websocket_vis.msg_handler = threaded_msg_handler + +print("standing up") +robot.standup() +print("robot is up") + + +def newmap(msg): + return ["costmap", robot.map.costmap.smudge()] + + +websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) +websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) + +local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) + +# Add RobotWebInterface with video stream +streams = {"unitree_video": robot.get_video_stream(), "local_planner_viz": local_planner_viz_stream} +web_interface = RobotWebInterface(port=5555, **streams) +web_interface.run() + +try: + while True: + # robot.move_vel(Vector(0.1, 0.1, 0.1)) + time.sleep(0.01) + +except KeyboardInterrupt: + print("Stopping robot") + robot.liedown() +except Exception as e: + print(f"Unexpected error in main loop: {e}") + import traceback + + traceback.print_exc() +finally: + print("Cleaning up...") + cleanup_robot() diff --git a/tests/run_webrtc.py b/tests/run_webrtc.py deleted file mode 100644 index ff96bed14e..0000000000 --- a/tests/run_webrtc.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import cv2 -import os -import asyncio -from dotenv import load_dotenv -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2, Color -from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.types.vector import Vector -import logging -import open3d as o3d -import reactivex.operators as ops -import numpy as np -import time -import threading - -# logging.basicConfig(level=logging.DEBUG) - -load_dotenv() -robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") - -websocket_vis = WebsocketVis() -websocket_vis.start() -websocket_vis.connect(robot.global_planner.vis_stream()) - - -def msg_handler(msgtype, data): - if msgtype == "click": - try: - robot.global_planner.set_goal(Vector(data["position"])) - except Exception as e: - print(f"Error setting goal: {e}") - return - - -def threaded_msg_handler(msgtype, data): - thread = threading.Thread(target=msg_handler, args=(msgtype, data)) - thread.daemon = True - thread.start() - - -websocket_vis.msg_handler = threaded_msg_handler - -print("standing up") -robot.standup() -print("robot is up") - - -def newmap(msg): - return ["costmap", robot.map.costmap.smudge()] - - -websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) -websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) - -try: - while True: - # robot.move_vel(Vector(0.1, 0.1, 0.1)) - time.sleep(0.01) - -except KeyboardInterrupt: - print("Stopping robot") - robot.liedown() diff --git a/tests/test_cerebras_agent_query.py b/tests/test_cerebras_agent_query.py deleted file mode 100644 index 5ed4007eed..0000000000 --- a/tests/test_cerebras_agent_query.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import test_header - -from dotenv import load_dotenv -from dimos.agents.cerebras_agent import CerebrasAgent - -# Load API key from environment -load_dotenv() - -# Create a CerebrasAgent instance -agent = CerebrasAgent(dev_name="test_agent", query="What is the capital of France?") - -# Use the stream_query method to get a response -response = agent.run_observable_query("What is the capital of France?").run() - -print(f"Response from Cerebras Agent: {response}") diff --git a/tests/test_object_detection_stream.py b/tests/test_object_detection_stream.py index d98006ab4d..ed0a64fa9e 100644 --- a/tests/test_object_detection_stream.py +++ b/tests/test_object_detection_stream.py @@ -20,8 +20,6 @@ from typing import List, Dict, Any from reactivex import Subject, operators as ops -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.logging_config import logger @@ -135,7 +133,7 @@ def main(): transform_to_map=robot.ros_control.transform_pose, detector=detector, video_stream=video_stream, - disable_depth=True, + disable_depth=False, ) else: # webcam mode @@ -170,7 +168,7 @@ def main(): class_filter=class_filter, detector=detector, video_stream=video_stream, - disable_depth=True, + disable_depth=False, ) # Set placeholder robot for cleanup diff --git a/tests/test_robot.py b/tests/test_robot.py index 5b2dd89d3d..76289273f7 100644 --- a/tests/test_robot.py +++ b/tests/test_robot.py @@ -15,10 +15,8 @@ import os import time import threading -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.robot.local_planner import navigate_to_goal_local +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.local_planner.local_planner import navigate_to_goal_local from dimos.web.robot_web_interface import RobotWebInterface from reactivex import operators as RxOps import tests.test_header @@ -27,15 +25,11 @@ def main(): print("Initializing Unitree Go2 robot with local planner visualization...") - # Initialize the robot with ROS control and skills - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills(), - ) + # Initialize the robot with webrtc interface + robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") # Get the camera stream - video_stream = robot.get_ros_video_stream() + video_stream = robot.get_video_stream() # The local planner visualization stream is created during robot initialization local_planner_stream = robot.local_planner_viz_stream @@ -80,7 +74,11 @@ def main(): print(f"Error during test: {e}") finally: print("Cleaning up...") - robot.cleanup() + # Make sure the robot stands down safely + try: + robot.liedown() + except: + pass print("Test completed")