diff --git a/dimos/agents/memory/image_embedding.py b/dimos/agents/memory/image_embedding.py index 1ad0e9132d..142839abd9 100644 --- a/dimos/agents/memory/image_embedding.py +++ b/dimos/agents/memory/image_embedding.py @@ -54,6 +54,7 @@ def __init__(self, model_name: str = "clip", dimensions: int = 512): self.dimensions = dimensions self.model = None self.processor = None + self.model_path = None self._initialize_model() @@ -68,10 +69,16 @@ def _initialize_model(self): if self.model_name == "clip": model_id = get_data("models_clip") / "model.onnx" + self.model_path = str(model_id) # Store for pickling processor_id = "openai/clip-vit-base-patch32" - self.model = ort.InferenceSession(model_id) + + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + + self.model = ort.InferenceSession(str(model_id), providers=providers) + + actual_providers = self.model.get_providers() self.processor = CLIPProcessor.from_pretrained(processor_id) - logger.info(f"Loaded CLIP model: {model_id}") + logger.info(f"Loaded CLIP model: {model_id} with providers: {actual_providers}") elif self.model_name == "resnet": model_id = "microsoft/resnet-50" self.model = AutoModel.from_pretrained(model_id) diff --git a/dimos/agents/memory/spatial_vector_db.py b/dimos/agents/memory/spatial_vector_db.py index cf44d0c589..e144e99757 100644 --- a/dimos/agents/memory/spatial_vector_db.py +++ b/dimos/agents/memory/spatial_vector_db.py @@ -38,7 +38,11 @@ class SpatialVectorDB: """ def __init__( - self, collection_name: str = "spatial_memory", chroma_client=None, visual_memory=None + self, + collection_name: str = "spatial_memory", + chroma_client=None, + visual_memory=None, + embedding_provider=None, ): """ Initialize the spatial vector database. @@ -47,6 +51,7 @@ def __init__( collection_name: Name of the vector database collection chroma_client: Optional ChromaDB client for persistence. If None, an in-memory client is used. visual_memory: Optional VisualMemory instance for storing images. If None, a new one is created. + embedding_provider: Optional ImageEmbeddingProvider instance for computing embeddings. If None, one will be created. """ self.collection_name = collection_name @@ -77,6 +82,9 @@ def __init__( # Use provided visual memory or create a new one self.visual_memory = visual_memory if visual_memory is not None else VisualMemory() + # Store the embedding provider to reuse for all operations + self.embedding_provider = embedding_provider + # Log initialization info with details about whether using existing collection client_type = "persistent" if chroma_client is not None else "in-memory" try: @@ -223,11 +231,12 @@ def query_by_text(self, text: str, limit: int = 5) -> List[Dict]: Returns: List of results, each containing the image, its metadata, and similarity score """ - from dimos.agents.memory.image_embedding import ImageEmbeddingProvider + if self.embedding_provider is None: + from dimos.agents.memory.image_embedding import ImageEmbeddingProvider - embedding_provider = ImageEmbeddingProvider(model_name="clip") + self.embedding_provider = ImageEmbeddingProvider(model_name="clip") - text_embedding = embedding_provider.get_text_embedding(text) + text_embedding = self.embedding_provider.get_text_embedding(text) results = self.image_collection.query( query_embeddings=[text_embedding.tolist()], diff --git a/dimos/manipulation/manip_aio_processer.py b/dimos/manipulation/manip_aio_processer.py index 2d80b527ce..aa439d2814 100644 --- a/dimos/manipulation/manip_aio_processer.py +++ b/dimos/manipulation/manip_aio_processer.py @@ -98,7 +98,6 @@ def __init__( self.segmenter = None if self.enable_segmentation: self.segmenter = Sam2DSegmenter( - device="cuda", use_tracker=False, # Disable tracker for simple segmentation use_analyzer=False, # Disable analyzer for simple segmentation ) diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index b2724fb59a..dfcb1dbcb0 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -40,7 +40,6 @@ select_points_from_depth, transform_points_3d, update_target_grasp_pose, - apply_grasp_distance, is_target_reached, ) from dimos.utils.transform_utils import ( diff --git a/dimos/msgs/nav_msgs/OccupancyGrid.py b/dimos/msgs/nav_msgs/OccupancyGrid.py index e66da51e89..4bb7495e86 100644 --- a/dimos/msgs/nav_msgs/OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/OccupancyGrid.py @@ -14,13 +14,14 @@ from __future__ import annotations -import math +import time from enum import IntEnum from typing import TYPE_CHECKING, BinaryIO, Optional import numpy as np from dimos_lcm.nav_msgs import MapMetaData from dimos_lcm.nav_msgs import OccupancyGrid as LCMOccupancyGrid +from dimos_lcm.std_msgs import Time as LCMTime from scipy import ndimage from dimos.msgs.geometry_msgs import Pose, Vector3, VectorLike @@ -79,7 +80,6 @@ def __init__( frame_id: Reference frame ts: Timestamp (defaults to current time if 0) """ - import time self.frame_id = frame_id self.ts = ts if ts != 0 else time.time() @@ -114,7 +114,6 @@ def __init__( def _to_lcm_time(self): """Convert timestamp to LCM Time.""" - from dimos_lcm.std_msgs import Time as LCMTime s = int(self.ts) return LCMTime(sec=s, nsec=int((self.ts - s) * 1_000_000_000)) @@ -174,11 +173,10 @@ def unknown_percent(self) -> float: """Percentage of cells that are unknown.""" return (self.unknown_cells / self.total_cells * 100) if self.total_cells > 0 else 0.0 - def inflate(self, radius: float, cost_scaling_factor: float = 0.0) -> "OccupancyGrid": - """Inflate obstacles by a given radius (vectorized). + def inflate(self, radius: float) -> "OccupancyGrid": + """Inflate obstacles by a given radius (binary inflation). Args: radius: Inflation radius in meters - cost_scaling_factor: Factor for decay (0.0 = no decay, binary inflation) Returns: New OccupancyGrid with inflated obstacles """ @@ -188,31 +186,18 @@ def inflate(self, radius: float, cost_scaling_factor: float = 0.0) -> "Occupancy # Get grid as numpy array grid_array = self.grid - # Create square kernel for binary inflation + # Create circular kernel for binary inflation kernel_size = 2 * cell_radius + 1 - kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) + y, x = np.ogrid[-cell_radius : cell_radius + 1, -cell_radius : cell_radius + 1] + kernel = (x**2 + y**2 <= cell_radius**2).astype(np.uint8) # Find occupied cells occupied_mask = grid_array >= CostValues.OCCUPIED - if cost_scaling_factor == 0.0: - # Binary inflation - inflated = ndimage.binary_dilation(occupied_mask, structure=kernel) - result_grid = grid_array.copy() - result_grid[inflated] = CostValues.OCCUPIED - else: - # Distance-based inflation with decay - # Create distance transform from occupied cells - distance_field = ndimage.distance_transform_edt(~occupied_mask) - - # Apply exponential decay based on distance - cost_field = CostValues.OCCUPIED * np.exp(-cost_scaling_factor * distance_field) - - # Combine with original grid, keeping higher values - result_grid = np.maximum(grid_array, cost_field).astype(np.int8) - - # Ensure occupied cells remain at max value - result_grid[occupied_mask] = CostValues.OCCUPIED + # Binary inflation + inflated = ndimage.binary_dilation(occupied_mask, structure=kernel) + result_grid = grid_array.copy() + result_grid[inflated] = CostValues.OCCUPIED # Create new OccupancyGrid with inflated data using numpy constructor return OccupancyGrid( @@ -350,7 +335,7 @@ def from_pointcloud( min_height: float = 0.1, max_height: float = 2.0, frame_id: Optional[str] = None, - mark_free_radius: float = 0.0, + mark_free_radius: float = 0.4, ) -> "OccupancyGrid": """Create an OccupancyGrid from a PointCloud2 message. @@ -367,8 +352,6 @@ def from_pointcloud( Returns: OccupancyGrid with occupied cells where points were projected """ - # Import here to avoid circular dependency - from dimos.msgs.sensor_msgs import PointCloud2 # Get points as numpy array points = cloud.as_numpy() @@ -379,22 +362,26 @@ def from_pointcloud( width=1, height=1, resolution=resolution, frame_id=frame_id or cloud.frame_id ) - # Filter points by height - height_mask = (points[:, 2] >= min_height) & (points[:, 2] <= max_height) - filtered_points = points[height_mask] + # Filter points by height for obstacles + obstacle_mask = (points[:, 2] >= min_height) & (points[:, 2] <= max_height) + obstacle_points = points[obstacle_mask] - if len(filtered_points) == 0: - # No points in height range + # Get points below min_height for marking as free space + ground_mask = points[:, 2] < min_height + ground_points = points[ground_mask] + + # Find bounds of the point cloud in X-Y plane (use all points) + if len(points) > 0: + min_x = np.min(points[:, 0]) + max_x = np.max(points[:, 0]) + min_y = np.min(points[:, 1]) + max_y = np.max(points[:, 1]) + else: + # Return empty grid if no points at all return cls( width=1, height=1, resolution=resolution, frame_id=frame_id or cloud.frame_id ) - # Find bounds of the point cloud in X-Y plane - min_x = np.min(filtered_points[:, 0]) - max_x = np.max(filtered_points[:, 0]) - min_y = np.min(filtered_points[:, 1]) - max_y = np.max(filtered_points[:, 1]) - # Add some padding around the bounds padding = 1.0 # 1 meter padding min_x -= padding @@ -416,23 +403,36 @@ def from_pointcloud( # Initialize grid (all unknown) grid = np.full((height, width), -1, dtype=np.int8) - # Convert points to grid indices - grid_x = ((filtered_points[:, 0] - min_x) / resolution).astype(np.int32) - grid_y = ((filtered_points[:, 1] - min_y) / resolution).astype(np.int32) + # First, mark ground points as free space + if len(ground_points) > 0: + ground_x = ((ground_points[:, 0] - min_x) / resolution).astype(np.int32) + ground_y = ((ground_points[:, 1] - min_y) / resolution).astype(np.int32) - # Clip indices to grid bounds - grid_x = np.clip(grid_x, 0, width - 1) - grid_y = np.clip(grid_y, 0, height - 1) + # Clip indices to grid bounds + ground_x = np.clip(ground_x, 0, width - 1) + ground_y = np.clip(ground_y, 0, height - 1) - # Mark cells as occupied - grid[grid_y, grid_x] = 100 # Lethal obstacle + # Mark ground cells as free + grid[ground_y, ground_x] = 0 # Free space - # Mark free space around obstacles based on mark_free_radius + # Then mark obstacle points (will override ground if at same location) + if len(obstacle_points) > 0: + obs_x = ((obstacle_points[:, 0] - min_x) / resolution).astype(np.int32) + obs_y = ((obstacle_points[:, 1] - min_y) / resolution).astype(np.int32) + + # Clip indices to grid bounds + obs_x = np.clip(obs_x, 0, width - 1) + obs_y = np.clip(obs_y, 0, height - 1) + + # Mark cells as occupied + grid[obs_y, obs_x] = 100 # Lethal obstacle + + # Apply mark_free_radius to expand free space areas if mark_free_radius > 0: - # Mark a specified radius around occupied cells as free - from scipy.ndimage import binary_dilation + # Expand existing free space areas by the specified radius + # This will NOT expand from obstacles, only from free space - occupied_mask = grid == 100 + free_mask = grid == 0 # Current free space free_radius_cells = int(np.ceil(mark_free_radius / resolution)) # Create circular kernel @@ -442,20 +442,11 @@ def from_pointcloud( ] kernel = x**2 + y**2 <= free_radius_cells**2 - known_area = binary_dilation(occupied_mask, structure=kernel, iterations=1) - # Mark non-occupied cells in the known area as free - grid[known_area & (grid != 100)] = 0 - else: - # Default: only mark immediate neighbors as free to preserve unknown - from scipy.ndimage import binary_dilation - - occupied_mask = grid == 100 - # Use a small 3x3 kernel to only mark immediate neighbors - structure = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) - immediate_neighbors = binary_dilation(occupied_mask, structure=structure, iterations=1) + # Dilate free space areas + expanded_free = ndimage.binary_dilation(free_mask, structure=kernel, iterations=1) - # Mark only immediate neighbors as free (not the occupied cells themselves) - grid[immediate_neighbors & (grid != 100)] = 0 + # Mark expanded areas as free, but don't override obstacles + grid[expanded_free & (grid != 100)] = 0 # Create and return OccupancyGrid # Get timestamp from cloud if available @@ -479,31 +470,28 @@ def gradient(self, obstacle_threshold: int = 50, max_distance: float = 2.0) -> " Args: obstacle_threshold: Cell values >= this are considered obstacles (default: 50) - max_distance: Maximum distance to compute gradient in meters (default: 5.0) + max_distance: Maximum distance to compute gradient in meters (default: 2.0) Returns: New OccupancyGrid with gradient values: - - -1: Unknown cells far from obstacles (beyond max_distance) + - -1: Unknown cells (preserved as-is) - 0: Free space far from obstacles - 1-99: Increasing cost as you approach obstacles - 100: At obstacles - Note: Unknown cells within max_distance of obstacles will have gradient - values assigned, allowing path planning through unknown areas. + Note: Unknown cells remain as unknown (-1) and do not receive gradient values. """ # Remember which cells are unknown - unknown_mask = self.grid == -1 - - # Create a working grid where unknown cells are treated as free for distance calculation - working_grid = self.grid.copy() - working_grid[unknown_mask] = 0 # Treat unknown as free for gradient computation + unknown_mask = self.grid == CostValues.UNKNOWN - # Create binary obstacle map from working grid + # Create binary obstacle map # Consider cells >= threshold as obstacles (1), everything else as free (0) - obstacle_map = (working_grid >= obstacle_threshold).astype(np.float32) + # Unknown cells are not considered obstacles for distance calculation + obstacle_map = (self.grid >= obstacle_threshold).astype(np.float32) # Compute distance transform (distance to nearest obstacle in cells) + # Unknown cells are treated as if they don't exist for distance calculation distance_cells = ndimage.distance_transform_edt(1 - obstacle_map) # Convert to meters and clip to max distance @@ -515,15 +503,13 @@ def gradient(self, obstacle_threshold: int = 50, max_distance: float = 2.0) -> " gradient_values = (1 - distance_meters / max_distance) * 100 # Ensure obstacles are exactly 100 - gradient_values[obstacle_map > 0] = 100 + gradient_values[obstacle_map > 0] = CostValues.OCCUPIED # Convert to int8 for OccupancyGrid gradient_data = gradient_values.astype(np.int8) - # Only preserve unknown cells that are beyond max_distance from any obstacle - # This allows gradient to spread through unknown areas near obstacles - far_unknown_mask = unknown_mask & (distance_meters >= max_distance) - gradient_data[far_unknown_mask] = -1 + # Preserve unknown cells as unknown (don't apply gradient to them) + gradient_data[unknown_mask] = CostValues.UNKNOWN # Create new OccupancyGrid with gradient gradient_grid = OccupancyGrid( diff --git a/dimos/msgs/nav_msgs/test_OccupancyGrid.py b/dimos/msgs/nav_msgs/test_OccupancyGrid.py index 94467726c7..83277b54bc 100644 --- a/dimos/msgs/nav_msgs/test_OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/test_OccupancyGrid.py @@ -231,14 +231,14 @@ def test_gradient(): grid_with_unknown = OccupancyGrid(data_with_unknown, resolution=0.1) gradient_with_unknown = grid_with_unknown.gradient(max_distance=1.0) # 1m max distance - # Unknown cells close to obstacles should get gradient values - assert gradient_with_unknown.grid[0, 0] != -1 # Should have gradient - assert gradient_with_unknown.grid[1, 1] != -1 # Should have gradient - - # Unknown cells far from obstacles (beyond max_distance) should remain unknown - # The far corner (8,8) is ~0.57m from nearest obstacle, within 1m threshold - # So it will get a gradient value, not remain unknown - assert gradient_with_unknown.unknown_cells < 8 # Some unknowns converted to gradient + # Unknown cells should remain unknown (new behavior - unknowns are preserved) + assert gradient_with_unknown.grid[0, 0] == -1 # Should remain unknown + assert gradient_with_unknown.grid[1, 1] == -1 # Should remain unknown + assert gradient_with_unknown.grid[8, 8] == -1 # Should remain unknown + assert gradient_with_unknown.grid[9, 9] == -1 # Should remain unknown + + # Unknown cells count should be preserved + assert gradient_with_unknown.unknown_cells == 8 # All unknowns preserved def test_filter_above(): diff --git a/dimos/navigation/bt_navigator/__init__.py b/dimos/navigation/bt_navigator/__init__.py new file mode 100644 index 0000000000..cfd252ff6a --- /dev/null +++ b/dimos/navigation/bt_navigator/__init__.py @@ -0,0 +1 @@ +from .navigator import BehaviorTreeNavigator diff --git a/dimos/navigation/bt_navigator/goal_validator.py b/dimos/navigation/bt_navigator/goal_validator.py new file mode 100644 index 0000000000..871c351db0 --- /dev/null +++ b/dimos/navigation/bt_navigator/goal_validator.py @@ -0,0 +1,440 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from typing import Optional, Tuple + +import numpy as np +from dimos.msgs.geometry_msgs import VectorLike, Vector3 +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid + + +def find_safe_goal( + costmap: OccupancyGrid, + goal: VectorLike, + algorithm: str = "bfs", + cost_threshold: int = 50, + min_clearance: float = 0.3, + max_search_distance: float = 5.0, + connectivity_check_radius: int = 3, +) -> Optional[Vector3]: + """ + Find a safe goal position when the original goal is in collision or too close to obstacles. + + Args: + costmap: The occupancy grid/costmap + goal: Original goal position in world coordinates + algorithm: Algorithm to use ("bfs", "spiral", "voronoi", "gradient_descent") + cost_threshold: Maximum acceptable cost for a safe position (default: 50) + min_clearance: Minimum clearance from obstacles in meters (default: 0.3m) + max_search_distance: Maximum distance to search from original goal in meters (default: 5.0m) + connectivity_check_radius: Radius in cells to check for connectivity (default: 3) + + Returns: + Safe goal position in world coordinates, or None if no safe position found + """ + + if algorithm == "bfs": + return _find_safe_goal_bfs( + costmap, + goal, + cost_threshold, + min_clearance, + max_search_distance, + connectivity_check_radius, + ) + elif algorithm == "spiral": + return _find_safe_goal_spiral( + costmap, + goal, + cost_threshold, + min_clearance, + max_search_distance, + connectivity_check_radius, + ) + elif algorithm == "voronoi": + return _find_safe_goal_voronoi( + costmap, goal, cost_threshold, min_clearance, max_search_distance + ) + elif algorithm == "gradient_descent": + return _find_safe_goal_gradient( + costmap, + goal, + cost_threshold, + min_clearance, + max_search_distance, + connectivity_check_radius, + ) + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + + +def _find_safe_goal_bfs( + costmap: OccupancyGrid, + goal: VectorLike, + cost_threshold: int, + min_clearance: float, + max_search_distance: float, + connectivity_check_radius: int, +) -> Optional[Vector3]: + """ + BFS-based search for nearest safe goal position. + This guarantees finding the closest valid position. + + Pros: + - Guarantees finding the closest safe position + - Can check connectivity to avoid isolated spots + - Efficient for small to medium search areas + + Cons: + - Can be slower for large search areas + - Memory usage scales with search area + """ + + # Convert goal to grid coordinates + goal_grid = costmap.world_to_grid(goal) + gx, gy = int(goal_grid.x), int(goal_grid.y) + + # Convert distances to grid cells + clearance_cells = int(np.ceil(min_clearance / costmap.resolution)) + max_search_cells = int(np.ceil(max_search_distance / costmap.resolution)) + + # BFS queue and visited set + queue = deque([(gx, gy, 0)]) + visited = set([(gx, gy)]) + + # 8-connected neighbors + neighbors = [(0, 1), (1, 0), (0, -1), (-1, 0), (1, 1), (1, -1), (-1, 1), (-1, -1)] + + while queue: + x, y, dist = queue.popleft() + + # Check if we've exceeded max search distance + if dist > max_search_cells: + break + + # Check if position is valid + if _is_position_safe( + costmap, x, y, cost_threshold, clearance_cells, connectivity_check_radius + ): + # Convert back to world coordinates + return costmap.grid_to_world((x, y)) + + # Add neighbors to queue + for dx, dy in neighbors: + nx, ny = x + dx, y + dy + + # Check bounds + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + if (nx, ny) not in visited: + visited.add((nx, ny)) + queue.append((nx, ny, dist + 1)) + + return None + + +def _find_safe_goal_spiral( + costmap: OccupancyGrid, + goal: VectorLike, + cost_threshold: int, + min_clearance: float, + max_search_distance: float, + connectivity_check_radius: int, +) -> Optional[Vector3]: + """ + Spiral search pattern from goal outward. + + Pros: + - Simple and predictable pattern + - Memory efficient + - Good for uniformly distributed obstacles + + Cons: + - May not find the absolute closest safe position + - Can miss nearby safe spots due to spiral pattern + """ + + # Convert goal to grid coordinates + goal_grid = costmap.world_to_grid(goal) + cx, cy = int(goal_grid.x), int(goal_grid.y) + + # Convert distances to grid cells + clearance_cells = int(np.ceil(min_clearance / costmap.resolution)) + max_radius = int(np.ceil(max_search_distance / costmap.resolution)) + + # Spiral outward + for radius in range(0, max_radius + 1): + if radius == 0: + # Check center point + if _is_position_safe( + costmap, cx, cy, cost_threshold, clearance_cells, connectivity_check_radius + ): + return costmap.grid_to_world((cx, cy)) + else: + # Check points on the square perimeter at this radius + points = [] + + # Top and bottom edges + for x in range(cx - radius, cx + radius + 1): + points.append((x, cy - radius)) # Top + points.append((x, cy + radius)) # Bottom + + # Left and right edges (excluding corners to avoid duplicates) + for y in range(cy - radius + 1, cy + radius): + points.append((cx - radius, y)) # Left + points.append((cx + radius, y)) # Right + + # Check each point + for x, y in points: + if 0 <= x < costmap.width and 0 <= y < costmap.height: + if _is_position_safe( + costmap, x, y, cost_threshold, clearance_cells, connectivity_check_radius + ): + return costmap.grid_to_world((x, y)) + + return None + + +def _find_safe_goal_voronoi( + costmap: OccupancyGrid, + goal: VectorLike, + cost_threshold: int, + min_clearance: float, + max_search_distance: float, +) -> Optional[Vector3]: + """ + Find safe position using Voronoi diagram (ridge points equidistant from obstacles). + + Pros: + - Finds positions maximally far from obstacles + - Good for narrow passages + - Natural safety margin + + Cons: + - More computationally expensive + - May find positions unnecessarily far from obstacles + - Requires scipy for efficient implementation + """ + + from scipy import ndimage + from skimage.morphology import skeletonize + + # Convert goal to grid coordinates + goal_grid = costmap.world_to_grid(goal) + gx, gy = int(goal_grid.x), int(goal_grid.y) + + # Create binary obstacle map + obstacle_map = costmap.grid >= cost_threshold + free_map = (costmap.grid < cost_threshold) & (costmap.grid != CostValues.UNKNOWN) + + # Compute distance transform + distance_field = ndimage.distance_transform_edt(free_map) + + # Find skeleton/medial axis (approximation of Voronoi diagram) + skeleton = skeletonize(free_map) + + # Filter skeleton points by minimum clearance + clearance_cells = int(np.ceil(min_clearance / costmap.resolution)) + valid_skeleton = skeleton & (distance_field >= clearance_cells) + + if not np.any(valid_skeleton): + # Fall back to BFS if no valid skeleton points + return _find_safe_goal_bfs( + costmap, goal, cost_threshold, min_clearance, max_search_distance, 3 + ) + + # Find nearest valid skeleton point to goal + skeleton_points = np.argwhere(valid_skeleton) + if len(skeleton_points) == 0: + return None + + # Calculate distances from goal to all skeleton points + distances = np.sqrt((skeleton_points[:, 1] - gx) ** 2 + (skeleton_points[:, 0] - gy) ** 2) + + # Filter by max search distance + max_search_cells = max_search_distance / costmap.resolution + valid_indices = distances <= max_search_cells + + if not np.any(valid_indices): + return None + + # Find closest valid point + valid_distances = distances[valid_indices] + valid_points = skeleton_points[valid_indices] + closest_idx = np.argmin(valid_distances) + best_y, best_x = valid_points[closest_idx] + + return costmap.grid_to_world((best_x, best_y)) + + +def _find_safe_goal_gradient( + costmap: OccupancyGrid, + goal: VectorLike, + cost_threshold: int, + min_clearance: float, + max_search_distance: float, + connectivity_check_radius: int, +) -> Optional[Vector3]: + """ + Use gradient descent on the costmap to find a safe position. + + Pros: + - Naturally flows away from obstacles + - Works well with gradient costmaps + - Can handle complex cost distributions + + Cons: + - Can get stuck in local minima + - Requires a gradient costmap + - May not find globally optimal position + """ + + # Convert goal to grid coordinates + goal_grid = costmap.world_to_grid(goal) + x, y = goal_grid.x, goal_grid.y + + # Convert distances to grid cells + clearance_cells = int(np.ceil(min_clearance / costmap.resolution)) + max_search_cells = int(np.ceil(max_search_distance / costmap.resolution)) + + # Create gradient if needed (assuming costmap might already be a gradient) + if np.all((costmap.grid == 0) | (costmap.grid == 100) | (costmap.grid == -1)): + # Binary map, create gradient + gradient_map = costmap.gradient( + obstacle_threshold=cost_threshold, max_distance=min_clearance * 2 + ) + grid = gradient_map.grid + else: + grid = costmap.grid + + # Gradient descent with momentum + momentum = 0.9 + learning_rate = 1.0 + vx, vy = 0.0, 0.0 + + best_x, best_y = None, None + best_cost = float("inf") + + for iteration in range(100): # Max iterations + ix, iy = int(x), int(y) + + # Check if current position is valid + if 0 <= ix < costmap.width and 0 <= iy < costmap.height: + current_cost = grid[iy, ix] + + # Check distance from original goal + dist = np.sqrt((x - goal_grid.x) ** 2 + (y - goal_grid.y) ** 2) + if dist > max_search_cells: + break + + # Check if position is safe + if _is_position_safe( + costmap, ix, iy, cost_threshold, clearance_cells, connectivity_check_radius + ): + if current_cost < best_cost: + best_x, best_y = ix, iy + best_cost = current_cost + + # If cost is very low, we found a good spot + if current_cost < 10: + break + + # Compute gradient using finite differences + gx, gy = 0.0, 0.0 + + if 0 < ix < costmap.width - 1: + gx = (grid[iy, min(ix + 1, costmap.width - 1)] - grid[iy, max(ix - 1, 0)]) / 2.0 + + if 0 < iy < costmap.height - 1: + gy = (grid[min(iy + 1, costmap.height - 1), ix] - grid[max(iy - 1, 0), ix]) / 2.0 + + # Update with momentum + vx = momentum * vx - learning_rate * gx + vy = momentum * vy - learning_rate * gy + + # Update position + x += vx + y += vy + + # Add small random noise to escape local minima + if iteration % 20 == 0: + x += np.random.randn() * 0.5 + y += np.random.randn() * 0.5 + + if best_x is not None and best_y is not None: + return costmap.grid_to_world((best_x, best_y)) + + return None + + +def _is_position_safe( + costmap: OccupancyGrid, + x: int, + y: int, + cost_threshold: int, + clearance_cells: int, + connectivity_check_radius: int, +) -> bool: + """ + Check if a position is safe based on multiple criteria. + + Args: + costmap: The occupancy grid + x, y: Grid coordinates to check + cost_threshold: Maximum acceptable cost + clearance_cells: Minimum clearance in cells + connectivity_check_radius: Radius to check for connectivity + + Returns: + True if position is safe, False otherwise + """ + + # Check if position itself is free + if costmap.grid[y, x] >= cost_threshold or costmap.grid[y, x] == CostValues.UNKNOWN: + return False + + # Check clearance around position + for dy in range(-clearance_cells, clearance_cells + 1): + for dx in range(-clearance_cells, clearance_cells + 1): + nx, ny = x + dx, y + dy + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + # Check if within circular clearance + if dx * dx + dy * dy <= clearance_cells * clearance_cells: + if costmap.grid[ny, nx] >= cost_threshold: + return False + + # Check connectivity (not surrounded by obstacles) + # Count free neighbors in a larger radius + free_count = 0 + total_count = 0 + + for dy in range(-connectivity_check_radius, connectivity_check_radius + 1): + for dx in range(-connectivity_check_radius, connectivity_check_radius + 1): + if dx == 0 and dy == 0: + continue + + nx, ny = x + dx, y + dy + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + total_count += 1 + if ( + costmap.grid[ny, nx] < cost_threshold + and costmap.grid[ny, nx] != CostValues.UNKNOWN + ): + free_count += 1 + + # Require at least 50% of neighbors to be free (not surrounded) + if total_count > 0 and free_count < total_count * 0.5: + return False + + return True diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py new file mode 100644 index 0000000000..3ca4587cb8 --- /dev/null +++ b/dimos/navigation/bt_navigator/navigator.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Navigator module for coordinating global and local planning. +""" + +import threading +import time +from enum import Enum +from typing import Optional + +from dimos.core import Module, In, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.navigation.local_planner.local_planner import BaseLocalPlanner +from dimos.navigation.bt_navigator.goal_validator import find_safe_goal +from dimos.protocol.tf import TF +from dimos.utils.logging_config import setup_logger +from dimos_lcm.std_msgs import Bool +from dimos.utils.transform_utils import apply_transform + +logger = setup_logger("dimos.navigation.bt_navigator") + + +class NavigatorState(Enum): + """Navigator state machine states.""" + + IDLE = "idle" + FOLLOWING_PATH = "following_path" + RECOVERY = "recovery" + + +class BehaviorTreeNavigator(Module): + """ + Navigator module for coordinating navigation tasks. + + Manages the state machine for navigation, coordinates between global + and local planners, and monitors goal completion. + + Inputs: + - odom: Current robot odometry + + Outputs: + - goal: Goal pose for global planner + """ + + # LCM inputs + odom: In[PoseStamped] = None + goal_request: In[PoseStamped] = None # Input for receiving goal requests + global_costmap: In[OccupancyGrid] = None + + # LCM outputs + goal: Out[PoseStamped] = None + goal_reached: Out[Bool] = None + + def __init__( + self, + local_planner: BaseLocalPlanner, + publishing_frequency: float = 1.0, + **kwargs, + ): + """Initialize the Navigator. + + Args: + publishing_frequency: Frequency to publish goals to global planner (Hz) + """ + super().__init__(**kwargs) + + # Parameters + self.publishing_frequency = publishing_frequency + self.publishing_period = 1.0 / publishing_frequency + + # State machine + self.state = NavigatorState.IDLE + self.state_lock = threading.Lock() + + # Current goal + self.current_goal: Optional[PoseStamped] = None + self.goal_lock = threading.Lock() + + # Goal reached state + self._goal_reached = False + self._goal_reached_lock = threading.Lock() + + # Latest data + self.latest_odom: Optional[PoseStamped] = None + self.latest_costmap: Optional[OccupancyGrid] = None + + # Control thread + self.control_thread: Optional[threading.Thread] = None + self.stop_event = threading.Event() + + self.local_planner = local_planner + # TF listener + self.tf = TF() + + logger.info("Navigator initialized") + + @rpc + def start(self): + """Start the navigator module.""" + # Subscribe to inputs + self.odom.subscribe(self._on_odom) + self.goal_request.subscribe(self._on_goal_request) + self.global_costmap.subscribe(self._on_costmap) + + # Start control thread + self.stop_event.clear() + self.control_thread = threading.Thread(target=self._control_loop, daemon=True) + self.control_thread.start() + + logger.info("Navigator started") + + @rpc + def cancel_goal(self) -> bool: + """ + Cancel the current navigation goal. + + Returns: + True if goal was cancelled, False if no goal was active + """ + self.stop() + return True + + @rpc + def cleanup(self): + """Clean up resources including stopping the control thread.""" + # First stop navigation + self.stop() + + # Then clean up the control thread + self.stop_event.set() + if self.control_thread and self.control_thread.is_alive(): + self.control_thread.join(timeout=2.0) + + logger.info("Navigator cleanup complete") + + @rpc + def set_goal(self, goal: PoseStamped, blocking: bool = False) -> bool: + """ + Set a new navigation goal. + + Args: + goal: Target pose to navigate to + + Returns: + non-blocking: True if goal was accepted, False otherwise + blocking: True if goal was reached, False otherwise + """ + transformed_goal = self._transform_goal_to_odom_frame(goal) + if not transformed_goal: + logger.error("Failed to transform goal to odometry frame") + return False + + with self.goal_lock: + self.current_goal = transformed_goal + + with self._goal_reached_lock: + self._goal_reached = False + + with self.state_lock: + self.state = NavigatorState.FOLLOWING_PATH + + if blocking: + while not self.is_goal_reached(): + if self.state == NavigatorState.IDLE: + logger.info("Navigation was cancelled") + return False + + time.sleep(self.publishing_period) + + return True + + @rpc + def get_state(self) -> NavigatorState: + """Get the current state of the navigator.""" + return self.state + + def _on_odom(self, msg: PoseStamped): + """Handle incoming odometry messages.""" + self.latest_odom = msg + + def _on_goal_request(self, msg: PoseStamped): + """Handle incoming goal requests.""" + self.set_goal(msg) + + def _on_costmap(self, msg: OccupancyGrid): + """Handle incoming costmap messages.""" + self.latest_costmap = msg + + def _transform_goal_to_odom_frame(self, goal: PoseStamped) -> Optional[PoseStamped]: + """Transform goal pose to the odometry frame.""" + if not goal.frame_id: + return goal + + odom_frame = self.latest_odom.frame_id + if goal.frame_id == odom_frame: + return goal + + try: + transform = self.tf.get( + parent_frame=odom_frame, + child_frame=goal.frame_id, + time_point=goal.ts, + time_tolerance=1.0, + ) + + if not transform: + logger.error(f"Could not find transform from '{goal.frame_id}' to '{odom_frame}'") + return None + + pose = apply_transform(goal, transform) + transformed_goal = PoseStamped( + position=pose.position, + orientation=pose.orientation, + frame_id=odom_frame, + ts=goal.ts, + ) + return transformed_goal + + except Exception as e: + logger.error(f"Failed to transform goal: {e}") + return None + + def _control_loop(self): + """Main control loop running in separate thread.""" + while not self.stop_event.is_set(): + with self.state_lock: + current_state = self.state + + if current_state == NavigatorState.FOLLOWING_PATH: + with self.goal_lock: + goal = self.current_goal + + if goal is not None and self.latest_costmap is not None: + # Find safe goal position + safe_goal_pos = find_safe_goal( + self.latest_costmap, + goal.position, + algorithm="bfs", + cost_threshold=80, + min_clearance=0.1, + max_search_distance=5.0, + ) + + # Create new goal with safe position + if safe_goal_pos: + safe_goal = PoseStamped( + position=safe_goal_pos, + orientation=goal.orientation, + frame_id=goal.frame_id, + ts=goal.ts, + ) + self.goal.publish(safe_goal) + else: + self.cancel_goal() + + if self.local_planner.is_goal_reached(): + with self._goal_reached_lock: + self._goal_reached = True + logger.info("Goal reached!") + reached_msg = Bool() + reached_msg.data = True + self.goal_reached.publish(reached_msg) + self.local_planner.reset() + with self.goal_lock: + self.current_goal = None + with self.state_lock: + self.state = NavigatorState.IDLE + + elif current_state == NavigatorState.RECOVERY: + with self.state_lock: + self.state = NavigatorState.IDLE + + time.sleep(self.publishing_period) + + @rpc + def is_goal_reached(self) -> bool: + """Check if the current goal has been reached.""" + with self._goal_reached_lock: + return self._goal_reached + + def stop(self): + """Stop navigation and return to IDLE state.""" + with self.goal_lock: + self.current_goal = None + + with self._goal_reached_lock: + self._goal_reached = False + + with self.state_lock: + self.state = NavigatorState.IDLE + + self.local_planner.reset() + + logger.info("Navigator stopped") diff --git a/dimos/navigation/frontier_exploration/__init__.py b/dimos/navigation/frontier_exploration/__init__.py new file mode 100644 index 0000000000..388a5bfe6f --- /dev/null +++ b/dimos/navigation/frontier_exploration/__init__.py @@ -0,0 +1 @@ +from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer diff --git a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py similarity index 52% rename from dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py rename to dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index cd344dd0b4..14d792ca65 100644 --- a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -12,65 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +import threading +import time +from typing import List, Optional, Tuple +from unittest.mock import MagicMock import numpy as np import pytest from PIL import Image, ImageDraw from reactivex import operators as ops -from dimos.robot.frontier_exploration.utils import costmap_to_pil_image -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( +from dimos import core +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.msgs.nav_msgs import OccupancyGrid, CostValues +from dimos.navigation.frontier_exploration.utils import costmap_to_pil_image +from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.msgs.geometry_msgs import Vector3 as Vector -from dimos.utils.testing import SensorReplay -def get_office_lidar_costmap(take_frames: int = 1, voxel_size: float = 0.5) -> tuple: - """ - Get a costmap from office_lidar data using SensorReplay. +def create_test_costmap(width=100, height=100, resolution=0.1): + """Create a simple test costmap with free, occupied, and unknown regions.""" + grid = np.full((height, width), CostValues.UNKNOWN, dtype=np.int8) - Args: - take_frames: Number of lidar frames to take (default 1) - voxel_size: Voxel size for map construction + # Create a larger free space region with more complex shape + # Central room + grid[40:60, 40:60] = CostValues.FREE - Returns: - Tuple of (costmap, first_lidar_message) for testing - """ - # Load office lidar data using SensorReplay as documented - lidar_stream = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + # Corridors extending from central room + grid[45:55, 20:40] = CostValues.FREE # Left corridor + grid[45:55, 60:80] = CostValues.FREE # Right corridor + grid[20:40, 45:55] = CostValues.FREE # Top corridor + grid[60:80, 45:55] = CostValues.FREE # Bottom corridor - # Create map with specified voxel size - map_obj = Map(voxel_size=voxel_size) + # Add some obstacles + grid[48:52, 48:52] = CostValues.OCCUPIED # Central obstacle + grid[35:38, 45:55] = CostValues.OCCUPIED # Top corridor obstacle + grid[62:65, 45:55] = CostValues.OCCUPIED # Bottom corridor obstacle - # Take only the specified number of frames and build map - limited_stream = lidar_stream.stream().pipe(ops.take(take_frames)) + # Create origin at bottom-left + from dimos.msgs.geometry_msgs import Pose - # Store the first lidar message for reference - first_lidar = None + origin = Pose() + origin.position.x = -5.0 # Center the map + origin.position.y = -5.0 + origin.position.z = 0.0 + origin.orientation.w = 1.0 - def capture_first_and_add(lidar_msg): - nonlocal first_lidar - if first_lidar is None: - first_lidar = lidar_msg - return map_obj.add_frame(lidar_msg) - - # Process the stream - limited_stream.pipe(ops.map(capture_first_and_add)).run() + occupancy_grid = OccupancyGrid( + grid=grid, resolution=resolution, origin=origin, frame_id="map", ts=time.time() + ) - # Get the resulting costmap - costmap = map_obj.costmap() + # Create a mock lidar message with origin + class MockLidar: + def __init__(self): + self.origin = Vector3(0.0, 0.0, 0.0) - return costmap, first_lidar + return occupancy_grid, MockLidar() def test_frontier_detection_with_office_lidar(): - """Test frontier detection using a single frame from office_lidar data.""" - # Get costmap from office lidar data - costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + """Test frontier detection using a test costmap.""" + # Get test costmap + costmap, first_lidar = create_test_costmap() # Verify we have a valid costmap assert costmap is not None, "Costmap should not be None" @@ -103,7 +107,7 @@ def test_frontier_detection_with_office_lidar(): # Verify frontiers are Vector objects with valid coordinates for i, frontier in enumerate(frontiers[:5]): # Check first 5 - assert isinstance(frontier, Vector), f"Frontier {i} should be a Vector" + assert isinstance(frontier, Vector3), f"Frontier {i} should be a Vector3" assert hasattr(frontier, "x") and hasattr(frontier, "y"), ( f"Frontier {i} should have x,y coordinates" ) @@ -114,8 +118,8 @@ def test_frontier_detection_with_office_lidar(): def test_exploration_goal_selection(): """Test the complete exploration goal selection pipeline.""" - # Get costmap from office lidar data - costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + # Get test costmap + costmap, first_lidar = create_test_costmap() # Initialize frontier explorer with default parameters explorer = WavefrontFrontierExplorer() @@ -127,21 +131,30 @@ def test_exploration_goal_selection(): goal = explorer.get_exploration_goal(robot_pose, costmap) if goal is not None: - assert isinstance(goal, Vector), "Goal should be a Vector" + assert isinstance(goal, Vector3), "Goal should be a Vector3" print(f"Selected exploration goal: ({goal.x:.2f}, {goal.y:.2f})") # Test that goal gets marked as explored assert len(explorer.explored_goals) == 1, "Goal should be marked as explored" assert explorer.explored_goals[0] == goal, "Explored goal should match selected goal" + # Test that goal is within costmap bounds + grid_pos = costmap.world_to_grid(goal) + assert 0 <= grid_pos.x < costmap.width, "Goal x should be within costmap bounds" + assert 0 <= grid_pos.y < costmap.height, "Goal y should be within costmap bounds" + + # Test that goal is at a reasonable distance from robot + distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) + assert 0.1 < distance < 20.0, f"Goal distance {distance:.2f}m should be reasonable" + else: print("No exploration goal selected - map may be fully explored") def test_exploration_session_reset(): """Test exploration session reset functionality.""" - # Get costmap - costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + # Get test costmap + costmap, first_lidar = create_test_costmap() # Initialize explorer and select a goal explorer = WavefrontFrontierExplorer() @@ -163,16 +176,100 @@ def test_exploration_session_reset(): "Exploration direction should be reset" ) assert explorer.last_costmap is None, "Last costmap should be cleared" - assert explorer.num_no_gain_attempts == 0, "No-gain attempts should be reset" + assert explorer.no_gain_counter == 0, "No-gain counter should be reset" print("Exploration session reset successfully") +def test_frontier_ranking(): + """Test frontier ranking and scoring logic.""" + # Get test costmap + costmap, first_lidar = create_test_costmap() + + # Initialize explorer with custom parameters + explorer = WavefrontFrontierExplorer( + min_frontier_size=5, min_distance_from_obstacles=0.5, info_gain_threshold=0.02 + ) + + robot_pose = first_lidar.origin + + # Get first set of frontiers + frontiers1 = explorer.detect_frontiers(robot_pose, costmap) + goal1 = explorer.get_exploration_goal(robot_pose, costmap) + + if goal1: + # Verify the selected goal is the first in the ranked list + assert frontiers1[0].x == goal1.x and frontiers1[0].y == goal1.y, ( + "Selected goal should be the highest ranked frontier" + ) + + # Test that goals are being marked as explored + assert len(explorer.explored_goals) == 1, "Goal should be marked as explored" + assert ( + explorer.explored_goals[0].x == goal1.x and explorer.explored_goals[0].y == goal1.y + ), "Explored goal should match selected goal" + + # Get another goal + goal2 = explorer.get_exploration_goal(robot_pose, costmap) + if goal2: + assert len(explorer.explored_goals) == 2, ( + "Second goal should also be marked as explored" + ) + + # Test distance to obstacles + obstacle_dist = explorer._compute_distance_to_obstacles(goal1, costmap) + assert obstacle_dist >= explorer.min_distance_from_obstacles, ( + f"Goal should be at least {explorer.min_distance_from_obstacles}m from obstacles" + ) + + print(f"Frontier ranking test passed - selected goal at ({goal1.x:.2f}, {goal1.y:.2f})") + print(f"Distance to obstacles: {obstacle_dist:.2f}m") + print(f"Total frontiers detected: {len(frontiers1)}") + else: + print("No frontiers found for ranking test") + + +def test_exploration_with_no_gain_detection(): + """Test information gain detection and exploration termination.""" + # Get initial costmap + costmap1, first_lidar = create_test_costmap() + + # Initialize explorer with low no-gain threshold for testing + explorer = WavefrontFrontierExplorer(info_gain_threshold=0.01, num_no_gain_attempts=2) + + robot_pose = first_lidar.origin + + # Select multiple goals to populate history + for i in range(6): + goal = explorer.get_exploration_goal(robot_pose, costmap1) + if goal: + print(f"Goal {i + 1}: ({goal.x:.2f}, {goal.y:.2f})") + + # Now use same costmap repeatedly to trigger no-gain detection + initial_counter = explorer.no_gain_counter + + # This should increment no-gain counter + goal = explorer.get_exploration_goal(robot_pose, costmap1) + assert explorer.no_gain_counter > initial_counter, "No-gain counter should increment" + + # Continue until exploration stops + for _ in range(3): + goal = explorer.get_exploration_goal(robot_pose, costmap1) + if goal is None: + break + + # Should have stopped due to no information gain + assert goal is None, "Exploration should stop after no-gain threshold" + assert explorer.no_gain_counter == 0, "Counter should reset after stopping" + + print("No-gain detection test passed") + + @pytest.mark.vis def test_frontier_detection_visualization(): """Test frontier detection with visualization (marked with @pytest.mark.vis).""" - # Get costmap from office lidar data - costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.2) + # Get test costmap + costmap, first_lidar = create_test_costmap() # Initialize frontier explorer with default parameters explorer = WavefrontFrontierExplorer() @@ -195,7 +292,7 @@ def test_frontier_detection_visualization(): base_image = costmap_to_pil_image(costmap, image_scale_factor) # Helper function to convert world coordinates to image coordinates - def world_to_image_coords(world_pos: Vector) -> tuple[int, int]: + def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]: grid_pos = costmap.world_to_grid(world_pos) img_x = int(grid_pos.x * image_scale_factor) img_y = int((costmap.height - grid_pos.y) * image_scale_factor) # Flip Y @@ -250,41 +347,3 @@ def world_to_image_coords(world_pos: Vector) -> tuple[int, int]: base_image.show(title="Frontier Detection - Office Lidar") print("Visualization displayed. Close the image window to continue.") - - -def test_multi_frame_exploration(): - """Tool test for multi-frame exploration analysis.""" - print("=== Multi-Frame Exploration Analysis ===") - - # Test with different numbers of frames - frame_counts = [1, 3, 5] - - for frame_count in frame_counts: - print(f"\n--- Testing with {frame_count} lidar frame(s) ---") - - # Get costmap with multiple frames - costmap, first_lidar = get_office_lidar_costmap(take_frames=frame_count, voxel_size=0.3) - - print( - f"Costmap: {costmap.width}x{costmap.height}, " - f"unknown: {costmap.unknown_percent:.1f}%, " - f"free: {costmap.free_percent:.1f}%, " - f"occupied: {costmap.occupied_percent:.1f}%" - ) - - # Initialize explorer with default parameters - explorer = WavefrontFrontierExplorer() - - # Detect frontiers - robot_pose = first_lidar.origin - frontiers = explorer.detect_frontiers(robot_pose, costmap) - - print(f"Detected {len(frontiers)} frontiers") - - # Get exploration goal - goal = explorer.get_exploration_goal(robot_pose, costmap) - if goal: - distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) - print(f"Selected goal at distance {distance:.2f}m") - else: - print("No exploration goal selected") diff --git a/dimos/robot/frontier_exploration/utils.py b/dimos/navigation/frontier_exploration/utils.py similarity index 69% rename from dimos/robot/frontier_exploration/utils.py rename to dimos/navigation/frontier_exploration/utils.py index 746f72e2f5..680af142fb 100644 --- a/dimos/robot/frontier_exploration/utils.py +++ b/dimos/navigation/frontier_exploration/utils.py @@ -19,14 +19,14 @@ import numpy as np from PIL import Image, ImageDraw from typing import List, Tuple -from dimos.types.costmap import Costmap, CostValues -from dimos.types.vector import Vector +from dimos.msgs.nav_msgs import OccupancyGrid, CostValues +from dimos.msgs.geometry_msgs import Vector3 import os import pickle import cv2 -def costmap_to_pil_image(costmap: Costmap, scale_factor: int = 2) -> Image.Image: +def costmap_to_pil_image(costmap: OccupancyGrid, scale_factor: int = 2) -> Image.Image: """ Convert costmap to PIL Image with ROS-style coloring and optional scaling. @@ -69,10 +69,10 @@ def costmap_to_pil_image(costmap: Costmap, scale_factor: int = 2) -> Image.Image def draw_frontiers_on_image( image: Image.Image, - costmap: Costmap, - frontiers: List[Vector], + costmap: OccupancyGrid, + frontiers: List[Vector3], scale_factor: int = 2, - unfiltered_frontiers: List[Vector] = None, + unfiltered_frontiers: List[Vector3] = None, ) -> Image.Image: """ Draw frontier points on the costmap image. @@ -90,7 +90,7 @@ def draw_frontiers_on_image( img_copy = image.copy() draw = ImageDraw.Draw(img_copy) - def world_to_image_coords(world_pos: Vector) -> Tuple[int, int]: + def world_to_image_coords(world_pos: Vector3) -> Tuple[int, int]: """Convert world coordinates to image pixel coordinates.""" grid_pos = costmap.world_to_grid(world_pos) # Flip Y coordinate and apply scaling @@ -139,50 +139,3 @@ def world_to_image_coords(world_pos: Vector) -> Tuple[int, int]: draw.text((x + radius + 2, y - radius), "BEST", fill=(255, 0, 0)) return img_copy - - -def smooth_costmap_for_frontiers( - costmap: Costmap, -) -> Costmap: - """ - Smooth a costmap using morphological operations for frontier exploration. - - This function applies OpenCV morphological operations to smooth free space - areas and improve connectivity for better frontier detection. It's designed - specifically for frontier exploration. - - Args: - costmap: Input Costmap object - - Returns: - Smoothed Costmap object with enhanced free space connectivity - """ - # Extract grid data and metadata from costmap - grid = costmap.grid - resolution = costmap.resolution - - # Work with a copy to avoid modifying input - filtered_grid = grid.copy() - - # 1. Create binary mask for free space - free_mask = (grid == CostValues.FREE).astype(np.uint8) * 255 - - # 2. Apply morphological operations for smoothing - kernel_size = 7 - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) - - # Dilate free space to connect nearby areas - dilated = cv2.dilate(free_mask, kernel, iterations=1) - - # Morphological closing to fill small gaps - closed = cv2.morphologyEx(dilated, cv2.MORPH_CLOSE, kernel, iterations=1) - - eroded = cv2.erode(closed, kernel, iterations=1) - - # Apply the smoothed free space back to costmap - # Only change unknown areas to free, don't override obstacles - smoothed_free = eroded == 255 - unknown_mask = grid == CostValues.UNKNOWN - filtered_grid[smoothed_free & unknown_mask] = CostValues.FREE - - return Costmap(grid=filtered_grid, origin=costmap.origin, resolution=resolution) diff --git a/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py similarity index 66% rename from dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py rename to dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 454a70e803..dd26f6f79c 100644 --- a/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -23,14 +23,15 @@ from collections import deque from dataclasses import dataclass from enum import IntFlag -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import numpy as np -from dimos.msgs.geometry_msgs import Vector3 as Vector -from dimos.robot.frontier_exploration.utils import smooth_costmap_for_frontiers -from dimos.types.costmap import Costmap, CostValues +from dimos.core import Module, In, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.msgs.nav_msgs import OccupancyGrid, CostValues from dimos.utils.logging_config import setup_logger +from dimos_lcm.std_msgs import Bool logger = setup_logger("dimos.robot.unitree.frontier_exploration") @@ -72,25 +73,38 @@ def clear(self): self.points.clear() -class WavefrontFrontierExplorer: +class WavefrontFrontierExplorer(Module): """ Wavefront frontier exploration algorithm implementation. This class encapsulates the frontier detection and exploration goal selection functionality using the wavefront algorithm with BFS exploration. + + Inputs: + - costmap: Current costmap for frontier detection + - odometry: Current robot pose + + Outputs: + - goal_request: Exploration goals sent to the navigator """ + # LCM inputs + costmap: In[OccupancyGrid] = None + odometry: In[PoseStamped] = None + goal_reached: In[Bool] = None + + # LCM outputs + goal_request: Out[PoseStamped] = None + def __init__( self, - min_frontier_size: int = 8, - occupancy_threshold: int = 65, - subsample_resolution: int = 3, - min_distance_from_obstacles: float = 0.6, + min_frontier_size: int = 5, + occupancy_threshold: int = 99, + min_distance_from_obstacles: float = 0.2, info_gain_threshold: float = 0.03, num_no_gain_attempts: int = 4, - set_goal: Optional[Callable] = None, - get_costmap: Optional[Callable] = None, - get_robot_pos: Optional[Callable] = None, + goal_timeout: float = 30.0, + **kwargs, ): """ Initialize the frontier explorer. @@ -98,29 +112,70 @@ def __init__( Args: min_frontier_size: Minimum number of points to consider a valid frontier occupancy_threshold: Cost threshold above which a cell is considered occupied (0-255) - subsample_resolution: Factor by which to subsample the costmap for faster processing (1=no subsampling, 2=half resolution, 4=quarter resolution) min_distance_from_obstacles: Minimum distance frontier must be from obstacles (meters) info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) num_no_gain_attempts: Maximum number of consecutive attempts with no information gain - set_goal: Callable to set navigation goal, signature: (goal: Vector, stop_event: Optional[threading.Event]) -> bool - get_costmap: Callable to get current costmap, signature: () -> Costmap - get_robot_pos: Callable to get current robot position, signature: () -> Vector """ + super().__init__(**kwargs) self.min_frontier_size = min_frontier_size self.occupancy_threshold = occupancy_threshold - self.subsample_resolution = subsample_resolution self.min_distance_from_obstacles = min_distance_from_obstacles self.info_gain_threshold = info_gain_threshold self.num_no_gain_attempts = num_no_gain_attempts - self.set_goal = set_goal - self.get_costmap = get_costmap - self.get_robot_pos = get_robot_pos self._cache = FrontierCache() self.explored_goals = [] # list of explored goals - self.exploration_direction = Vector([0.0, 0.0]) # current exploration direction + self.exploration_direction = Vector3(0.0, 0.0, 0.0) # current exploration direction self.last_costmap = None # store last costmap for information comparison + self.no_gain_counter = 0 # track consecutive no-gain attempts + self.goal_timeout = goal_timeout + + # Latest data + self.latest_costmap: Optional[OccupancyGrid] = None + self.latest_odometry: Optional[PoseStamped] = None + + # Goal reached event + self.goal_reached_event = threading.Event() + + # Exploration state + self.exploration_active = False + self.exploration_thread: Optional[threading.Thread] = None + self.stop_event = threading.Event() + + logger.info("WavefrontFrontierExplorer module initialized") + + @rpc + def start(self): + """Start the frontier exploration module.""" + # Subscribe to inputs + self.costmap.subscribe(self._on_costmap) + self.odometry.subscribe(self._on_odometry) + + # Subscribe to goal_reached if available + if self.goal_reached.transport is not None: + self.goal_reached.subscribe(self._on_goal_reached) + + logger.info("WavefrontFrontierExplorer started") - def _count_costmap_information(self, costmap: Costmap) -> int: + @rpc + def cleanup(self): + """Clean up resources.""" + self.stop_exploration() + logger.info("WavefrontFrontierExplorer cleanup complete") + + def _on_costmap(self, msg: OccupancyGrid): + """Handle incoming costmap messages.""" + self.latest_costmap = msg + + def _on_odometry(self, msg: PoseStamped): + """Handle incoming odometry messages.""" + self.latest_odometry = msg + + def _on_goal_reached(self, msg: Bool): + """Handle goal reached messages.""" + if msg.data: + self.goal_reached_event.set() + + def _count_costmap_information(self, costmap: OccupancyGrid) -> int: """ Count the amount of information in a costmap (free space + obstacles). @@ -134,7 +189,7 @@ def _count_costmap_information(self, costmap: Costmap) -> int: obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) return int(free_count + obstacle_count) - def _get_neighbors(self, point: GridPoint, costmap: Costmap) -> List[GridPoint]: + def _get_neighbors(self, point: GridPoint, costmap: OccupancyGrid) -> List[GridPoint]: """Get valid neighboring points for a given grid point.""" neighbors = [] @@ -152,26 +207,24 @@ def _get_neighbors(self, point: GridPoint, costmap: Costmap) -> List[GridPoint]: return neighbors - def _is_frontier_point(self, point: GridPoint, costmap: Costmap) -> bool: + def _is_frontier_point(self, point: GridPoint, costmap: OccupancyGrid) -> bool: """ Check if a point is a frontier point. A frontier point is an unknown cell adjacent to at least one free cell and not adjacent to any occupied cells. """ # Point must be unknown - world_pos = costmap.grid_to_world(Vector([float(point.x), float(point.y)])) - cost = costmap.get_value(world_pos) + cost = costmap.grid[point.y, point.x] if cost != CostValues.UNKNOWN: return False has_free = False for neighbor in self._get_neighbors(point, costmap): - neighbor_world = costmap.grid_to_world(Vector([float(neighbor.x), float(neighbor.y)])) - neighbor_cost = costmap.get_value(neighbor_world) + neighbor_cost = costmap.grid[neighbor.y, neighbor.x] # If adjacent to occupied space, not a frontier - if neighbor_cost and neighbor_cost > self.occupancy_threshold: + if neighbor_cost > self.occupancy_threshold: return False # Check if adjacent to free space @@ -180,7 +233,9 @@ def _is_frontier_point(self, point: GridPoint, costmap: Costmap) -> bool: return has_free - def _find_free_space(self, start_x: int, start_y: int, costmap: Costmap) -> Tuple[int, int]: + def _find_free_space( + self, start_x: int, start_y: int, costmap: OccupancyGrid + ) -> Tuple[int, int]: """ Find the nearest free space point using BFS from the starting position. """ @@ -195,8 +250,7 @@ def _find_free_space(self, start_x: int, start_y: int, costmap: Costmap) -> Tupl visited.add((point.x, point.y)) # Check if this point is free space - world_pos = costmap.grid_to_world(Vector([float(point.x), float(point.y)])) - if costmap.get_value(world_pos) == CostValues.FREE: + if costmap.grid[point.y, point.x] == CostValues.FREE: return (point.x, point.y) # Add neighbors to search @@ -207,45 +261,36 @@ def _find_free_space(self, start_x: int, start_y: int, costmap: Costmap) -> Tupl # If no free space found, return original position return (start_x, start_y) - def _compute_centroid(self, frontier_points: List[Vector]) -> Vector: + def _compute_centroid(self, frontier_points: List[Vector3]) -> Vector3: """Compute the centroid of a list of frontier points.""" if not frontier_points: - return Vector([0.0, 0.0]) + return Vector3(0.0, 0.0, 0.0) # Vectorized approach using numpy points_array = np.array([[point.x, point.y] for point in frontier_points]) centroid = np.mean(points_array, axis=0) - return Vector([centroid[0], centroid[1]]) + return Vector3(centroid[0], centroid[1], 0.0) - def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector]: + def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> List[Vector3]: """ Main frontier detection algorithm using wavefront exploration. Args: - robot_pose: Current robot position in world coordinates (Vector with x, y) - costmap: Costmap for additional analysis + robot_pose: Current robot position in world coordinates + costmap: Costmap for frontier detection Returns: List of frontier centroids in world coordinates """ self._cache.clear() - # Apply filtered costmap (now default) - working_costmap = smooth_costmap_for_frontiers(costmap) - - # Subsample the costmap for faster processing - if self.subsample_resolution > 1: - subsampled_costmap = working_costmap.subsample(self.subsample_resolution) - else: - subsampled_costmap = working_costmap - - # Convert robot pose to subsampled grid coordinates - subsampled_grid_pos = subsampled_costmap.world_to_grid(robot_pose) - grid_x, grid_y = int(subsampled_grid_pos.x), int(subsampled_grid_pos.y) + # Convert robot pose to grid coordinates + grid_pos = costmap.world_to_grid(robot_pose) + grid_x, grid_y = int(grid_pos.x), int(grid_pos.y) # Find nearest free space to start exploration - free_x, free_y = self._find_free_space(grid_x, grid_y, subsampled_costmap) + free_x, free_y = self._find_free_space(grid_x, grid_y, costmap) start_point = self._cache.get_point(free_x, free_y) start_point.classification = PointClassification.MapOpen @@ -269,7 +314,7 @@ def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector] current_point.classification |= PointClassification.MapClosed # Check if this point starts a new frontier - if self._is_frontier_point(current_point, subsampled_costmap): + if self._is_frontier_point(current_point, costmap): frontier_candidates += 1 current_point.classification |= PointClassification.FrontierOpen frontier_queue = deque([current_point]) @@ -284,11 +329,11 @@ def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector] continue # If this is still a frontier point, add to current frontier - if self._is_frontier_point(frontier_point, subsampled_costmap): + if self._is_frontier_point(frontier_point, costmap): new_frontier.append(frontier_point) # Add neighbors to frontier queue - for neighbor in self._get_neighbors(frontier_point, subsampled_costmap): + for neighbor in self._get_neighbors(frontier_point, costmap): if not ( neighbor.classification & ( @@ -305,8 +350,8 @@ def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector] if len(new_frontier) >= self.min_frontier_size: world_points = [] for point in new_frontier: - world_pos = subsampled_costmap.grid_to_world( - Vector([float(point.x), float(point.y)]) + world_pos = costmap.grid_to_world( + Vector3(float(point.x), float(point.y), 0.0) ) world_points.append(world_pos) @@ -316,21 +361,16 @@ def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector] frontier_sizes.append(len(new_frontier)) # Store frontier size # Add ALL neighbors to main exploration queue to explore entire free space - for neighbor in self._get_neighbors(current_point, subsampled_costmap): + for neighbor in self._get_neighbors(current_point, costmap): if not ( neighbor.classification & (PointClassification.MapOpen | PointClassification.MapClosed) ): # Check if neighbor is free space or unknown (explorable) - neighbor_world = subsampled_costmap.grid_to_world( - Vector([float(neighbor.x), float(neighbor.y)]) - ) - neighbor_cost = subsampled_costmap.get_value(neighbor_world) + neighbor_cost = costmap.grid[neighbor.y, neighbor.x] # Add free space and unknown space to exploration queue - if neighbor_cost is not None and ( - neighbor_cost == CostValues.FREE or neighbor_cost == CostValues.UNKNOWN - ): + if neighbor_cost == CostValues.FREE or neighbor_cost == CostValues.UNKNOWN: neighbor.classification |= PointClassification.MapOpen map_queue.append(neighbor) @@ -347,32 +387,34 @@ def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector] return ranked_frontiers - def _update_exploration_direction(self, robot_pose: Vector, goal_pose: Optional[Vector] = None): + def _update_exploration_direction( + self, robot_pose: Vector3, goal_pose: Optional[Vector3] = None + ): """Update the current exploration direction based on robot movement or selected goal.""" if goal_pose is not None: # Calculate direction from robot to goal - direction = Vector([goal_pose.x - robot_pose.x, goal_pose.y - robot_pose.y]) + direction = Vector3(goal_pose.x - robot_pose.x, goal_pose.y - robot_pose.y, 0.0) magnitude = np.sqrt(direction.x**2 + direction.y**2) if magnitude > 0.1: # Avoid division by zero for very close goals - self.exploration_direction = Vector( - [direction.x / magnitude, direction.y / magnitude] + self.exploration_direction = Vector3( + direction.x / magnitude, direction.y / magnitude, 0.0 ) - def _compute_direction_momentum_score(self, frontier: Vector, robot_pose: Vector) -> float: + def _compute_direction_momentum_score(self, frontier: Vector3, robot_pose: Vector3) -> float: """Compute direction momentum score for a frontier.""" if self.exploration_direction.x == 0 and self.exploration_direction.y == 0: return 0.0 # No momentum if no previous direction # Calculate direction from robot to frontier - frontier_direction = Vector([frontier.x - robot_pose.x, frontier.y - robot_pose.y]) + frontier_direction = Vector3(frontier.x - robot_pose.x, frontier.y - robot_pose.y, 0.0) magnitude = np.sqrt(frontier_direction.x**2 + frontier_direction.y**2) if magnitude < 0.1: return 0.0 # Too close to calculate meaningful direction # Normalize frontier direction - frontier_direction = Vector( - [frontier_direction.x / magnitude, frontier_direction.y / magnitude] + frontier_direction = Vector3( + frontier_direction.x / magnitude, frontier_direction.y / magnitude, 0.0 ) # Calculate dot product for directional alignment @@ -384,7 +426,7 @@ def _compute_direction_momentum_score(self, frontier: Vector, robot_pose: Vector # Return momentum score (higher for same direction, lower for opposite) return max(0.0, dot_product) # Only positive momentum, no penalty for different directions - def _compute_distance_to_explored_goals(self, frontier: Vector) -> float: + def _compute_distance_to_explored_goals(self, frontier: Vector3) -> float: """Compute distance from frontier to the nearest explored goal.""" if not self.explored_goals: return 5.0 # Default consistent value when no explored goals @@ -396,7 +438,7 @@ def _compute_distance_to_explored_goals(self, frontier: Vector) -> float: return min_distance - def _compute_distance_to_obstacles(self, frontier: Vector, costmap: Costmap) -> float: + def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGrid) -> float: """ Compute the minimum distance from a frontier point to the nearest obstacle. @@ -444,7 +486,7 @@ def _compute_distance_to_obstacles(self, frontier: Vector, costmap: Costmap) -> return min_distance if min_distance != float("inf") else float("inf") def _compute_comprehensive_frontier_score( - self, frontier: Vector, frontier_size: int, robot_pose: Vector, costmap: Costmap + self, frontier: Vector3, frontier_size: int, robot_pose: Vector3, costmap: OccupancyGrid ) -> float: """Compute comprehensive score considering multiple criteria.""" @@ -484,11 +526,11 @@ def _compute_comprehensive_frontier_score( def _rank_frontiers( self, - frontier_centroids: List[Vector], + frontier_centroids: List[Vector3], frontier_sizes: List[int], - robot_pose: Vector, - costmap: Costmap, - ) -> List[Vector]: + robot_pose: Vector3, + costmap: OccupancyGrid, + ) -> List[Vector3]: """ Find the single best frontier using comprehensive scoring and filtering. @@ -530,12 +572,14 @@ def _rank_frontiers( # Extract just the frontiers (remove scores) and return as list return [frontier for frontier, _ in valid_frontiers] - def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional[Vector]: + def get_exploration_goal( + self, robot_pose: Vector3, costmap: OccupancyGrid + ) -> Optional[Vector3]: """ Get the single best exploration goal using comprehensive frontier scoring. Args: - robot_pose: Current robot position in world coordinates (Vector with x, y) + robot_pose: Current robot position in world coordinates costmap: Costmap for additional analysis Returns: @@ -556,15 +600,15 @@ def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional logger.info( f"Current information: {current_info}, Last information: {last_info}" ) - self.num_no_gain_attempts += 1 - if self.num_no_gain_attempts >= self.num_no_gain_attempts: + self.no_gain_counter += 1 + if self.no_gain_counter >= self.num_no_gain_attempts: logger.info( - "No information gain for {} consecutive attempts, skipping frontier selection".format( - self.num_no_gain_attempts - ) + f"No information gain for {self.no_gain_counter} consecutive attempts" ) self.reset_exploration_session() return None + else: + self.no_gain_counter = 0 # Always detect new frontiers to get most up-to-date information # The new algorithm filters out explored areas and returns only the best frontier @@ -593,7 +637,7 @@ def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional self.last_costmap = costmap return None - def mark_explored_goal(self, goal: Vector): + def mark_explored_goal(self, goal: Vector3): """Mark a goal as explored.""" self.explored_goals.append(goal) @@ -605,49 +649,120 @@ def reset_exploration_session(self): needs to forget its previous exploration history. """ self.explored_goals.clear() # Clear all previously explored goals - self.exploration_direction = Vector([0.0, 0.0]) # Reset exploration direction + self.exploration_direction = Vector3(0.0, 0.0, 0.0) # Reset exploration direction self.last_costmap = None # Clear last costmap comparison - self.num_no_gain_attempts = 0 # Reset no-gain attempt counter + self.no_gain_counter = 0 # Reset no-gain attempt counter self._cache.clear() # Clear frontier point cache logger.info("Exploration session reset - all state variables cleared") - def explore(self, stop_event: Optional[threading.Event] = None) -> bool: + @rpc + def explore(self) -> bool: """ - Perform autonomous frontier exploration by continuously finding and navigating to frontiers. + Start autonomous frontier exploration. - Args: - stop_event: Optional threading.Event to signal when exploration should stop + Returns: + bool: True if exploration started, False if already exploring + """ + if self.exploration_active: + logger.warning("Exploration already active") + return False + + self.exploration_active = True + self.stop_event.clear() + + # Start exploration thread + self.exploration_thread = threading.Thread(target=self._exploration_loop, daemon=True) + self.exploration_thread.start() + + logger.info("Started autonomous frontier exploration") + return True + + @rpc + def stop_exploration(self) -> bool: + """ + Stop autonomous frontier exploration. Returns: - bool: True if exploration completed successfully, False if stopped or failed + bool: True if exploration was stopped, False if not exploring """ + if not self.exploration_active: + return False - logger.info("Starting autonomous frontier exploration") + self.exploration_active = False + self.stop_event.set() - while True: - # Check if stop event is set - if stop_event and stop_event.is_set(): - logger.info("Exploration stopped by stop event") - return False + if self.exploration_thread and self.exploration_thread.is_alive(): + self.exploration_thread.join(timeout=2.0) + + logger.info("Stopped autonomous frontier exploration") + return True - # Get fresh robot position and costmap data - robot_pose = self.get_robot_pos() - costmap = self.get_costmap() + def _exploration_loop(self): + """Main exploration loop running in separate thread.""" + # Track number of goals published + goals_published = 0 + consecutive_failures = 0 + max_consecutive_failures = 10 # Allow more attempts before giving up - # Get the next frontier goal - next_goal = self.get_exploration_goal(robot_pose, costmap) - if not next_goal: - logger.info("No more frontiers found, exploration complete") - return True + while self.exploration_active and not self.stop_event.is_set(): + # Check if we have required data + if self.latest_costmap is None or self.latest_odometry is None: + threading.Event().wait(0.5) + continue - # Navigate to the frontier - logger.info(f"Navigating to frontier at {next_goal}") - navigation_successful = self.set_goal( - next_goal, + # Get robot pose from odometry + robot_pose = Vector3( + self.latest_odometry.position.x, self.latest_odometry.position.y, 0.0 ) - if not navigation_successful: - logger.warning("Failed to navigate to frontier, continuing exploration") - # Continue to try other frontiers instead of stopping - continue + # Get exploration goal + goal = self.get_exploration_goal(robot_pose, self.latest_costmap) + + if goal: + # Publish goal to navigator + goal_msg = PoseStamped() + goal_msg.position.x = goal.x + goal_msg.position.y = goal.y + goal_msg.position.z = 0.0 + goal_msg.orientation.w = 1.0 # No rotation + goal_msg.frame_id = "world" + goal_msg.ts = self.latest_costmap.ts + + self.goal_request.publish(goal_msg) + logger.info(f"Published frontier goal: ({goal.x:.2f}, {goal.y:.2f})") + + goals_published += 1 + consecutive_failures = 0 # Reset failure counter on success + + # Clear the goal reached event for next iteration + self.goal_reached_event.clear() + + # Wait for goal to be reached or timeout + logger.info("Waiting for goal to be reached...") + goal_reached = self.goal_reached_event.wait(timeout=self.goal_timeout) + + if goal_reached: + logger.info("Goal reached, finding next frontier") + else: + logger.warning("Goal timeout after 30 seconds, finding next frontier anyway") + else: + consecutive_failures += 1 + + # Only give up if we've published at least 2 goals AND had many consecutive failures + if goals_published >= 2 and consecutive_failures >= max_consecutive_failures: + logger.info( + f"Exploration complete after {goals_published} goals and {consecutive_failures} consecutive failures finding new frontiers" + ) + self.exploration_active = False + break + elif goals_published < 2: + logger.info( + f"No frontier found, but only {goals_published} goals published so far. Retrying in 2 seconds..." + ) + threading.Event().wait(2.0) + else: + logger.info( + f"No frontier found (attempt {consecutive_failures}/{max_consecutive_failures}). Retrying in 2 seconds..." + ) + threading.Event().wait(2.0) diff --git a/dimos/navigation/global_planner/__init__.py b/dimos/navigation/global_planner/__init__.py new file mode 100644 index 0000000000..4b158f73a1 --- /dev/null +++ b/dimos/navigation/global_planner/__init__.py @@ -0,0 +1,2 @@ +from dimos.navigation.global_planner.planner import AstarPlanner, Planner +from dimos.navigation.global_planner.algo import astar diff --git a/dimos/navigation/global_planner/algo.py b/dimos/navigation/global_planner/algo.py new file mode 100644 index 0000000000..08cae6d147 --- /dev/null +++ b/dimos/navigation/global_planner/algo.py @@ -0,0 +1,217 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import heapq +import math +from typing import Optional + +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, VectorLike +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid, Path +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.unitree.global_planner.astar") + + +def astar( + costmap: OccupancyGrid, + goal: VectorLike, + start: VectorLike = (0.0, 0.0), + cost_threshold: int = 90, + unknown_penalty: float = 0.8, +) -> Optional[Path]: + """ + A* path planning algorithm from start to goal position. + + Args: + costmap: Costmap object containing the environment + goal: Goal position as any vector-like object + start: Start position as any vector-like object (default: origin [0,0]) + cost_threshold: Cost threshold above which a cell is considered an obstacle + + Returns: + Path object containing waypoints, or None if no path found + """ + + # Convert world coordinates to grid coordinates directly using vector-like inputs + start_vector = costmap.world_to_grid(start) + goal_vector = costmap.world_to_grid(goal) + logger.debug(f"ASTAR {costmap} {start_vector} -> {goal_vector}") + + # Store positions as tuples for dictionary keys + start_tuple = (int(start_vector.x), int(start_vector.y)) + goal_tuple = (int(goal_vector.x), int(goal_vector.y)) + + # Check if goal is out of bounds + if not (0 <= goal_tuple[0] < costmap.width and 0 <= goal_tuple[1] < costmap.height): + return None + + # Define possible movements (8-connected grid with diagonal movements) + directions = [ + (0, 1), + (1, 0), + (0, -1), + (-1, 0), + (1, 1), + (1, -1), + (-1, 1), + (-1, -1), + ] + + # Cost for each movement (straight vs diagonal) + sc = 1.0 # Straight cost + dc = 1.42 # Diagonal cost (approximately sqrt(2)) + movement_costs = [sc, sc, sc, sc, dc, dc, dc, dc] + + # A* algorithm implementation + open_set = [] # Priority queue for nodes to explore + closed_set = set() # Set of explored nodes + + # Dictionary to store cost from start and parents for each node + g_score = {start_tuple: 0} + parents = {} + + # Heuristic function (Octile distance for 8-connected grid) + def heuristic(x1, y1, x2, y2): + dx = abs(x2 - x1) + dy = abs(y2 - y1) + # Octile distance: optimal for 8-connected grids with diagonal movement + return (dx + dy) + (dc - 2 * sc) * min(dx, dy) + + # Start with the starting node + f_score = g_score[start_tuple] + heuristic( + start_tuple[0], start_tuple[1], goal_tuple[0], goal_tuple[1] + ) + heapq.heappush(open_set, (f_score, start_tuple)) + + # Track nodes already in open set to avoid duplicates + open_set_hash = {start_tuple} + + while open_set: + # Get the node with the lowest f_score + current_f, current = heapq.heappop(open_set) + current_x, current_y = current + + # Remove from open set hash + if current in open_set_hash: + open_set_hash.remove(current) + + # Skip if already processed (can happen with duplicate entries) + if current in closed_set: + continue + + # Check if we've reached the goal + if current == goal_tuple: + # Reconstruct the path + waypoints = [] + while current in parents: + world_point = costmap.grid_to_world(current) + # Create PoseStamped with identity quaternion (no orientation) + pose = PoseStamped( + frame_id="world", + position=[world_point.x, world_point.y, 0.0], + orientation=Quaternion(0, 0, 0, 1), # Identity quaternion + ) + waypoints.append(pose) + current = parents[current] + + # Add the start position + start_world_point = costmap.grid_to_world(start_tuple) + start_pose = PoseStamped( + frame_id="world", + position=[start_world_point.x, start_world_point.y, 0.0], + orientation=Quaternion(0, 0, 0, 1), + ) + waypoints.append(start_pose) + + # Reverse the path (start to goal) + waypoints.reverse() + + # Add the goal position if it's not already included + goal_point = costmap.grid_to_world(goal_tuple) + + if ( + not waypoints + or (waypoints[-1].x - goal_point.x) ** 2 + (waypoints[-1].y - goal_point.y) ** 2 + > 1e-10 + ): + goal_pose = PoseStamped( + frame_id="world", + position=[goal_point.x, goal_point.y, 0.0], + orientation=Quaternion(0, 0, 0, 1), + ) + waypoints.append(goal_pose) + + return Path(frame_id="world", poses=waypoints) + + # Add current node to closed set + closed_set.add(current) + + # Explore neighbors + for i, (dx, dy) in enumerate(directions): + neighbor_x, neighbor_y = current_x + dx, current_y + dy + neighbor = (neighbor_x, neighbor_y) + + # Check if the neighbor is valid + if not (0 <= neighbor_x < costmap.width and 0 <= neighbor_y < costmap.height): + continue + + # Check if the neighbor is already explored + if neighbor in closed_set: + continue + + # Get the neighbor's cost value + neighbor_val = costmap.grid[neighbor_y, neighbor_x] + + # Skip if it's a hard obstacle + if neighbor_val >= cost_threshold: + continue + + # Calculate movement cost with penalties + # Unknown cells get half the penalty of obstacles + if neighbor_val == CostValues.UNKNOWN: # Unknown cell (-1) + # Unknown cells have a moderate traversal cost (half of obstacle threshold) + cell_cost = cost_threshold * unknown_penalty + elif neighbor_val == CostValues.FREE: # Free space (0) + # Free cells have minimal cost + cell_cost = 0.0 + else: + # Other cells use their actual cost value (1-99) + cell_cost = neighbor_val + + # Calculate cost penalty based on cell cost (higher cost = higher penalty) + # This encourages the planner to prefer lower-cost paths + cost_penalty = cell_cost / CostValues.OCCUPIED # Normalized penalty (divide by 100) + + tentative_g_score = g_score[current] + movement_costs[i] * (1.0 + cost_penalty) + + # Get the current g_score for the neighbor or set to infinity if not yet explored + neighbor_g_score = g_score.get(neighbor, float("inf")) + + # If this path to the neighbor is better than any previous one + if tentative_g_score < neighbor_g_score: + # Update the neighbor's scores and parent + parents[neighbor] = current + g_score[neighbor] = tentative_g_score + f_score = tentative_g_score + heuristic( + neighbor_x, neighbor_y, goal_tuple[0], goal_tuple[1] + ) + + # Add the neighbor to the open set with its f_score + # Only add if not already in open set to reduce duplicates + if neighbor not in open_set_hash: + heapq.heappush(open_set, (f_score, neighbor)) + open_set_hash.add(neighbor) + + # If we get here, no path was found + return None diff --git a/dimos/robot/global_planner/planner.py b/dimos/navigation/global_planner/planner.py similarity index 63% rename from dimos/robot/global_planner/planner.py rename to dimos/navigation/global_planner/planner.py index 95b0d060e5..186163cffb 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/navigation/global_planner/planner.py @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading from abc import abstractmethod from dataclasses import dataclass -from typing import Callable, Optional +from typing import Optional from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Pose, PoseLike, PoseStamped, Vector3, VectorLike, to_pose +from dimos.msgs.geometry_msgs import Pose, PoseStamped from dimos.msgs.nav_msgs import OccupancyGrid, Path -from dimos.robot.global_planner.algo import astar +from dimos.navigation.global_planner.algo import astar from dimos.utils.logging_config import setup_logger -from dimos.web.websocket_vis.helpers import Visualizable logger = setup_logger("dimos.robot.unitree.global_planner") @@ -95,69 +93,75 @@ def resample_path(path: Path, spacing: float) -> Path: @dataclass -class Planner(Visualizable, Module): +class Planner(Module): target: In[PoseStamped] = None path: Out[Path] = None def __init__(self): Module.__init__(self) - Visualizable.__init__(self) - - @rpc - def set_goal( - self, - goal: VectorLike, - goal_theta: Optional[float] = None, - stop_event: Optional[threading.Event] = None, - ): - path = self.plan(goal) - if not path: - logger.warning("No path found to the goal.") - return False - - print("pathing success", path) class AstarPlanner(Planner): - target: In[Vector3] = None - path: Out[Path] = None - - get_costmap: Callable[[], OccupancyGrid] - get_robot_pos: Callable[[], Vector3] - set_local_nav: Callable[[Path, Optional[threading.Event], Optional[float]], bool] = None + # LCM inputs + target: In[PoseStamped] = None + global_costmap: In[OccupancyGrid] = None + odom: In[PoseStamped] = None - conservativism: int = 8 + # LCM outputs + path: Out[Path] = None - def __init__( - self, - get_costmap: Callable[[], OccupancyGrid], - get_robot_pos: Callable[[], Vector3], - set_local_nav: Callable[[Path, Optional[threading.Event], Optional[float]], bool] = None, - ): + def __init__(self): super().__init__() - self.get_costmap = get_costmap - self.get_robot_pos = get_robot_pos - self.set_local_nav = set_local_nav + + # Latest data + self.latest_costmap: Optional[OccupancyGrid] = None + self.latest_odom: Optional[PoseStamped] = None @rpc def start(self): - self.target.subscribe(self.plan) + # Subscribe to inputs + self.target.subscribe(self._on_target) + self.global_costmap.subscribe(self._on_costmap) + self.odom.subscribe(self._on_odom) + + logger.info("A* planner started") - def plan(self, goallike: PoseLike) -> Path: - goal = to_pose(goallike) - logger.info(f"planning path to goal {goal}") - pos = self.get_robot_pos() - costmap = self.get_costmap().gradient() + def _on_costmap(self, msg: OccupancyGrid): + """Handle incoming costmap messages.""" + self.latest_costmap = msg - self.vis("target", goal) + def _on_odom(self, msg: PoseStamped): + """Handle incoming odometry messages.""" + self.latest_odom = msg - path = astar(costmap, goal.position, pos) + def _on_target(self, msg: PoseStamped): + """Handle incoming target messages and trigger planning.""" + if self.latest_costmap is None or self.latest_odom is None: + logger.warning("Cannot plan: missing costmap or odometry data") + return + path = self.plan(msg) if path: - path = resample_path(path, 0.1) self.path.publish(path) - if hasattr(self, "set_local_nav") and self.set_local_nav: - self.set_local_nav(path) - logger.warning(f"Path found: {path}") + + def plan(self, goal: Pose) -> Optional[Path]: + """Plan a path from current position to goal.""" + if self.latest_costmap is None or self.latest_odom is None: + logger.warning("Cannot plan: missing costmap or odometry data") + return None + + logger.debug(f"Planning path to goal {goal}") + + # Get current position from odometry + robot_pos = self.latest_odom.position + + # Run A* planning + path = astar(self.latest_costmap, goal.position, robot_pos) + + if path: + path = resample_path(path, 0.1) + logger.debug(f"Path found with {len(path.poses)} waypoints") return path + logger.warning("No path found to the goal.") + return None diff --git a/dimos/navigation/local_planner/__init__.py b/dimos/navigation/local_planner/__init__.py new file mode 100644 index 0000000000..f6b97d6762 --- /dev/null +++ b/dimos/navigation/local_planner/__init__.py @@ -0,0 +1,2 @@ +from dimos.navigation.local_planner.local_planner import BaseLocalPlanner +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner diff --git a/dimos/navigation/local_planner/holonomic_local_planner.py b/dimos/navigation/local_planner/holonomic_local_planner.py new file mode 100644 index 0000000000..3a8c73d3e2 --- /dev/null +++ b/dimos/navigation/local_planner/holonomic_local_planner.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Gradient-Augmented Look-Ahead Pursuit (GLAP) holonomic local planner. +""" + +from typing import Optional, Tuple + +import numpy as np + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.navigation.local_planner import BaseLocalPlanner +from dimos.utils.transform_utils import quaternion_to_euler, normalize_angle + + +class HolonomicLocalPlanner(BaseLocalPlanner): + """ + Gradient-Augmented Look-Ahead Pursuit (GLAP) holonomic local planner. + + This planner combines path following with obstacle avoidance using + costmap gradients to produce smooth holonomic velocity commands. + + Args: + lookahead_dist: Look-ahead distance in meters (default: 1.0) + k_rep: Repulsion gain for obstacle avoidance (default: 1.0) + alpha: Low-pass filter coefficient [0-1] (default: 0.5) + v_max: Maximum velocity per component in m/s (default: 0.8) + goal_tolerance: Distance threshold to consider goal reached (default: 0.5) + control_frequency: Control loop frequency in Hz (default: 10.0) + """ + + def __init__( + self, + lookahead_dist: float = 1.0, + k_rep: float = 0.5, + alpha: float = 0.5, + v_max: float = 0.8, + goal_tolerance: float = 0.5, + control_frequency: float = 10.0, + **kwargs, + ): + """Initialize the GLAP planner with specified parameters.""" + super().__init__( + goal_tolerance=goal_tolerance, control_frequency=control_frequency, **kwargs + ) + + # Algorithm parameters + self.lookahead_dist = lookahead_dist + self.k_rep = k_rep + self.alpha = alpha + self.v_max = v_max + + # Previous velocity for filtering (vx, vy, vtheta) + self.v_prev = np.array([0.0, 0.0, 0.0]) + + def compute_velocity(self) -> Optional[Vector3]: + """ + Compute velocity commands using GLAP algorithm. + + Returns: + Vector3 with x, y velocities in robot frame and z as angular velocity + """ + if self.latest_odom is None or self.latest_path is None or self.latest_costmap is None: + return None + + pose = np.array([self.latest_odom.position.x, self.latest_odom.position.y]) + + euler = quaternion_to_euler(self.latest_odom.orientation) + robot_yaw = euler.z + + path_points = [] + for pose_stamped in self.latest_path.poses: + path_points.append([pose_stamped.position.x, pose_stamped.position.y]) + + if len(path_points) == 0: + return None + + path = np.array(path_points) + + costmap = self.latest_costmap.grid + + v_follow_odom = self._compute_path_following(pose, path) + + v_rep_odom = self._compute_obstacle_repulsion(pose, costmap) + + v_odom = v_follow_odom + v_rep_odom + + # Transform velocity from odom frame to robot frame + cos_yaw = np.cos(robot_yaw) + sin_yaw = np.sin(robot_yaw) + + v_robot_x = cos_yaw * v_odom[0] + sin_yaw * v_odom[1] + v_robot_y = -sin_yaw * v_odom[0] + cos_yaw * v_odom[1] + + # Compute angular velocity to align with path direction + closest_idx, _ = self._find_closest_point_on_path(pose, path) + lookahead_point = self._find_lookahead_point(path, closest_idx) + + dx = lookahead_point[0] - pose[0] + dy = lookahead_point[1] - pose[1] + desired_yaw = np.arctan2(dy, dx) + + yaw_error = normalize_angle(desired_yaw - robot_yaw) + k_angular = 2.0 # Angular gain + v_theta = k_angular * yaw_error + + v_robot_x = np.clip(v_robot_x, -self.v_max, self.v_max) + v_robot_y = np.clip(v_robot_y, -self.v_max, self.v_max) + v_theta = np.clip(v_theta, -self.v_max, self.v_max) + + v_raw = np.array([v_robot_x, v_robot_y, v_theta]) + v_filtered = self.alpha * v_raw + (1 - self.alpha) * self.v_prev + self.v_prev = v_filtered + + return Vector3(v_filtered[0], v_filtered[1], v_filtered[2]) + + def _compute_path_following(self, pose: np.ndarray, path: np.ndarray) -> np.ndarray: + """ + Compute path following velocity using pure pursuit. + + Args: + pose: Current robot position [x, y] + path: Path waypoints as Nx2 array + + Returns: + Path following velocity vector [vx, vy] + """ + closest_idx, _ = self._find_closest_point_on_path(pose, path) + + carrot = self._find_lookahead_point(path, closest_idx) + + direction = carrot - pose + distance = np.linalg.norm(direction) + + if distance < 1e-6: + return np.zeros(2) + + v_follow = self.v_max * direction / distance + + return v_follow + + def _compute_obstacle_repulsion(self, pose: np.ndarray, costmap: np.ndarray) -> np.ndarray: + """ + Compute obstacle repulsion velocity from costmap gradient. + + Args: + pose: Current robot position [x, y] + costmap: 2D costmap array + + Returns: + Repulsion velocity vector [vx, vy] + """ + grid_point = self.latest_costmap.world_to_grid(pose) + grid_x = int(grid_point.x) + grid_y = int(grid_point.y) + + height, width = costmap.shape + if not (1 <= grid_x < width - 1 and 1 <= grid_y < height - 1): + return np.zeros(2) + + # Compute gradient using central differences + # Note: costmap is in row-major order (y, x) + gx = (costmap[grid_y, grid_x + 1] - costmap[grid_y, grid_x - 1]) / ( + 2.0 * self.latest_costmap.resolution + ) + gy = (costmap[grid_y + 1, grid_x] - costmap[grid_y - 1, grid_x]) / ( + 2.0 * self.latest_costmap.resolution + ) + + # Gradient points towards higher cost, so negate for repulsion + v_rep = -self.k_rep * np.array([gx, gy]) + + return v_rep + + def _find_closest_point_on_path( + self, pose: np.ndarray, path: np.ndarray + ) -> Tuple[int, np.ndarray]: + """ + Find the closest point on the path to current pose. + + Args: + pose: Current position [x, y] + path: Path waypoints as Nx2 array + + Returns: + Tuple of (closest_index, closest_point) + """ + distances = np.linalg.norm(path - pose, axis=1) + closest_idx = np.argmin(distances) + return closest_idx, path[closest_idx] + + def _find_lookahead_point(self, path: np.ndarray, start_idx: int) -> np.ndarray: + """ + Find look-ahead point on path at specified distance. + + Args: + path: Path waypoints as Nx2 array + start_idx: Starting index for search + + Returns: + Look-ahead point [x, y] + """ + accumulated_dist = 0.0 + + for i in range(start_idx, len(path) - 1): + segment_dist = np.linalg.norm(path[i + 1] - path[i]) + + if accumulated_dist + segment_dist >= self.lookahead_dist: + remaining_dist = self.lookahead_dist - accumulated_dist + t = remaining_dist / segment_dist + carrot = path[i] + t * (path[i + 1] - path[i]) + return carrot + + accumulated_dist += segment_dist + + return path[-1] + + def _clip(self, v: np.ndarray) -> np.ndarray: + """Instance method to clip velocity with access to v_max.""" + return np.clip(v, -self.v_max, self.v_max) diff --git a/dimos/navigation/local_planner/local_planner.py b/dimos/navigation/local_planner/local_planner.py new file mode 100644 index 0000000000..2fa8fc6f37 --- /dev/null +++ b/dimos/navigation/local_planner/local_planner.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base Local Planner Module for robot navigation. +Subscribes to local costmap, odometry, and path, publishes movement commands. +""" + +import threading +import time +from abc import abstractmethod +from typing import Optional + +from dimos.core import Module, In, Out, rpc +from dimos.msgs.geometry_msgs import Vector3, PoseStamped +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import get_distance + +logger = setup_logger("dimos.robot.local_planner") + + +class BaseLocalPlanner(Module): + """ + local planner module for obstacle avoidance and path following. + + Subscribes to: + - /local_costmap: Local occupancy grid for obstacle detection + - /odom: Robot odometry for current pose + - /path: Path to follow (continuously updated at ~1Hz) + + Publishes: + - /cmd_vel: Velocity commands for robot movement + """ + + # LCM inputs + local_costmap: In[OccupancyGrid] = None + odom: In[PoseStamped] = None + path: In[Path] = None + + # LCM outputs + cmd_vel: Out[Vector3] = None + + def __init__(self, goal_tolerance: float = 0.5, control_frequency: float = 10.0, **kwargs): + """Initialize the local planner module. + + Args: + goal_tolerance: Distance threshold to consider goal reached (meters) + control_frequency: Frequency for control loop (Hz) + """ + super().__init__(**kwargs) + + # Parameters + self.goal_tolerance = goal_tolerance + self.control_frequency = control_frequency + self.control_period = 1.0 / control_frequency + + # Latest data + self.latest_costmap: Optional[OccupancyGrid] = None + self.latest_odom: Optional[PoseStamped] = None + self.latest_path: Optional[Path] = None + + # Control thread + self.planning_thread: Optional[threading.Thread] = None + self.stop_planning = threading.Event() + + logger.info("Local planner module initialized") + + @rpc + def start(self): + """Start the local planner module.""" + # Subscribe to inputs + self.local_costmap.subscribe(self._on_costmap) + self.odom.subscribe(self._on_odom) + self.path.subscribe(self._on_path) + + logger.info("Local planner module started") + + def _on_costmap(self, msg: OccupancyGrid): + self.latest_costmap = msg + + def _on_odom(self, msg: PoseStamped): + self.latest_odom = msg + + def _on_path(self, msg: Path): + self.latest_path = msg + + if msg and len(msg.poses) > 0: + if self.planning_thread is None or not self.planning_thread.is_alive(): + self._start_planning_thread() + + def _start_planning_thread(self): + """Start the planning thread.""" + self.stop_planning.clear() + self.planning_thread = threading.Thread(target=self._follow_path_loop, daemon=True) + self.planning_thread.start() + logger.debug("Started follow path thread") + + def _follow_path_loop(self): + """Main planning loop that runs in a separate thread.""" + while not self.stop_planning.is_set(): + if self.is_goal_reached(): + logger.info("Goal reached, stopping planning thread") + self.stop_planning.set() + stop_cmd = Vector3(0, 0, 0) + self.cmd_vel.publish(stop_cmd) + break + + # Compute and publish velocity + self._plan() + + time.sleep(self.control_period) + + def _plan(self): + """Compute and publish velocity command.""" + cmd_vel = self.compute_velocity() + + if cmd_vel is not None: + self.cmd_vel.publish(cmd_vel) + + @abstractmethod + def compute_velocity(self) -> Optional[Vector3]: + """ + Compute velocity commands based on current costmap, odometry, and path. + Must be implemented by derived classes. + + Returns: + Vector3 message with velocity commands x, y, theta, or None if no command + """ + pass + + @rpc + def is_goal_reached(self) -> bool: + """ + Check if the robot has reached the goal position. + + Returns: + True if goal is reached within tolerance, False otherwise + """ + if self.latest_odom is None or self.latest_path is None: + return False + + if len(self.latest_path.poses) == 0: + return True + + goal_pose = self.latest_path.poses[-1] + distance = get_distance(self.latest_odom, goal_pose) + + goal_reached = distance < self.goal_tolerance + + if goal_reached: + logger.info(f"Goal reached! Distance: {distance:.3f}m < {self.goal_tolerance}m") + + return goal_reached + + @rpc + def reset(self): + """Reset the local planner state, clearing the current path.""" + # Clear the latest path + self.latest_path = None + self.latest_odom = None + self.latest_costmap = None + self.stop() + logger.info("Local planner reset") + + @rpc + def stop(self): + """Stop the local planner and any running threads.""" + if self.planning_thread and self.planning_thread.is_alive(): + self.stop_planning.set() + self.planning_thread.join(timeout=1.0) + self.planning_thread = None + stop_cmd = Vector3(0, 0, 0) + self.cmd_vel.publish(stop_cmd) + + logger.info("Local planner stopped") diff --git a/dimos/navigation/local_planner/test_base_local_planner.py b/dimos/navigation/local_planner/test_base_local_planner.py new file mode 100644 index 0000000000..93ec26882b --- /dev/null +++ b/dimos/navigation/local_planner/test_base_local_planner.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the GLAP (Gradient-Augmented Look-Ahead Pursuit) holonomic local planner. +""" + +import numpy as np +import pytest + +from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion +from dimos.msgs.nav_msgs import Path, OccupancyGrid +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner + + +class TestHolonomicLocalPlanner: + """Test suite for HolonomicLocalPlanner.""" + + @pytest.fixture + def planner(self): + """Create a planner instance for testing.""" + return HolonomicLocalPlanner( + lookahead_dist=1.5, + k_rep=1.0, + alpha=1.0, # No filtering for deterministic tests + v_max=1.0, + goal_tolerance=0.5, + control_frequency=10.0, + ) + + @pytest.fixture + def empty_costmap(self): + """Create an empty costmap (all free space).""" + costmap = OccupancyGrid( + grid=np.zeros((100, 100), dtype=np.int8), resolution=0.1, origin=Pose() + ) + costmap.origin.position.x = -5.0 + costmap.origin.position.y = -5.0 + return costmap + + def test_straight_path_no_obstacles(self, planner, empty_costmap): + """Test that planner follows straight path with no obstacles.""" + # Set current position at origin + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Create straight path along +X + path = Path() + for i in range(10): + ps = PoseStamped() + ps.position.x = float(i) + ps.position.y = 0.0 + ps.orientation.w = 1.0 # Identity quaternion + path.poses.append(ps) + planner.latest_path = path + + # Set empty costmap + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Should move along +X + assert vel is not None + assert vel.x > 0.9 # Close to v_max + assert abs(vel.y) < 0.1 # Near zero + assert abs(vel.z) < 0.1 # Small angular velocity when aligned with path + + def test_obstacle_gradient_repulsion(self, planner): + """Test that obstacle gradients create repulsive forces.""" + # Set position at origin + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Simple path forward + path = Path() + ps = PoseStamped() + ps.position.x = 5.0 + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + # Create costmap with gradient pointing south (higher cost north) + costmap_grid = np.zeros((100, 100), dtype=np.int8) + for i in range(100): + costmap_grid[i, :] = max(0, 50 - i) # Gradient from north to south + + planner.latest_costmap = OccupancyGrid(grid=costmap_grid, resolution=0.1, origin=Pose()) + planner.latest_costmap.origin.position.x = -5.0 + planner.latest_costmap.origin.position.y = -5.0 + + # Compute velocity + vel = planner.compute_velocity() + + # Should have positive Y component (pushed north by gradient) + assert vel is not None + assert vel.y > 0.1 # Repulsion pushes north + + def test_lowpass_filter(self): + """Test that low-pass filter smooths velocity commands.""" + # Create planner with alpha=0.5 for filtering + planner = HolonomicLocalPlanner( + lookahead_dist=1.0, + k_rep=0.0, # No repulsion + alpha=0.5, # 50% filtering + v_max=1.0, + ) + + # Setup similar to straight path test + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + path = Path() + ps = PoseStamped() + ps.position.x = 5.0 + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = OccupancyGrid( + grid=np.zeros((100, 100), dtype=np.int8), resolution=0.1, origin=Pose() + ) + planner.latest_costmap.origin.position.x = -5.0 + planner.latest_costmap.origin.position.y = -5.0 + + # First call - previous velocity is zero + vel1 = planner.compute_velocity() + assert vel1 is not None + + # Store first velocity + first_vx = vel1.x + + # Second call - should be filtered + vel2 = planner.compute_velocity() + assert vel2 is not None + + # With alpha=0.5 and same conditions: + # v2 = 0.5 * v_raw + 0.5 * v1 + # The filtering effect should be visible + # v2 should be between v1 and the raw velocity + assert vel2.x != first_vx # Should be different due to filtering + assert 0 < vel2.x <= planner.v_max # Should still be positive and within limits + + def test_no_path(self, planner, empty_costmap): + """Test that planner returns None when no path is available.""" + planner.latest_odom = PoseStamped() + planner.latest_costmap = empty_costmap + planner.latest_path = Path() # Empty path + + vel = planner.compute_velocity() + assert vel is None + + def test_no_odometry(self, planner, empty_costmap): + """Test that planner returns None when no odometry is available.""" + planner.latest_odom = None + planner.latest_costmap = empty_costmap + + path = Path() + ps = PoseStamped() + ps.position.x = 1.0 + ps.position.y = 0.0 + path.poses.append(ps) + planner.latest_path = path + + vel = planner.compute_velocity() + assert vel is None + + def test_no_costmap(self, planner): + """Test that planner returns None when no costmap is available.""" + planner.latest_odom = PoseStamped() + planner.latest_costmap = None + + path = Path() + ps = PoseStamped() + ps.position.x = 1.0 + ps.position.y = 0.0 + path.poses.append(ps) + planner.latest_path = path + + vel = planner.compute_velocity() + assert vel is None + + def test_goal_reached(self, planner, empty_costmap): + """Test velocity when robot is at goal.""" + # Set robot at goal position + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 5.0 + planner.latest_odom.position.y = 0.0 + + # Path with single point at robot position + path = Path() + ps = PoseStamped() + ps.position.x = 5.0 + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Should have near-zero velocity + assert vel is not None + assert abs(vel.x) < 0.1 + assert abs(vel.y) < 0.1 + + def test_velocity_saturation(self, planner, empty_costmap): + """Test that velocities are capped at v_max.""" + # Set robot far from goal to maximize commanded velocity + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Create path far away + path = Path() + ps = PoseStamped() + ps.position.x = 100.0 # Very far + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Velocity should be saturated at v_max + assert vel is not None + assert abs(vel.x) <= planner.v_max + 0.01 # Small tolerance + assert abs(vel.y) <= planner.v_max + 0.01 + + def test_lookahead_interpolation(self, planner, empty_costmap): + """Test that lookahead point is correctly interpolated on path.""" + # Set robot at origin + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Create path with waypoints closer than lookahead distance + path = Path() + for i in range(5): + ps = PoseStamped() + ps.position.x = i * 0.5 # 0.5m spacing + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Should move forward along path + assert vel is not None + assert vel.x > 0.5 # Moving forward + assert abs(vel.y) < 0.1 # Staying on path + + def test_curved_path_following(self, planner, empty_costmap): + """Test following a curved path.""" + # Set robot at origin + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Create curved path (quarter circle) + path = Path() + for i in range(10): + angle = (np.pi / 2) * (i / 9.0) # 0 to 90 degrees + ps = PoseStamped() + ps.position.x = 2.0 * np.cos(angle) + ps.position.y = 2.0 * np.sin(angle) + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Should have both X and Y components for curved motion + assert vel is not None + assert vel.x > 0.3 # Moving forward + assert vel.y > 0.1 # Turning left (positive Y) + + def test_robot_frame_transformation(self, empty_costmap): + """Test that velocities are correctly transformed to robot frame.""" + # Create planner with no filtering for deterministic test + planner = HolonomicLocalPlanner( + lookahead_dist=1.0, + k_rep=0.0, # No repulsion + alpha=1.0, # No filtering + v_max=1.0, + ) + + # Set robot at origin but rotated 90 degrees (facing +Y in odom frame) + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + # Quaternion for 90 degree rotation around Z + planner.latest_odom.orientation = Quaternion(0.0, 0.0, 0.7071068, 0.7071068) + + # Create path along +X axis in odom frame + path = Path() + for i in range(5): + ps = PoseStamped() + ps.position.x = float(i) + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Robot is facing +Y, path is along +X + # So in robot frame: forward is +Y direction, path is to the right + assert vel is not None + assert abs(vel.x) < 0.1 # No forward velocity in robot frame + assert vel.y < -0.5 # Should move right (negative Y in robot frame) + assert vel.z < -0.5 # Should turn right (negative angular velocity) + + def test_angular_velocity_computation(self, empty_costmap): + """Test that angular velocity is computed to align with path.""" + planner = HolonomicLocalPlanner( + lookahead_dist=2.0, + k_rep=0.0, # No repulsion + alpha=1.0, # No filtering + v_max=1.0, + ) + + # Robot at origin facing +X + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + planner.latest_odom.orientation.w = 1.0 # Identity quaternion + + # Create path at 45 degrees + path = Path() + for i in range(5): + ps = PoseStamped() + ps.position.x = float(i) + ps.position.y = float(i) # Diagonal path + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Path is at 45 degrees, robot facing 0 degrees + # Should have positive angular velocity to turn left + assert vel is not None + assert vel.x > 0.5 # Moving forward + assert vel.y > 0.5 # Also moving left (diagonal path) + assert vel.z > 0.5 # Positive angular velocity to turn towards path diff --git a/dimos/perception/detection2d/utils.py b/dimos/perception/detection2d/utils.py index dbe19baf30..73e0eb5671 100644 --- a/dimos/perception/detection2d/utils.py +++ b/dimos/perception/detection2d/utils.py @@ -15,7 +15,6 @@ import numpy as np import cv2 from dimos.types.vector import Vector -from dimos.utils.transform_utils import distance_angle_to_goal_xy def filter_detections( @@ -305,34 +304,3 @@ def calculate_object_size_from_bbox(bbox, depth, camera_intrinsics): height_m = (height_px * depth) / fy return width_m, height_m - - -def calculate_position_rotation_from_bbox(bbox, depth, camera_intrinsics): - """ - Calculate position (xyz) and rotation (roll, pitch, yaw) for an object - based on its bounding box and depth. - - Args: - bbox: Bounding box [x1, y1, x2, y2] - depth: Depth value in meters - camera_intrinsics: List [fx, fy, cx, cy] with camera parameters - - Returns: - Vector: position - Vector: rotation - """ - # Calculate distance and angle to object - distance, angle = calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics) - - # Convert distance and angle to x,y coordinates (in camera frame) - # Note: We negate the angle since positive angle means object is to the right, - # but we want positive y to be to the left in the standard coordinate system - x, y = distance_angle_to_goal_xy(distance, -angle) - - # For now, rotation is only in yaw (around z-axis) - # We can use the negative of the angle as an estimate of the object's yaw - # assuming objects tend to face the camera - position = Vector([x, y, 0.0]) - rotation = Vector([0.0, 0.0, -angle]) - - return position, rotation diff --git a/dimos/perception/semantic_seg.py b/dimos/perception/semantic_seg.py index a07e69c279..6626e3bc9f 100644 --- a/dimos/perception/semantic_seg.py +++ b/dimos/perception/semantic_seg.py @@ -25,7 +25,6 @@ class SemanticSegmentationStream: def __init__( self, - device: str = "cuda", enable_mono_depth: bool = True, enable_rich_labeling: bool = True, camera_params: dict = None, @@ -35,7 +34,6 @@ def __init__( Initialize a semantic segmentation stream using Sam2DSegmenter. Args: - device: Computation device ("cuda" or "cpu") enable_mono_depth: Whether to enable monocular depth processing enable_rich_labeling: Whether to enable rich labeling camera_params: Dictionary containing either: @@ -43,7 +41,6 @@ def __init__( - Physical parameters: resolution, focal_length, sensor_size """ self.segmenter = Sam2DSegmenter( - device=device, min_analysis_interval=5.0, use_tracker=True, use_analyzer=True, diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 60ffe2105a..188b9b81d9 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -28,6 +28,7 @@ from dimos.core import In, Module, Out, rpc from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs import Vector3, Quaternion, Pose, PoseStamped from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.logging_config import setup_logger from dimos.agents.memory.spatial_vector_db import SpatialVectorDB @@ -144,18 +145,18 @@ def __init__( logger.error(f"Error loading visual memory: {e}") self._visual_memory = VisualMemory(output_dir=output_dir) - # Initialize vector database + self.embedding_provider: ImageEmbeddingProvider = ImageEmbeddingProvider( + model_name=embedding_model, dimensions=embedding_dimensions + ) + self.vector_db: SpatialVectorDB = SpatialVectorDB( collection_name=collection_name, chroma_client=self._chroma_client, visual_memory=self._visual_memory, + embedding_provider=self.embedding_provider, ) - self.embedding_provider: ImageEmbeddingProvider = ImageEmbeddingProvider( - model_name=embedding_model, dimensions=embedding_dimensions - ) - - self.last_position: Optional[Vector] = None + self.last_position: Optional[Vector3] = None self.last_record_time: Optional[float] = None self.frame_count: int = 0 @@ -206,12 +207,10 @@ def _process_frame(self): position = self._latest_odom.position orientation = self._latest_odom.orientation - # Convert to Vector objects - position_vec = Vector([position.x, position.y, position.z]) - - # Get euler angles from quaternion orientation - euler = orientation.to_euler() - rotation_vec = Vector([euler.x, euler.y, euler.z]) + # Create Pose object with position and orientation + current_pose = Pose( + position=Vector3(position.x, position.y, position.z), orientation=orientation + ) # Process the frame directly try: @@ -219,7 +218,13 @@ def _process_frame(self): # Check distance constraint if self.last_position is not None: - distance_moved = (self.last_position - position_vec).length() + distance_moved = np.linalg.norm( + [ + current_pose.position.x - self.last_position.x, + current_pose.position.y - self.last_position.y, + current_pose.position.z - self.last_position.z, + ] + ) if distance_moved < self.min_distance_threshold: logger.debug( f"Position has not moved enough: {distance_moved:.4f}m < {self.min_distance_threshold}m, skipping frame" @@ -242,14 +247,17 @@ def _process_frame(self): frame_id = f"frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + # Get euler angles from quaternion orientation for metadata + euler = orientation.to_euler() + # Create metadata dictionary with primitive types only metadata = { - "pos_x": float(position_vec.x), - "pos_y": float(position_vec.y), - "pos_z": float(position_vec.z), - "rot_x": float(rotation_vec.x), - "rot_y": float(rotation_vec.y), - "rot_z": float(rotation_vec.z), + "pos_x": float(current_pose.position.x), + "pos_y": float(current_pose.position.y), + "pos_z": float(current_pose.position.z), + "rot_x": float(euler.x), + "rot_y": float(euler.y), + "rot_z": float(euler.z), "timestamp": current_time, "frame_id": frame_id, } @@ -263,12 +271,13 @@ def _process_frame(self): ) # Update tracking variables - self.last_position = position_vec + self.last_position = current_pose.position self.last_record_time = current_time self.stored_frame_count += 1 logger.info( - f"Stored frame at position {position_vec}, rotation {rotation_vec}) " + f"Stored frame at position ({current_pose.position.x:.2f}, {current_pose.position.y:.2f}, {current_pose.position.z:.2f}), " + f"rotation ({euler.x:.2f}, {euler.y:.2f}, {euler.z:.2f}) " f"stored {self.stored_frame_count}/{self.frame_count} frames" ) @@ -403,16 +412,23 @@ def process_combined_data(data): position_vec = data.get("position") # Use .get() for consistency rotation_vec = data.get("rotation") # Get rotation data if available - if not position_vec or not rotation_vec: + if position_vec is None or rotation_vec is None: logger.info("No position or rotation data available, skipping frame") return None - if ( - self.last_position is not None - and (self.last_position - position_vec).length() < self.min_distance_threshold - ): - logger.debug("Position has not moved, skipping frame") - return None + position_v3 = Vector3(position_vec.x, position_vec.y, position_vec.z) + + if self.last_position is not None: + distance_moved = np.linalg.norm( + [ + position_v3.x - self.last_position.x, + position_v3.y - self.last_position.y, + position_v3.z - self.last_position.z, + ] + ) + if distance_moved < self.min_distance_threshold: + logger.debug("Position has not moved, skipping frame") + return None if ( self.last_record_time is not None @@ -429,9 +445,9 @@ def process_combined_data(data): # Create metadata dictionary with primitive types only metadata = { - "pos_x": float(position_vec.x), - "pos_y": float(position_vec.y), - "pos_z": float(position_vec.z), + "pos_x": float(position_v3.x), + "pos_y": float(position_v3.y), + "pos_z": float(position_v3.z), "rot_x": float(rotation_vec.x), "rot_y": float(rotation_vec.y), "rot_z": float(rotation_vec.z), @@ -443,19 +459,20 @@ def process_combined_data(data): vector_id=frame_id, image=frame, embedding=frame_embedding, metadata=metadata ) - self.last_position = position_vec + self.last_position = position_v3 self.last_record_time = current_time self.stored_frame_count += 1 logger.info( - f"Stored frame at position {position_vec}, rotation {rotation_vec})" - f" stored {self.stored_frame_count}/{self.frame_count} frames" + f"Stored frame at position ({position_v3.x:.2f}, {position_v3.y:.2f}, {position_v3.z:.2f}), " + f"rotation ({rotation_vec.x:.2f}, {rotation_vec.y:.2f}, {rotation_vec.z:.2f}) " + f"stored {self.stored_frame_count}/{self.frame_count} frames" ) # Create return dictionary with primitive-compatible values return { "frame": frame, - "position": (position_vec.x, position_vec.y, position_vec.z), + "position": (position_v3.x, position_v3.y, position_v3.z), "rotation": (rotation_vec.x, rotation_vec.y, rotation_vec.z), "frame_id": frame_id, "timestamp": current_time, diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index 11af5edb42..c8cf8de26b 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -28,12 +28,11 @@ from dimos.msgs.geometry_msgs import Pose from dimos.perception.spatial_perception import SpatialMemory from dimos.stream.video_provider import VideoProvider -from dimos.types.vector import Vector @pytest.mark.heavy class TestSpatialMemory: - @pytest.fixture(scope="function") + @pytest.fixture(scope="class") def temp_dir(self): # Create a temporary directory for storing spatial memory data temp_dir = tempfile.mkdtemp() @@ -41,66 +40,60 @@ def temp_dir(self): # Clean up shutil.rmtree(temp_dir) - def test_spatial_memory_initialization(self): + @pytest.fixture(scope="class") + def spatial_memory(self, temp_dir): + # Create a single SpatialMemory instance to be reused across all tests + memory = SpatialMemory( + collection_name="test_collection", + embedding_model="clip", + new_memory=True, + db_path=os.path.join(temp_dir, "chroma_db"), + visual_memory_path=os.path.join(temp_dir, "visual_memory.pkl"), + output_dir=os.path.join(temp_dir, "images"), + min_distance_threshold=0.01, + min_time_threshold=0.01, + ) + yield memory + # Clean up + memory.cleanup() + + def test_spatial_memory_initialization(self, spatial_memory): """Test SpatialMemory initializes correctly with CLIP model.""" - try: - # Initialize spatial memory with default CLIP model - memory = SpatialMemory( - collection_name="test_collection", embedding_model="clip", new_memory=True - ) - assert memory is not None - assert memory.embedding_model == "clip" - assert memory.embedding_provider is not None - except Exception as e: - # If the model doesn't initialize, skip the test - pytest.fail(f"Failed to initialize model: {e}") + # Use the shared spatial_memory fixture + assert spatial_memory is not None + assert spatial_memory.embedding_model == "clip" + assert spatial_memory.embedding_provider is not None - def test_image_embedding(self): + def test_image_embedding(self, spatial_memory): """Test generating image embeddings using CLIP.""" - try: - # Initialize spatial memory with CLIP model - memory = SpatialMemory( - collection_name="test_collection", embedding_model="clip", new_memory=True - ) - - # Create a test image - use a simple colored square - test_image = np.zeros((224, 224, 3), dtype=np.uint8) - test_image[50:150, 50:150] = [0, 0, 255] # Blue square - - # Generate embedding - embedding = memory.embedding_provider.get_embedding(test_image) - - # Check embedding shape and characteristics - assert embedding is not None - assert isinstance(embedding, np.ndarray) - assert embedding.shape[0] == memory.embedding_dimensions - - # Check that embedding is normalized (unit vector) - assert np.isclose(np.linalg.norm(embedding), 1.0, atol=1e-5) - - # Test text embedding - text_embedding = memory.embedding_provider.get_text_embedding("a blue square") - assert text_embedding is not None - assert isinstance(text_embedding, np.ndarray) - assert text_embedding.shape[0] == memory.embedding_dimensions - assert np.isclose(np.linalg.norm(text_embedding), 1.0, atol=1e-5) - except Exception as e: - pytest.fail(f"Error in test: {e}") - - def test_spatial_memory_processing(self, temp_dir): + # Use the shared spatial_memory fixture + # Create a test image - use a simple colored square + test_image = np.zeros((224, 224, 3), dtype=np.uint8) + test_image[50:150, 50:150] = [0, 0, 255] # Blue square + + # Generate embedding + embedding = spatial_memory.embedding_provider.get_embedding(test_image) + + # Check embedding shape and characteristics + assert embedding is not None + assert isinstance(embedding, np.ndarray) + assert embedding.shape[0] == spatial_memory.embedding_dimensions + + # Check that embedding is normalized (unit vector) + assert np.isclose(np.linalg.norm(embedding), 1.0, atol=1e-5) + + # Test text embedding + text_embedding = spatial_memory.embedding_provider.get_text_embedding("a blue square") + assert text_embedding is not None + assert isinstance(text_embedding, np.ndarray) + assert text_embedding.shape[0] == spatial_memory.embedding_dimensions + assert np.isclose(np.linalg.norm(text_embedding), 1.0, atol=1e-5) + + def test_spatial_memory_processing(self, spatial_memory, temp_dir): """Test processing video frames and building spatial memory with CLIP embeddings.""" try: - # Initialize spatial memory with temporary storage - memory = SpatialMemory( - collection_name="test_collection", - embedding_model="clip", - new_memory=True, - db_path=os.path.join(temp_dir, "chroma_db"), - visual_memory_path=os.path.join(temp_dir, "visual_memory.pkl"), - output_dir=os.path.join(temp_dir, "images"), - min_distance_threshold=0.01, - min_time_threshold=0.01, - ) + # Use the shared spatial_memory fixture + memory = spatial_memory from dimos.utils.data import get_data @@ -206,7 +199,6 @@ def on_completed(): except Exception as e: pytest.fail(f"Error in test: {e}") finally: - memory.cleanup() video_provider.dispose_all() diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index 0339ae038f..5166ef2443 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -26,11 +26,9 @@ from dimos import core from dimos.core import Module, In, Out, rpc from dimos.msgs.sensor_msgs import Image -from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.perception.spatial_perception import SpatialMemory from dimos.protocol import pubsub -from dimos.types.vector import Vector from dimos.utils.data import get_data from dimos.utils.testing import TimedSensorReplay from dimos.utils.logging_config import setup_logger diff --git a/dimos/perception/test_tracking_modules.py b/dimos/perception/test_tracking_modules.py deleted file mode 100644 index affb8ace57..0000000000 --- a/dimos/perception/test_tracking_modules.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for object and person tracking modules with LCM integration.""" - -import asyncio -import os -import pytest -import numpy as np -from typing import Dict -from reactivex import operators as ops - -from dimos import core -from dimos.core import Module, Out, rpc -from dimos.msgs.sensor_msgs import Image -from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.protocol import pubsub -from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay -from dimos.utils.logging_config import setup_logger -import tempfile -from dimos.core import stop - -logger = setup_logger("test_tracking_modules") - -pubsub.lcm.autoconf() - - -class VideoReplayModule(Module): - """Module that replays video data from TimedSensorReplay.""" - - video_out: Out[Image] = None - - def __init__(self, video_path: str): - super().__init__() - self.video_path = video_path - self._subscription = None - - @rpc - def start(self): - """Start replaying video data.""" - # Use TimedSensorReplay to replay video frames - video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) - - self._subscription = ( - video_replay.stream().pipe(ops.sample(0.1)).subscribe(self.video_out.publish) - ) - - logger.info("VideoReplayModule started") - - @rpc - def stop(self): - if self._subscription: - self._subscription.dispose() - self._subscription = None - logger.info("VideoReplayModule stopped") - - -@pytest.mark.skip(reason="Tracking tests hanging due to ONNX/CUDA cleanup issues") -@pytest.mark.heavy -class TestTrackingModules: - @pytest.fixture(scope="function") - def temp_dir(self): - temp_dir = tempfile.mkdtemp(prefix="tracking_test_") - yield temp_dir - - @pytest.mark.asyncio - async def test_person_tracking_module_with_replay(self, temp_dir): - """Test PersonTrackingStream module with TimedSensorReplay inputs.""" - - # Start Dask - dimos = core.start(1) - - try: - data_path = get_data("unitree_office_walk") - video_path = os.path.join(data_path, "video") - - video_module = dimos.deploy(VideoReplayModule, video_path) - video_module.video_out.transport = core.LCMTransport("/test_video", Image) - - person_tracker = dimos.deploy( - PersonTrackingStream, - camera_intrinsics=[619.061157, 619.061157, 317.883459, 238.543800], - camera_pitch=-0.174533, - camera_height=0.3, - ) - - person_tracker.tracking_data.transport = core.pLCMTransport("/person_tracking") - person_tracker.video.connect(video_module.video_out) - - video_module.start() - person_tracker.start() - person_tracker.enable_tracking() - await asyncio.sleep(2) - - results = [] - - from dimos.protocol.pubsub.lcmpubsub import PickleLCM - - lcm_instance = PickleLCM() - lcm_instance.start() - - def on_message(msg, topic): - results.append(msg) - - lcm_instance.subscribe("/person_tracking", on_message) - - await asyncio.sleep(3) - - video_module.stop() - - assert len(results) > 0 - - for msg in results: - assert "targets" in msg - assert isinstance(msg["targets"], list) - - tracking_data = person_tracker.get_tracking_data() - assert isinstance(tracking_data, dict) - assert "targets" in tracking_data - - logger.info(f"Person tracking test passed with {len(results)} messages") - - finally: - lcm_instance.stop() - # stop(dimos) - dimos.close() - dimos.shutdown() - - @pytest.mark.asyncio - async def test_object_tracking_module_with_replay(self, temp_dir): - """Test ObjectTrackingStream module with TimedSensorReplay inputs.""" - - # Start Dask - dimos = core.start(1) - - try: - data_path = get_data("unitree_office_walk") - video_path = os.path.join(data_path, "video") - - video_module = dimos.deploy(VideoReplayModule, video_path) - video_module.video_out.transport = core.LCMTransport("/test_video", Image) - - object_tracker = dimos.deploy( - ObjectTrackingStream, - camera_intrinsics=[619.061157, 619.061157, 317.883459, 238.543800], - camera_pitch=-0.174533, - camera_height=0.3, - ) - - object_tracker.tracking_data.transport = core.pLCMTransport("/object_tracking") - object_tracker.video.connect(video_module.video_out) - - video_module.start() - object_tracker.start() - # object_tracker.track([100, 100, 200, 200]) - results = [] - - from dimos.protocol.pubsub.lcmpubsub import PickleLCM - - lcm_instance = PickleLCM() - lcm_instance.start() - - def on_message(msg, topic): - results.append(msg) - - lcm_instance.subscribe("/object_tracking", on_message) - - await asyncio.sleep(5) - - video_module.stop() - - assert len(results) > 0 - - for msg in results: - assert "targets" in msg - assert isinstance(msg["targets"], list) - - logger.info(f"Object tracking test passed with {len(results)} messages") - - finally: - lcm_instance.stop() - # stop(dimos) - dimos.close() - dimos.shutdown() - - @pytest.mark.asyncio - async def test_tracking_rpc_methods(self, temp_dir): - """Test RPC methods on tracking modules while they're running with video.""" - - # Start Dask - dimos = core.start(1) - - try: - data_path = get_data("unitree_office_walk") - video_path = os.path.join(data_path, "video") - - video_module = dimos.deploy(VideoReplayModule, video_path) - video_module.video_out.transport = core.LCMTransport("/test_video", Image) - - person_tracker = dimos.deploy( - PersonTrackingStream, - camera_intrinsics=[619.061157, 619.061157, 317.883459, 238.543800], - camera_pitch=-0.174533, - camera_height=0.3, - ) - - object_tracker = dimos.deploy( - ObjectTrackingStream, - camera_intrinsics=[619.061157, 619.061157, 317.883459, 238.543800], - camera_pitch=-0.174533, - camera_height=0.3, - ) - - person_tracker.tracking_data.transport = core.pLCMTransport("/person_tracking") - object_tracker.tracking_data.transport = core.pLCMTransport("/object_tracking") - - person_tracker.video.connect(video_module.video_out) - object_tracker.video.connect(video_module.video_out) - - video_module.start() - person_tracker.start() - object_tracker.start() - - # person_tracker.enable_tracking() - # object_tracker.track([100, 100, 200, 200]) - await asyncio.sleep(2) - - person_data = person_tracker.get_tracking_data() - assert isinstance(person_data, dict) - assert "frame" in person_data - assert "viz_frame" in person_data - assert "targets" in person_data - assert isinstance(person_data["targets"], list) - - object_data = object_tracker.get_tracking_data() - assert isinstance(object_data, dict) - assert "frame" in object_data - assert "viz_frame" in object_data - assert "targets" in object_data - assert isinstance(object_data["targets"], list) - - assert person_data["frame"] is not None - assert object_data["frame"] is not None - - video_module.stop() - - logger.info("RPC methods test passed") - - finally: - # stop(dimos) - dimos.close() - dimos.shutdown() - - @pytest.mark.asyncio - async def test_visualization_streams(self, temp_dir): - """Test that visualization frames are properly generated.""" - - # Start Dask - dimos = core.start(1) - - try: - data_path = get_data("unitree_office_walk") - video_path = os.path.join(data_path, "video") - - video_module = dimos.deploy(VideoReplayModule, video_path) - video_module.video_out.transport = core.LCMTransport("/test_video", Image) - - person_tracker = dimos.deploy( - PersonTrackingStream, - camera_intrinsics=[619.061157, 619.061157, 317.883459, 238.543800], - camera_pitch=-0.174533, - camera_height=0.3, - ) - - object_tracker = dimos.deploy( - ObjectTrackingStream, - camera_intrinsics=[619.061157, 619.061157, 317.883459, 238.543800], - camera_pitch=-0.174533, - camera_height=0.3, - ) - - person_tracker.tracking_data.transport = core.pLCMTransport("/person_tracking") - object_tracker.tracking_data.transport = core.pLCMTransport("/object_tracking") - - person_tracker.video.connect(video_module.video_out) - object_tracker.video.connect(video_module.video_out) - - video_module.start() - person_tracker.start() - object_tracker.start() - - # person_tracker.enable_tracking() - # object_tracker.track([100, 100, 200, 200]) - - person_data = person_tracker.get_tracking_data() - object_data = object_tracker.get_tracking_data() - - video_module.stop() - - if person_data["viz_frame"] is not None: - viz_frame = person_data["viz_frame"] - assert isinstance(viz_frame, np.ndarray) - assert len(viz_frame.shape) == 3 - assert viz_frame.shape[2] == 3 - logger.info("Person tracking visualization frame verified") - - if object_data["viz_frame"] is not None: - viz_frame = object_data["viz_frame"] - assert isinstance(viz_frame, np.ndarray) - assert len(viz_frame.shape) == 3 - assert viz_frame.shape[2] == 3 - logger.info("Object tracking visualization frame verified") - - finally: - # stop(dimos) - dimos.close() - dimos.shutdown() - - -if __name__ == "__main__": - pytest.main(["-v", "-s", __file__]) diff --git a/dimos/perception/visual_servoing.py b/dimos/perception/visual_servoing.py deleted file mode 100644 index 40cee7c60c..0000000000 --- a/dimos/perception/visual_servoing.py +++ /dev/null @@ -1,500 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import threading -from typing import Dict, Optional, List, Tuple -import logging -import numpy as np - -from dimos.utils.simple_controller import VisualServoingController - -# Configure logging -logger = logging.getLogger(__name__) - - -def calculate_iou(box1, box2): - """Calculate Intersection over Union between two bounding boxes.""" - x1 = max(box1[0], box2[0]) - y1 = max(box1[1], box2[1]) - x2 = min(box1[2], box2[2]) - y2 = min(box1[3], box2[3]) - - intersection = max(0, x2 - x1) * max(0, y2 - y1) - area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) - area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) - union = area1 + area2 - intersection - - return intersection / union if union > 0 else 0 - - -class VisualServoing: - """ - A class that performs visual servoing to track and follow a human target. - - The class will use the provided tracking stream to detect people and estimate - their distance and angle, then use a VisualServoingController to generate - appropriate velocity commands to track the target. - """ - - def __init__( - self, - tracking_stream=None, - max_linear_speed=0.8, - max_angular_speed=1.5, - desired_distance=1.5, - max_lost_frames=10000, - iou_threshold=0.6, - ): - """Initialize the visual servoing. - - Args: - tracking_stream: Observable tracking stream (must be already set up) - max_linear_speed: Maximum linear speed in m/s - max_angular_speed: Maximum angular speed in rad/s - desired_distance: Desired distance to maintain from target in meters - max_lost_frames: Maximum number of frames target can be lost before stopping tracking - iou_threshold: Minimum IOU threshold to consider bounding boxes as matching - """ - self.tracking_stream = tracking_stream - self.max_linear_speed = max_linear_speed - self.max_angular_speed = max_angular_speed - self.desired_distance = desired_distance - self.max_lost_frames = max_lost_frames - self.iou_threshold = iou_threshold - - # Initialize the controller with PID parameters tuned for slow-moving robot - # Distance PID: (kp, ki, kd, output_limits, integral_limit, deadband, output_deadband) - distance_pid_params = ( - 1.0, # kp: Moderate proportional gain for smooth approach - 0.2, # ki: Small integral gain to eliminate steady-state error - 0.1, # kd: Some damping for smooth motion - (-self.max_linear_speed, self.max_linear_speed), # output_limits - 0.5, # integral_limit: Prevent windup - 0.1, # deadband: Small deadband for distance control - 0.05, # output_deadband: Minimum output to overcome friction - ) - - # Angle PID: (kp, ki, kd, output_limits, integral_limit, deadband, output_deadband) - angle_pid_params = ( - 1.4, # kp: Higher proportional gain for responsive turning - 0.1, # ki: Small integral gain - 0.05, # kd: Light damping to prevent oscillation - (-self.max_angular_speed, self.max_angular_speed), # output_limits - 0.3, # integral_limit: Prevent windup - 0.1, # deadband: Small deadband for angle control - 0.1, # output_deadband: Minimum output to overcome friction - True, # Invert output for angular control - ) - - # Initialize the visual servoing controller - self.controller = VisualServoingController( - distance_pid_params=distance_pid_params, angle_pid_params=angle_pid_params - ) - - # Initialize tracking state - self.last_control_time = time.time() - self.running = False - self.current_target = None # (target_id, bbox) - self.target_lost_frames = 0 - - # Add variables to track current distance and angle - self.current_distance = None - self.current_angle = None - - # Stream subscription management - self.subscription = None - self.latest_result = None - self.result_lock = threading.Lock() - self.stop_event = threading.Event() - - # Subscribe to the tracking stream - self._subscribe_to_tracking_stream() - - def start_tracking( - self, - desired_distance: int = None, - point: Tuple[int, int] = None, - timeout_wait_for_target: float = 20.0, - ) -> bool: - """ - Start tracking a human target using visual servoing. - - Args: - point: Optional tuple of (x, y) coordinates in image space. If provided, - will find the target whose bounding box contains this point. - If None, will track the closest person. - - Returns: - bool: True if tracking was successfully started, False otherwise - """ - if desired_distance is not None: - self.desired_distance = desired_distance - - if self.tracking_stream is None: - self.running = False - return False - - # Get the latest frame and targets from person tracker - try: - # Try getting the result multiple times with delays - for attempt in range(10): - result = self._get_current_tracking_result() - - if result is not None: - break - - logger.warning( - f"Attempt {attempt + 1}: No tracking result, retrying in 1 second..." - ) - time.sleep(3) # Wait 1 second between attempts - - if result is None: - logger.warning("Stream error, no targets found after multiple attempts") - return False - - targets = result.get("targets") - - # If bbox is provided, find matching target based on IOU - if point is not None and not self.running: - # Find the target with highest IOU to the provided bbox - best_target = self._find_target_by_point(point, targets) - # If no bbox is provided, find the closest person - elif not self.running: - if timeout_wait_for_target > 0.0 and len(targets) == 0: - # Wait for target to appear - start_time = time.time() - while time.time() - start_time < timeout_wait_for_target: - time.sleep(0.2) - result = self._get_current_tracking_result() - targets = result.get("targets") - if len(targets) > 0: - break - best_target = self._find_closest_target(targets) - else: - # Already tracking - return True - - if best_target: - # Set as current target and reset lost counter - target_id = best_target.get("target_id") - target_bbox = best_target.get("bbox") - self.current_target = (target_id, target_bbox) - self.target_lost_frames = 0 - self.running = True - logger.info(f"Started tracking target ID: {target_id}") - - # Get distance and angle and compute control (store as initial control values) - distance = best_target.get("distance") - angle = best_target.get("angle") - self._compute_control(distance, angle) - return True - else: - if point is not None: - logger.warning("No matching target found") - else: - logger.warning("No suitable target found for tracking") - self.running = False - return False - except Exception as e: - logger.error(f"Error starting tracking: {e}") - self.running = False - return False - - def _find_target_by_point(self, point, targets): - """Find the target whose bounding box contains the given point. - - Args: - point: Tuple of (x, y) coordinates in image space - targets: List of target dictionaries - - Returns: - dict: The target whose bbox contains the point, or None if no match - """ - x, y = point - for target in targets: - bbox = target.get("bbox") - if not bbox: - continue - - x1, y1, x2, y2 = bbox - if x1 <= x <= x2 and y1 <= y <= y2: - return target - return None - - def updateTracking(self) -> Dict[str, any]: - """ - Update tracking of current target. - - Returns: - Dict with linear_vel, angular_vel, and running state - """ - if not self.running or self.current_target is None: - self.running = False - self.current_distance = None - self.current_angle = None - return {"linear_vel": 0.0, "angular_vel": 0.0} - - # Get the latest tracking result - result = self._get_current_tracking_result() - - # Get targets from result - targets = result.get("targets") - - # Try to find current target by ID or IOU - current_target_id, current_bbox = self.current_target - target_found = False - - # First try to find by ID - for target in targets: - if target.get("target_id") == current_target_id: - # Found by ID, update bbox - self.current_target = (current_target_id, target.get("bbox")) - self.target_lost_frames = 0 - target_found = True - - # Store current distance and angle - self.current_distance = target.get("distance") - self.current_angle = target.get("angle") - - # Compute control - control = self._compute_control(self.current_distance, self.current_angle) - return control - - # If not found by ID, try to find by IOU - if not target_found and current_bbox is not None: - best_target = self._find_best_target_by_iou(current_bbox, targets) - if best_target: - # Update target - new_id = best_target.get("target_id") - new_bbox = best_target.get("bbox") - self.current_target = (new_id, new_bbox) - self.target_lost_frames = 0 - logger.info(f"Target ID updated: {current_target_id} -> {new_id}") - - # Store current distance and angle - self.current_distance = best_target.get("distance") - self.current_angle = best_target.get("angle") - - # Compute control - control = self._compute_control(self.current_distance, self.current_angle) - return control - - # Target not found, increment lost counter - if not target_found: - self.target_lost_frames += 1 - logger.warning(f"Target lost: frame {self.target_lost_frames}/{self.max_lost_frames}") - - # Check if target is lost for too many frames - if self.target_lost_frames >= self.max_lost_frames: - logger.info("Target lost for too many frames, stopping tracking") - self.stop_tracking() - return {"linear_vel": 0.0, "angular_vel": 0.0, "running": False} - - return {"linear_vel": 0.0, "angular_vel": 0.0} - - def _compute_control(self, distance: float, angle: float) -> Dict[str, float]: - """ - Compute control commands based on measured distance and angle. - - Args: - distance: Measured distance to target in meters - angle: Measured angle to target in radians - - Returns: - Dict with linear_vel and angular_vel keys - """ - current_time = time.time() - dt = current_time - self.last_control_time - self.last_control_time = current_time - - # Compute control with visual servoing controller - linear_vel, angular_vel = self.controller.compute_control( - measured_distance=distance, - measured_angle=angle, - desired_distance=self.desired_distance, - desired_angle=0.0, # Keep target centered - dt=dt, - ) - - # Log control values for debugging - logger.debug(f"Distance: {distance:.2f}m, Angle: {np.rad2deg(angle):.1f}°") - logger.debug(f"Control: linear={linear_vel:.2f}m/s, angular={angular_vel:.2f}rad/s") - - return {"linear_vel": linear_vel, "angular_vel": angular_vel} - - def _find_best_target_by_iou(self, bbox: List[float], targets: List[Dict]) -> Optional[Dict]: - """ - Find the target with highest IOU to the given bbox. - - Args: - bbox: Bounding box to match [x1, y1, x2, y2] - targets: List of target dictionaries - - Returns: - Best matching target or None if no match found - """ - if not targets: - return None - - best_iou = self.iou_threshold - best_target = None - - for target in targets: - target_bbox = target.get("bbox") - if target_bbox is None: - continue - - iou = calculate_iou(bbox, target_bbox) - if iou > best_iou: - best_iou = iou - best_target = target - - return best_target - - def _find_closest_target(self, targets: List[Dict]) -> Optional[Dict]: - """ - Find the target with shortest distance to the camera. - - Args: - targets: List of target dictionaries - - Returns: - The closest target or None if no targets available - """ - if not targets: - return None - - closest_target = None - min_distance = float("inf") - - for target in targets: - distance = target.get("distance") - if distance is not None and distance < min_distance: - min_distance = distance - closest_target = target - - return closest_target - - def _subscribe_to_tracking_stream(self): - """ - Subscribe to the already set up tracking stream. - """ - if self.tracking_stream is None: - logger.warning("No tracking stream provided to subscribe to") - return - - try: - # Set up subscription to process frames - self.subscription = self.tracking_stream.subscribe( - on_next=self._on_tracking_result, - on_error=self._on_tracking_error, - on_completed=self._on_tracking_completed, - ) - - logger.info("Subscribed to tracking stream successfully") - except Exception as e: - logger.error(f"Error subscribing to tracking stream: {e}") - - def _on_tracking_result(self, result): - """ - Callback for tracking stream results. - - This updates the latest result for use by _get_current_tracking_result. - - Args: - result: The result from the tracking stream - """ - if self.stop_event.is_set(): - return - - # Update the latest result - with self.result_lock: - self.latest_result = result - - def _on_tracking_error(self, error): - """ - Callback for tracking stream errors. - - Args: - error: The error from the tracking stream - """ - logger.error(f"Tracking stream error: {error}") - self.stop_event.set() - - def _on_tracking_completed(self): - """Callback for tracking stream completion.""" - logger.info("Tracking stream completed") - self.stop_event.set() - - def _get_current_tracking_result(self) -> Optional[Dict]: - """ - Get the current tracking result. - - Returns the latest result cached from the tracking stream subscription. - - Returns: - Dict with 'frame' and 'targets' or None if not available - """ - # Return the latest cached result - with self.result_lock: - return self.latest_result - - def stop_tracking(self): - """Stop tracking and reset controller state.""" - self.running = False - self.current_target = None - self.target_lost_frames = 0 - self.current_distance = None - self.current_angle = None - return {"linear_vel": 0.0, "angular_vel": 0.0, "running": False} - - def is_goal_reached(self, distance_threshold=0.2, angle_threshold=0.1) -> bool: - """ - Check if the robot has reached the tracking goal (desired distance and angle). - - Args: - distance_threshold: Maximum allowed difference between current and desired distance (meters) - angle_threshold: Maximum allowed difference between current and desired angle (radians) - - Returns: - bool: True if both distance and angle are within threshold of desired values - """ - if not self.running or self.current_target is None: - return False - - # Use the stored distance and angle values - if self.current_distance is None or self.current_angle is None: - return False - - # Check if within thresholds - distance_error = abs(self.current_distance - self.desired_distance) - angle_error = abs(self.current_angle) # Desired angle is always 0 (centered) - - logger.debug( - f"Goal check - Distance error: {distance_error:.2f}m, Angle error: {angle_error:.2f}rad" - ) - - return (distance_error <= distance_threshold) and (angle_error <= angle_threshold) - - def cleanup(self): - """Clean up all resources used by the visual servoing.""" - self.stop_event.set() - if self.subscription: - self.subscription.dispose() - self.subscription = None - - def __del__(self): - """Destructor to ensure cleanup on object deletion.""" - self.cleanup() diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index 68a5bd3008..b01ae40cca 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -80,6 +80,43 @@ def unsubscribe(): return unsubscribe + def wait_for_message(self, topic: Topic, timeout: float = 1.0) -> Any: + """Wait for a single message on the specified topic. + + Args: + topic: The topic to listen on + timeout: Maximum time to wait for a message in seconds + + Returns: + The received message or None if timeout occurred + """ + received_message = None + message_event = threading.Event() + + def message_handler(channel, data): + nonlocal received_message + try: + # Decode the message if type is specified + if hasattr(self, "decode") and topic.lcm_type is not None: + received_message = self.decode(data, topic) + else: + received_message = data + message_event.set() + except Exception as e: + print(f"Error decoding message: {e}") + message_event.set() + + # Subscribe to the topic + subscription = self.l.subscribe(str(topic), message_handler) + + try: + # Wait for message or timeout + message_event.wait(timeout) + return received_message + finally: + # Clean up subscription + self.l.unsubscribe(subscription) + class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any]): def encode(self, msg: LCMMsg, _: Topic) -> bytes: diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index fb30f41a07..452255e4c6 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -13,6 +13,7 @@ # limitations under the License. import subprocess +import threading import time from unittest.mock import patch @@ -179,3 +180,211 @@ def callback(msg, topic): assert received_topic == topic print(test_message, topic) + + +def test_wait_for_message_basic(): + """Test basic wait_for_message functionality - message arrives before timeout.""" + lcm = LCM(autoconf=True) + lcm.start() + + topic = Topic(topic="/test_wait", lcm_type=MockLCMMessage) + test_message = MockLCMMessage("wait_test_data") + + # Publish message after a short delay in another thread + def publish_delayed(): + time.sleep(0.1) + lcm.publish(topic, test_message) + + publisher_thread = threading.Thread(target=publish_delayed) + publisher_thread.start() + + # Wait for message with 1 second timeout + start_time = time.time() + received_msg = lcm.wait_for_message(topic, timeout=1.0) + elapsed_time = time.time() - start_time + + publisher_thread.join() + + # Check that we received the message + assert received_msg is not None + assert isinstance(received_msg, MockLCMMessage) + assert received_msg.data == "wait_test_data" + + # Check that we didn't wait the full timeout + assert elapsed_time < 0.5 # Should receive message in ~0.1 seconds + + +def test_wait_for_message_timeout(): + """Test wait_for_message timeout - no message published.""" + lcm = LCM(autoconf=True) + lcm.start() + + topic = Topic(topic="/test_timeout", lcm_type=MockLCMMessage) + + # Wait for message that will never come + start_time = time.time() + received_msg = lcm.wait_for_message(topic, timeout=0.5) + elapsed_time = time.time() - start_time + + # Check that we got None (timeout) + assert received_msg is None + + # Check that we waited approximately the timeout duration + assert 0.4 < elapsed_time < 0.7 # Allow some tolerance + + +def test_wait_for_message_immediate(): + """Test wait_for_message with message published immediately after subscription.""" + lcm = LCM(autoconf=True) + lcm.start() + + topic = Topic(topic="/test_immediate", lcm_type=MockLCMMessage) + test_message = MockLCMMessage("immediate_data") + + # Start waiting in a thread + received_msg = None + + def wait_for_msg(): + nonlocal received_msg + received_msg = lcm.wait_for_message(topic, timeout=1.0) + + wait_thread = threading.Thread(target=wait_for_msg) + wait_thread.start() + + # Give a tiny bit of time for subscription to be established + time.sleep(0.01) + + # Now publish the message + start_time = time.time() + lcm.publish(topic, test_message) + + # Wait for the thread to complete + wait_thread.join() + elapsed_time = time.time() - start_time + + # Check that we received the message quickly + assert received_msg is not None + assert isinstance(received_msg, MockLCMMessage) + assert received_msg.data == "immediate_data" + assert elapsed_time < 0.2 # Should be nearly immediate + + +def test_wait_for_message_multiple_sequential(): + """Test multiple sequential wait_for_message calls.""" + lcm = LCM(autoconf=True) + lcm.start() + + topic = Topic(topic="/test_sequential", lcm_type=MockLCMMessage) + + # Test multiple messages in sequence + messages = ["msg1", "msg2", "msg3"] + + for msg_data in messages: + test_message = MockLCMMessage(msg_data) + + # Publish in background + def publish_delayed(msg=test_message): + time.sleep(0.05) + lcm.publish(topic, msg) + + publisher_thread = threading.Thread(target=publish_delayed) + publisher_thread.start() + + # Wait and verify + received_msg = lcm.wait_for_message(topic, timeout=1.0) + assert received_msg is not None + assert received_msg.data == msg_data + + publisher_thread.join() + + +def test_wait_for_message_concurrent(): + """Test concurrent wait_for_message calls on different topics.""" + lcm = LCM(autoconf=True) + lcm.start() + + topic1 = Topic(topic="/test_concurrent1", lcm_type=MockLCMMessage) + topic2 = Topic(topic="/test_concurrent2", lcm_type=MockLCMMessage) + + message1 = MockLCMMessage("concurrent1") + message2 = MockLCMMessage("concurrent2") + + received_messages = {} + + def wait_for_topic(topic_name, topic): + msg = lcm.wait_for_message(topic, timeout=2.0) + received_messages[topic_name] = msg + + # Start waiting on both topics + thread1 = threading.Thread(target=wait_for_topic, args=("topic1", topic1)) + thread2 = threading.Thread(target=wait_for_topic, args=("topic2", topic2)) + + thread1.start() + thread2.start() + + # Publish to both topics after a delay + time.sleep(0.1) + lcm.publish(topic1, message1) + lcm.publish(topic2, message2) + + # Wait for both threads to complete + thread1.join(timeout=3.0) + thread2.join(timeout=3.0) + + # Verify both messages were received + assert "topic1" in received_messages + assert "topic2" in received_messages + assert received_messages["topic1"].data == "concurrent1" + assert received_messages["topic2"].data == "concurrent2" + + +def test_wait_for_message_wrong_topic(): + """Test wait_for_message doesn't receive messages from wrong topic.""" + lcm = LCM(autoconf=True) + lcm.start() + + topic_correct = Topic(topic="/test_correct", lcm_type=MockLCMMessage) + topic_wrong = Topic(topic="/test_wrong", lcm_type=MockLCMMessage) + + message = MockLCMMessage("wrong_topic_data") + + # Publish to wrong topic + lcm.publish(topic_wrong, message) + + # Wait on correct topic + received_msg = lcm.wait_for_message(topic_correct, timeout=0.3) + + # Should timeout and return None + assert received_msg is None + + +def test_wait_for_message_pickle(): + """Test wait_for_message with PickleLCM.""" + lcm = PickleLCM(autoconf=True) + lcm.start() + + topic = Topic(topic="/test_pickle") + test_obj = {"key": "value", "number": 42} + + # Publish after delay + def publish_delayed(): + time.sleep(0.1) + lcm.publish(topic, test_obj) + + publisher_thread = threading.Thread(target=publish_delayed) + publisher_thread.start() + + # Wait for message + received_msg = lcm.wait_for_message(topic, timeout=1.0) + + publisher_thread.join() + + # Verify received object + assert received_msg is not None + # PickleLCM's wait_for_message returns the pickled bytes, need to decode + import pickle + + decoded_msg = pickle.loads(received_msg) + assert decoded_msg == test_obj + assert decoded_msg["key"] == "value" + assert decoded_msg["number"] == 42 diff --git a/dimos/protocol/tf/__init__.py b/dimos/protocol/tf/__init__.py index 36d82e8394..518a9b97f0 100644 --- a/dimos/protocol/tf/__init__.py +++ b/dimos/protocol/tf/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.protocol.tf.tf import TF, LCMTF, PubSubTF, TFSpec, TFConfig +from dimos.protocol.tf.tf import TF, LCMTF, PubSubTF, TFSpec, TFConfig, TBuffer, MultiTBuffer -__all__ = ["TF", "LCMTF", "PubSubTF", "TFSpec", "TFConfig"] +__all__ = ["TF", "LCMTF", "PubSubTF", "TFSpec", "TFConfig", "TBuffer", "MultiTBuffer"] diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 5c9489c87d..72fdfc5d3e 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -18,7 +18,7 @@ from dimos.core import TF from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.protocol.tf.tf import MultiTBuffer, TBuffer +from dimos.protocol.tf import MultiTBuffer, TBuffer def test_tf_main(): diff --git a/dimos/robot/frontier_exploration/__init__.py b/dimos/robot/frontier_exploration/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/dimos/robot/frontier_exploration/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/dimos/robot/frontier_exploration/qwen_frontier_predictor.py b/dimos/robot/frontier_exploration/qwen_frontier_predictor.py deleted file mode 100644 index 2ccdb89a17..0000000000 --- a/dimos/robot/frontier_exploration/qwen_frontier_predictor.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Qwen-based frontier exploration goal predictor using vision language model. - -This module provides a frontier goal detector that uses Qwen's vision capabilities -to analyze costmap images and predict optimal exploration goals. -""" - -import os -import glob -import json -import re -from typing import Optional, List, Tuple - -import numpy as np -from PIL import Image, ImageDraw - -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector -from dimos.models.qwen.video_query import query_single_frame -from dimos.robot.frontier_exploration.utils import ( - costmap_to_pil_image, - smooth_costmap_for_frontiers, -) - - -class QwenFrontierPredictor: - """ - Qwen-based frontier exploration goal predictor. - - Uses Qwen's vision capabilities to analyze costmap images and predict - optimal exploration goals based on visual understanding of the map structure. - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "qwen2.5-vl-72b-instruct", - use_smoothed_costmap: bool = True, - image_scale_factor: int = 4, - ): - """ - Initialize the Qwen frontier predictor. - - Args: - api_key: Alibaba API key for Qwen access - model_name: Qwen model to use for predictions - image_scale_factor: Scale factor for image processing - """ - self.api_key = api_key or os.getenv("ALIBABA_API_KEY") - if not self.api_key: - raise ValueError( - "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" - ) - - self.model_name = model_name - self.image_scale_factor = image_scale_factor - self.use_smoothed_costmap = use_smoothed_costmap - - # Storage for previously explored goals - self.explored_goals: List[Vector] = [] - - def _world_to_image_coords(self, world_pos: Vector, costmap: Costmap) -> Tuple[int, int]: - """Convert world coordinates to image pixel coordinates.""" - grid_pos = costmap.world_to_grid(world_pos) - img_x = int(grid_pos.x * self.image_scale_factor) - img_y = int((costmap.height - grid_pos.y) * self.image_scale_factor) # Flip Y - return img_x, img_y - - def _image_to_world_coords(self, img_x: int, img_y: int, costmap: Costmap) -> Vector: - """Convert image pixel coordinates to world coordinates.""" - # Unscale and flip Y coordinate - grid_x = img_x / self.image_scale_factor - grid_y = costmap.height - (img_y / self.image_scale_factor) - - # Convert grid to world coordinates - world_pos = costmap.grid_to_world(Vector([grid_x, grid_y])) - return world_pos - - def _draw_goals_on_image( - self, - image: Image.Image, - robot_pose: Vector, - costmap: Costmap, - latest_goal: Optional[Vector] = None, - ) -> Image.Image: - """ - Draw explored goals and robot position on the costmap image. - - Args: - image: PIL Image to draw on - robot_pose: Current robot position - costmap: Costmap for coordinate conversion - latest_goal: Latest predicted goal to highlight in red - - Returns: - PIL Image with goals drawn - """ - img_copy = image.copy() - draw = ImageDraw.Draw(img_copy) - - # Draw previously explored goals as green dots - for explored_goal in self.explored_goals: - x, y = self._world_to_image_coords(explored_goal, costmap) - radius = 8 - draw.ellipse( - [x - radius, y - radius, x + radius, y + radius], - fill=(0, 255, 0), - outline=(0, 128, 0), - width=2, - ) - - # Draw robot position as blue dot - robot_x, robot_y = self._world_to_image_coords(robot_pose, costmap) - robot_radius = 10 - draw.ellipse( - [ - robot_x - robot_radius, - robot_y - robot_radius, - robot_x + robot_radius, - robot_y + robot_radius, - ], - fill=(0, 0, 255), - outline=(0, 0, 128), - width=3, - ) - - # Draw latest predicted goal as red dot - if latest_goal: - goal_x, goal_y = self._world_to_image_coords(latest_goal, costmap) - goal_radius = 12 - draw.ellipse( - [ - goal_x - goal_radius, - goal_y - goal_radius, - goal_x + goal_radius, - goal_y + goal_radius, - ], - fill=(255, 0, 0), - outline=(128, 0, 0), - width=3, - ) - - return img_copy - - def _create_vision_prompt(self) -> str: - """Create the vision prompt for Qwen model.""" - prompt = """You are an expert robot navigation system analyzing a costmap for frontier exploration. - -COSTMAP LEGEND: -- Light gray pixels (205,205,205): FREE SPACE - areas the robot can navigate -- Dark gray pixels (128,128,128): UNKNOWN SPACE - unexplored areas that need exploration -- Black pixels (0,0,0): OBSTACLES - walls, furniture, blocked areas -- Blue dot: CURRENT ROBOT POSITION -- Green dots: PREVIOUSLY EXPLORED GOALS - avoid these areas - -TASK: Find the best frontier exploration goal by identifying the optimal point where: -1. It's at the boundary between FREE SPACE (light gray) and UNKNOWN SPACE (dark gray) (HIGHEST Priority) -2. It's reasonably far from the robot position (blue dot) (MEDIUM Priority) -3. It's reasonably far from previously explored goals (green dots) (MEDIUM Priority) -4. It leads to a large area of unknown space to explore (HIGH Priority) -5. It's accessible from the robot's current position through free space (MEDIUM Priority) -6. It's not near or on obstacles (HIGHEST Priority) - -RESPONSE FORMAT: Return ONLY the pixel coordinates as a JSON object: -{"x": pixel_x_coordinate, "y": pixel_y_coordinate, "reasoning": "brief explanation"} - -Example: {"x": 245, "y": 187, "reasoning": "Large unknown area to the north, good distance from robot and previous goals"} - -Analyze the image and identify the single best frontier exploration goal.""" - - return prompt - - def _parse_prediction_response(self, response: str) -> Optional[Tuple[int, int, str]]: - """ - Parse the model's response to extract coordinates and reasoning. - - Args: - response: Raw response from Qwen model - - Returns: - Tuple of (x, y, reasoning) or None if parsing failed - """ - try: - # Try to find JSON object in response - json_match = re.search(r"\{[^}]*\}", response) - if json_match: - json_str = json_match.group() - data = json.loads(json_str) - - if "x" in data and "y" in data: - x = int(data["x"]) - y = int(data["y"]) - reasoning = data.get("reasoning", "No reasoning provided") - return (x, y, reasoning) - - # Fallback: try to extract coordinates with regex - coord_match = re.search(r"[^\d]*(\d+)[^\d]+(\d+)", response) - if coord_match: - x = int(coord_match.group(1)) - y = int(coord_match.group(2)) - return (x, y, "Coordinates extracted from response") - - except (json.JSONDecodeError, ValueError, KeyError) as e: - print(f"DEBUG: Failed to parse prediction response: {e}") - - return None - - def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional[Vector]: - """ - Get the best exploration goal using Qwen vision analysis. - - Args: - robot_pose: Current robot position in world coordinates - costmap: Current costmap for analysis - - Returns: - Single best frontier goal in world coordinates, or None if no suitable goal found - """ - print( - f"DEBUG: Qwen frontier prediction starting with {len(self.explored_goals)} explored goals" - ) - - # Create costmap image - if self.use_smoothed_costmap: - costmap = smooth_costmap_for_frontiers(costmap, alpha=4.0) - - base_image = costmap_to_pil_image(costmap, self.image_scale_factor) - - # Draw goals on image (without latest goal initially) - annotated_image = self._draw_goals_on_image(base_image, robot_pose, costmap) - - # Query Qwen model for frontier prediction - try: - prompt = self._create_vision_prompt() - - # Convert PIL image to numpy array for query_single_frame - annotated_array = np.array(annotated_image) - - response = query_single_frame( - annotated_array, prompt, api_key=self.api_key, model_name=self.model_name - ) - - print(f"DEBUG: Qwen response: {response}") - - # Parse response to get coordinates - parsed_result = self._parse_prediction_response(response) - if not parsed_result: - print("DEBUG: Failed to parse Qwen response") - return None - - img_x, img_y, reasoning = parsed_result - print(f"DEBUG: Parsed coordinates: ({img_x}, {img_y}), Reasoning: {reasoning}") - - # Convert image coordinates to world coordinates - predicted_goal = self._image_to_world_coords(img_x, img_y, costmap) - print( - f"DEBUG: Predicted goal in world coordinates: ({predicted_goal.x:.2f}, {predicted_goal.y:.2f})" - ) - - # Store the goal in explored_goals for future reference - self.explored_goals.append(predicted_goal) - - print(f"DEBUG: Successfully predicted frontier goal: {predicted_goal}") - return predicted_goal - - except Exception as e: - print(f"DEBUG: Error during Qwen prediction: {e}") - return None - - -def test_qwen_frontier_detection(): - """ - Visual test for Qwen frontier detection using saved costmaps. - Shows frontier detection results with Qwen predictions. - """ - - # Path to saved costmaps - saved_maps_dir = os.path.join(os.getcwd(), "assets", "saved_maps") - - if not os.path.exists(saved_maps_dir): - print(f"Error: Saved maps directory not found: {saved_maps_dir}") - return - - # Get all pickle files - pickle_files = sorted(glob.glob(os.path.join(saved_maps_dir, "*.pickle"))) - - if not pickle_files: - print(f"No pickle files found in {saved_maps_dir}") - return - - print(f"Found {len(pickle_files)} costmap files for Qwen testing") - - # Initialize Qwen frontier predictor - predictor = QwenFrontierPredictor(image_scale_factor=4, use_smoothed_costmap=False) - - # Track the robot pose across iterations - robot_pose = None - - # Process each costmap - for i, pickle_file in enumerate(pickle_files): - print( - f"\n--- Processing costmap {i + 1}/{len(pickle_files)}: {os.path.basename(pickle_file)} ---" - ) - - try: - # Load the costmap - costmap = Costmap.from_pickle(pickle_file) - print( - f"Loaded costmap: {costmap.width}x{costmap.height}, resolution: {costmap.resolution}" - ) - - # Set robot pose: first iteration uses center, subsequent use last predicted goal - if robot_pose is None: - # First iteration: use center of costmap as robot position - center_world = costmap.grid_to_world( - Vector([costmap.width / 2, costmap.height / 2]) - ) - robot_pose = Vector([center_world.x, center_world.y]) - # else: robot_pose remains the last predicted goal - - print(f"Using robot position: {robot_pose}") - - # Get frontier prediction from Qwen - print("Getting Qwen frontier prediction...") - predicted_goal = predictor.get_exploration_goal(robot_pose, costmap) - - if predicted_goal: - distance = np.sqrt( - (predicted_goal.x - robot_pose.x) ** 2 + (predicted_goal.y - robot_pose.y) ** 2 - ) - print(f"Predicted goal: {predicted_goal}, Distance: {distance:.2f}m") - - # Show the final visualization - base_image = costmap_to_pil_image(costmap, predictor.image_scale_factor) - final_image = predictor._draw_goals_on_image( - base_image, robot_pose, costmap, predicted_goal - ) - - # Display image - title = f"Qwen Frontier Prediction {i + 1:04d}" - final_image.show(title=title) - - # Update robot pose for next iteration - robot_pose = predicted_goal - - else: - print("No suitable frontier goal predicted by Qwen") - - except Exception as e: - print(f"Error processing {pickle_file}: {e}") - continue - - print(f"\n=== Qwen Frontier Detection Test Complete ===") - print(f"Final explored goals count: {len(predictor.explored_goals)}") - - -if __name__ == "__main__": - test_qwen_frontier_detection() diff --git a/dimos/robot/global_planner/__init__.py b/dimos/robot/global_planner/__init__.py deleted file mode 100644 index f26a5e8f7c..0000000000 --- a/dimos/robot/global_planner/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from dimos.robot.global_planner.planner import AstarPlanner, Planner diff --git a/dimos/robot/global_planner/algo.py b/dimos/robot/global_planner/algo.py deleted file mode 100644 index 893efa4de9..0000000000 --- a/dimos/robot/global_planner/algo.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import heapq -import math -from collections import deque -from typing import Optional, Tuple - -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, VectorLike -from dimos.msgs.geometry_msgs import Vector3 as Vector -from dimos.msgs.nav_msgs import OccupancyGrid, Path -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.unitree.global_planner.astar") - - -def find_nearest_free_cell( - costmap: OccupancyGrid, - position: VectorLike, - cost_threshold: int = 90, - max_search_radius: int = 20, -) -> Tuple[int, int]: - """ - Find the nearest unoccupied cell in the costmap using BFS. - - Args: - costmap: Costmap object containing the environment - position: Position to find nearest free cell from - cost_threshold: Cost threshold above which a cell is considered an obstacle - max_search_radius: Maximum search radius in cells - - Returns: - Tuple of (x, y) in grid coordinates of the nearest free cell, - or the original position if no free cell is found within max_search_radius - """ - # Convert world coordinates to grid coordinates - grid_pos = costmap.world_to_grid(position) - start_x, start_y = int(grid_pos.x), int(grid_pos.y) - - # If the cell is already free, return it - if 0 <= start_x < costmap.width and 0 <= start_y < costmap.height: - if costmap.grid[start_y, start_x] < cost_threshold: - return (start_x, start_y) - - # BFS to find nearest free cell - queue = deque([(start_x, start_y, 0)]) # (x, y, distance) - visited = set([(start_x, start_y)]) - - # Possible movements (8-connected grid) - directions = [ - (0, 1), - (1, 0), - (0, -1), - (-1, 0), # horizontal/vertical - (1, 1), - (1, -1), - (-1, 1), - (-1, -1), # diagonal - ] - - while queue: - x, y, dist = queue.popleft() - - # Check if we've reached the maximum search radius - if dist > max_search_radius: - logger.info( - f"Could not find free cell within {max_search_radius} cells of ({start_x}, {start_y})" - ) - return (start_x, start_y) # Return original position if no free cell found - - # Check if this cell is valid and free - if 0 <= x < costmap.width and 0 <= y < costmap.height: - if costmap.grid[y, x] < cost_threshold: - logger.info( - f"Found free cell at ({x}, {y}), {dist} cells away from ({start_x}, {start_y})" - ) - return (x, y) - - # Add neighbors to the queue - for dx, dy in directions: - nx, ny = x + dx, y + dy - if (nx, ny) not in visited: - visited.add((nx, ny)) - queue.append((nx, ny, dist + 1)) - - # If the queue is empty and no free cell is found, return the original position - return (start_x, start_y) - - -def astar( - costmap: OccupancyGrid, - goal: VectorLike, - start: VectorLike = (0.0, 0.0), - cost_threshold: int = 90, - allow_diagonal: bool = True, -) -> Optional[Path]: - """ - A* path planning algorithm from start to goal position. - - Args: - costmap: Costmap object containing the environment - goal: Goal position as any vector-like object - start: Start position as any vector-like object (default: origin [0,0]) - cost_threshold: Cost threshold above which a cell is considered an obstacle - allow_diagonal: Whether to allow diagonal movements - - Returns: - Path object containing waypoints, or None if no path found - """ - - # Convert world coordinates to grid coordinates directly using vector-like inputs - start_vector = costmap.world_to_grid(start) - goal_vector = costmap.world_to_grid(goal) - logger.info(f"ASTAR {costmap} {start_vector} -> {goal_vector}") - - # Store original positions for reference - original_start = (int(start_vector.x), int(start_vector.y)) - original_goal = (int(goal_vector.x), int(goal_vector.y)) - - adjusted_start = original_start - adjusted_goal = original_goal - - # Check if start is out of bounds or in an obstacle - start_valid = 0 <= start_vector.x < costmap.width and 0 <= start_vector.y < costmap.height - - start_in_obstacle = False - if start_valid: - start_in_obstacle = costmap.grid[int(start_vector.y), int(start_vector.x)] >= cost_threshold - - if not start_valid or start_in_obstacle: - logger.info("Start position is out of bounds or in an obstacle, finding nearest free cell") - adjusted_start = find_nearest_free_cell(costmap, start, cost_threshold) - # Update start_vector for later use - start_vector = Vector(adjusted_start[0], adjusted_start[1]) - - # Check if goal is out of bounds or in an obstacle - goal_valid = 0 <= goal_vector.x < costmap.width and 0 <= goal_vector.y < costmap.height - - goal_in_obstacle = False - if goal_valid: - goal_in_obstacle = costmap.grid[int(goal_vector.y), int(goal_vector.x)] >= cost_threshold - - if not goal_valid or goal_in_obstacle: - logger.info("Goal position is out of bounds or in an obstacle, finding nearest free cell") - adjusted_goal = find_nearest_free_cell(costmap, goal, cost_threshold) - # Update goal_vector for later use - goal_vector = Vector(adjusted_goal[0], adjusted_goal[1]) - - # Define possible movements (8-connected grid) - if allow_diagonal: - # 8-connected grid: horizontal, vertical, and diagonal movements - directions = [ - (0, 1), - (1, 0), - (0, -1), - (-1, 0), - (1, 1), - (1, -1), - (-1, 1), - (-1, -1), - ] - else: - # 4-connected grid: only horizontal and vertical ts - directions = [(0, 1), (1, 0), (0, -1), (-1, 0)] - - # Cost for each movement (straight vs diagonal) - sc = 1.0 - dc = 1.42 - movement_costs = [sc, sc, sc, sc, dc, dc, dc, dc] if allow_diagonal else [sc, sc, sc, sc] - - # A* algorithm implementation - open_set = [] # Priority queue for nodes to explore - closed_set = set() # Set of explored nodes - - # Use adjusted positions as tuples for dictionary keys - start_tuple = adjusted_start - goal_tuple = adjusted_goal - - # Dictionary to store cost from start and parents for each node - g_score = {start_tuple: 0} - parents = {} - - # Heuristic function (Euclidean distance) - def heuristic(x1, y1, x2, y2): - return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) - - # Start with the starting node - f_score = g_score[start_tuple] + heuristic( - start_tuple[0], start_tuple[1], goal_tuple[0], goal_tuple[1] - ) - heapq.heappush(open_set, (f_score, start_tuple)) - - while open_set: - # Get the node with the lowest f_score - _, current = heapq.heappop(open_set) - current_x, current_y = current - - # Check if we've reached the goal - if current == goal_tuple: - # Reconstruct the path - waypoints = [] - while current in parents: - world_point = costmap.grid_to_world(current) - # Create PoseStamped with identity quaternion (no orientation) - pose = PoseStamped( - frame_id="world", - position=[world_point.x, world_point.y, 0.0], - orientation=Quaternion(0, 0, 0, 1), # Identity quaternion - ) - waypoints.append(pose) - current = parents[current] - - # Add the start position - start_world_point = costmap.grid_to_world(start_tuple) - start_pose = PoseStamped( - frame_id="world", - position=[start_world_point.x, start_world_point.y, 0.0], - orientation=Quaternion(0, 0, 0, 1), - ) - waypoints.append(start_pose) - - # Reverse the path (start to goal) - waypoints.reverse() - - # Add the goal position if it's not already included - goal_point = costmap.grid_to_world(goal_tuple) - - if ( - not waypoints - or (waypoints[-1].x - goal_point.x) ** 2 + (waypoints[-1].y - goal_point.y) ** 2 - > 1e-10 - ): - goal_pose = PoseStamped( - frame_id="world", - position=[goal_point.x, goal_point.y, 0.0], - orientation=Quaternion(0, 0, 0, 1), - ) - waypoints.append(goal_pose) - - # If we adjusted the goal, add the original goal as the final point - if adjusted_goal != original_goal and goal_valid: - original_goal_point = costmap.grid_to_world(original_goal) - original_goal_pose = PoseStamped( - frame_id="world", - position=[original_goal_point.x, original_goal_point.y, 0.0], - orientation=Quaternion(0, 0, 0, 1), - ) - waypoints.append(original_goal_pose) - - return Path(frame_id="world", poses=waypoints) - - # Add current node to closed set - closed_set.add(current) - - # Explore neighbors - for i, (dx, dy) in enumerate(directions): - neighbor_x, neighbor_y = current_x + dx, current_y + dy - neighbor = (neighbor_x, neighbor_y) - - # Check if the neighbor is valid - if not (0 <= neighbor_x < costmap.width and 0 <= neighbor_y < costmap.height): - continue - - # Check if the neighbor is already explored - if neighbor in closed_set: - continue - - # Check if the neighbor is an obstacle - neighbor_val = costmap.grid[neighbor_y, neighbor_x] - if neighbor_val >= cost_threshold: # or neighbor_val < 0: - continue - - obstacle_proximity_penalty = costmap.grid[neighbor_y, neighbor_x] / 25 - tentative_g_score = ( - g_score[current] - + movement_costs[i] - + (obstacle_proximity_penalty * movement_costs[i]) - ) - - # Get the current g_score for the neighbor or set to infinity if not yet explored - neighbor_g_score = g_score.get(neighbor, float("inf")) - - # If this path to the neighbor is better than any previous one - if tentative_g_score < neighbor_g_score: - # Update the neighbor's scores and parent - parents[neighbor] = current - g_score[neighbor] = tentative_g_score - f_score = tentative_g_score + heuristic( - neighbor_x, neighbor_y, goal_tuple[0], goal_tuple[1] - ) - - # Add the neighbor to the open set with its f_score - heapq.heappush(open_set, (f_score, neighbor)) - - # If we get here, no path was found - return None diff --git a/dimos/robot/local_planner/__init__.py b/dimos/robot/local_planner/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/local_planner/local_planner.py b/dimos/robot/local_planner/local_planner.py deleted file mode 100644 index 3c9633a67c..0000000000 --- a/dimos/robot/local_planner/local_planner.py +++ /dev/null @@ -1,1464 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import math -import threading -import time -from abc import ABC, abstractmethod -from collections import deque -from typing import Any, Callable, Dict, Optional, Tuple - -import cv2 -import numpy as np -import reactivex as rx -from reactivex import Observable -from reactivex import operators as ops -from reactivex.subject import Subject - -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Vector3 -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.costmap import Costmap -from dimos.types.path import Path -from dimos.types.vector import Vector, VectorLike, to_tuple -from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import distance_angle_to_goal_xy, normalize_angle - -logger = setup_logger("dimos.robot.unitree.local_planner", level=logging.DEBUG) - - -class BaseLocalPlanner(Module, ABC): - """ - Abstract base class for local planners that handle obstacle avoidance and path following. - - This class defines the common interface and shared functionality that all local planners - must implement, regardless of the specific algorithm used. - - Args: - get_costmap: Function to get the latest local costmap - get_robot_pose: Function to get the latest robot pose (returning odom object) - move: Function to send velocity commands - safety_threshold: Distance to maintain from obstacles (meters) - max_linear_vel: Maximum linear velocity (m/s) - max_angular_vel: Maximum angular velocity (rad/s) - lookahead_distance: Lookahead distance for path following (meters) - goal_tolerance: Distance at which the goal is considered reached (meters) - angle_tolerance: Angle at which the goal orientation is considered reached (radians) - robot_width: Width of the robot for visualization (meters) - robot_length: Length of the robot for visualization (meters) - visualization_size: Size of the visualization image in pixels - control_frequency: Frequency at which the planner is called (Hz) - safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) - max_recovery_attempts: Maximum number of recovery attempts before failing navigation. - If the robot gets stuck and cannot recover within this many attempts, navigation will fail. - global_planner_plan: Optional callable to plan a global path to the goal. - If provided, this will be used to generate a path to the goal before local planning. - """ - - odom: In[PoseStamped] = None - movecmd: Out[Vector3] = None - latest_odom: PoseStamped = None - - def __init__( - self, - get_costmap: Callable[[], Optional[Costmap]], - safety_threshold: float = 0.5, - max_linear_vel: float = 0.8, - max_angular_vel: float = 1.0, - lookahead_distance: float = 1.0, - goal_tolerance: float = 0.75, - angle_tolerance: float = 0.5, - robot_width: float = 0.5, - robot_length: float = 0.7, - visualization_size: int = 400, - control_frequency: float = 10.0, - safe_goal_distance: float = 1.5, - max_recovery_attempts: int = 4, - global_planner_plan: Optional[Callable[[VectorLike], Optional[Any]]] = None, - ): # Control frequency in Hz - # Store callables for robot interactions - Module.__init__(self) - - self.get_costmap = get_costmap - - # Store parameters - self.safety_threshold = safety_threshold - self.max_linear_vel = max_linear_vel - self.max_angular_vel = max_angular_vel - self.lookahead_distance = lookahead_distance - self.goal_tolerance = goal_tolerance - self.angle_tolerance = angle_tolerance - self.robot_width = robot_width - self.robot_length = robot_length - self.visualization_size = visualization_size - self.control_frequency = control_frequency - self.control_period = 1.0 / control_frequency # Period in seconds - self.safe_goal_distance = safe_goal_distance # Distance to ignore obstacles at goal - self.ignore_obstacles = False # Flag for derived classes to check - self.max_recovery_attempts = max_recovery_attempts # Maximum recovery attempts - self.recovery_attempts = 0 # Current number of recovery attempts - self.global_planner_plan = global_planner_plan # Global planner function for replanning - - # Goal and Waypoint Tracking - self.goal_xy: Optional[Tuple[float, float]] = None # Current target for planning - self.goal_theta: Optional[float] = None # Goal orientation (radians) - self.waypoints: Optional[Path] = None # List of waypoints to follow - self.waypoints_in_absolute: Optional[Path] = None # Full path in absolute frame - self.waypoint_is_relative: bool = False # Whether waypoints are in relative frame - self.current_waypoint_index: int = 0 # Index of the next waypoint to reach - self.final_goal_reached: bool = False # Flag indicating if the final waypoint is reached - self.position_reached: bool = False # Flag indicating if position goal is reached - - # Stuck detection - self.stuck_detection_window_seconds = 4.0 # Time window for stuck detection (seconds) - self.position_history_size = int(self.stuck_detection_window_seconds * control_frequency) - self.position_history = deque( - maxlen=self.position_history_size - ) # History of recent positions - self.stuck_distance_threshold = 0.15 # Distance threshold for stuck detection (meters) - self.unstuck_distance_threshold = ( - 0.5 # Distance threshold for unstuck detection (meters) - increased hysteresis - ) - self.stuck_time_threshold = 3.0 # Time threshold for stuck detection (seconds) - increased - self.is_recovery_active = False # Whether recovery behavior is active - self.recovery_start_time = 0.0 # When recovery behavior started - self.recovery_duration = ( - 10.0 # How long to run recovery before giving up (seconds) - increased - ) - self.last_update_time = time.time() # Last time position was updated - self.navigation_failed = False # Flag indicating if navigation should be terminated - - # Recovery improvements - self.recovery_cooldown_time = ( - 3.0 # Seconds to wait after recovery before checking stuck again - ) - self.last_recovery_end_time = 0.0 # When the last recovery ended - self.pre_recovery_position = ( - None # Position when recovery started (for better stuck detection) - ) - self.backup_duration = 4.0 # How long to backup when stuck (seconds) - - # Cached data updated periodically for consistent plan() execution time - self._robot_pose = None - self._costmap = None - self._update_frequency = 10.0 # Hz - how often to update cached data - self._update_timer = None - - async def start(self): - """Start the local planner's periodic updates and any other initialization.""" - self._start_periodic_updates() - - def setodom(odom: Odometry): - self.latest_odom = odom - - self.odom.subscribe(setodom) - # self.get_move_stream(frequency=20.0).subscribe(self.movecmd.publish) - - def _start_periodic_updates(self): - self._update_timer = threading.Thread(target=self._periodic_update, daemon=True) - self._update_timer.start() - - def _periodic_update(self): - while True: - self._robot_pose = self._format_robot_pose() - # print("robot pose", self._robot_pose) - self._costmap = self.get_costmap() - time.sleep(1.0 / self._update_frequency) - - def reset(self): - """ - Reset all navigation and state tracking variables. - Should be called whenever a new goal is set. - """ - # Reset stuck detection state - self.position_history.clear() - self.is_recovery_active = False - self.recovery_start_time = 0.0 - self.last_update_time = time.time() - - # Reset navigation state flags - self.navigation_failed = False - self.position_reached = False - self.final_goal_reached = False - self.ignore_obstacles = False - - # Reset recovery improvements - self.last_recovery_end_time = 0.0 - self.pre_recovery_position = None - - # Reset recovery attempts - self.recovery_attempts = 0 - - # Clear waypoint following state - self.waypoints = None - self.current_waypoint_index = 0 - self.goal_xy = None # Clear previous goal - self.goal_theta = None # Clear previous goal orientation - - logger.info("Local planner state has been reset") - - def _format_robot_pose(self) -> Tuple[Tuple[float, float], float]: - """ - Get the current robot position and orientation. - - Returns: - Tuple containing: - - position as (x, y) tuple - - orientation (theta) in radians - """ - if self._robot_pose is None: - return ((0.0, 0.0), 0.0) # Fallback if not yet initialized - - pos = self.latest_odom.position - euler = self.latest_odom.orientation.to_euler() - return (pos.x, pos.y), euler.z - - def _get_costmap(self): - """Get cached costmap data.""" - return self._costmap - - def clear_cache(self): - """Clear all cached data to force fresh retrieval on next access.""" - self._robot_pose = None - self._costmap = None - - def set_goal( - self, goal_xy: VectorLike, is_relative: bool = False, goal_theta: Optional[float] = None - ): - """Set a single goal position, converting to absolute frame if necessary. - This clears any existing waypoints being followed. - - Args: - goal_xy: The goal position to set. - is_relative: Whether the goal is in the robot's relative frame. - goal_theta: Optional goal orientation in radians - """ - # Reset all state variables - self.reset() - - target_goal_xy: Optional[Tuple[float, float]] = None - - # Transform goal to absolute frame if it's relative - if is_relative: - # Get current robot pose - odom = self._robot_pose - if odom is None: - logger.warning("Robot pose not yet available, cannot set relative goal") - return - robot_pos, robot_rot = odom.pos, odom.rot - - # Extract current position and orientation - robot_x, robot_y = robot_pos.x, robot_pos.y - robot_theta = robot_rot.z # Assuming rotation is euler angles - - # Transform the relative goal into absolute coordinates - goal_x, goal_y = to_tuple(goal_xy) - # Rotate - abs_x = goal_x * math.cos(robot_theta) - goal_y * math.sin(robot_theta) - abs_y = goal_x * math.sin(robot_theta) + goal_y * math.cos(robot_theta) - # Translate - target_goal_xy = (robot_x + abs_x, robot_y + abs_y) - - logger.info( - f"Goal set in relative frame, converted to absolute: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" - ) - else: - target_goal_xy = to_tuple(goal_xy) - logger.info( - f"Goal set directly in absolute frame: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" - ) - - # Check if goal is valid (in bounds and not colliding) - if not self.is_goal_in_costmap_bounds(target_goal_xy) or self.check_goal_collision( - target_goal_xy - ): - logger.warning( - "Goal is in collision or out of bounds. Adjusting goal to valid position." - ) - self.goal_xy = self.adjust_goal_to_valid_position(target_goal_xy) - else: - self.goal_xy = target_goal_xy # Set the adjusted or original valid goal - - # Set goal orientation if provided - if goal_theta is not None: - if is_relative: - # Transform the orientation to absolute frame - odom = self._robot_pose - if odom is None: - logger.warning( - "Robot pose not yet available, cannot set relative goal orientation" - ) - return - robot_theta = odom.rot.z - self.goal_theta = normalize_angle(goal_theta + robot_theta) - else: - self.goal_theta = goal_theta - - def set_goal_waypoints(self, waypoints: Path, goal_theta: Optional[float] = None): - """Sets a path of waypoints for the robot to follow. - - Args: - waypoints: A list of waypoints to follow. Each waypoint is a tuple of (x, y) coordinates in absolute frame. - goal_theta: Optional final orientation in radians - """ - # Reset all state variables - self.reset() - - if not isinstance(waypoints, Path) or len(waypoints) == 0: - logger.warning("Invalid or empty path provided to set_goal_waypoints. Ignoring.") - self.waypoints = None - self.waypoint_is_relative = False - self.goal_xy = None - self.goal_theta = None - self.current_waypoint_index = 0 - return - - logger.info(f"Setting goal waypoints with {len(waypoints)} points.") - self.waypoints = waypoints - self.waypoint_is_relative = False - self.current_waypoint_index = 0 - - # Waypoints are always in absolute frame - self.waypoints_in_absolute = waypoints - - # Set the initial target to the first waypoint, adjusting if necessary - first_waypoint = self.waypoints_in_absolute[0] - if not self.is_goal_in_costmap_bounds(first_waypoint) or self.check_goal_collision( - first_waypoint - ): - logger.warning("First waypoint is invalid. Adjusting...") - self.goal_xy = self.adjust_goal_to_valid_position(first_waypoint) - else: - self.goal_xy = to_tuple(first_waypoint) # Initial target - - # Set goal orientation if provided - if goal_theta is not None: - self.goal_theta = goal_theta - - def _get_final_goal_position(self) -> Optional[Tuple[float, float]]: - """ - Get the final goal position (either last waypoint or direct goal). - - Returns: - Tuple (x, y) of the final goal, or None if no goal is set - """ - if self.waypoints_in_absolute is not None and len(self.waypoints_in_absolute) > 0: - return to_tuple(self.waypoints_in_absolute[-1]) - elif self.goal_xy is not None: - return self.goal_xy - return None - - def _distance_to_position(self, target_position: Tuple[float, float]) -> float: - """ - Calculate distance from the robot to a target position. - - Args: - target_position: Target (x, y) position - - Returns: - Distance in meters - """ - robot_pos, _ = self._format_robot_pose() - return np.linalg.norm( - [target_position[0] - robot_pos[0], target_position[1] - robot_pos[1]] - ) - - def plan(self) -> Dict[str, float]: - """ - Main planning method that computes velocity commands. - This includes common planning logic like waypoint following, - with algorithm-specific calculations delegated to subclasses. - - Returns: - Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys - """ - # If goal orientation is specified, rotate to match it - if ( - self.position_reached - and self.goal_theta is not None - and not self._is_goal_orientation_reached() - ): - return self._rotate_to_goal_orientation() - elif self.position_reached and self.goal_theta is None: - self.final_goal_reached = True - logger.info("Position goal reached. Stopping.") - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Check if the robot is stuck and handle accordingly - if self.check_if_stuck() and not self.position_reached: - # Check if we're stuck but close to our goal - final_goal_pos = self._get_final_goal_position() - - # If we have a goal position, check distance to it - if final_goal_pos is not None: - distance_to_goal = self._distance_to_position(final_goal_pos) - - # If we're stuck but within 2x safe_goal_distance of the goal, consider it a success - if distance_to_goal < 2.0 * self.safe_goal_distance: - logger.info( - f"Robot is stuck but within {distance_to_goal:.2f}m of goal (< {2.0 * self.safe_goal_distance:.2f}m). Considering navigation successful." - ) - self.position_reached = True - return {"x_vel": 0.0, "angular_vel": 0.0} - - if self.navigation_failed: - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Otherwise, execute normal recovery behavior - logger.warning("Robot is stuck - executing recovery behavior") - return self.execute_recovery_behavior() - - # Reset obstacle ignore flag - self.ignore_obstacles = False - - # --- Waypoint Following Mode --- - if self.waypoints is not None: - if self.final_goal_reached: - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Get current robot pose - robot_pos, robot_theta = self._format_robot_pose() - robot_pos_np = np.array(robot_pos) - - # Check if close to final waypoint - if self.waypoints_in_absolute is not None and len(self.waypoints_in_absolute) > 0: - final_waypoint = self.waypoints_in_absolute[-1] - dist_to_final = np.linalg.norm(robot_pos_np - final_waypoint) - - # If we're close to the final waypoint, adjust it and ignore obstacles - if dist_to_final < self.safe_goal_distance: - final_wp_tuple = to_tuple(final_waypoint) - adjusted_goal = self.adjust_goal_to_valid_position(final_wp_tuple) - # Create a new Path with the adjusted final waypoint - new_waypoints = self.waypoints_in_absolute[:-1] # Get all but the last waypoint - new_waypoints.append(adjusted_goal) # Append the adjusted goal - self.waypoints_in_absolute = new_waypoints - self.ignore_obstacles = True - - # Update the target goal based on waypoint progression - just_reached_final = self._update_waypoint_target(robot_pos_np) - - # If the helper indicates the final goal was just reached, stop immediately - if just_reached_final: - return {"x_vel": 0.0, "angular_vel": 0.0} - - # --- Single Goal or Current Waypoint Target Set --- - if self.goal_xy is None: - # If no goal is set (e.g., empty path or rejected goal), stop. - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Get necessary data for planning - costmap = self._get_costmap() - if costmap is None: - logger.warning("Local costmap is None. Cannot plan.") - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Check if close to single goal mode goal - if self.waypoints is None: - # Get distance to goal - goal_distance = self._distance_to_position(self.goal_xy) - - # If within safe distance of goal, adjust it and ignore obstacles - if goal_distance < self.safe_goal_distance: - self.goal_xy = self.adjust_goal_to_valid_position(self.goal_xy) - self.ignore_obstacles = True - - # First check position - if goal_distance < self.goal_tolerance or self.position_reached: - self.position_reached = True - - else: - self.position_reached = False - - # Call the algorithm-specific planning implementation - return self._compute_velocity_commands() - - @abstractmethod - def _compute_velocity_commands(self) -> Dict[str, float]: - """ - Algorithm-specific method to compute velocity commands. - Must be implemented by derived classes. - - Returns: - Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys - """ - pass - - def _rotate_to_goal_orientation(self) -> Dict[str, float]: - """Compute velocity commands to rotate to the goal orientation. - - Returns: - Dict[str, float]: Velocity commands with zero linear velocity - """ - # Get current robot orientation - _, robot_theta = self._format_robot_pose() - - # Calculate the angle difference - angle_diff = normalize_angle(self.goal_theta - robot_theta) - - # Determine rotation direction and speed - if abs(angle_diff) < self.angle_tolerance: - # Already at correct orientation - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Calculate rotation speed - proportional to the angle difference - # but capped at max_angular_vel - direction = 1.0 if angle_diff > 0 else -1.0 - angular_vel = direction * min(abs(angle_diff), self.max_angular_vel) - - return {"x_vel": 0.0, "angular_vel": angular_vel} - - def _is_goal_orientation_reached(self) -> bool: - """Check if the current robot orientation matches the goal orientation. - - Returns: - bool: True if orientation is reached or no orientation goal is set - """ - if self.goal_theta is None: - return True # No orientation goal set - - # Get current robot orientation - _, robot_theta = self._format_robot_pose() - - # Calculate the angle difference and normalize - angle_diff = abs(normalize_angle(self.goal_theta - robot_theta)) - - return angle_diff <= self.angle_tolerance - - def _update_waypoint_target(self, robot_pos_np: np.ndarray) -> bool: - """Helper function to manage waypoint progression and update the target goal. - - Args: - robot_pos_np: Current robot position as a numpy array [x, y]. - - Returns: - bool: True if the final waypoint has just been reached, False otherwise. - """ - if self.waypoints is None or len(self.waypoints) == 0: - return False # Not in waypoint mode or empty path - - # Waypoints are always in absolute frame - self.waypoints_in_absolute = self.waypoints - - # Check if final goal is reached - final_waypoint = self.waypoints_in_absolute[-1] - dist_to_final = np.linalg.norm(robot_pos_np - final_waypoint) - - if dist_to_final <= self.goal_tolerance: - # Final waypoint position reached - if self.goal_theta is not None: - # Check orientation if specified - if self._is_goal_orientation_reached(): - self.final_goal_reached = True - return True - # Continue rotating - self.position_reached = True - return False - else: - # No orientation goal, mark as reached - self.final_goal_reached = True - return True - - # Always find the lookahead point - lookahead_point = None - for i in range(self.current_waypoint_index, len(self.waypoints_in_absolute)): - wp = self.waypoints_in_absolute[i] - dist_to_wp = np.linalg.norm(robot_pos_np - wp) - if dist_to_wp >= self.lookahead_distance: - lookahead_point = wp - # Update current waypoint index to this point - self.current_waypoint_index = i - break - - # If no point is far enough, target the final waypoint - if lookahead_point is None: - lookahead_point = self.waypoints_in_absolute[-1] - self.current_waypoint_index = len(self.waypoints_in_absolute) - 1 - - # Set the lookahead point as the immediate target, adjusting if needed - if not self.is_goal_in_costmap_bounds(lookahead_point) or self.check_goal_collision( - lookahead_point - ): - adjusted_lookahead = self.adjust_goal_to_valid_position(lookahead_point) - # Only update if adjustment didn't fail completely - if adjusted_lookahead is not None: - self.goal_xy = adjusted_lookahead - else: - self.goal_xy = to_tuple(lookahead_point) - - return False # Final goal not reached in this update cycle - - @abstractmethod - def update_visualization(self) -> np.ndarray: - """ - Generate visualization of the planning state. - Must be implemented by derived classes. - - Returns: - np.ndarray: Visualization image as numpy array - """ - pass - - def create_stream(self, frequency_hz: float = None) -> Observable: - """ - Create an Observable stream that emits the visualization image at a fixed frequency. - - Args: - frequency_hz: Optional frequency override (defaults to 1/4 of control_frequency if None) - - Returns: - Observable: Stream of visualization frames - """ - # Default to 1/4 of control frequency if not specified (to reduce CPU usage) - if frequency_hz is None: - frequency_hz = self.control_frequency / 4.0 - - subject = Subject() - sleep_time = 1.0 / frequency_hz - - def frame_emitter(): - while True: - try: - # Generate the frame using the updated method - frame = self.update_visualization() - subject.on_next(frame) - except Exception as e: - logger.error(f"Error in frame emitter thread: {e}") - # Optionally, emit an error frame or simply skip - # subject.on_error(e) # This would terminate the stream - time.sleep(sleep_time) - - emitter_thread = threading.Thread(target=frame_emitter, daemon=True) - emitter_thread.start() - logger.info(f"Started visualization frame emitter thread at {frequency_hz:.1f} Hz") - return subject - - @abstractmethod - def check_collision(self, direction: float) -> bool: - """ - Check if there's a collision in the given direction. - Must be implemented by derived classes. - - Args: - direction: Direction to check for collision in radians - - Returns: - bool: True if collision detected, False otherwise - """ - pass - - def is_goal_reached(self) -> bool: - """Check if the final goal (single or last waypoint) is reached, including orientation.""" - if self.waypoints is not None: - # Waypoint mode: check if the final waypoint and orientation have been reached - return self.final_goal_reached and self._is_goal_orientation_reached() - else: - # Single goal mode: check distance to the single goal and orientation - if self.goal_xy is None: - return False # No goal set - - if self.goal_theta is None: - return self.position_reached - - return self.position_reached and self._is_goal_orientation_reached() - - def check_goal_collision(self, goal_xy: VectorLike) -> bool: - """Check if the current goal is in collision with obstacles in the costmap. - - Returns: - bool: True if goal is in collision, False if goal is safe or cannot be checked - """ - - costmap = self._get_costmap() - if costmap is None: - logger.warning("Cannot check collision: No costmap available") - return False - - # Check if the position is occupied - collision_threshold = 80 # Consider values above 80 as obstacles - - # Use Costmap's is_occupied method - return costmap.is_occupied(goal_xy, threshold=collision_threshold) - - def is_goal_in_costmap_bounds(self, goal_xy: VectorLike) -> bool: - """Check if the goal position is within the bounds of the costmap. - - Args: - goal_xy: Goal position (x, y) in odom frame - - Returns: - bool: True if the goal is within the costmap bounds, False otherwise - """ - costmap = self._get_costmap() - if costmap is None: - logger.warning("Cannot check bounds: No costmap available") - return False - - # Get goal position in grid coordinates - goal_point = costmap.world_to_grid(goal_xy) - goal_cell_x, goal_cell_y = goal_point.x, goal_point.y - - # Check if goal is within the costmap bounds - is_in_bounds = 0 <= goal_cell_x < costmap.width and 0 <= goal_cell_y < costmap.height - - if not is_in_bounds: - logger.warning(f"Goal ({goal_xy[0]:.2f}, {goal_xy[1]:.2f}) is outside costmap bounds") - - return is_in_bounds - - def adjust_goal_to_valid_position( - self, goal_xy: VectorLike, clearance: float = 0.5 - ) -> Tuple[float, float]: - """Find a valid (non-colliding) goal position by moving it towards the robot. - - Args: - goal_xy: Original goal position (x, y) in odom frame - clearance: Additional distance to move back from obstacles for better clearance (meters) - - Returns: - Tuple[float, float]: A valid goal position, or the original goal if already valid - """ - [pos, rot] = self._format_robot_pose() - - robot_x, robot_y = pos[0], pos[1] - - # Original goal - goal_x, goal_y = to_tuple(goal_xy) - - if not self.check_goal_collision((goal_x, goal_y)): - return (goal_x, goal_y) - - # Calculate vector from goal to robot - dx = robot_x - goal_x - dy = robot_y - goal_y - distance = np.sqrt(dx * dx + dy * dy) - - if distance < 0.001: # Goal is at robot position - return to_tuple(goal_xy) - - # Normalize direction vector - dx /= distance - dy /= distance - - # Step size - step_size = 0.25 # meters - - # Move goal towards robot step by step - current_x, current_y = goal_x, goal_y - steps = 0 - max_steps = 50 # Safety limit - - # Variables to store the first valid position found - valid_found = False - valid_x, valid_y = None, None - - while steps < max_steps: - # Move towards robot - current_x += dx * step_size - current_y += dy * step_size - steps += 1 - - # Check if we've reached or passed the robot - new_distance = np.sqrt((current_x - robot_x) ** 2 + (current_y - robot_y) ** 2) - if new_distance < step_size: - # We've reached the robot without finding a valid point - # Move back one step from robot to avoid self-collision - current_x = robot_x - dx * step_size - current_y = robot_y - dy * step_size - break - - # Check if this position is valid - if not self.check_goal_collision( - (current_x, current_y) - ) and self.is_goal_in_costmap_bounds((current_x, current_y)): - # Store the first valid position - if not valid_found: - valid_found = True - valid_x, valid_y = current_x, current_y - - # If clearance is requested, continue searching for a better position - if clearance > 0: - continue - - # Calculate position with additional clearance - if clearance > 0: - # Calculate clearance position - clearance_x = current_x + dx * clearance - clearance_y = current_y + dy * clearance - - # Check if the clearance position is also valid - if not self.check_goal_collision( - (clearance_x, clearance_y) - ) and self.is_goal_in_costmap_bounds((clearance_x, clearance_y)): - return (clearance_x, clearance_y) - - # Return the valid position without clearance - return (current_x, current_y) - - # If we found a valid position earlier but couldn't add clearance - if valid_found: - return (valid_x, valid_y) - - logger.warning( - f"Could not find valid goal after {steps} steps, using closest point to robot" - ) - return (current_x, current_y) - - def check_if_stuck(self) -> bool: - """ - Check if the robot is stuck by analyzing movement history. - Includes improvements to prevent oscillation between stuck and recovered states. - - Returns: - bool: True if the robot is determined to be stuck, False otherwise - """ - # Get current position and time - current_time = time.time() - - # Get current robot position - [pos, _] = self._format_robot_pose() - current_position = (pos[0], pos[1], current_time) - - # If we're already in recovery, don't add movements to history (they're intentional) - # Instead, check if we should continue or end recovery - if self.is_recovery_active: - # Check if we've moved far enough from our pre-recovery position to consider unstuck - if self.pre_recovery_position is not None: - pre_recovery_x, pre_recovery_y = self.pre_recovery_position[:2] - displacement_from_start = np.sqrt( - (pos[0] - pre_recovery_x) ** 2 + (pos[1] - pre_recovery_y) ** 2 - ) - - # If we've moved far enough, we're unstuck - if displacement_from_start > self.unstuck_distance_threshold: - logger.info( - f"Robot has escaped from stuck state (moved {displacement_from_start:.3f}m from start)" - ) - self.is_recovery_active = False - self.last_recovery_end_time = current_time - # Do not reset recovery attempts here - only reset during replanning or goal reaching - # Clear position history to start fresh tracking - self.position_history.clear() - return False - - # Check if we've been trying to recover for too long - recovery_time = current_time - self.recovery_start_time - if recovery_time > self.recovery_duration: - logger.error( - f"Recovery behavior has been active for {self.recovery_duration}s without success" - ) - self.navigation_failed = True - return True - - # Continue recovery - return True - - # Check cooldown period - don't immediately check for stuck after recovery - if current_time - self.last_recovery_end_time < self.recovery_cooldown_time: - # Add position to history but don't check for stuck yet - self.position_history.append(current_position) - return False - - # Add current position to history (newest is appended at the end) - self.position_history.append(current_position) - - # Need enough history to make a determination - min_history_size = int( - self.stuck_detection_window_seconds * self.control_frequency * 0.6 - ) # 60% of window - if len(self.position_history) < min_history_size: - return False - - # Find positions within our detection window - window_start_time = current_time - self.stuck_detection_window_seconds - window_positions = [] - - # Collect positions within the window (newest entries will be at the end) - for pos_x, pos_y, timestamp in self.position_history: - if timestamp >= window_start_time: - window_positions.append((pos_x, pos_y, timestamp)) - - # Need at least a few positions in the window - if len(window_positions) < 3: - return False - - # Ensure correct order: oldest to newest - window_positions.sort(key=lambda p: p[2]) - - # Get the oldest and newest positions in the window - oldest_x, oldest_y, oldest_time = window_positions[0] - newest_x, newest_y, newest_time = window_positions[-1] - - # Calculate time range in the window - time_range = newest_time - oldest_time - - # Calculate displacement from oldest to newest position - displacement = np.sqrt((newest_x - oldest_x) ** 2 + (newest_y - oldest_y) ** 2) - - # Also check average displacement over multiple sub-windows to avoid false positives - sub_window_size = max(3, len(window_positions) // 3) - avg_displacement = 0.0 - displacement_count = 0 - - for i in range(0, len(window_positions) - sub_window_size, sub_window_size // 2): - start_pos = window_positions[i] - end_pos = window_positions[min(i + sub_window_size, len(window_positions) - 1)] - sub_displacement = np.sqrt( - (end_pos[0] - start_pos[0]) ** 2 + (end_pos[1] - start_pos[1]) ** 2 - ) - avg_displacement += sub_displacement - displacement_count += 1 - - if displacement_count > 0: - avg_displacement /= displacement_count - - # Check if we're stuck - moved less than threshold over minimum time - is_currently_stuck = ( - time_range >= self.stuck_time_threshold - and time_range <= self.stuck_detection_window_seconds - and displacement < self.stuck_distance_threshold - and avg_displacement < self.stuck_distance_threshold * 1.5 - ) - - if is_currently_stuck: - logger.warning( - f"Robot appears to be stuck! Total displacement: {displacement:.3f}m, " - f"avg displacement: {avg_displacement:.3f}m over {time_range:.1f}s" - ) - - # Start recovery behavior - self.is_recovery_active = True - self.recovery_start_time = current_time - self.pre_recovery_position = current_position - - # Clear position history to avoid contamination during recovery - self.position_history.clear() - - # Increment recovery attempts - self.recovery_attempts += 1 - logger.warning( - f"Starting recovery attempt {self.recovery_attempts}/{self.max_recovery_attempts}" - ) - - # Check if maximum recovery attempts have been exceeded - if self.recovery_attempts > self.max_recovery_attempts: - logger.error( - f"Maximum recovery attempts ({self.max_recovery_attempts}) exceeded. Navigation failed." - ) - self.navigation_failed = True - - return True - - return False - - def execute_recovery_behavior(self) -> Dict[str, float]: - """ - Execute enhanced recovery behavior when the robot is stuck. - - First attempt: Backup for a set duration - - Second+ attempts: Replan to the original goal using global planner - - Returns: - Dict[str, float]: Velocity commands for the recovery behavior - """ - current_time = time.time() - recovery_time = current_time - self.recovery_start_time - - # First recovery attempt: Simple backup behavior - if self.recovery_attempts % 2 == 0: - if recovery_time < self.backup_duration: - logger.warning(f"Recovery attempt 1: backup for {recovery_time:.1f}s") - return {"x_vel": -0.5, "angular_vel": 0.0} # Backup at moderate speed - else: - logger.info("Recovery attempt 1: backup completed") - self.recovery_attempts += 1 - return {"x_vel": 0.0, "angular_vel": 0.0} - - final_goal = self.waypoints_in_absolute[-1] - logger.info( - f"Recovery attempt {self.recovery_attempts}: replanning to final waypoint {final_goal}" - ) - - new_path = self.global_planner_plan(Vector([final_goal[0], final_goal[1]])) - - if new_path is not None: - logger.info("Replanning successful. Setting new waypoints.") - attempts = self.recovery_attempts - self.set_goal_waypoints(new_path, self.goal_theta) - self.recovery_attempts = attempts - self.is_recovery_active = False - self.last_recovery_end_time = current_time - else: - logger.error("Global planner could not find a path to the goal. Recovery failed.") - self.navigation_failed = True - - return {"x_vel": 0.0, "angular_vel": 0.0} - - @rpc - def navigate_to_goal_local( - self, - goal_xy_robot: Tuple[float, float], - goal_theta: Optional[float] = None, - distance: float = 0.0, - timeout: float = 60.0, - stop_event: Optional[threading.Event] = None, - ) -> bool: - """ - Navigates the robot to a goal specified in the robot's local frame - using the local planner. - - Args: - robot: Robot instance to control - goal_xy_robot: Tuple (x, y) representing the goal position relative - to the robot's current position and orientation. - distance: Desired distance to maintain from the goal in meters. - If non-zero, the robot will stop this far away from the goal. - timeout: Maximum time (in seconds) allowed to reach the goal. - stop_event: Optional threading.Event to signal when navigation should stop - - Returns: - bool: True if the goal was reached within the timeout, False otherwise. - """ - logger.info( - f"Starting navigation to local goal {goal_xy_robot} with distance {distance}m and timeout {timeout}s." - ) - - self.reset() - - goal_x, goal_y = goal_xy_robot - - # Calculate goal orientation to face the target - if goal_theta is None: - goal_theta = np.arctan2(goal_y, goal_x) - - # If distance is non-zero, adjust the goal to stop at the desired distance - if distance > 0: - # Calculate magnitude of the goal vector - goal_distance = np.sqrt(goal_x**2 + goal_y**2) - - # Only adjust if goal is further than the desired distance - if goal_distance > distance: - goal_x, goal_y = distance_angle_to_goal_xy(goal_distance - distance, goal_theta) - - # Set the goal in the robot's frame with orientation to face the original target - self.set_goal((goal_x, goal_y), is_relative=True, goal_theta=goal_theta) - - # Get control period from robot's local planner for consistent timing - control_period = 1.0 / self.control_frequency - - start_time = time.time() - goal_reached = False - - try: - while time.time() - start_time < timeout and not (stop_event and stop_event.is_set()): - # Check if goal has been reached - if self.is_goal_reached(): - logger.info("Goal reached successfully.") - goal_reached = True - break - - # Check if navigation failed flag is set - if self.navigation_failed: - logger.error("Navigation aborted due to repeated recovery failures.") - goal_reached = False - break - - # Get planned velocity towards the goal - vel_command = self.plan() - x_vel = vel_command.get("x_vel", 0.0) - angular_vel = vel_command.get("angular_vel", 0.0) - - # Send velocity command - self.movecmd.publish(Vector3(x_vel, 0, angular_vel)) - - # Control loop frequency - use robot's control frequency - time.sleep(control_period) - - if not goal_reached: - logger.warning( - f"Navigation timed out after {timeout} seconds before reaching goal." - ) - - except KeyboardInterrupt: - logger.info("Navigation to local goal interrupted by user.") - goal_reached = False # Consider interruption as failure - except Exception as e: - logger.error(f"Error during navigation to local goal: {e}") - goal_reached = False # Consider error as failure - finally: - logger.info("Stopping robot after navigation attempt.") - self.movecmd.publish(Vector3(0, 0, 0)) # Stop the robot - - return goal_reached - - @rpc - def navigate_path_local( - self, - path: Path, - timeout: float = 120.0, - goal_theta: Optional[float] = None, - stop_event: Optional[threading.Event] = None, - ) -> bool: - """ - Navigates the robot along a path of waypoints using the waypoint following capability - of the local planner. - - Args: - robot: Robot instance to control - path: Path object containing waypoints in absolute frame - timeout: Maximum time (in seconds) allowed to follow the complete path - goal_theta: Optional final orientation in radians - stop_event: Optional threading.Event to signal when navigation should stop - - Returns: - bool: True if the entire path was successfully followed, False otherwise - """ - logger.info( - f"Starting navigation along path with {len(path)} waypoints and timeout {timeout}s." - ) - - self.reset() - print() - # Set the path in the local planner - self.set_goal_waypoints(path, goal_theta=goal_theta) - - # Get control period from robot's local planner for consistent timing - control_period = 1.0 / self.control_frequency - - start_time = time.time() - path_completed = False - - try: - while time.time() - start_time < timeout and not (stop_event and stop_event.is_set()): - # Check if the entire path has been traversed - if self.is_goal_reached(): - logger.info("Path traversed successfully.") - path_completed = True - break - - # Check if navigation failed flag is set - if self.navigation_failed: - logger.error("Navigation aborted due to repeated recovery failures.") - path_completed = False - break - - # Get planned velocity towards the current waypoint target - vel_command = self.plan() - x_vel = vel_command.get("x_vel", 0.0) - angular_vel = vel_command.get("angular_vel", 0.0) - - # Send velocity command - self.movecmd.publish(Vector3(x_vel, 0, angular_vel)) - - # Control loop frequency - use robot's control frequency - time.sleep(control_period) - - if not path_completed: - logger.warning( - f"Path following timed out after {timeout} seconds before completing the path." - ) - - except KeyboardInterrupt: - logger.info("Path navigation interrupted by user.") - path_completed = False - except Exception as e: - logger.error(f"Error during path navigation: {e}") - path_completed = False - finally: - logger.info("Stopping robot after path navigation attempt.") - self.movecmd.publish(Vector3(0, 0, 0)) # Stop the robot - - return path_completed - - -def visualize_local_planner_state( - occupancy_grid: np.ndarray, - grid_resolution: float, - grid_origin: Tuple[float, float], - robot_pose: Tuple[float, float, float], - visualization_size: int = 400, - robot_width: float = 0.5, - robot_length: float = 0.7, - map_size_meters: float = 10.0, - goal_xy: Optional[Tuple[float, float]] = None, - goal_theta: Optional[float] = None, - histogram: Optional[np.ndarray] = None, - selected_direction: Optional[float] = None, - waypoints: Optional["Path"] = None, - current_waypoint_index: Optional[int] = None, -) -> np.ndarray: - """Generate a bird's eye view visualization of the local costmap. - Optionally includes VFH histogram, selected direction, and waypoints path. - - Args: - occupancy_grid: 2D numpy array of the occupancy grid - grid_resolution: Resolution of the grid in meters/cell - grid_origin: Tuple (x, y) of the grid origin in the odom frame - robot_pose: Tuple (x, y, theta) of the robot pose in the odom frame - visualization_size: Size of the visualization image in pixels - robot_width: Width of the robot in meters - robot_length: Length of the robot in meters - map_size_meters: Size of the map to visualize in meters - goal_xy: Optional tuple (x, y) of the goal position in the odom frame - goal_theta: Optional goal orientation in radians (in odom frame) - histogram: Optional numpy array of the VFH histogram - selected_direction: Optional selected direction angle in radians - waypoints: Optional Path object containing waypoints to visualize - current_waypoint_index: Optional index of the current target waypoint - """ - - robot_x, robot_y, robot_theta = robot_pose - grid_origin_x, grid_origin_y = grid_origin - vis_size = visualization_size - scale = vis_size / map_size_meters - - vis_img = np.ones((vis_size, vis_size, 3), dtype=np.uint8) * 255 - center_x = vis_size // 2 - center_y = vis_size // 2 - - grid_height, grid_width = occupancy_grid.shape - - # Calculate robot position relative to grid origin - robot_rel_x = robot_x - grid_origin_x - robot_rel_y = robot_y - grid_origin_y - robot_cell_x = int(robot_rel_x / grid_resolution) - robot_cell_y = int(robot_rel_y / grid_resolution) - - half_size_cells = int(map_size_meters / grid_resolution / 2) - - # Draw grid cells (using standard occupancy coloring) - for y in range( - max(0, robot_cell_y - half_size_cells), min(grid_height, robot_cell_y + half_size_cells) - ): - for x in range( - max(0, robot_cell_x - half_size_cells), min(grid_width, robot_cell_x + half_size_cells) - ): - cell_rel_x_meters = (x - robot_cell_x) * grid_resolution - cell_rel_y_meters = (y - robot_cell_y) * grid_resolution - - img_x = int(center_x + cell_rel_x_meters * scale) - img_y = int(center_y - cell_rel_y_meters * scale) # Flip y-axis - - if 0 <= img_x < vis_size and 0 <= img_y < vis_size: - cell_value = occupancy_grid[y, x] - if cell_value == -1: - color = (200, 200, 200) # Unknown (Light gray) - elif cell_value == 0: - color = (255, 255, 255) # Free (White) - else: # Occupied - # Scale darkness based on occupancy value (0-100) - darkness = 255 - int(155 * (cell_value / 100)) - 100 - color = (darkness, darkness, darkness) # Shades of gray/black - - cell_size_px = max(1, int(grid_resolution * scale)) - cv2.rectangle( - vis_img, - (img_x - cell_size_px // 2, img_y - cell_size_px // 2), - (img_x + cell_size_px // 2, img_y + cell_size_px // 2), - color, - -1, - ) - - # Draw waypoints path if provided - if waypoints is not None and len(waypoints) > 0: - try: - path_points = [] - for i, waypoint in enumerate(waypoints): - # Convert waypoint from odom frame to visualization frame - wp_x, wp_y = waypoint[0], waypoint[1] - wp_rel_x = wp_x - robot_x - wp_rel_y = wp_y - robot_y - - wp_img_x = int(center_x + wp_rel_x * scale) - wp_img_y = int(center_y - wp_rel_y * scale) # Flip y-axis - - if 0 <= wp_img_x < vis_size and 0 <= wp_img_y < vis_size: - path_points.append((wp_img_x, wp_img_y)) - - # Draw each waypoint as a small circle - cv2.circle(vis_img, (wp_img_x, wp_img_y), 3, (0, 128, 0), -1) # Dark green dots - - # Highlight current target waypoint - if current_waypoint_index is not None and i == current_waypoint_index: - cv2.circle(vis_img, (wp_img_x, wp_img_y), 6, (0, 0, 255), 2) # Red circle - - # Connect waypoints with lines to show the path - if len(path_points) > 1: - for i in range(len(path_points) - 1): - cv2.line( - vis_img, path_points[i], path_points[i + 1], (0, 200, 0), 1 - ) # Green line - except Exception as e: - logger.error(f"Error drawing waypoints: {e}") - - # Draw histogram - if histogram is not None: - num_bins = len(histogram) - # Find absolute maximum value (ignoring any negative debug values) - abs_histogram = np.abs(histogram) - max_hist_value = np.max(abs_histogram) if np.max(abs_histogram) > 0 else 1.0 - hist_scale = (vis_size / 2) * 0.8 # Scale histogram lines to 80% of half the viz size - - for i in range(num_bins): - # Angle relative to robot's forward direction - angle_relative_to_robot = (i / num_bins) * 2 * math.pi - math.pi - # Angle in the visualization frame (relative to image +X axis) - vis_angle = angle_relative_to_robot + robot_theta - - # Get the value and check if it's a special debug value (negative) - hist_val = histogram[i] - is_debug_value = hist_val < 0 - - # Use absolute value for line length - normalized_val = min(1.0, abs(hist_val) / max_hist_value) - line_length = normalized_val * hist_scale - - # Calculate endpoint using the visualization angle - end_x = int(center_x + line_length * math.cos(vis_angle)) - end_y = int(center_y - line_length * math.sin(vis_angle)) # Flipped Y - - # Color based on value and whether it's a debug value - if is_debug_value: - # Use green for debug values (minimum cost bin) - color = (0, 255, 0) # Green - line_width = 2 # Thicker line for emphasis - else: - # Regular coloring for normal values (blue to red gradient based on obstacle density) - blue = max(0, 255 - int(normalized_val * 255)) - red = min(255, int(normalized_val * 255)) - color = (blue, 0, red) # BGR format: obstacles are redder, clear areas are bluer - line_width = 1 - - cv2.line(vis_img, (center_x, center_y), (end_x, end_y), color, line_width) - - # Draw robot - robot_length_px = int(robot_length * scale) - robot_width_px = int(robot_width * scale) - robot_pts = np.array( - [ - [-robot_length_px / 2, -robot_width_px / 2], - [robot_length_px / 2, -robot_width_px / 2], - [robot_length_px / 2, robot_width_px / 2], - [-robot_length_px / 2, robot_width_px / 2], - ], - dtype=np.float32, - ) - rotation_matrix = np.array( - [ - [math.cos(robot_theta), -math.sin(robot_theta)], - [math.sin(robot_theta), math.cos(robot_theta)], - ] - ) - robot_pts = np.dot(robot_pts, rotation_matrix.T) - robot_pts[:, 0] += center_x - robot_pts[:, 1] = center_y - robot_pts[:, 1] # Flip y-axis - cv2.fillPoly( - vis_img, [robot_pts.reshape((-1, 1, 2)).astype(np.int32)], (0, 0, 255) - ) # Red robot - - # Draw robot direction line - front_x = int(center_x + (robot_length_px / 2) * math.cos(robot_theta)) - front_y = int(center_y - (robot_length_px / 2) * math.sin(robot_theta)) - cv2.line(vis_img, (center_x, center_y), (front_x, front_y), (255, 0, 0), 2) # Blue line - - # Draw selected direction - if selected_direction is not None: - # selected_direction is relative to robot frame - # Angle in the visualization frame (relative to image +X axis) - vis_angle_selected = selected_direction + robot_theta - - # Make slightly longer than max histogram line - sel_dir_line_length = (vis_size / 2) * 0.9 - - sel_end_x = int(center_x + sel_dir_line_length * math.cos(vis_angle_selected)) - sel_end_y = int(center_y - sel_dir_line_length * math.sin(vis_angle_selected)) # Flipped Y - - cv2.line( - vis_img, (center_x, center_y), (sel_end_x, sel_end_y), (0, 165, 255), 2 - ) # BGR for Orange - - # Draw goal - if goal_xy is not None: - goal_x, goal_y = goal_xy - goal_rel_x_map = goal_x - robot_x - goal_rel_y_map = goal_y - robot_y - goal_img_x = int(center_x + goal_rel_x_map * scale) - goal_img_y = int(center_y - goal_rel_y_map * scale) # Flip y-axis - if 0 <= goal_img_x < vis_size and 0 <= goal_img_y < vis_size: - cv2.circle(vis_img, (goal_img_x, goal_img_y), 5, (0, 255, 0), -1) # Green circle - cv2.circle(vis_img, (goal_img_x, goal_img_y), 8, (0, 0, 0), 1) # Black outline - - # Draw goal orientation - if goal_theta is not None and goal_xy is not None: - # For waypoint mode, only draw orientation at the final waypoint - if waypoints is not None and len(waypoints) > 0: - # Use the final waypoint position - final_waypoint = waypoints[-1] - goal_x, goal_y = final_waypoint[0], final_waypoint[1] - else: - # Use the current goal position - goal_x, goal_y = goal_xy - - goal_rel_x_map = goal_x - robot_x - goal_rel_y_map = goal_y - robot_y - goal_img_x = int(center_x + goal_rel_x_map * scale) - goal_img_y = int(center_y - goal_rel_y_map * scale) # Flip y-axis - - # Calculate goal orientation vector direction in visualization frame - # goal_theta is already in odom frame, need to adjust for visualization orientation - goal_dir_length = 30 # Length of direction indicator in pixels - goal_dir_end_x = int(goal_img_x + goal_dir_length * math.cos(goal_theta)) - goal_dir_end_y = int(goal_img_y - goal_dir_length * math.sin(goal_theta)) # Flip y-axis - - # Draw goal orientation arrow - if 0 <= goal_img_x < vis_size and 0 <= goal_img_y < vis_size: - cv2.arrowedLine( - vis_img, - (goal_img_x, goal_img_y), - (goal_dir_end_x, goal_dir_end_y), - (255, 0, 255), - 4, - ) # Magenta arrow - - # Add scale bar - scale_bar_length_px = int(1.0 * scale) - scale_bar_x = vis_size - scale_bar_length_px - 10 - scale_bar_y = vis_size - 20 - cv2.line( - vis_img, - (scale_bar_x, scale_bar_y), - (scale_bar_x + scale_bar_length_px, scale_bar_y), - (0, 0, 0), - 2, - ) - cv2.putText( - vis_img, "1m", (scale_bar_x, scale_bar_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1 - ) - - # Add status info - status_text = [] - if waypoints is not None: - if current_waypoint_index is not None: - status_text.append(f"WP: {current_waypoint_index}/{len(waypoints)}") - else: - status_text.append(f"WPs: {len(waypoints)}") - - y_pos = 20 - for text in status_text: - cv2.putText(vis_img, text, (10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) - y_pos += 20 - - return vis_img diff --git a/dimos/robot/local_planner/vfh_local_planner.py b/dimos/robot/local_planner/vfh_local_planner.py deleted file mode 100644 index 5945f8bd00..0000000000 --- a/dimos/robot/local_planner/vfh_local_planner.py +++ /dev/null @@ -1,431 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from typing import Dict, Tuple, Optional, Callable, Any -import cv2 -import logging - -from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import normalize_angle - -from dimos.robot.local_planner.local_planner import BaseLocalPlanner, visualize_local_planner_state -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector, VectorLike - -logger = setup_logger("dimos.robot.unitree.vfh_local_planner", level=logging.DEBUG) - - -class VFHPurePursuitPlanner(BaseLocalPlanner): - """ - A local planner that combines Vector Field Histogram (VFH) for obstacle avoidance - with Pure Pursuit for goal tracking. - """ - - def __init__( - self, - get_costmap: Callable[[], Optional[Costmap]], - safety_threshold: float = 0.8, - histogram_bins: int = 144, - max_linear_vel: float = 0.8, - max_angular_vel: float = 1.0, - lookahead_distance: float = 1.0, - goal_tolerance: float = 0.4, - angle_tolerance: float = 0.1, # ~5.7 degrees - robot_width: float = 0.5, - robot_length: float = 0.7, - visualization_size: int = 400, - control_frequency: float = 10.0, - safe_goal_distance: float = 1.0, - max_recovery_attempts: int = 3, - global_planner_plan: Optional[Callable[[VectorLike], Optional[Any]]] = None, - ): - """ - Initialize the VFH + Pure Pursuit planner. - - Args: - get_costmap: Function to get the latest local costmap - get_robot_pose: Function to get the latest robot pose (returning odom object) - move: Function to send velocity commands - safety_threshold: Distance to maintain from obstacles (meters) - histogram_bins: Number of directional bins in the polar histogram - max_linear_vel: Maximum linear velocity (m/s) - max_angular_vel: Maximum angular velocity (rad/s) - lookahead_distance: Lookahead distance for pure pursuit (meters) - goal_tolerance: Distance at which the goal is considered reached (meters) - angle_tolerance: Angle at which the goal orientation is considered reached (radians) - robot_width: Width of the robot for visualization (meters) - robot_length: Length of the robot for visualization (meters) - visualization_size: Size of the visualization image in pixels - control_frequency: Frequency at which the planner is called (Hz) - safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) - max_recovery_attempts: Maximum number of recovery attempts - global_planner_plan: Optional function to get the global plan - """ - # Initialize base class - super().__init__( - get_costmap=get_costmap, - safety_threshold=safety_threshold, - max_linear_vel=max_linear_vel, - max_angular_vel=max_angular_vel, - lookahead_distance=lookahead_distance, - goal_tolerance=goal_tolerance, - angle_tolerance=angle_tolerance, - robot_width=robot_width, - robot_length=robot_length, - visualization_size=visualization_size, - control_frequency=control_frequency, - safe_goal_distance=safe_goal_distance, - max_recovery_attempts=max_recovery_attempts, - global_planner_plan=global_planner_plan, - ) - - # VFH specific parameters - self.histogram_bins = histogram_bins - self.histogram = None - self.selected_direction = None - - # VFH tuning parameters - self.alpha = 0.25 # Histogram smoothing factor - self.obstacle_weight = 5.0 - self.goal_weight = 2.0 - self.prev_direction_weight = 1.0 - self.prev_selected_angle = 0.0 - self.prev_linear_vel = 0.0 - self.linear_vel_filter_factor = 0.4 - self.low_speed_nudge = 0.1 - - # Add after other initialization - self.angle_mapping = np.linspace(-np.pi, np.pi, self.histogram_bins, endpoint=False) - self.smoothing_kernel = np.array([self.alpha, (1 - 2 * self.alpha), self.alpha]) - - def _compute_velocity_commands(self) -> Dict[str, float]: - """ - VFH + Pure Pursuit specific implementation of velocity command computation. - - Returns: - Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys - """ - # Get necessary data for planning - costmap = self._get_costmap() - if costmap is None: - logger.warning("No costmap available for planning") - return {"x_vel": 0.0, "angular_vel": 0.0} - - robot_pos, robot_theta = self._format_robot_pose() - robot_x, robot_y = robot_pos - robot_pose = (robot_x, robot_y, robot_theta) - - # Calculate goal-related parameters - goal_x, goal_y = self.goal_xy - dx = goal_x - robot_x - dy = goal_y - robot_y - goal_distance = np.linalg.norm([dx, dy]) - goal_direction = np.arctan2(dy, dx) - robot_theta - goal_direction = normalize_angle(goal_direction) - - self.histogram = self.build_polar_histogram(costmap, robot_pose) - - # If we're ignoring obstacles near the goal, zero out the histogram - if self.ignore_obstacles: - self.histogram = np.zeros_like(self.histogram) - - self.selected_direction = self.select_direction( - self.goal_weight, - self.obstacle_weight, - self.prev_direction_weight, - self.histogram, - goal_direction, - ) - - # Calculate Pure Pursuit Velocities - linear_vel, angular_vel = self.compute_pure_pursuit(goal_distance, self.selected_direction) - - # Slow down when turning sharply - if abs(self.selected_direction) > 0.25: # ~15 degrees - # Scale from 1.0 (small turn) to 0.5 (sharp turn at 90 degrees or more) - turn_factor = max(0.25, 1.0 - (abs(self.selected_direction) / (np.pi / 2))) - linear_vel *= turn_factor - - # Apply Collision Avoidance Stop - skip if ignoring obstacles - if not self.ignore_obstacles and self.check_collision( - self.selected_direction, safety_threshold=0.5 - ): - # Re-select direction prioritizing obstacle avoidance if colliding - self.selected_direction = self.select_direction( - self.goal_weight * 0.2, - self.obstacle_weight, - self.prev_direction_weight * 0.2, - self.histogram, - goal_direction, - ) - linear_vel, angular_vel = self.compute_pure_pursuit( - goal_distance, self.selected_direction - ) - - if self.check_collision(0.0, safety_threshold=self.safety_threshold): - linear_vel = 0.0 - - self.prev_linear_vel = linear_vel - filtered_linear_vel = self.prev_linear_vel * self.linear_vel_filter_factor + linear_vel * ( - 1 - self.linear_vel_filter_factor - ) - - return {"x_vel": filtered_linear_vel, "angular_vel": angular_vel} - - def _smooth_histogram(self, histogram: np.ndarray) -> np.ndarray: - """ - Apply advanced smoothing to the polar histogram to better identify valleys - and reduce noise. - - Args: - histogram: Raw histogram to smooth - - Returns: - np.ndarray: Smoothed histogram - """ - # Apply a windowed average with variable width based on obstacle density - smoothed = np.zeros_like(histogram) - bins = len(histogram) - - # First pass: basic smoothing with a 5-point kernel - # This uses a wider window than the original 3-point smoother - for i in range(bins): - # Compute indices with wrap-around - indices = [(i + j) % bins for j in range(-2, 3)] - # Apply weighted average (more weight to the center) - weights = [0.1, 0.2, 0.4, 0.2, 0.1] # Sum = 1.0 - smoothed[i] = sum(histogram[idx] * weight for idx, weight in zip(indices, weights)) - - # Second pass: peak and valley enhancement - enhanced = np.zeros_like(smoothed) - for i in range(bins): - # Check neighboring values - prev_idx = (i - 1) % bins - next_idx = (i + 1) % bins - - # Enhance valleys (low values) - if smoothed[i] < smoothed[prev_idx] and smoothed[i] < smoothed[next_idx]: - # It's a local minimum - make it even lower - enhanced[i] = smoothed[i] * 0.8 - # Enhance peaks (high values) - elif smoothed[i] > smoothed[prev_idx] and smoothed[i] > smoothed[next_idx]: - # It's a local maximum - make it even higher - enhanced[i] = min(1.0, smoothed[i] * 1.2) - else: - enhanced[i] = smoothed[i] - - return enhanced - - def build_polar_histogram(self, costmap: Costmap, robot_pose: Tuple[float, float, float]): - """ - Build a polar histogram of obstacle densities around the robot. - - Args: - costmap: Costmap object with grid and metadata - robot_pose: Tuple (x, y, theta) of the robot pose in the odom frame - - Returns: - np.ndarray: Polar histogram of obstacle densities - """ - - # Get grid and find all obstacle cells - occupancy_grid = costmap.grid - y_indices, x_indices = np.where(occupancy_grid > 0) - if len(y_indices) == 0: # No obstacles - return np.zeros(self.histogram_bins) - - # Get robot position in grid coordinates - robot_x, robot_y, robot_theta = robot_pose - robot_point = costmap.world_to_grid((robot_x, robot_y)) - robot_cell_x, robot_cell_y = robot_point.x, robot_point.y - - # Vectorized distance and angle calculation - dx_cells = x_indices - robot_cell_x - dy_cells = y_indices - robot_cell_y - distances = np.sqrt(dx_cells**2 + dy_cells**2) * costmap.resolution - angles_grid = np.arctan2(dy_cells, dx_cells) - angles_robot = normalize_angle(angles_grid - robot_theta) - - # Convert to bin indices - bin_indices = ((angles_robot + np.pi) / (2 * np.pi) * self.histogram_bins).astype( - int - ) % self.histogram_bins - - # Get obstacle values - obstacle_values = occupancy_grid[y_indices, x_indices] / 100.0 - - # Build histogram - histogram = np.zeros(self.histogram_bins) - mask = distances > 0 - # Weight obstacles by inverse square of distance and cell value - np.add.at(histogram, bin_indices[mask], obstacle_values[mask] / (distances[mask] ** 2)) - - # Apply the enhanced smoothing - return self._smooth_histogram(histogram) - - def select_direction( - self, goal_weight, obstacle_weight, prev_direction_weight, histogram, goal_direction - ): - """ - Select best direction based on a simple weighted cost function. - - Args: - goal_weight: Weight for the goal direction component - obstacle_weight: Weight for the obstacle avoidance component - prev_direction_weight: Weight for previous direction consistency - histogram: Polar histogram of obstacle density - goal_direction: Desired direction to goal - - Returns: - float: Selected direction in radians - """ - # Normalize histogram if needed - if np.max(histogram) > 0: - histogram = histogram / np.max(histogram) - - # Calculate costs for each possible direction - angle_diffs = np.abs(normalize_angle(self.angle_mapping - goal_direction)) - prev_diffs = np.abs(normalize_angle(self.angle_mapping - self.prev_selected_angle)) - - # Combine costs with weights - obstacle_costs = obstacle_weight * histogram - goal_costs = goal_weight * angle_diffs - prev_costs = prev_direction_weight * prev_diffs - - total_costs = obstacle_costs + goal_costs + prev_costs - - # Select direction with lowest cost - min_cost_idx = np.argmin(total_costs) - selected_angle = self.angle_mapping[min_cost_idx] - - # Update history for next iteration - self.prev_selected_angle = selected_angle - - return selected_angle - - def compute_pure_pursuit( - self, goal_distance: float, goal_direction: float - ) -> Tuple[float, float]: - """Compute pure pursuit velocities.""" - if goal_distance < self.goal_tolerance: - return 0.0, 0.0 - - lookahead = min(self.lookahead_distance, goal_distance) - linear_vel = min(self.max_linear_vel, goal_distance) - angular_vel = 2.0 * np.sin(goal_direction) / lookahead - angular_vel = max(-self.max_angular_vel, min(angular_vel, self.max_angular_vel)) - - return linear_vel, angular_vel - - def check_collision(self, selected_direction: float, safety_threshold: float = 1.0) -> bool: - """Check if there's an obstacle in the selected direction within safety threshold.""" - # Skip collision check if ignoring obstacles - if self.ignore_obstacles: - return False - - # Get the latest costmap and robot pose - costmap = self._get_costmap() - if costmap is None: - return False # No costmap available - - robot_pos, robot_theta = self._format_robot_pose() - robot_x, robot_y = robot_pos - - # Direction in world frame - direction_world = robot_theta + selected_direction - - # Safety distance in cells - safety_cells = int(safety_threshold / costmap.resolution) - - # Get robot position in grid coordinates - robot_point = costmap.world_to_grid((robot_x, robot_y)) - robot_cell_x, robot_cell_y = robot_point.x, robot_point.y - - # Check for obstacles along the selected direction - for dist in range(1, safety_cells + 1): - # Calculate cell position - cell_x = robot_cell_x + int(dist * np.cos(direction_world)) - cell_y = robot_cell_y + int(dist * np.sin(direction_world)) - - # Check if cell is within grid bounds - if not (0 <= cell_x < costmap.width and 0 <= cell_y < costmap.height): - continue - - # Check if cell contains an obstacle (threshold at 50) - if costmap.grid[int(cell_y), int(cell_x)] > 50: - return True - - return False # No collision detected - - def update_visualization(self) -> np.ndarray: - """Generate visualization of the planning state.""" - try: - costmap = self._get_costmap() - if costmap is None: - raise ValueError("Costmap is None") - - robot_pos, robot_theta = self._format_robot_pose() - robot_x, robot_y = robot_pos - robot_pose = (robot_x, robot_y, robot_theta) - - goal_xy = self.goal_xy # This could be a lookahead point or final goal - - # Get the latest histogram and selected direction, if available - histogram = getattr(self, "histogram", None) - selected_direction = getattr(self, "selected_direction", None) - - # Get waypoint data if in waypoint mode - waypoints_to_draw = self.waypoints_in_absolute - current_wp_index_to_draw = ( - self.current_waypoint_index if self.waypoints_in_absolute is not None else None - ) - # Ensure index is valid before passing - if waypoints_to_draw is not None and current_wp_index_to_draw is not None: - if not (0 <= current_wp_index_to_draw < len(waypoints_to_draw)): - current_wp_index_to_draw = None # Invalidate index if out of bounds - - return visualize_local_planner_state( - occupancy_grid=costmap.grid, - grid_resolution=costmap.resolution, - grid_origin=(costmap.origin.x, costmap.origin.y), - robot_pose=robot_pose, - goal_xy=goal_xy, # Current target (lookahead or final) - goal_theta=self.goal_theta, # Pass goal orientation if available - visualization_size=self.visualization_size, - robot_width=self.robot_width, - robot_length=self.robot_length, - histogram=histogram, - selected_direction=selected_direction, - waypoints=waypoints_to_draw, # Pass the full path - current_waypoint_index=current_wp_index_to_draw, # Pass the target index - ) - except Exception as e: - logger.error(f"Error during visualization update: {e}") - # Return a blank image with error text - blank = ( - np.ones((self.visualization_size, self.visualization_size, 3), dtype=np.uint8) * 255 - ) - cv2.putText( - blank, - "Viz Error", - (self.visualization_size // 4, self.visualization_size // 2), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 0, 0), - 2, - ) - return blank diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py deleted file mode 100644 index 58526b5f0c..0000000000 --- a/dimos/robot/robot.py +++ /dev/null @@ -1,435 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base module for all DIMOS robots. - -This module provides the foundation for all DIMOS robots, including both physical -and simulated implementations, with common functionality for movement, control, -and video streaming. -""" - -from abc import ABC, abstractmethod -import os -from typing import Optional, List, Union, Dict, Any - -from dimos.hardware.interface import HardwareInterface -from dimos.perception.spatial_perception import SpatialMemory -from dimos.manipulation.manipulation_interface import ManipulationInterface -from dimos.types.robot_capabilities import RobotCapability -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger -from dimos.robot.connection_interface import ConnectionInterface - -from dimos.skills.skills import SkillLibrary -from reactivex import Observable, operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler - -from dimos.utils.threadpool import get_scheduler -from dimos.utils.reactive import backpressure -from dimos.stream.video_provider import VideoProvider - -logger = setup_logger("dimos.robot.robot") - - -class Robot(ABC): - """Base class for all DIMOS robots. - - This abstract base class defines the common interface and functionality for all - DIMOS robots, whether physical or simulated. It provides methods for movement, - rotation, video streaming, and hardware configuration management. - - Attributes: - agent_config: Configuration for the robot's agent. - hardware_interface: Interface to the robot's hardware components. - ros_control: ROS-based control system for the robot. - output_dir: Directory for storing output files. - disposables: Collection of disposable resources for cleanup. - pool_scheduler: Thread pool scheduler for managing concurrent operations. - """ - - def __init__( - self, - hardware_interface: HardwareInterface = None, - connection_interface: ConnectionInterface = None, - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - pool_scheduler: ThreadPoolScheduler = None, - skill_library: SkillLibrary = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = False, - capabilities: List[RobotCapability] = None, - video_stream: Optional[Observable] = None, - enable_perception: bool = True, - ): - """Initialize a Robot instance. - - Args: - hardware_interface: Interface to the robot's hardware. Defaults to None. - connection_interface: Connection interface for robot control and communication. - output_dir: Directory for storing output files. Defaults to "./assets/output". - pool_scheduler: Thread pool scheduler. If None, one will be created. - skill_library: Skill library instance. If None, one will be created. - spatial_memory_collection: Name of the collection in the ChromaDB database. - new_memory: If True, creates a new spatial memory from scratch. Defaults to False. - capabilities: List of robot capabilities. Defaults to None. - video_stream: Optional video stream. Defaults to None. - enable_perception: If True, enables perception streams and spatial memory. Defaults to True. - """ - self.hardware_interface = hardware_interface - self.connection_interface = connection_interface - self.output_dir = output_dir - self.disposables = CompositeDisposable() - self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() - self.skill_library = skill_library if skill_library else SkillLibrary() - self.enable_perception = enable_perception - - # Initialize robot capabilities - self.capabilities = capabilities or [] - - # Create output directory if it doesn't exist - os.makedirs(self.output_dir, exist_ok=True) - logger.info(f"Robot outputs will be saved to: {self.output_dir}") - - # Initialize memory properties - self.memory_dir = os.path.join(self.output_dir, "memory") - os.makedirs(self.memory_dir, exist_ok=True) - - # Initialize spatial memory properties - self.spatial_memory_dir = os.path.join(self.memory_dir, "spatial_memory") - self.spatial_memory_collection = spatial_memory_collection - self.db_path = os.path.join(self.spatial_memory_dir, "chromadb_data") - self.visual_memory_path = os.path.join(self.spatial_memory_dir, "visual_memory.pkl") - - # Create spatial memory directory - os.makedirs(self.spatial_memory_dir, exist_ok=True) - os.makedirs(self.db_path, exist_ok=True) - - # Initialize spatial memory properties - self._video_stream = video_stream - - # Only create video stream if connection interface is available - if self.connection_interface is not None: - # Get video stream - always create this, regardless of enable_perception - self._video_stream = self.get_video_stream(fps=10) # Lower FPS for processing - - # Create SpatialMemory instance only if perception is enabled - if self.enable_perception: - self._spatial_memory = SpatialMemory( - collection_name=self.spatial_memory_collection, - db_path=self.db_path, - visual_memory_path=self.visual_memory_path, - new_memory=new_memory, - output_dir=self.spatial_memory_dir, - video_stream=self._video_stream, - get_pose=self.get_pose, - ) - logger.info("Spatial memory initialized") - else: - self._spatial_memory = None - logger.info("Spatial memory disabled (enable_perception=False)") - - # Initialize manipulation interface if the robot has manipulation capability - self._manipulation_interface = None - if RobotCapability.MANIPULATION in self.capabilities: - # Initialize manipulation memory properties if the robot has manipulation capability - self.manipulation_memory_dir = os.path.join(self.memory_dir, "manipulation_memory") - - # Create manipulation memory directory - os.makedirs(self.manipulation_memory_dir, exist_ok=True) - - self._manipulation_interface = ManipulationInterface( - output_dir=self.output_dir, # Use the main output directory - new_memory=new_memory, - ) - logger.info("Manipulation interface initialized") - - def get_video_stream(self, fps: int = 30) -> Observable: - """Get the video stream with rate limiting and frame processing. - - Args: - fps: Frames per second for the video stream. Defaults to 30. - - Returns: - Observable: An observable stream of video frames. - - Raises: - RuntimeError: If no connection interface is available for video streaming. - """ - if self.connection_interface is None: - raise RuntimeError("No connection interface available for video streaming") - - stream = self.connection_interface.get_video_stream(fps) - if stream is None: - raise RuntimeError("No video stream available from connection interface") - - return stream.pipe( - ops.observe_on(self.pool_scheduler), - ) - - def move(self, velocity: Vector, duration: float = 0.0) -> bool: - """Move the robot using velocity commands. - - Args: - velocity: Velocity vector [x, y, yaw] where: - x: Linear velocity in x direction (m/s) - y: Linear velocity in y direction (m/s) - yaw: Angular velocity (rad/s) - duration: Duration to apply command (seconds). If 0, apply once. - - Returns: - bool: True if movement succeeded. - - Raises: - RuntimeError: If no connection interface is available. - """ - if self.connection_interface is None: - raise RuntimeError("No connection interface available for movement") - - return self.connection_interface.move(velocity, duration) - - def spin(self, degrees: float, speed: float = 45.0) -> bool: - """Rotate the robot by a specified angle. - - Args: - degrees: Angle to rotate in degrees (positive for counter-clockwise, - negative for clockwise). - speed: Angular speed in degrees/second. Defaults to 45.0. - - Returns: - bool: True if rotation succeeded. - - Raises: - RuntimeError: If no connection interface is available. - """ - if self.connection_interface is None: - raise RuntimeError("No connection interface available for rotation") - - # Convert degrees to radians - import math - - angular_velocity = math.radians(speed) - duration = abs(degrees) / speed if speed > 0 else 0 - - # Set direction based on sign of degrees - if degrees < 0: - angular_velocity = -angular_velocity - - velocity = Vector(0.0, 0.0, angular_velocity) - return self.connection_interface.move(velocity, duration) - - @abstractmethod - def get_pose(self) -> dict: - """ - Get the current pose (position and rotation) of the robot. - - Returns: - Dictionary containing: - - position: Tuple[float, float, float] (x, y, z) - - rotation: Tuple[float, float, float] (roll, pitch, yaw) in radians - """ - pass - - def webrtc_req( - self, - api_id: int, - topic: str = None, - parameter: str = "", - priority: int = 0, - request_id: str = None, - data=None, - timeout: float = 1000.0, - ): - """Send a WebRTC request command to the robot. - - Args: - api_id: The API ID for the command. - topic: The API topic to publish to. Defaults to ROSControl.webrtc_api_topic. - parameter: Additional parameter data. Defaults to "". - priority: Priority of the request. Defaults to 0. - request_id: Unique identifier for the request. If None, one will be generated. - data: Additional data to include with the request. Defaults to None. - timeout: Timeout for the request in milliseconds. Defaults to 1000.0. - - Returns: - The result of the WebRTC request. - - Raises: - RuntimeError: If no connection interface with WebRTC capability is available. - """ - if self.connection_interface is None: - raise RuntimeError("No connection interface available for WebRTC commands") - - # WebRTC requests are only available on ROS control interfaces - if hasattr(self.connection_interface, "queue_webrtc_req"): - return self.connection_interface.queue_webrtc_req( - api_id=api_id, - topic=topic, - parameter=parameter, - priority=priority, - request_id=request_id, - data=data, - timeout=timeout, - ) - else: - raise RuntimeError("WebRTC requests not supported by this connection interface") - - def pose_command(self, roll: float, pitch: float, yaw: float) -> bool: - """Send a pose command to the robot. - - Args: - roll: Roll angle in radians. - pitch: Pitch angle in radians. - yaw: Yaw angle in radians. - - Returns: - bool: True if command was sent successfully. - - Raises: - RuntimeError: If no connection interface with pose command capability is available. - """ - # Pose commands are only available on ROS control interfaces - if hasattr(self.connection_interface, "pose_command"): - return self.connection_interface.pose_command(roll, pitch, yaw) - else: - raise RuntimeError("Pose commands not supported by this connection interface") - - def update_hardware_interface(self, new_hardware_interface: HardwareInterface): - """Update the hardware interface with a new configuration. - - Args: - new_hardware_interface: New hardware interface to use for the robot. - """ - self.hardware_interface = new_hardware_interface - - def get_hardware_configuration(self): - """Retrieve the current hardware configuration. - - Returns: - The current hardware configuration from the hardware interface. - - Raises: - AttributeError: If hardware_interface is None. - """ - return self.hardware_interface.get_configuration() - - def set_hardware_configuration(self, configuration): - """Set a new hardware configuration. - - Args: - configuration: The new hardware configuration to set. - - Raises: - AttributeError: If hardware_interface is None. - """ - self.hardware_interface.set_configuration(configuration) - - @property - def spatial_memory(self) -> Optional[SpatialMemory]: - """Get the robot's spatial memory. - - Returns: - SpatialMemory: The robot's spatial memory system, or None if perception is disabled. - """ - return self._spatial_memory - - @property - def manipulation_interface(self) -> Optional[ManipulationInterface]: - """Get the robot's manipulation interface. - - Returns: - ManipulationInterface: The robot's manipulation interface or None if not available. - """ - return self._manipulation_interface - - def has_capability(self, capability: RobotCapability) -> bool: - """Check if the robot has a specific capability. - - Args: - capability: The capability to check for - - Returns: - bool: True if the robot has the capability, False otherwise - """ - return capability in self.capabilities - - def get_spatial_memory(self) -> Optional[SpatialMemory]: - """Simple getter for the spatial memory instance. - (For backwards compatibility) - - Returns: - The spatial memory instance or None if not set. - """ - return self._spatial_memory if self._spatial_memory else None - - @property - def video_stream(self) -> Optional[Observable]: - """Get the robot's video stream. - - Returns: - Observable: The robot's video stream or None if not available. - """ - return self._video_stream - - def get_skills(self): - """Get the robot's skill library. - - Returns: - The robot's skill library for adding/managing skills. - """ - return self.skill_library - - def cleanup(self): - """Clean up resources used by the robot. - - This method should be called when the robot is no longer needed to - ensure proper release of resources such as ROS connections and - subscriptions. - """ - # Dispose of resources - if self.disposables: - self.disposables.dispose() - - # Clean up connection interface - if self.connection_interface: - self.connection_interface.disconnect() - - self.disposables.dispose() - - -class MockRobot(Robot): - def __init__(self): - super().__init__() - self.ros_control = None - self.hardware_interface = None - self.skill_library = SkillLibrary() - - def my_print(self): - print("Hello, world!") - - -class MockManipulationRobot(Robot): - def __init__(self, skill_library: Optional[SkillLibrary] = None): - video_provider = VideoProvider("webcam", video_source=0) # Default camera - video_stream = backpressure( - video_provider.capture_video_as_observable(realtime=True, fps=30) - ) - - super().__init__( - capabilities=[RobotCapability.MANIPULATION], - video_stream=video_stream, - skill_library=skill_library, - ) - self.camera_intrinsics = [489.33, 367.0, 320.0, 240.0] - self.ros_control = None - self.hardware_interface = None diff --git a/dimos/robot/ros_observable_topic.py b/dimos/robot/ros_observable_topic.py index 697ddff398..ef99ceadee 100644 --- a/dimos/robot/ros_observable_topic.py +++ b/dimos/robot/ros_observable_topic.py @@ -23,8 +23,8 @@ from nav_msgs import msg from dimos.utils.logging_config import setup_logger from dimos.utils.threadpool import get_scheduler -from dimos.types.costmap import Costmap from dimos.types.vector import Vector +from dimos.msgs.nav_msgs import OccupancyGrid from typing import Union, Callable, Any @@ -37,8 +37,7 @@ __all__ = ["ROSObservableTopicAbility", "QOS"] -ConversionType = Costmap -TopicType = Union[ConversionType, msg.OccupancyGrid, msg.Odometry] +TopicType = Union[OccupancyGrid, msg.OccupancyGrid, msg.Odometry] class QOS(enum.Enum): @@ -82,15 +81,15 @@ class ROSObservableTopicAbility: # └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) # def _maybe_conversion(self, msg_type: TopicType, callback) -> Callable[[TopicType], Any]: - if msg_type == Costmap: - return lambda msg: callback(Costmap.from_msg(msg)) + if msg_type == "Costmap": + return lambda msg: callback(OccupancyGrid.from_msg(msg)) # just for test, not sure if this Vector auto-instantiation is used irl if msg_type == Vector: return lambda msg: callback(Vector.from_msg(msg)) return callback def _sub_msg_type(self, msg_type): - if msg_type == Costmap: + if msg_type == "Costmap": return msg.OccupancyGrid if msg_type == Vector: diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index 6119cba860..ffead0c2c4 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -30,13 +30,12 @@ from reactivex.subject import Subject from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Pose, Transform +from dimos.msgs.geometry_msgs import Pose, Transform, Vector3 from dimos.msgs.sensor_msgs import Image from dimos.robot.connection_interface import ConnectionInterface from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.vector import Vector from dimos.utils.reactive import backpressure, callback_to_observable VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] @@ -81,7 +80,7 @@ def start_background_loop(): self.thread.start() self.connection_ready.wait() - def move(self, velocity: Vector, duration: float = 0.0) -> bool: + def move(self, velocity: Vector3, duration: float = 0.0) -> bool: """Send movement command to the robot using velocity commands. Args: @@ -290,7 +289,7 @@ def stop(self) -> bool: Returns: bool: True if stop command was sent successfully """ - return self.move(Vector(0.0, 0.0, 0.0)) + return self.move(Vector3(0.0, 0.0, 0.0)) def disconnect(self) -> None: """Disconnect from the robot and clean up resources.""" diff --git a/dimos/robot/unitree_webrtc/multiprocess/example_usage.py b/dimos/robot/unitree_webrtc/multiprocess/example_usage.py deleted file mode 100644 index 84620e15f6..0000000000 --- a/dimos/robot/unitree_webrtc/multiprocess/example_usage.py +++ /dev/null @@ -1,169 +0,0 @@ -#!/usr/bin/env python3 - - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Example usage of UnitreeGo2Light and UnitreeGo2Heavy classes.""" - -import asyncio -import os -import threading - -import reactivex as rx -import reactivex.operators as ops - -from dimos.agents.claude_agent import ClaudeAgent -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.robot.unitree_webrtc.multiprocess.unitree_go2 import UnitreeGo2Light -from dimos.robot.unitree_webrtc.multiprocess.unitree_go2_heavy import UnitreeGo2Heavy -from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import AbstractRobotSkill, SkillLibrary -from dimos.stream.audio.pipelines import stt, tts -from dimos.utils.reactive import backpressure -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.perception.object_detection_stream import ObjectDetectionStream - - -async def run_light_robot(): - """Example of running the lightweight robot without GPU modules.""" - ip = os.getenv("ROBOT_IP") - - robot = UnitreeGo2Light(ip) - - await robot.start() - - # pose = robot.get_pose() - # print(f"Robot position: {pose['position']}") - # print(f"Robot rotation: {pose['rotation']}") - - # from dimos.msgs.geometry_msgs import Vector3 - - # robot.move(Vector3(0.5, 0, 0), duration=2.0) - - # robot.explore() - - while True: - await asyncio.sleep(1) - - -async def run_heavy_robot(): - """Example of running the heavy robot with full GPU modules.""" - ip = os.getenv("ROBOT_IP") - - # Create heavy robot instance with all features - robot = UnitreeGo2Heavy(ip=ip, new_memory=True, enable_perception=True) - - # Start the robot - await robot.start() - - if robot.spatial_memory: - print("Spatial memory initialized") - - skills = robot.get_skills() - print(f"Available skills: {[skill.__class__.__name__ for skill in skills]}") - - from dimos.types.robot_capabilities import RobotCapability - - if robot.has_capability(RobotCapability.VISION): - print("Robot has vision capability") - - # Start exploration with spatial memory recording - print(robot.spatial_memory.query_by_text("kitchen")) - - # robot.frontier_explorer.explore() - - # Create a subject for agent responses - agent_response_subject = rx.subject.Subject() - agent_response_stream = agent_response_subject.pipe(ops.share()) - audio_subject = rx.subject.Subject() - - video_stream = robot.get_video_stream() # WebRTC doesn't use ROS video stream - - # Initialize ObjectDetectionStream with robot - object_detector = ObjectDetectionStream( - camera_intrinsics=robot.camera_intrinsics, - get_pose=robot.get_pose, - video_stream=video_stream, - draw_masks=True, - ) - - # Create visualization stream for web interface - viz_stream = backpressure(object_detector.get_stream()).pipe( - ops.share(), - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), - ) - - # Get tracking visualization streams if available - tracking_streams = {} - if robot.person_tracking_stream: - tracking_streams["person_tracking"] = robot.person_tracking_stream.pipe( - ops.map(lambda x: x.get("viz_frame") if x else None), - ops.filter(lambda x: x is not None), - ) - if robot.object_tracking_stream: - tracking_streams["object_tracking"] = robot.object_tracking_stream.pipe( - ops.map(lambda x: x.get("viz_frame") if x else None), - ops.filter(lambda x: x is not None), - ) - - streams = { - "unitree_video": robot.get_video_stream(), # Changed from get_ros_video_stream to get_video_stream for WebRTC - "object_detection": viz_stream, # Uncommented object detection - **tracking_streams, # Add tracking streams if available - } - text_streams = { - "agent_responses": agent_response_stream, - } - - web_interface = RobotWebInterface( - port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams - ) - - stt_node = stt() - stt_node.consume_audio(audio_subject.pipe(ops.share())) - - agent = ClaudeAgent( - dev_name="test_agent", - # input_query_stream=stt_node.emit_text(), - input_query_stream=web_interface.query_stream, - skills=robot.get_skills(), - system_query="You are a helpful robot.", - model_name="claude-3-5-haiku-latest", - thinking_budget_tokens=0, - max_output_tokens_per_request=8192, - # model_name="llama-4-scout-17b-16e-instruct", - ) - agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - - # Start web interface in a separate thread to avoid blocking - web_thread = threading.Thread(target=web_interface.run) - web_thread.daemon = True - web_thread.start() - - # Keep running - while True: - await asyncio.sleep(1) - - -if __name__ == "__main__": - use_heavy = True - - if use_heavy: - print("Running UnitreeGo2Heavy with GPU modules...") - asyncio.run(run_heavy_robot()) - else: - print("Running UnitreeGo2Light without GPU modules...") - asyncio.run(run_light_robot()) diff --git a/dimos/robot/unitree_webrtc/multiprocess/individual_node_example.py b/dimos/robot/unitree_webrtc/multiprocess/individual_node_example.py deleted file mode 100644 index 3bcd9af9ff..0000000000 --- a/dimos/robot/unitree_webrtc/multiprocess/individual_node_example.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import time - -from reactivex import operators as ops - -from dimos import core -from dimos.core import In, Module, Out -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.sensor_msgs import Image -from dimos.protocol import pubsub -from dimos.robot.global_planner import AstarPlanner -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.vector import Vector -from dimos.utils.reactive import backpressure, getter_streaming -from dimos.utils.testing import TimedSensorReplay - - -class DebugModule(Module): - target: In[Vector] = None - - def start(self): - self.target.subscribe(lambda x: print("TARGET", x)) - - -if __name__ == "__main__": - dimos = core.start(1) - debugModule = dimos.deploy(DebugModule) - debugModule.target.transport = core.LCMTransport("/clicked_point", Vector3) - debugModule.start() - time.sleep(1000) diff --git a/dimos/robot/unitree_webrtc/multiprocess/test_unitree_go2_cpu_module.py b/dimos/robot/unitree_webrtc/multiprocess/test_unitree_go2_cpu_module.py deleted file mode 100644 index 98c222d62c..0000000000 --- a/dimos/robot/unitree_webrtc/multiprocess/test_unitree_go2_cpu_module.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import threading -import time - -import pytest - -from dimos import core -from dimos.core import Module, Out, rpc -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Vector3 -from dimos.msgs.sensor_msgs import Image -from dimos.protocol import pubsub -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( - WavefrontFrontierExplorer, -) -from dimos.robot.global_planner import AstarPlanner -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.robot.unitree_webrtc.multiprocess.unitree_go2 import ConnectionModule, ControlModule -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("test_unitree_go2_cpu_module") - -pubsub.lcm.autoconf() - - -class MovementControlModule(Module): - """Simple module to send movement commands for testing.""" - - movecmd: Out[Vector3] = None - - def __init__(self): - super().__init__() - self.commands_sent = [] - - @rpc - def send_move_command(self, x: float, y: float, yaw: float): - """Send a movement command.""" - cmd = Vector3(x, y, yaw) - self.movecmd.publish(cmd) - self.commands_sent.append(cmd) - logger.info(f"Sent move command: x={x}, y={y}, yaw={yaw}") - - @rpc - def send_explore_sequence(self): - """Send a sequence of exploration commands.""" - - def send_commands(): - commands = [ - (0.5, 0.0, 0.0), - (0.0, 0.0, 0.3), - (0.5, 0.0, 0.0), - (0.0, 0.0, -0.3), - (0.3, 0.0, 0.0), - (0.0, 0.0, 0.0), - ] - - for x, y, yaw in commands: - self.send_move_command(x, y, yaw) - time.sleep(0.5) - - thread = threading.Thread(target=send_commands, daemon=True) - thread.start() - - @rpc - def get_command_count(self) -> int: - """Get number of commands sent.""" - return len(self.commands_sent) - - -@pytest.mark.module -class TestUnitreeGo2CPUModule: - @pytest.mark.asyncio - async def test_unitree_go2_connection_explore_movement(self): - """Test UnitreeGo2 modules with FakeRTC for exploration and movement without spatial memory.""" - - # Start Dask - dimos = core.start(4) - - try: - # Deploy ConnectionModule with FakeRTC (uses test data) - connection = dimos.deploy( - ConnectionModule, "127.0.0.1" - ) # IP doesn't matter for FakeRTC - - # Configure LCM transports - connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) - connection.odom.transport = core.LCMTransport("/odom", PoseStamped) - connection.video.transport = core.LCMTransport("/video", Image) - - # Deploy Map module - mapper = dimos.deploy(Map, voxel_size=0.5, global_publish_interval=2.5) - mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) - mapper.lidar.connect(connection.lidar) - - # Deploy Local planner - local_planner = dimos.deploy( - VFHPurePursuitPlanner, - get_costmap=connection.get_local_costmap, - ) - local_planner.odom.connect(connection.odom) - local_planner.movecmd.transport = core.LCMTransport("/move", Vector3) - connection.movecmd.connect(local_planner.movecmd) - - # Deploy Global planner - global_planner = dimos.deploy( - AstarPlanner, - get_costmap=mapper.costmap, - get_robot_pos=connection.get_pos, - set_local_nav=local_planner.navigate_path_local, - ) - global_planner.path.transport = core.pLCMTransport("/global_path") - - # Deploy Control module for testing - ctrl = dimos.deploy(ControlModule) - ctrl.plancmd.transport = core.LCMTransport("/global_target", Pose) - global_planner.target.connect(ctrl.plancmd) - - # Deploy movement control module - movement = dimos.deploy(MovementControlModule) - movement.movecmd.transport = core.LCMTransport("/test_move", Vector3) - - # Connect movement commands to connection module as well - connection.movecmd.connect(movement.movecmd) - - # Start all modules - mapper.start() - connection.start() - local_planner.start() - global_planner.start() - - logger.info("All modules started") - - # Wait for initialization - await asyncio.sleep(3) - - # Test get methods - odom = connection.get_odom() - assert odom is not None, "Should get odometry" - logger.info(f"Got odometry: position={odom.position}") - - pos = connection.get_pos() - assert pos is not None, "Should get position" - logger.info(f"Got position: {pos}") - - local_costmap = connection.get_local_costmap() - assert local_costmap is not None, "Should get local costmap" - logger.info(f"Got local costmap with shape: {local_costmap.grid.shape}") - - # Test mapper costmap - global_costmap = mapper.costmap() - assert global_costmap is not None, "Should get global costmap" - logger.info(f"Got global costmap with shape: {global_costmap.grid.shape}") - - # Test movement commands - movement.send_move_command(0.5, 0.0, 0.0) - await asyncio.sleep(0.5) - - movement.send_move_command(0.0, 0.0, 0.3) - await asyncio.sleep(0.5) - - movement.send_move_command(0.0, 0.0, 0.0) - await asyncio.sleep(0.5) - - # Check commands were sent - cmd_count = movement.get_command_count() - assert cmd_count == 3, f"Expected 3 commands, got {cmd_count}" - - # Test explore sequence - logger.info("Testing explore sequence") - movement.send_explore_sequence() - - # Wait for sequence to complete - await asyncio.sleep(4) - - # Verify explore commands were sent - final_count = movement.get_command_count() - assert final_count == 9, f"Expected 9 total commands, got {final_count}" - - # Test frontier exploration setup - frontier_explorer = WavefrontFrontierExplorer( - set_goal=global_planner.set_goal, - get_costmap=mapper.costmap, - get_robot_pos=connection.get_pos, - ) - logger.info("Frontier explorer created successfully") - - # Start control module to trigger planning - ctrl.start() - logger.info("Control module started - will trigger planning in 4 seconds") - - await asyncio.sleep(5) - - logger.info("All UnitreeGo2 CPU module tests passed!") - - finally: - dimos.close() - logger.info("Closed Dask cluster") - - -if __name__ == "__main__": - pytest.main(["-v", "-s", __file__]) diff --git a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py deleted file mode 100644 index 37808c6dbb..0000000000 --- a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py +++ /dev/null @@ -1,458 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import asyncio -import functools -import logging -import os -import threading -import time -import warnings -from typing import Callable, Optional - -from reactivex import Observable -from reactivex import operators as ops - -import dimos.core.colors as colors -from dimos import core -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image -from dimos.perception.spatial_perception import SpatialMemory -from dimos.protocol import pubsub -from dimos.protocol.tf import TF -from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( - WavefrontFrontierExplorer, -) -from dimos.robot.global_planner import AstarPlanner -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection, VideoMessage -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector -from dimos.utils.data import get_data -from dimos.utils.logging_config import setup_logger -from dimos.utils.reactive import getter_streaming -from dimos.utils.testing import TimedSensorReplay - -logger = setup_logger("dimos.robot.unitree_webrtc.multiprocess.unitree_go2", level=logging.INFO) - -# Suppress verbose loggers -logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) -logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) -logging.getLogger("websockets.server").setLevel(logging.ERROR) -logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) -logging.getLogger("asyncio").setLevel(logging.ERROR) -logging.getLogger("root").setLevel(logging.WARNING) - -# Suppress warnings -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") - - -# can be swapped in for UnitreeWebRTCConnection -class FakeRTC(UnitreeWebRTCConnection): - def __init__(self, *args, **kwargs): - # ensures we download msgs from lfs store - data = get_data("unitree_office_walk") - - def connect(self): ... - - def standup(self): - print("standup supressed") - - def liedown(self): - print("liedown supressed") - - @functools.cache - def lidar_stream(self): - print("lidar stream start") - lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) - return lidar_store.stream() - - @functools.cache - def odom_stream(self): - print("odom stream start") - odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) - return odom_store.stream() - - @functools.cache - def video_stream(self): - print("video stream start") - video_store = TimedSensorReplay( - "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() - ) - return video_store.stream() - - def move(self, vector: Vector): - print("move supressed", vector) - - -class ConnectionModule(FakeRTC, Module): - movecmd: In[Vector3] = None - odom: Out[Vector3] = None - lidar: Out[LidarMessage] = None - video: Out[VideoMessage] = None - ip: str - - _odom: Callable[[], Odometry] - _lidar: Callable[[], LidarMessage] - - @rpc - def move(self, vector: Vector3): - super().move(vector) - - def __init__(self, ip: str, *args, **kwargs): - self.ip = ip - self.tf = TF() - Module.__init__(self, *args, **kwargs) - - @rpc - def start(self): - # Initialize the parent WebRTC connection - super().__init__(self.ip) - self.tf = TF() - # Connect sensor streams to LCM outputs - self.lidar_stream().subscribe(self.lidar.publish) - self.odom_stream().subscribe(self.odom.publish) - self.video_stream().subscribe(self.video.publish) - self.tf_stream().subscribe(self.tf.publish) - - # Connect LCM input to robot movement commands - self.movecmd.subscribe(self.move) - - # Set up streaming getters for latest sensor data - self._odom = getter_streaming(self.odom_stream()) - self._lidar = getter_streaming(self.lidar_stream()) - - @rpc - def get_local_costmap(self) -> Costmap: - return self._lidar().costmap() - - @rpc - def get_odom(self) -> Odometry: - return self._odom() - - @rpc - def get_pos(self) -> Vector: - return self._odom().position - - -class ControlModule(Module): - plancmd: Out[Pose] = None - - @rpc - def start(self): - def plancmd(): - time.sleep(4) - print(colors.red("requesting global plan")) - self.plancmd.publish(Pose(0, 0, 0, 0, 0, 0, 1)) - - thread = threading.Thread(target=plancmd, daemon=True) - thread.start() - - -class UnitreeGo2Light: - def __init__( - self, - ip: str, - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - ): - self.output_dir = output_dir - self.ip = ip - self.dimos = None - self.connection = None - self.mapper = None - self.local_planner = None - self.global_planner = None - self.frontier_explorer = None - self.foxglove_bridge = None - self.ctrl = None - - # Spatial Memory Initialization ====================================== - # Create output directory - os.makedirs(self.output_dir, exist_ok=True) - logger.info(f"Robot outputs will be saved to: {self.output_dir}") - - # Initialize memory directories - self.memory_dir = os.path.join(self.output_dir, "memory") - os.makedirs(self.memory_dir, exist_ok=True) - - # Initialize spatial memory properties - self.spatial_memory_dir = os.path.join(self.memory_dir, "spatial_memory") - self.spatial_memory_collection = "spatial_memory" - self.db_path = os.path.join(self.spatial_memory_dir, "chromadb_data") - self.visual_memory_path = os.path.join(self.spatial_memory_dir, "visual_memory.pkl") - - # Create spatial memory directory - os.makedirs(self.spatial_memory_dir, exist_ok=True) - os.makedirs(self.db_path, exist_ok=True) - - self.spatial_memory_module = None - # ============================================================== - - async def start(self): - self.dimos = core.start(4) - - # Connection Module - Robot sensor data interface via WebRTC =================== - self.connection = self.dimos.deploy(ConnectionModule, self.ip) - - # This enables LCM transport - # Ensures system multicast, udp sizes are auto-adjusted if needed - - # Configure ConnectionModule LCM transport outputs for sensor data streams - # OUTPUT: LiDAR point cloud data to /lidar topic - self.connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) - # OUTPUT: Robot odometry/pose data to /odom topic - self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) - # OUTPUT: Camera video frames to /video topic - self.connection.video.transport = core.LCMTransport("/video", Image) - # ====================================================================== - # self.connection.tf.transport = core.LCMTransport("/tf", LidarMessage) - - # Map Module - Point cloud accumulation and costmap generation ========= - self.mapper = self.dimos.deploy(Map, voxel_size=0.5, global_publish_interval=2.5) - - # OUTPUT: Accumulated point cloud map to /global_map topic - self.mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) - - # Connect ConnectionModule OUTPUT lidar to Map INPUT lidar for point cloud accumulation - self.mapper.lidar.connect(self.connection.lidar) - # ==================================================================== - - # Local planner Module, LCM transport & connection ================ - self.local_planner = self.dimos.deploy( - VFHPurePursuitPlanner, - get_costmap=self.connection.get_local_costmap, - ) - - # Connects odometry LCM stream to BaseLocalPlanner odometry input - self.local_planner.odom.connect(self.connection.odom) - - # Configures BaseLocalPlanner movecmd output to /move LCM topic - self.local_planner.movecmd.transport = core.LCMTransport("/move", Vector3) - - # Connects connection.movecmd input to local_planner.movecmd output - self.connection.movecmd.connect(self.local_planner.movecmd) - # =================================================================== - - # Global Planner Module =============== - self.global_planner = self.dimos.deploy( - AstarPlanner, - get_costmap=self.mapper.costmap, - get_robot_pos=self.connection.get_pos, - set_local_nav=self.local_planner.navigate_path_local, - ) - - # Spatial Memory Module ====================================== - self.spatial_memory_module = self.dimos.deploy( - SpatialMemory, - collection_name=self.spatial_memory_collection, - db_path=self.db_path, - visual_memory_path=self.visual_memory_path, - output_dir=self.spatial_memory_dir, - ) - - # Connect video and odometry streams to spatial memory - self.spatial_memory_module.video.connect(self.connection.video) - self.spatial_memory_module.odom.connect(self.connection.odom) - - # Start the spatial memory module - self.spatial_memory_module.start() - - logger.info("Spatial memory module deployed and connected") - # ============================================================== - - # Configure AstarPlanner OUTPUT path: Out[Path] to /global_path LCM topic - self.global_planner.path.transport = core.pLCMTransport("/global_path") - # ====================================== - - # Global Planner Control Module =========================== - # Debug module that sends (0,0,0) goal after 4 second delay - self.ctrl = self.dimos.deploy(ControlModule) - - # Configure ControlModule OUTPUT to publish goal coordinates to /global_target - self.ctrl.plancmd.transport = core.LCMTransport("/global_target", Vector3) - - # Connect ControlModule OUTPUT to AstarPlanner INPUT - triggers A* planning when goal received - self.global_planner.target.connect(self.ctrl.plancmd) - # ========================================== - - # Visualization ============================ - self.foxglove_bridge = FoxgloveBridge() - # ========================================== - - self.frontier_explorer = WavefrontFrontierExplorer( - set_goal=self.global_planner.set_goal, - get_costmap=self.mapper.costmap, - get_robot_pos=self.connection.get_pos, - ) - - # Prints full module IO - print("\n") - for module in [ - self.connection, - self.mapper, - self.local_planner, - self.global_planner, - self.ctrl, - ]: - print(module.io(), "\n") - - # Start modules ============================= - self.mapper.start() - self.connection.start() - self.local_planner.start() - self.global_planner.start() - self.foxglove_bridge.start() - # self.ctrl.start() # DEBUG - - await asyncio.sleep(2) - print("querying system") - print(self.mapper.costmap()) - logger.info("UnitreeGo2Light initialized and started") - - def get_pose(self) -> dict: - """Get the current pose (position and rotation) of the robot. - - Returns: - Dictionary containing: - - position: Vector (x, y, z) - - rotation: Vector (roll, pitch, yaw) in radians - """ - if not self.connection: - raise RuntimeError("Connection not established. Call start() first.") - odom = self.connection.get_odom() - position = Vector(odom.x, odom.y, odom.z) - rotation = Vector(odom.roll, odom.pitch, odom.yaw) - return {"position": position, "rotation": rotation} - - def move(self, velocity: Vector, duration: float = 0.0) -> bool: - """Move the robot using velocity commands. - - Args: - velocity: Velocity vector [x, y, yaw] - duration: Duration to apply command (seconds) - - Returns: - bool: True if movement succeeded - """ - if not self.connection: - raise RuntimeError("Connection not established. Call start() first.") - self.connection.move(Vector3(velocity.x, velocity.y, velocity.z)) - if duration > 0: - time.sleep(duration) - self.connection.move(Vector3(0, 0, 0)) # Stop - return True - - def explore(self, stop_event=None) -> bool: - """Start autonomous frontier exploration. - - Args: - stop_event: Optional threading.Event to signal when exploration should stop - - Returns: - bool: True if exploration completed successfully - """ - if not self.frontier_explorer: - raise RuntimeError("Frontier explorer not initialized. Call start() first.") - return self.frontier_explorer.explore(stop_event=stop_event) - - def standup(self): - """Make the robot stand up.""" - if self.connection and hasattr(self.connection, "standup"): - return self.connection.standup() - - def liedown(self): - """Make the robot lie down.""" - if self.connection and hasattr(self.connection, "liedown"): - return self.connection.liedown() - - @property - def costmap(self): - """Access to the costmap for navigation.""" - if not self.mapper: - raise RuntimeError("Mapper not initialized. Call start() first.") - return self.mapper.costmap - - @property - def spatial_memory(self) -> Optional[SpatialMemory]: - """Get the robot's spatial memory module. - - Returns: - SpatialMemory module instance or None if perception is disabled - """ - return self.spatial_memory_module - - def get_video_stream(self, fps: int = 30) -> Observable: - """Get the video stream with rate limiting and processing. - - Args: - fps: Frames per second for rate limiting - - Returns: - Observable stream of video frames - """ - # Import required modules for LCM subscription - from reactivex import create - from reactivex.disposable import Disposable - - from dimos.msgs.sensor_msgs import Image - from dimos.protocol.pubsub.lcmpubsub import LCM, Topic - - lcm_instance = LCM() - lcm_instance.start() - - topic = Topic("/video", Image) - - def subscribe(observer, scheduler=None): - unsubscribe_fn = lcm_instance.subscribe(topic, lambda msg, _: observer.on_next(msg)) - - return Disposable(unsubscribe_fn) - - return create(subscribe).pipe( - ops.map( - lambda img: img.data if hasattr(img, "data") else img - ), # Convert Image message to numpy array - ops.sample(1.0 / fps), - ) - - -async def run_light_robot(): - """Run the lightweight robot without GPU modules.""" - ip = os.getenv("ROBOT_IP") - pubsub.lcm.autoconf() - - robot = UnitreeGo2Light(ip) - - await robot.start() - - pose = robot.get_pose() - print(f"Robot position: {pose['position']}") - print(f"Robot rotation: {pose['rotation']}") - robot.explore() - # Keep the program running - while True: - await asyncio.sleep(1) - - -if __name__ == "__main__": - import os - - print("Running UnitreeGo2Light...") - asyncio.run(run_light_robot()) diff --git a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2_heavy.py b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2_heavy.py deleted file mode 100644 index 235088c478..0000000000 --- a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2_heavy.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Heavy version of Unitree Go2 with GPU-required modules.""" - -import asyncio -from typing import Dict, List, Optional - -import numpy as np -from reactivex import Observable -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler - -from dimos import core -from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.robot.unitree_webrtc.multiprocess.unitree_go2 import UnitreeGo2Light -from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import AbstractRobotSkill, SkillLibrary -from dimos.types.robot_capabilities import RobotCapability -from dimos.utils.logging_config import setup_logger -from dimos.utils.threadpool import get_scheduler - -logger = setup_logger("dimos.robot.unitree_webrtc.multiprocess.unitree_go2_heavy") - - -class UnitreeGo2Heavy(UnitreeGo2Light): - """Heavy version of Unitree Go2 with additional GPU-required modules. - - This class extends UnitreeGo2Light with: - - Spatial memory with ChromaDB - - Person tracking stream - - Object tracking stream - - Skill library integration - - Full perception capabilities - """ - - def __init__( - self, - ip: str, - skill_library: Optional[SkillLibrary] = None, - robot_capabilities: Optional[List[RobotCapability]] = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = True, - enable_perception: bool = True, - pool_scheduler: Optional[ThreadPoolScheduler] = None, - ): - """Initialize Unitree Go2 Heavy with full capabilities. - - Args: - ip: IP address of the robot - output_dir: Directory for output files - skill_library: Skill library instance - robot_capabilities: List of robot capabilities - spatial_memory_collection: Collection name for spatial memory - new_memory: Whether to create new spatial memory - enable_perception: Whether to enable perception streams - pool_scheduler: Thread pool scheduler for async operations - """ - super().__init__(ip) - - self.enable_perception = enable_perception - self.disposables = CompositeDisposable() - self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() - - # Initialize capabilities - self.capabilities = robot_capabilities or [ - RobotCapability.LOCOMOTION, - RobotCapability.VISION, - RobotCapability.AUDIO, - ] - - # Camera configuration for Unitree Go2 - self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] - self.camera_pitch = np.deg2rad(0) # negative for downward pitch - self.camera_height = 0.44 # meters - - # Initialize skill library - if skill_library is None: - skill_library = MyUnitreeSkills() - self.skill_library = skill_library - - # Initialize spatial memory module (will be deployed after connection is established) - self._video_stream = None - self.new_memory = new_memory - - # Tracking modules (deployed after start) - self.person_tracker_module = None - self.object_tracker_module = None - - # Tracking stream observables for backward compatibility - self.person_tracking_stream = None - self.object_tracking_stream = None - - # References to tracker instances for skills - self.person_tracker = None - self.object_tracker = None - - async def start(self): - """Start the robot modules and initialize heavy components.""" - # First start the lightweight components - await super().start() - - await asyncio.sleep(0.5) - - # Now we have connection publishing to LCM, initialize video stream - self._video_stream = self.get_video_stream(fps=10) # Lower FPS for processing - - if self.enable_perception: - self.person_tracker_module = self.dimos.deploy( - PersonTrackingStream, - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - - # Configure person tracker LCM transport - self.person_tracker_module.video.connect(self.connection.video) - self.person_tracker_module.tracking_data.transport = core.pLCMTransport( - "/person_tracking" - ) - - self.object_tracker_module = self.dimos.deploy( - ObjectTrackingStream, - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - - # Configure object tracker LCM transport - self.object_tracker_module.video.connect(self.connection.video) - self.object_tracker_module.tracking_data.transport = core.pLCMTransport( - "/object_tracking" - ) - - # Start the tracking modules - self.person_tracker_module.start() - self.object_tracker_module.start() - - # Create Observable streams directly from the tracking outputs - logger.info("Creating Observable streams from tracking outputs") - self.person_tracking_stream = self.person_tracker_module.tracking_data.observable() - self.object_tracking_stream = self.object_tracker_module.tracking_data.observable() - - self.person_tracking_stream.subscribe( - lambda x: logger.debug( - f"Person tracking stream received: {type(x)} with {len(x.get('targets', []))} targets" - ) - ) - self.object_tracking_stream.subscribe( - lambda x: logger.debug( - f"Object tracking stream received: {type(x)} with {len(x.get('targets', []))} targets" - ) - ) - - # Create tracker references for skills to access RPC methods - self.person_tracker = self.person_tracker_module - self.object_tracker = self.object_tracker_module - - logger.info("Person and object tracking modules deployed and connected") - else: - logger.info("Perception disabled or video stream unavailable") - - if self.skill_library is not None: - for skill in self.skill_library: - if isinstance(skill, AbstractRobotSkill): - self.skill_library.create_instance(skill.__name__, robot=self) - if isinstance(self.skill_library, MyUnitreeSkills): - self.skill_library._robot = self - self.skill_library.init() - self.skill_library.initialize_skills() - - logger.info("UnitreeGo2Heavy initialized with all modules") - - @property - def video_stream(self) -> Optional[Observable]: - """Get the robot's video stream. - - Returns: - Observable video stream or None if not available - """ - return self._video_stream - - def get_skills(self): - """Get the robot's skill library. - - Returns: - The robot's skill library for adding/managing skills - """ - return self.skill_library - - def has_capability(self, capability: RobotCapability) -> bool: - """Check if the robot has a specific capability. - - Args: - capability: The capability to check for - - Returns: - bool: True if the robot has the capability - """ - return capability in self.capabilities - - def cleanup(self): - """Clean up resources used by the robot.""" - # Dispose of reactive resources - if self.disposables: - self.disposables.dispose() - - # Clean up tracking modules - if self.person_tracker_module: - self.person_tracker_module.cleanup() - self.person_tracker_module = None - if self.object_tracker_module: - self.object_tracker_module.cleanup() - self.object_tracker_module = None - - # Clear references - self.person_tracker = None - self.object_tracker = None - - logger.info("UnitreeGo2Heavy cleanup completed") diff --git a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2_navonly.py b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2_navonly.py deleted file mode 100644 index 3ee42305b2..0000000000 --- a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2_navonly.py +++ /dev/null @@ -1,229 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import functools -import logging -import os -import threading -import time -import warnings -from typing import Callable, Optional - -from reactivex import Observable -from reactivex import operators as ops - -import dimos.core.colors as colors -from dimos import core -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Transform, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid, Path -from dimos.msgs.sensor_msgs import Image -from dimos.perception.spatial_perception import SpatialMemory -from dimos.protocol import pubsub -from dimos.protocol.tf import TF -from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( - WavefrontFrontierExplorer, -) -from dimos.robot.global_planner import AstarPlanner -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection, VideoMessage -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector -from dimos.utils.data import get_data -from dimos.utils.logging_config import setup_logger -from dimos.utils.reactive import getter_streaming -from dimos.utils.testing import TimedSensorReplay - -logger = setup_logger("dimos.robot.unitree_webrtc.multiprocess.unitree_go2", level=logging.INFO) - -# Suppress verbose loggers -logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) -logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) -logging.getLogger("websockets.server").setLevel(logging.ERROR) -logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) -logging.getLogger("asyncio").setLevel(logging.ERROR) -logging.getLogger("root").setLevel(logging.WARNING) - -# Suppress warnings -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") - - -# can be swapped in for UnitreeWebRTCConnection -class FakeRTC(UnitreeWebRTCConnection): - def __init__(self, *args, **kwargs): - # ensures we download msgs from lfs store - data = get_data("unitree_office_walk") - - def connect(self): ... - - def standup(self): - print("standup supressed") - - def liedown(self): - print("liedown supressed") - - @functools.cache - def lidar_stream(self): - print("lidar stream start") - lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) - return lidar_store.stream() - - @functools.cache - def odom_stream(self): - print("odom stream start") - odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) - return odom_store.stream() - - @functools.cache - def video_stream(self): - print("video stream start") - video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) - return video_store.stream() - - def move(self, vector: Vector): - ... - # print("move supressed", vector) - - -class ConnectionModule(FakeRTC, Module): - movecmd: In[Vector3] = None - odom: Out[Vector3] = None - lidar: Out[LidarMessage] = None - video: Out[VideoMessage] = None - ip: str - - _odom: Callable[[], Odometry] - _lidar: Callable[[], LidarMessage] - - @rpc - def move(self, vector: Vector3): - super().move(vector) - - def __init__(self, ip: str, *args, **kwargs): - self.ip = ip - self.tf = TF() - Module.__init__(self, *args, **kwargs) - - @rpc - def start(self): - # Initialize the parent WebRTC connection - super().__init__(self.ip) - # Connect sensor streams to LCM outputs - self.lidar_stream().subscribe(self.lidar.publish) - self.odom_stream().subscribe(self.odom.publish) - # self.video_stream().subscribe(self.video.publish) - self.tf_stream().subscribe(self.tf.publish) - - # Connect LCM input to robot movement commands - self.movecmd.subscribe(self.move) - - # Set up streaming getters for latest sensor data - self._odom = getter_streaming(self.odom_stream()) - self._lidar = getter_streaming(self.lidar_stream()) - - @rpc - def get_local_costmap(self) -> Costmap: - return self._lidar().costmap() - - @rpc - def get_odom(self) -> Odometry: - return self._odom() - - @rpc - def get_pos(self) -> Vector: - return self._odom().position - - -class ControlModule(Module): - plancmd: Out[Pose] = None - - @rpc - def start(self): - def plancmd(): - while True: - time.sleep(0.5) - print(colors.red("requesting global plan")) - self.plancmd.publish( - PoseStamped( - ts=time.time(), - position=(0, 0, 0), - orientation=(0, 0, 0, 1), - ) - ) - - thread = threading.Thread(target=plancmd, daemon=True) - thread.start() - - -class UnitreeGo2Light: - ip: str - - def __init__(self, ip: str): - self.ip = ip - - def start(self): - dimos = core.start(4) - - connection = dimos.deploy(ConnectionModule, self.ip) - connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) - connection.odom.transport = core.LCMTransport("/odom", PoseStamped) - connection.video.transport = core.LCMTransport("/video", Image) - connection.movecmd.transport = core.LCMTransport("/mov", Vector3) - - mapper = dimos.deploy(Map, voxel_size=0.5, global_publish_interval=2.5) - - mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) - mapper.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) - - mapper.lidar.connect(connection.lidar) - - global_planner = dimos.deploy( - AstarPlanner, - get_costmap=mapper.costmap, - get_robot_pos=connection.get_pos, - set_local_nav=print, - ) - - ctrl = dimos.deploy(ControlModule) - - ctrl.plancmd.transport = core.LCMTransport("/global_target", PoseStamped) - global_planner.path.transport = core.LCMTransport("/global_path", Path) - global_planner.target.connect(ctrl.plancmd) - foxglove_bridge = FoxgloveBridge() - - connection.start() - mapper.start() - global_planner.start() - foxglove_bridge.start() - ctrl.start() - - -if __name__ == "__main__": - import os - - ip = os.getenv("ROBOT_IP") - pubsub.lcm.autoconf() - robot = UnitreeGo2Light(ip) - robot.start() - - while True: - time.sleep(1) diff --git a/dimos/robot/unitree_webrtc/run.py b/dimos/robot/unitree_webrtc/run.py new file mode 100644 index 0000000000..10d45bfb09 --- /dev/null +++ b/dimos/robot/unitree_webrtc/run.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Unitree Go2 robot with Claude agent integration. +Provides navigation and interaction capabilities with natural language interface. +""" + +import os +import sys +import time +from dotenv import load_dotenv + +import reactivex as rx +import reactivex.operators as ops + +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.agents.claude_agent import ClaudeAgent +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.observe import Observe +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal, Explore +from dimos.skills.unitree.unitree_speak import UnitreeSpeak +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.stream.audio.pipelines import stt, tts +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.unitree_webrtc.run") + +# Load environment variables +load_dotenv() + +# System prompt - loaded from prompt.txt +SYSTEM_PROMPT_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), + "assets/agent/prompt.txt", +) + + +def main(): + """Main entry point.""" + print("\n" + "=" * 60) + print("Unitree Go2 Robot with Claude Agent") + print("=" * 60) + print("\nThis system integrates:") + print(" - Unitree Go2 quadruped robot") + print(" - WebRTC communication interface") + print(" - Claude AI for natural language understanding") + print(" - Spatial memory and navigation") + print(" - Web interface with text and voice input") + print("\nStarting system...\n") + + # Check for API key + if not os.getenv("ANTHROPIC_API_KEY"): + print("WARNING: ANTHROPIC_API_KEY not found in environment") + print("Please set your API key in .env file or environment") + sys.exit(1) + + # Load system prompt + try: + with open(SYSTEM_PROMPT_PATH, "r") as f: + system_prompt = f.read() + except FileNotFoundError: + logger.error(f"System prompt file not found at {SYSTEM_PROMPT_PATH}") + sys.exit(1) + + logger.info("Starting Unitree Go2 Robot with Agent") + + # Create robot instance + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ) + + robot.start() + time.sleep(3) + + try: + logger.info("Robot initialized successfully") + + # Set up skill library + skills = robot.get_skills() + # skills.add(ObserveStream) + # skills.add(Observe) + skills.add(KillSkill) + skills.add(NavigateWithText) + skills.add(GetPose) + skills.add(NavigateToGoal) + skills.add(Explore) + + # Create skill instances + skills.create_instance("KillSkill", robot=robot, skill_library=skills) + skills.create_instance("NavigateWithText", robot=robot) + skills.create_instance("GetPose", robot=robot) + skills.create_instance("NavigateToGoal", robot=robot) + skills.create_instance("Explore", robot=robot) + + logger.info(f"Skills registered: {[skill.__name__ for skill in skills.get_class_skills()]}") + + # Set up streams for agent and web interface + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + audio_subject = rx.subject.Subject() + + # Set up streams for web interface + streams = {} + + text_streams = { + "agent_responses": agent_response_stream, + } + + # Create web interface first (needed for agent) + try: + web_interface = RobotWebInterface( + port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams + ) + logger.info("Web interface created successfully") + except Exception as e: + logger.error(f"Failed to create web interface: {e}") + raise + + # Set up speech-to-text + stt_node = stt() + stt_node.consume_audio(audio_subject.pipe(ops.share())) + + # Create Claude agent + agent = ClaudeAgent( + dev_name="unitree_go2_agent", + input_query_stream=web_interface.query_stream, # Use text input from web interface + # input_query_stream=stt_node.emit_text(), # Uncomment to use voice input + skills=skills, + system_query=system_prompt, + model_name="claude-3-5-haiku-latest", + thinking_budget_tokens=0, + max_output_tokens_per_request=8192, + ) + + # Subscribe to agent responses + agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + + # Set up text-to-speech for agent responses + tts_node = tts() + tts_node.consume_text(agent.get_response_observable()) + + # Create skill instances that need agent reference + skills.create_instance("ObserveStream", robot=robot, agent=agent) + skills.create_instance("Observe", robot=robot, agent=agent) + + logger.info("=" * 60) + logger.info("Unitree Go2 Agent Ready!") + logger.info(f"Web interface available at: http://localhost:5555") + logger.info("You can:") + logger.info(" - Type commands in the web interface") + logger.info(" - Use voice commands") + logger.info(" - Ask the robot to navigate to locations") + logger.info(" - Ask the robot to observe and describe its surroundings") + logger.info(" - Ask the robot to follow people or explore areas") + logger.info("=" * 60) + + # Run web interface (this blocks) + web_interface.run() + + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + except Exception as e: + logger.error(f"Error running robot: {e}") + import traceback + + traceback.print_exc() + finally: + logger.info("Shutting down...") + # WebRTC robot doesn't have a stop method, just log shutdown + logger.info("Shutdown complete") + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py b/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py new file mode 100644 index 0000000000..a9f2ce7d25 --- /dev/null +++ b/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py @@ -0,0 +1,199 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + +from dimos import core +from dimos.core import Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Vector3, Quaternion +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.protocol import pubsub +from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer +from dimos.navigation.global_planner import AstarPlanner +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator +from dimos.robot.unitree_webrtc.unitree_go2 import ConnectionModule +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("test_unitree_go2_integration") + +pubsub.lcm.autoconf() + + +class MovementControlModule(Module): + """Simple module to send movement commands for testing.""" + + movecmd: Out[Vector3] = None + + def __init__(self): + super().__init__() + self.commands_sent = [] + + @rpc + def send_move_command(self, x: float, y: float, yaw: float): + """Send a movement command.""" + cmd = Vector3(x, y, yaw) + self.movecmd.publish(cmd) + self.commands_sent.append(cmd) + logger.info(f"Sent move command: x={x}, y={y}, yaw={yaw}") + + @rpc + def get_command_count(self) -> int: + """Get number of commands sent.""" + return len(self.commands_sent) + + +@pytest.mark.module +class TestUnitreeGo2CoreModules: + @pytest.mark.asyncio + async def test_unitree_go2_navigation_stack(self): + """Test UnitreeGo2 core navigation modules without perception/visualization.""" + + # Start Dask + dimos = core.start(4) + + try: + # Deploy ConnectionModule with playback mode (uses test data) + connection = dimos.deploy( + ConnectionModule, + ip="127.0.0.1", # IP doesn't matter for playback + playback=True, # Enable playback mode + ) + + # Configure LCM transports + connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + connection.odom.transport = core.LCMTransport("/odom", PoseStamped) + connection.video.transport = core.LCMTransport("/video", Image) + + # Deploy Map module + mapper = dimos.deploy(Map, voxel_size=0.5, global_publish_interval=2.5) + mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) + mapper.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) + mapper.local_costmap.transport = core.LCMTransport("/local_costmap", OccupancyGrid) + mapper.lidar.connect(connection.lidar) + + # Deploy navigation stack + global_planner = dimos.deploy(AstarPlanner) + local_planner = dimos.deploy(HolonomicLocalPlanner) + navigator = dimos.deploy(BehaviorTreeNavigator, local_planner=local_planner) + + # Set up transports first + from dimos.msgs.nav_msgs import Path + from dimos_lcm.std_msgs import Bool + + navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) + navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) + navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) + navigator.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) + global_planner.path.transport = core.LCMTransport("/global_path", Path) + local_planner.cmd_vel.transport = core.LCMTransport("/cmd_vel", Vector3) + + # Configure navigation connections + global_planner.target.connect(navigator.goal) + global_planner.global_costmap.connect(mapper.global_costmap) + global_planner.odom.connect(connection.odom) + + local_planner.path.connect(global_planner.path) + local_planner.local_costmap.connect(mapper.local_costmap) + local_planner.odom.connect(connection.odom) + + connection.movecmd.connect(local_planner.cmd_vel) + navigator.odom.connect(connection.odom) + + # Deploy movement control module for testing + movement = dimos.deploy(MovementControlModule) + movement.movecmd.transport = core.LCMTransport("/test_move", Vector3) + connection.movecmd.connect(movement.movecmd) + + # Start all modules + connection.start() + mapper.start() + global_planner.start() + local_planner.start() + navigator.start() + + logger.info("All core modules started") + + # Wait for initialization + await asyncio.sleep(3) + + # Test movement commands + movement.send_move_command(0.5, 0.0, 0.0) + await asyncio.sleep(0.5) + + movement.send_move_command(0.0, 0.0, 0.3) + await asyncio.sleep(0.5) + + movement.send_move_command(0.0, 0.0, 0.0) + await asyncio.sleep(0.5) + + # Check commands were sent + cmd_count = movement.get_command_count() + assert cmd_count == 3, f"Expected 3 commands, got {cmd_count}" + logger.info(f"Successfully sent {cmd_count} movement commands") + + # Test navigation + target_pose = PoseStamped( + frame_id="world", + position=Vector3(2.0, 1.0, 0.0), + orientation=Quaternion(0, 0, 0, 1), + ) + + # Set navigation goal (non-blocking) + try: + navigator.set_goal(target_pose, blocking=False) + logger.info("Navigation goal set") + except Exception as e: + logger.warning(f"Navigation goal setting failed: {e}") + + await asyncio.sleep(2) + + # Cancel navigation + navigator.cancel_goal() + logger.info("Navigation cancelled") + + # Test frontier exploration + frontier_explorer = dimos.deploy(WavefrontFrontierExplorer) + frontier_explorer.costmap.connect(mapper.global_costmap) + frontier_explorer.odometry.connect(connection.odom) + frontier_explorer.goal_request.transport = core.LCMTransport( + "/frontier_goal", PoseStamped + ) + frontier_explorer.goal_reached.transport = core.LCMTransport("/frontier_reached", Bool) + frontier_explorer.start() + + # Try to start exploration + result = frontier_explorer.explore() + logger.info(f"Exploration started: {result}") + + await asyncio.sleep(2) + + # Stop exploration + frontier_explorer.stop_exploration() + logger.info("Exploration stopped") + + logger.info("All core navigation tests passed!") + + finally: + dimos.close() + logger.info("Closed Dask cluster") + + +if __name__ == "__main__": + pytest.main(["-v", "-s", __file__]) diff --git a/dimos/robot/unitree_webrtc/multiprocess/test_actors.py b/dimos/robot/unitree_webrtc/testing/test_actors.py similarity index 82% rename from dimos/robot/unitree_webrtc/multiprocess/test_actors.py rename to dimos/robot/unitree_webrtc/testing/test_actors.py index a346859afb..1b42412249 100644 --- a/dimos/robot/unitree_webrtc/multiprocess/test_actors.py +++ b/dimos/robot/unitree_webrtc/testing/test_actors.py @@ -12,27 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -import functools import time from typing import Callable import pytest -from reactivex import operators as ops from dimos import core -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.sensor_msgs import Image -from dimos.protocol import pubsub -from dimos.robot.global_planner import AstarPlanner +from dimos.core import Module, rpc from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.type.map import Map as Mapper -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector -from dimos.utils.reactive import backpressure, getter_streaming -from dimos.utils.testing import TimedSensorReplay @pytest.fixture diff --git a/dimos/robot/unitree_webrtc/testing/test_multimock.py b/dimos/robot/unitree_webrtc/testing/test_multimock.py deleted file mode 100644 index 1d64cbd3a0..0000000000 --- a/dimos/robot/unitree_webrtc/testing/test_multimock.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import pytest - -from reactivex import operators as ops - -from dimos.utils.reactive import backpressure -from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.robot.unitree_webrtc.type.timeseries import to_datetime -from dimos.robot.unitree_webrtc.testing.multimock import Multimock - - -@pytest.mark.needsdata -@pytest.mark.vis -def test_multimock_stream(): - backpressure(Multimock("athens_odom").stream().pipe(ops.map(Odometry.from_msg))).subscribe( - lambda x: print(x) - ) - map = Map() - - def lidarmsg(msg): - frame = LidarMessage.from_msg(msg) - map.add_frame(frame) - return [map, map.costmap.smudge()] - - mapstream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) - show3d_stream(mapstream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() - time.sleep(5) - - -@pytest.mark.needsdata -def test_clock_mismatch(): - for odometry_raw in Multimock("athens_odom").iterate(): - print( - odometry_raw.ts - to_datetime(odometry_raw.data["data"]["header"]["stamp"]), - odometry_raw.data["data"]["header"]["stamp"], - ) - - -@pytest.mark.needsdata -def test_odom_stream(): - for odometry_raw in Multimock("athens_odom").iterate(): - print(Odometry.from_msg(odometry_raw.data)) - - -@pytest.mark.needsdata -def test_lidar_stream(): - for lidar_raw in Multimock("athens_lidar").iterate(): - lidarmsg = LidarMessage.from_msg(lidar_raw.data) - print(lidarmsg) - print(lidar_raw) - - -@pytest.mark.needsdata -def test_multimock_timeseries(): - odom = Odometry.from_msg(Multimock("athens_odom").load_one(1).data) - lidar_raw = Multimock("athens_lidar").load_one(1).data - lidar = LidarMessage.from_msg(lidar_raw) - map = Map() - map.add_frame(lidar) - print(odom) - print(lidar) - print(lidar_raw) - print(map.costmap) - - -@pytest.mark.needsdata -def test_origin_changes(): - for lidar_raw in Multimock("athens_lidar").iterate(): - print(LidarMessage.from_msg(lidar_raw.data).origin) - - -@pytest.mark.needsdata -@pytest.mark.vis -def test_webui_multistream(): - websocket_vis = WebsocketVis() - websocket_vis.start() - - odom_stream = Multimock("athens_odom").stream().pipe(ops.map(Odometry.from_msg)) - lidar_stream = backpressure( - Multimock("athens_lidar").stream().pipe(ops.map(LidarMessage.from_msg)) - ) - - map = Map() - map_stream = map.consume(lidar_stream) - - costmap_stream = map_stream.pipe( - ops.map(lambda x: ["costmap", map.costmap.smudge(preserve_unknown=False)]) - ) - - websocket_vis.connect(costmap_stream) - websocket_vis.connect(odom_stream.pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) - - show3d_stream(lidar_stream, clearframe=True).run() diff --git a/dimos/robot/unitree_webrtc/multiprocess/test_tooling.py b/dimos/robot/unitree_webrtc/testing/test_tooling.py similarity index 100% rename from dimos/robot/unitree_webrtc/multiprocess/test_tooling.py rename to dimos/robot/unitree_webrtc/testing/test_tooling.py diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index 03ab277c82..fec56f9f44 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -21,8 +21,6 @@ from dimos.msgs.geometry_msgs import Vector3 from dimos.msgs.sensor_msgs import PointCloud2 from dimos.robot.unitree_webrtc.type.timeseries import to_human_readable -from dimos.types.costmap import Costmap, pointcloud_to_costmap -from dimos.types.vector import Vector class RawLidarPoints(TypedDict): @@ -53,7 +51,7 @@ class LidarMessage(PointCloud2): resolution: float # we lose resolution when encoding PointCloud2 origin: Vector3 raw_msg: Optional[RawLidarMsg] - _costmap: Optional[Costmap] = None + # _costmap: Optional[Costmap] = None # TODO: Fix after costmap migration def __init__(self, **kwargs): super().__init__( @@ -116,15 +114,16 @@ def __add__(self, other: "LidarMessage") -> "LidarMessage": def o3d_geometry(self): return self.pointcloud - def costmap(self, voxel_size: float = 0.2) -> Costmap: - if not self._costmap: - down_sampled_pointcloud = self.pointcloud.voxel_down_sample(voxel_size=voxel_size) - inflate_radius_m = 1.0 * voxel_size if voxel_size > self.resolution else 0.0 - grid, origin_xy = pointcloud_to_costmap( - down_sampled_pointcloud, - resolution=self.resolution, - inflate_radius_m=inflate_radius_m, - ) - self._costmap = Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.resolution) - - return self._costmap + # TODO: Fix after costmap migration + # def costmap(self, voxel_size: float = 0.2) -> Costmap: + # if not self._costmap: + # down_sampled_pointcloud = self.pointcloud.voxel_down_sample(voxel_size=voxel_size) + # inflate_radius_m = 1.0 * voxel_size if voxel_size > self.resolution else 0.0 + # grid, origin_xy = pointcloud_to_costmap( + # down_sampled_pointcloud, + # resolution=self.resolution, + # inflate_radius_m=inflate_radius_m, + # ) + # self._costmap = Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.resolution) + # + # return self._costmap diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index b9e820eea8..a674d4d0b7 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -32,6 +32,7 @@ class Map(Module): lidar: In[LidarMessage] = None global_map: Out[LidarMessage] = None global_costmap: Out[OccupancyGrid] = None + local_costmap: Out[OccupancyGrid] = None pointcloud: o3d.geometry.PointCloud = o3d.geometry.PointCloud() @@ -59,12 +60,12 @@ def publish(_): occupancygrid = ( OccupancyGrid.from_pointcloud( self.to_lidar_message(), - resolution=0.05, - min_height=0.1, - max_height=2.0, + resolution=self.cost_resolution, + min_height=0.15, + max_height=0.6, ) .inflate(0.1) - .gradient() + .gradient(max_distance=1.0) ) self.global_costmap.publish(occupancygrid) @@ -91,20 +92,18 @@ def add_frame(self, frame: LidarMessage) -> "Map": """Voxelise *frame* and splice it into the running map.""" new_pct = frame.pointcloud.voxel_down_sample(voxel_size=self.voxel_size) self.pointcloud = splice_cylinder(self.pointcloud, new_pct, shrink=0.5) - return self - - def consume(self, observable: Observable[LidarMessage]) -> Observable["Map"]: - """Reactive operator that folds a stream of `LidarMessage` into the map.""" - return observable.pipe(ops.map(self.add_frame)) + local_costmap = OccupancyGrid.from_pointcloud( + frame, + resolution=self.cost_resolution, + min_height=0.15, + max_height=0.6, + ).gradient(max_distance=0.25) + self.local_costmap.publish(local_costmap) @property def o3d_geometry(self) -> o3d.geometry.PointCloud: return self.pointcloud - @rpc - def costmap(self) -> OccupancyGrid: - return OccupancyGrid.from_pointcloud(self.to_PointCloud2()) - def splice_sphere( map_pcd: o3d.geometry.PointCloud, @@ -149,21 +148,3 @@ def splice_cylinder( survivors = map_pcd.select_by_index(victims, invert=True) return survivors + patch_pcd - - -def _inflate_lethal(costmap: np.ndarray, radius: int, lethal_val: int = 100) -> np.ndarray: - """Return *costmap* with lethal cells dilated by *radius* grid steps (circular).""" - if radius <= 0 or not np.any(costmap == lethal_val): - return costmap - - mask = costmap == lethal_val - dilated = mask.copy() - for dy in range(-radius, radius + 1): - for dx in range(-radius, radius + 1): - if dx * dx + dy * dy > radius * radius or (dx == 0 and dy == 0): - continue - dilated |= np.roll(mask, shift=(dy, dx), axis=(0, 1)) - - out = costmap.copy() - out[dilated] = lethal_val - return out diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py index 16b826cb87..27d59f2cb8 100644 --- a/dimos/robot/unitree_webrtc/type/odometry.py +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -11,22 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from datetime import datetime -from io import BytesIO -from typing import BinaryIO, Literal, TypeAlias, TypedDict +import time +from typing import Literal, TypedDict from scipy.spatial.transform import Rotation as R from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 from dimos.robot.unitree_webrtc.type.timeseries import ( - EpochLike, Timestamped, - to_datetime, - to_human_readable, ) -from dimos.types.timestamped import to_timestamp -from dimos.types.vector import Vector, VectorLike raw_odometry_msg_sample = { "type": "msg", @@ -104,7 +97,7 @@ def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": pose["orientation"].get("w"), ) - ts = to_timestamp(msg["data"]["header"]["stamp"]) + ts = time.time() return Odometry(position=pos, orientation=rot, ts=ts, frame_id="world") def __repr__(self) -> str: diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py index 9e3141768d..e28df7ad8d 100644 --- a/dimos/robot/unitree_webrtc/type/test_map.py +++ b/dimos/robot/unitree_webrtc/type/test_map.py @@ -14,28 +14,41 @@ import pytest -from dimos.robot.unitree_webrtc.testing.helpers import show3d, show3d_stream +from dimos.robot.unitree_webrtc.testing.helpers import show3d from dimos.robot.unitree_webrtc.testing.mock import Mock from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.map import Map, splice_sphere -from dimos.utils.reactive import backpressure from dimos.utils.testing import SensorReplay @pytest.mark.vis def test_costmap_vis(): map = Map() - for frame in Mock("office").iterate(): + map.start() + mock = Mock("office") + frames = list(mock.iterate()) + + for frame in frames: print(frame) map.add_frame(frame) - costmap = map.costmap - print(costmap) - show3d(costmap.smudge().pointcloud, title="Costmap").run() + + # Get global map and costmap + global_map = map.to_lidar_message() + print(f"Global map has {len(global_map.pointcloud.points)} points") + show3d(global_map.pointcloud, title="Global Map").run() @pytest.mark.vis def test_reconstruction_with_realtime_vis(): - show3d_stream(Map().consume(Mock("office").stream(rate_hz=60.0)), clearframe=True).run() + map = Map() + map.start() + mock = Mock("office") + + # Process frames and visualize final map + for frame in mock.iterate(): + map.add_frame(frame) + + show3d(map.pointcloud, title="Reconstructed Map").run() @pytest.mark.vis @@ -48,25 +61,38 @@ def test_splice_vis(): @pytest.mark.vis def test_robot_vis(): - show3d_stream( - Map().consume(backpressure(Mock("office").stream())), - clearframe=True, - title="gloal dynamic map test", - ) + map = Map() + map.start() + mock = Mock("office") + + # Process all frames + for frame in mock.iterate(): + map.add_frame(frame) + + show3d(map.pointcloud, title="global dynamic map test").run() def test_robot_mapping(): - lidar_stream = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + lidar_replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) map = Map(voxel_size=0.5) - # this will block until map has consumed the whole stream - map.consume(lidar_stream.stream()).run() + # Mock the output streams to avoid publishing errors + class MockStream: + def publish(self, msg): + pass # Do nothing + + map.local_costmap = MockStream() + map.global_costmap = MockStream() + map.global_map = MockStream() - # we investigate built map - costmap = map.costmap() + # Process all frames from replay + for frame in lidar_replay.iterate(): + map.add_frame(frame) - assert costmap.grid.shape == (442, 314) + # Check the built map + global_map = map.to_lidar_message() + pointcloud = global_map.pointcloud - assert 70 <= costmap.unknown_percent <= 95 - assert 4 < costmap.free_percent < 10 - assert 1 < costmap.occupied_percent < 15 + # Verify map has points + assert len(pointcloud.points) > 0 + print(f"Map contains {len(pointcloud.points)} points") diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 805c8efb28..0578547760 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + # Copyright 2025 Dimensional Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,94 +14,354 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Optional, List -import time -import numpy as np + +import functools +import logging import os -from dimos.robot.robot import Robot -from dimos.robot.unitree_webrtc.type.map import Map +import time +import warnings +from typing import Callable, Optional + +from dimos import core +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3, Quaternion +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.sensor_msgs import Image +from dimos.perception.spatial_perception import SpatialMemory +from dimos.protocol import pubsub +from dimos.protocol.tf import TF +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.navigation.global_planner import AstarPlanner +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator +from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection -from dimos.robot.global_planner.planner import AstarPlanner -from dimos.utils.reactive import getter_streaming -from dimos.skills.skills import AbstractRobotSkill, SkillLibrary -from go2_webrtc_driver.constants import VUI_COLOR -from go2_webrtc_driver.webrtc_driver import WebRTCConnectionMethod -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.robot.local_planner.local_planner import navigate_path_local -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.types.robot_capabilities import RobotCapability -from dimos.types.vector import Vector +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.robot.frontier_exploration.qwen_frontier_predictor import QwenFrontierPredictor -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( - WavefrontFrontierExplorer, -) -import threading +from dimos.skills.skills import AbstractRobotSkill, SkillLibrary +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay +from dimos_lcm.std_msgs import Bool + +logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) + +# Suppress verbose loggers +logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("websockets.server").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) +logging.getLogger("root").setLevel(logging.WARNING) + +# Suppress warnings +warnings.filterwarnings("ignore", message="coroutine.*was never awaited") +warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") + + +class FakeRTC: + """Fake WebRTC connection for testing with recorded data.""" + + def __init__(self, *args, **kwargs): + data = get_data("unitree_office_walk") + + def connect(self): + pass + + def standup(self): + print("standup suppressed") + + def liedown(self): + print("liedown suppressed") + + @functools.cache + def lidar_stream(self): + print("lidar stream start") + lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) + return lidar_store.stream() + + @functools.cache + def odom_stream(self): + print("odom stream start") + odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + return odom_store.stream() + + @functools.cache + def video_stream(self): + print("video stream start") + video_store = TimedSensorReplay( + "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() + ) + return video_store.stream() + + def move(self, vector: Vector3, duration: float = 0.0): + pass + + def publish_request(self, topic: str, data: dict): + """Fake publish request for testing.""" + return {"status": "ok", "message": "Fake publish"} -class Color(VUI_COLOR): ... +class ConnectionModule(Module): + """Module that handles robot sensor data and movement commands.""" + movecmd: In[Vector3] = None + odom: Out[PoseStamped] = None + lidar: Out[LidarMessage] = None + video: Out[Image] = None + ip: str + playback: bool + + _odom: PoseStamped = None + _lidar: LidarMessage = None + + def __init__(self, ip: str = None, playback: bool = False, *args, **kwargs): + self.ip = ip + self.playback = playback + self.tf = TF() + self.connection = None + Module.__init__(self, *args, **kwargs) + + @rpc + def start(self): + """Start the connection and subscribe to sensor streams.""" + if self.playback: + self.connection = FakeRTC(self.ip) + else: + self.connection = UnitreeWebRTCConnection(self.ip) + + # Connect sensor streams to outputs + self.connection.lidar_stream().subscribe(self.lidar.publish) + self.connection.odom_stream().subscribe(self._publish_tf) + self.connection.video_stream().subscribe(self.video.publish) + self.movecmd.subscribe(self.move) + + def _publish_tf(self, msg): + self._odom = msg + self.odom.publish(msg) + self.tf.publish(Transform.from_pose("base_link", msg)) + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=time.time(), + ) + self.tf.publish(camera_link) + + @rpc + def get_odom(self) -> Optional[PoseStamped]: + """Get the robot's odometry. + + Returns: + The robot's odometry + """ + return self._odom + + @rpc + def move(self, vector: Vector3, duration: float = 0.0): + """Send movement command to robot.""" + self.connection.move(vector, duration) + + @rpc + def standup(self): + """Make the robot stand up.""" + return self.connection.standup() + + @rpc + def liedown(self): + """Make the robot lie down.""" + return self.connection.liedown() + + @rpc + def publish_request(self, topic: str, data: dict): + """Publish a request to the WebRTC connection. + Args: + topic: The RTC topic to publish to + data: The data dictionary to publish + Returns: + The result of the publish request + """ + return self.connection.publish_request(topic, data) + + +class UnitreeGo2: + """Full Unitree Go2 robot with navigation and perception capabilities.""" -class UnitreeGo2(Robot): def __init__( self, ip: str, - mode: str = "ai", - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - skill_library: SkillLibrary = None, - robot_capabilities: List[RobotCapability] = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = True, - enable_perception: bool = True, + output_dir: str = None, + websocket_port: int = 7779, + skill_library: Optional[SkillLibrary] = None, + playback: bool = False, ): - """Initialize Unitree Go2 robot with WebRTC control interface. + """Initialize the robot system. Args: - ip: IP address of the robot - mode: Robot mode (ai, etc.) - output_dir: Directory for output files + ip: Robot IP address (or None for fake connection) + output_dir: Directory for saving outputs (default: assets/output) + enable_perception: Whether to enable spatial memory/perception + websocket_port: Port for web visualization skill_library: Skill library instance - robot_capabilities: List of robot capabilities - spatial_memory_collection: Collection name for spatial memory - new_memory: Whether to create new spatial memory - enable_perception: Whether to enable perception streams and spatial memory + playback: If True, use recorded data instead of real robot connection """ - # Create WebRTC connection interface - self.webrtc_connection = UnitreeWebRTCConnection( - ip=ip, - mode=mode, + self.ip = ip + self.playback = playback or (ip is None) # Auto-enable playback if no IP provided + self.output_dir = output_dir or os.path.join(os.getcwd(), "assets", "output") + self.websocket_port = websocket_port + + # Initialize skill library + if skill_library is None: + skill_library = MyUnitreeSkills() + self.skill_library = skill_library + + self.dimos = None + self.connection = None + self.mapper = None + self.global_planner = None + self.local_planner = None + self.navigator = None + self.frontier_explorer = None + self.websocket_vis = None + self.foxglove_bridge = None + self.spatial_memory_module = None + + self._setup_directories() + + def _setup_directories(self): + """Setup directories for spatial memory storage.""" + os.makedirs(self.output_dir, exist_ok=True) + logger.info(f"Robot outputs will be saved to: {self.output_dir}") + + # Initialize memory directories + self.memory_dir = os.path.join(self.output_dir, "memory") + os.makedirs(self.memory_dir, exist_ok=True) + + # Initialize spatial memory properties + self.spatial_memory_dir = os.path.join(self.memory_dir, "spatial_memory") + self.spatial_memory_collection = "spatial_memory" + self.db_path = os.path.join(self.spatial_memory_dir, "chromadb_data") + self.visual_memory_path = os.path.join(self.spatial_memory_dir, "visual_memory.pkl") + + # Create spatial memory directories + os.makedirs(self.spatial_memory_dir, exist_ok=True) + os.makedirs(self.db_path, exist_ok=True) + + def start(self): + """Start the robot system with all modules.""" + self.dimos = core.start(4) + + self._deploy_connection() + self._deploy_mapping() + self._deploy_navigation() + self._deploy_visualization() + self._deploy_perception() + + self._start_modules() + + logger.info("UnitreeGo2 initialized and started") + logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") + + def _deploy_connection(self): + """Deploy and configure the connection module.""" + self.connection = self.dimos.deploy(ConnectionModule, self.ip, playback=self.playback) + + self.connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) + self.connection.video.transport = core.LCMTransport("/video", Image) + self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", Vector3) + + def _deploy_mapping(self): + """Deploy and configure the mapping module.""" + self.mapper = self.dimos.deploy(Map, voxel_size=0.5, global_publish_interval=2.5) + + self.mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) + self.mapper.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) + self.mapper.local_costmap.transport = core.LCMTransport("/local_costmap", OccupancyGrid) + + self.mapper.lidar.connect(self.connection.lidar) + + def _deploy_navigation(self): + """Deploy and configure navigation modules.""" + self.global_planner = self.dimos.deploy(AstarPlanner) + self.local_planner = self.dimos.deploy(HolonomicLocalPlanner) + self.navigator = self.dimos.deploy(BehaviorTreeNavigator, local_planner=self.local_planner) + self.frontier_explorer = self.dimos.deploy(WavefrontFrontierExplorer) + + self.navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) + self.navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) + self.navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) + self.navigator.global_costmap.transport = core.LCMTransport( + "/global_costmap", OccupancyGrid + ) + self.global_planner.path.transport = core.LCMTransport("/global_path", Path) + self.local_planner.cmd_vel.transport = core.LCMTransport("/cmd_vel", Vector3) + self.frontier_explorer.goal_request.transport = core.LCMTransport( + "/goal_request", PoseStamped ) + self.frontier_explorer.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) - print("standing up") - self.webrtc_connection.standup() + self.global_planner.target.connect(self.navigator.goal) - # Initialize WebRTC-specific features - self.lidar_stream = self.webrtc_connection.lidar_stream() - self.odom = getter_streaming(self.webrtc_connection.odom_stream()) - self.map = Map(voxel_size=0.2) - self.map_stream = self.map.consume(self.lidar_stream) - self.lidar_message = getter_streaming(self.lidar_stream) + self.global_planner.global_costmap.connect(self.mapper.global_costmap) + self.global_planner.odom.connect(self.connection.odom) - if skill_library is None: - skill_library = MyUnitreeSkills() + self.local_planner.path.connect(self.global_planner.path) + self.local_planner.local_costmap.connect(self.mapper.local_costmap) + self.local_planner.odom.connect(self.connection.odom) + + self.connection.movecmd.connect(self.local_planner.cmd_vel) + + self.navigator.odom.connect(self.connection.odom) + + self.frontier_explorer.costmap.connect(self.mapper.global_costmap) + self.frontier_explorer.odometry.connect(self.connection.odom) - # Initialize base robot with connection interface - super().__init__( - connection_interface=self.webrtc_connection, - output_dir=output_dir, - skill_library=skill_library, - capabilities=robot_capabilities - or [ - RobotCapability.LOCOMOTION, - RobotCapability.VISION, - RobotCapability.AUDIO, - ], - spatial_memory_collection=spatial_memory_collection, - new_memory=new_memory, - enable_perception=enable_perception, + def _deploy_visualization(self): + """Deploy and configure visualization modules.""" + self.websocket_vis = self.dimos.deploy(WebsocketVisModule, port=self.websocket_port) + self.websocket_vis.click_goal.transport = core.LCMTransport("/goal_request", PoseStamped) + + self.websocket_vis.robot_pose.connect(self.connection.odom) + self.websocket_vis.path.connect(self.global_planner.path) + self.websocket_vis.global_costmap.connect(self.mapper.global_costmap) + + self.foxglove_bridge = FoxgloveBridge() + + def _deploy_perception(self): + """Deploy and configure the spatial memory module.""" + self.spatial_memory_module = self.dimos.deploy( + SpatialMemory, + collection_name=self.spatial_memory_collection, + db_path=self.db_path, + visual_memory_path=self.visual_memory_path, + output_dir=self.spatial_memory_dir, ) + self.spatial_memory_module.video.connect(self.connection.video) + self.spatial_memory_module.odom.connect(self.connection.odom) + + logger.info("Spatial memory module deployed and connected") + + def _start_modules(self): + """Start all deployed modules in the correct order.""" + self.connection.start() + self.mapper.start() + self.global_planner.start() + self.local_planner.start() + self.navigator.start() + self.frontier_explorer.start() + self.websocket_vis.start() + self.foxglove_bridge.start() + + if self.spatial_memory_module: + self.spatial_memory_module.start() + + # Initialize skills after connection is established if self.skill_library is not None: for skill in self.skill_library: if isinstance(skill, AbstractRobotSkill): @@ -109,116 +371,92 @@ def __init__( self.skill_library.init() self.skill_library.initialize_skills() - # Camera configuration - self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] - self.camera_pitch = np.deg2rad(0) # negative for downward pitch - self.camera_height = 0.44 # meters - - # Initialize visual servoing using connection interface - video_stream = self.get_video_stream() - if video_stream is not None and enable_perception: - self.person_tracker = PersonTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - self.object_tracker = ObjectTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - person_tracking_stream = self.person_tracker.create_stream(video_stream) - object_tracking_stream = self.object_tracker.create_stream(video_stream) - - self.person_tracking_stream = person_tracking_stream - self.object_tracking_stream = object_tracking_stream - else: - # Video stream not available or perception disabled - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None - - self.global_planner = AstarPlanner( - set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( - self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event - ), - get_costmap=lambda: self.map.costmap, - get_robot_pos=lambda: self.odom().pos, - ) + def move(self, vector: Vector3, duration: float = 0.0): + """Send movement command to robot.""" + self.connection.move(vector, duration) - # Initialize the local planner using WebRTC-specific methods - self.local_planner = VFHPurePursuitPlanner( - get_costmap=lambda: self.lidar_message().costmap(), - get_robot_pose=lambda: self.odom(), - move=self.move, # Use the robot's move method directly - robot_width=0.36, # Unitree Go2 width in meters - robot_length=0.6, # Unitree Go2 length in meters - max_linear_vel=0.7, - max_angular_vel=0.65, - lookahead_distance=1.5, - visualization_size=500, # 500x500 pixel visualization - global_planner_plan=self.global_planner.plan, - ) + def explore(self) -> bool: + """Start autonomous frontier exploration. - # Initialize frontier exploration - self.frontier_explorer = WavefrontFrontierExplorer( - set_goal=self.global_planner.set_goal, - get_costmap=lambda: self.map.costmap, - get_robot_pos=lambda: self.odom().pos, - ) + Returns: + True if exploration started successfully + """ + return self.frontier_explorer.explore() - # Create the visualization stream at 5Hz - self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) + def navigate_to(self, pose: PoseStamped, blocking: bool = True): + """Navigate to a target pose. - def get_pose(self) -> dict: - """ - Get the current pose (position and rotation) of the robot in the map frame. + Args: + pose: Target pose to navigate to + blocking: If True, block until goal is reached. If False, return immediately. Returns: - Dictionary containing: - - position: Vector (x, y, z) - - rotation: Vector (roll, pitch, yaw) in radians + If blocking=True: True if navigation was successful, False otherwise + If blocking=False: True if goal was accepted, False otherwise """ - position = Vector(self.odom().pos.x, self.odom().pos.y, self.odom().pos.z) - orientation = Vector(self.odom().rot.x, self.odom().rot.y, self.odom().rot.z) - return {"position": position, "rotation": orientation} - def explore(self, stop_event: Optional[threading.Event] = None) -> bool: + logger.info( + f"Navigating to pose: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" + ) + return self.navigator.set_goal(pose, blocking=blocking) + + def stop_exploration(self) -> bool: + """Stop autonomous exploration. + + Returns: + True if exploration was stopped """ - Start autonomous frontier exploration. + return self.frontier_explorer.stop_exploration() - Args: - stop_event: Optional threading.Event to signal when exploration should stop + def cancel_navigation(self) -> bool: + """Cancel the current navigation goal. Returns: - bool: True if exploration completed successfully, False if stopped or failed + True if goal was cancelled """ - return self.frontier_explorer.explore(stop_event=stop_event) + return self.navigator.cancel_goal() - def odom_stream(self): - """Get the odometry stream from the robot. + @property + def spatial_memory(self) -> Optional[SpatialMemory]: + """Get the robot's spatial memory module. Returns: - Observable stream of robot odometry data containing position and orientation. + SpatialMemory module instance or None if perception is disabled """ - return self.webrtc_connection.odom_stream() + return self.spatial_memory_module - def standup(self): - """Make the robot stand up. + def get_skills(self): + """Get the robot's skill library. - Uses AI mode standup if robot is in AI mode, otherwise uses normal standup. + Returns: + The robot's skill library for adding/managing skills """ - return self.webrtc_connection.standup() + return self.skill_library - def liedown(self): - """Make the robot lie down. + def get_odom(self) -> PoseStamped: + """Get the robot's odometry. - Commands the robot to lie down on the ground. + Returns: + The robot's odometry """ - return self.webrtc_connection.liedown() + return self.connection.get_odom() - @property - def costmap(self): - """Access to the costmap for navigation.""" - return self.map.costmap + +def main(): + """Main entry point.""" + ip = os.getenv("ROBOT_IP") + + pubsub.lcm.autoconf() + + robot = UnitreeGo2(ip=ip, websocket_port=7779, playback=False) + robot.start() + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Shutting down...") + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree_webrtc/unitree_skills.py b/dimos/robot/unitree_webrtc/unitree_skills.py index f9dfc1eede..ef3b22ebb4 100644 --- a/dimos/robot/unitree_webrtc/unitree_skills.py +++ b/dimos/robot/unitree_webrtc/unitree_skills.py @@ -26,7 +26,7 @@ from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary from dimos.types.constants import Colors -from dimos.types.vector import Vector +from dimos.msgs.geometry_msgs import Vector3 from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD # Module-level constant for Unitree WebRTC control definitions @@ -231,8 +231,8 @@ def __call__(self): f"{Colors.RESET_COLOR}" ) else: - # Use WebRTC publish_request interface through the robot's webrtc_connection - result = self._robot.webrtc_connection.publish_request( + # Use WebRTC publish_request interface through the robot's connection module + result = self._robot.connection.publish_request( RTC_TOPIC["SPORT_MOD"], {"api_id": self._app_id} ) string = f"{Colors.GREEN_PRINT_COLOR}{self.__class__.__name__} was successful: id={self._app_id}{Colors.RESET_COLOR}" @@ -262,7 +262,8 @@ class Move(AbstractRobotSkill): duration: float = Field(default=0.0, description="How long to move (seconds).") def __call__(self): - return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) + self._robot.move(Vector3(self.x, self.y, self.yaw), duration=self.duration) + return f"started moving with velocity={self.x}, {self.y}, {self.yaw} for {self.duration} seconds" class Wait(AbstractSkill): """Wait for a specified amount of time.""" diff --git a/dimos/skills/navigation.py b/dimos/skills/navigation.py index 6d67ae04f2..e5ead5ab85 100644 --- a/dimos/skills/navigation.py +++ b/dimos/skills/navigation.py @@ -34,7 +34,6 @@ from dimos.utils.logging_config import setup_logger from dimos.models.qwen.video_query import get_bbox_from_qwen_frame from dimos.utils.transform_utils import distance_angle_to_goal_xy -from dimos.robot.local_planner.local_planner import navigate_to_goal_local logger = setup_logger("dimos.skills.semantic_map_skills") @@ -653,27 +652,36 @@ def __call__(self): try: # Start exploration using the robot's explore method - result = self._robot.explore(stop_event=self._stop_event) + result = self._robot.explore() if result: - logger.info("Exploration completed successfully - no more frontiers found") + logger.info("Exploration started successfully") + + # Wait for exploration to complete or timeout + start_time = time.time() + while time.time() - start_time < self.timeout: + if self._stop_event.is_set(): + logger.info("Exploration stopped by user") + self._robot.stop_exploration() + return { + "success": False, + "message": "Exploration stopped by user", + } + time.sleep(0.5) + + # Timeout reached, stop exploration + logger.info(f"Exploration timeout reached after {self.timeout} seconds") + self._robot.stop_exploration() return { "success": True, - "message": "Exploration completed - all accessible areas explored", + "message": f"Exploration ran for {self.timeout} seconds", } else: - if self._stop_event.is_set(): - logger.info("Exploration stopped by user") - return { - "success": False, - "message": "Exploration stopped by user", - } - else: - logger.warning("Exploration did not complete successfully") - return { - "success": False, - "message": "Exploration failed or was interrupted", - } + logger.warning("Failed to start exploration") + return { + "success": False, + "message": "Failed to start exploration", + } except Exception as e: error_msg = f"Error during exploration: {e}" @@ -696,4 +704,11 @@ def stop(self): skill_library = self._robot.get_skills() self.unregister_as_running("Explore", skill_library) self._stop_event.set() + + # Stop the robot's exploration if it's running + try: + self._robot.stop_exploration() + except Exception as e: + logger.error(f"Error stopping exploration: {e}") + return "Exploration stopped" diff --git a/dimos/skills/unitree/unitree_speak.py b/dimos/skills/unitree/unitree_speak.py index 05004398f9..f06666c30a 100644 --- a/dimos/skills/unitree/unitree_speak.py +++ b/dimos/skills/unitree/unitree_speak.py @@ -82,9 +82,7 @@ def _webrtc_request(self, api_id: int, parameter: dict = None): request_data = {"api_id": api_id, "parameter": json.dumps(parameter) if parameter else "{}"} - return self._robot.webrtc_connection.publish_request( - RTC_TOPIC["AUDIO_HUB_REQ"], request_data - ) + return self._robot.connection.publish_request(RTC_TOPIC["AUDIO_HUB_REQ"], request_data) def _upload_audio_to_robot(self, audio_data: bytes, filename: str) -> str: try: diff --git a/dimos/types/costmap.py b/dimos/types/costmap.py deleted file mode 100644 index 2d9b1c433e..0000000000 --- a/dimos/types/costmap.py +++ /dev/null @@ -1,534 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import base64 -import pickle -import math -import numpy as np -from typing import Optional -from scipy import ndimage -from dimos.types.ros_polyfill import OccupancyGrid -from scipy.ndimage import binary_dilation -from dimos.types.vector import Vector, VectorLike, x, y, to_vector -import open3d as o3d -from matplotlib.path import Path -from PIL import Image -import cv2 - -DTYPE2STR = { - np.float32: "f32", - np.float64: "f64", - np.int32: "i32", - np.int8: "i8", -} - -STR2DTYPE = {v: k for k, v in DTYPE2STR.items()} - - -class CostValues: - """Standard cost values for occupancy grid cells.""" - - FREE = 0 # Free space - UNKNOWN = -1 # Unknown space - OCCUPIED = 100 # Occupied/lethal space - - -def encode_ndarray(arr: np.ndarray, compress: bool = False): - arr_c = np.ascontiguousarray(arr) - payload = arr_c.tobytes() - b64 = base64.b64encode(payload).decode("ascii") - - return { - "type": "grid", - "shape": arr_c.shape, - "dtype": DTYPE2STR[arr_c.dtype.type], - "data": b64, - } - - -class Costmap: - """Class to hold ROS OccupancyGrid data.""" - - def __init__( - self, - grid: np.ndarray, - origin: VectorLike, - origin_theta: float = 0, - resolution: float = 0.05, - ): - """Initialize Costmap with its core attributes.""" - self.grid = grid - self.resolution = resolution - self.origin = to_vector(origin).to_2d() - self.origin_theta = origin_theta - self.width = self.grid.shape[1] - self.height = self.grid.shape[0] - - def serialize(self) -> tuple: - """Serialize the Costmap instance to a tuple.""" - return { - "type": "costmap", - "grid": encode_ndarray(self.grid), - "origin": self.origin.serialize(), - "resolution": self.resolution, - "origin_theta": self.origin_theta, - } - - @classmethod - def from_msg(cls, costmap_msg: OccupancyGrid) -> "Costmap": - """Create a Costmap instance from a ROS OccupancyGrid message.""" - if costmap_msg is None: - raise Exception("need costmap msg") - - # Extract info from the message - width = costmap_msg.info.width - height = costmap_msg.info.height - resolution = costmap_msg.info.resolution - - # Get origin position as a vector-like object - origin = ( - costmap_msg.info.origin.position.x, - costmap_msg.info.origin.position.y, - ) - - # Calculate orientation from quaternion - qx = costmap_msg.info.origin.orientation.x - qy = costmap_msg.info.origin.orientation.y - qz = costmap_msg.info.origin.orientation.z - qw = costmap_msg.info.origin.orientation.w - origin_theta = math.atan2(2.0 * (qw * qz + qx * qy), 1.0 - 2.0 * (qy * qy + qz * qz)) - - # Convert to numpy array - data = np.array(costmap_msg.data, dtype=np.int8) - grid = data.reshape((height, width)) - - return cls( - grid=grid, - resolution=resolution, - origin=origin, - origin_theta=origin_theta, - ) - - def save_pickle(self, pickle_path: str): - """Save costmap to a pickle file. - - Args: - pickle_path: Path to save the pickle file - """ - with open(pickle_path, "wb") as f: - pickle.dump(self, f) - - @classmethod - def from_pickle(cls, pickle_path: str) -> "Costmap": - """Load costmap from a pickle file containing either a Costmap object or constructor arguments.""" - with open(pickle_path, "rb") as f: - data = pickle.load(f) - - # Check if data is already a Costmap object - if isinstance(data, cls): - return data - else: - # Assume it's constructor arguments - costmap = cls(*data) - return costmap - - @classmethod - def create_empty( - cls, width: int = 100, height: int = 100, resolution: float = 0.1 - ) -> "Costmap": - """Create an empty costmap with specified dimensions.""" - return cls( - grid=np.zeros((height, width), dtype=np.int8), - resolution=resolution, - origin=(0.0, 0.0), - origin_theta=0.0, - ) - - def world_to_grid(self, point: VectorLike) -> Vector: - """Convert world coordinates to grid coordinates. - - Args: - point: A vector-like object containing X,Y coordinates - - Returns: - Tuple of (grid_x, grid_y) as integers - """ - return (to_vector(point) - self.origin) / self.resolution - - def grid_to_world(self, grid_point: VectorLike) -> Vector: - return to_vector(grid_point) * self.resolution + self.origin - - def is_occupied(self, point: VectorLike, threshold: int = 50) -> bool: - """Check if a position in world coordinates is occupied. - - Args: - point: Vector-like object containing X,Y coordinates - threshold: Cost threshold above which a cell is considered occupied (0-100) - - Returns: - True if position is occupied or out of bounds, False otherwise - """ - grid_point = self.world_to_grid(point) - grid_x, grid_y = int(grid_point.x), int(grid_point.y) - if 0 <= grid_x < self.width and 0 <= grid_y < self.height: - # Consider unknown (-1) as unoccupied for navigation purposes - value = self.grid[grid_y, grid_x] - return value >= threshold - return True # Consider out-of-bounds as occupied - - def get_value(self, point: VectorLike) -> Optional[int]: - point = self.world_to_grid(point) - - if 0 <= point.x < self.width and 0 <= point.y < self.height: - return int(self.grid[int(point.y), int(point.x)]) - return None - - def set_value(self, point: VectorLike, value: int = 0) -> bool: - point = self.world_to_grid(point) - - if 0 <= point.x < self.width and 0 <= point.y < self.height: - self.grid[int(point.y), int(point.x)] = value - return value - return False - - def smudge( - self, - kernel_size: int = 7, - iterations: int = 25, - decay_factor: float = 0.9, - threshold: int = 90, - preserve_unknown: bool = False, - ) -> "Costmap": - """ - Creates a new costmap with expanded obstacles (smudged). - - Args: - kernel_size: Size of the convolution kernel for dilation (must be odd) - iterations: Number of dilation iterations - decay_factor: Factor to reduce cost as distance increases (0.0-1.0) - threshold: Minimum cost value to consider as an obstacle for expansion - preserve_unknown: Whether to keep unknown (-1) cells as unknown - - Returns: - A new Costmap instance with expanded obstacles - """ - # Make sure kernel size is odd - if kernel_size % 2 == 0: - kernel_size += 1 - - # Create a copy of the grid for processing - grid_copy = self.grid.copy() - - # Create a mask of unknown cells if needed - unknown_mask = None - if preserve_unknown: - unknown_mask = grid_copy == -1 - # Temporarily replace unknown cells with 0 for processing - # This allows smudging to go over unknown areas - grid_copy[unknown_mask] = 0 - - # Create a mask of cells that are above the threshold - obstacle_mask = grid_copy >= threshold - - # Create a binary map of obstacles - binary_map = obstacle_mask.astype(np.uint8) * 100 - - # Create a circular kernel for dilation (instead of square) - y, x = np.ogrid[ - -kernel_size // 2 : kernel_size // 2 + 1, - -kernel_size // 2 : kernel_size // 2 + 1, - ] - kernel = (x * x + y * y <= (kernel_size // 2) * (kernel_size // 2)).astype(np.uint8) - - # Create distance map using dilation - # Each iteration adds one 'ring' of cells around obstacles - dilated_map = binary_map.copy() - - # Store each layer of dilation with decreasing values - layers = [] - - # First layer is the original obstacle cells - layers.append(binary_map.copy()) - - for i in range(iterations): - # Dilate the binary map - dilated = ndimage.binary_dilation( - dilated_map > 0, structure=kernel, iterations=1 - ).astype(np.uint8) - - # Calculate the new layer (cells that were just added in this iteration) - new_layer = (dilated - (dilated_map > 0).astype(np.uint8)) * 100 - - # Apply decay factor based on distance from obstacle - new_layer = new_layer * (decay_factor ** (i + 1)) - - # Add to layers list - layers.append(new_layer) - - # Update dilated map for next iteration - dilated_map = dilated * 100 - - # Combine all layers to create a distance-based cost map - smudged_map = np.zeros_like(grid_copy) - for layer in layers: - # For each cell, keep the maximum value across all layers - smudged_map = np.maximum(smudged_map, layer) - - # Preserve original obstacles - smudged_map[obstacle_mask] = grid_copy[obstacle_mask] - - # When preserve_unknown is true, restore all original unknown cells - # This overlays unknown cells on top of the smudged map - if preserve_unknown and unknown_mask is not None: - smudged_map[unknown_mask] = -1 - - # Ensure cost values are in valid range (0-100) except for unknown (-1) - if preserve_unknown: - valid_cells = ~unknown_mask - smudged_map[valid_cells] = np.clip(smudged_map[valid_cells], 0, 100) - else: - smudged_map = np.clip(smudged_map, 0, 100) - - # Create a new costmap with the smudged grid - return Costmap( - grid=smudged_map.astype(np.int8), - resolution=self.resolution, - origin=self.origin, - origin_theta=self.origin_theta, - ) - - def subsample(self, subsample_factor: int = 2) -> "Costmap": - """ - Create a subsampled (lower resolution) version of the costmap. - - Args: - subsample_factor: Factor by which to reduce resolution (e.g., 2 = half resolution, 4 = quarter resolution) - - Returns: - New Costmap instance with reduced resolution - """ - if subsample_factor <= 1: - return self # No subsampling needed - - # Calculate new grid dimensions - new_height = self.height // subsample_factor - new_width = self.width // subsample_factor - - # Create new grid by subsampling - subsampled_grid = np.zeros((new_height, new_width), dtype=self.grid.dtype) - - # Sample every subsample_factor-th point - for i in range(new_height): - for j in range(new_width): - orig_i = i * subsample_factor - orig_j = j * subsample_factor - - # Take a small neighborhood and use the most conservative value - # (prioritize occupied > unknown > free for safety) - neighborhood = self.grid[ - orig_i : min(orig_i + subsample_factor, self.height), - orig_j : min(orig_j + subsample_factor, self.width), - ] - - # Priority: Occupied (100) > Unknown (-1) > Free (0) - if np.any(neighborhood == CostValues.OCCUPIED): - subsampled_grid[i, j] = CostValues.OCCUPIED - elif np.any(neighborhood == CostValues.UNKNOWN): - subsampled_grid[i, j] = CostValues.UNKNOWN - else: - subsampled_grid[i, j] = CostValues.FREE - - # Create new costmap with adjusted resolution and origin - new_resolution = self.resolution * subsample_factor - - return Costmap( - grid=subsampled_grid, - resolution=new_resolution, - origin=self.origin, # Origin stays the same - ) - - @property - def total_cells(self) -> int: - return self.width * self.height - - @property - def occupied_cells(self) -> int: - return np.sum(self.grid >= 0.1) - - @property - def unknown_cells(self) -> int: - return np.sum(self.grid == -1) - - @property - def free_cells(self) -> int: - return self.total_cells - self.occupied_cells - self.unknown_cells - - @property - def free_percent(self) -> float: - return (self.free_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 - - @property - def occupied_percent(self) -> float: - return (self.occupied_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 - - @property - def unknown_percent(self) -> float: - return (self.unknown_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 - - def __str__(self) -> str: - """ - Create a string representation of the Costmap. - - Returns: - A formatted string with key costmap information - """ - - cell_info = [ - "▦ Costmap", - f"{self.width}x{self.height}", - f"({self.width * self.resolution:.1f}x{self.height * self.resolution:.1f}m @", - f"{1 / self.resolution:.0f}cm res)", - f"Origin: ({x(self.origin):.2f}, {y(self.origin):.2f})", - f"▣ {self.occupied_percent:.1f}%", - f"□ {self.free_percent:.1f}%", - f"◌ {self.unknown_percent:.1f}%", - ] - - return " ".join(cell_info) - - def costmap_to_image(self, image_path: str) -> None: - """ - Convert costmap to JPEG image with ROS-style coloring. - Free space: light grey, Obstacles: black, Unknown: dark gray - - Args: - image_path: Path to save the JPEG image - """ - # Create image array (height, width, 3 for RGB) - img_array = np.zeros((self.height, self.width, 3), dtype=np.uint8) - - # Apply ROS-style coloring based on costmap values - for i in range(self.height): - for j in range(self.width): - value = self.grid[i, j] - if value == CostValues.FREE: # Free space = light grey (205, 205, 205) - img_array[i, j] = [205, 205, 205] - elif value == CostValues.UNKNOWN: # Unknown = dark gray (128, 128, 128) - img_array[i, j] = [128, 128, 128] - elif value >= CostValues.OCCUPIED: # Occupied/obstacles = black (0, 0, 0) - img_array[i, j] = [0, 0, 0] - else: # Any other values (low cost) = light grey - img_array[i, j] = [205, 205, 205] - - # Flip vertically to match ROS convention (origin at bottom-left) - img_array = np.flipud(img_array) - - # Create PIL image and save as JPEG - img = Image.fromarray(img_array, "RGB") - img.save(image_path, "JPEG", quality=95) - print(f"Costmap image saved to: {image_path}") - - -def _inflate_lethal(costmap: np.ndarray, radius: int, lethal_val: int = 100) -> np.ndarray: - """Return *costmap* with lethal cells dilated by *radius* grid steps (circular).""" - if radius <= 0 or not np.any(costmap == lethal_val): - return costmap - - mask = costmap == lethal_val - dilated = mask.copy() - for dy in range(-radius, radius + 1): - for dx in range(-radius, radius + 1): - if dx * dx + dy * dy > radius * radius or (dx == 0 and dy == 0): - continue - dilated |= np.roll(mask, shift=(dy, dx), axis=(0, 1)) - - out = costmap.copy() - out[dilated] = lethal_val - return out - - -def pointcloud_to_costmap( - pcd: o3d.geometry.PointCloud, - *, - resolution: float = 0.05, - ground_z: float = 0.0, - obs_min_height: float = 0.15, - max_height: Optional[float] = 0.5, - inflate_radius_m: Optional[float] = None, - default_unknown: int = -1, - cost_free: int = 0, - cost_lethal: int = 100, -) -> tuple[np.ndarray, np.ndarray]: - """Rasterise *pcd* into a 2-D int8 cost-map with optional obstacle inflation. - - Grid origin is **aligned** to the `resolution` lattice so that when - `resolution == voxel_size` every voxel centroid lands squarely inside a cell - (no alternating blank lines). - """ - - pts = np.asarray(pcd.points, dtype=np.float32) - if pts.size == 0: - return np.full((1, 1), default_unknown, np.int8), np.zeros(2, np.float32) - - # 0. Ceiling filter -------------------------------------------------------- - if max_height is not None: - pts = pts[pts[:, 2] <= max_height] - if pts.size == 0: - return np.full((1, 1), default_unknown, np.int8), np.zeros(2, np.float32) - - # 1. Bounding box & aligned origin --------------------------------------- - xy_min = pts[:, :2].min(axis=0) - xy_max = pts[:, :2].max(axis=0) - - # Align origin to the resolution grid (anchor = 0,0) - origin = np.floor(xy_min / resolution) * resolution - - # Grid dimensions (inclusive) ------------------------------------------- - Nx, Ny = (np.ceil((xy_max - origin) / resolution).astype(int) + 1).tolist() - - # 2. Bin points ------------------------------------------------------------ - idx_xy = np.floor((pts[:, :2] - origin) / resolution).astype(np.int32) - np.clip(idx_xy[:, 0], 0, Nx - 1, out=idx_xy[:, 0]) - np.clip(idx_xy[:, 1], 0, Ny - 1, out=idx_xy[:, 1]) - - lin = idx_xy[:, 1] * Nx + idx_xy[:, 0] - z_max = np.full(Nx * Ny, -np.inf, np.float32) - np.maximum.at(z_max, lin, pts[:, 2]) - z_max = z_max.reshape(Ny, Nx) - - # 3. Cost rules ----------------------------------------------------------- - costmap = np.full_like(z_max, default_unknown, np.int8) - known = z_max != -np.inf - costmap[known] = cost_free - - lethal = z_max >= (ground_z + obs_min_height) - costmap[lethal] = cost_lethal - - # 4. Optional inflation ---------------------------------------------------- - if inflate_radius_m and inflate_radius_m > 0: - cells = int(np.ceil(inflate_radius_m / resolution)) - costmap = _inflate_lethal(costmap, cells, lethal_val=cost_lethal) - - return costmap, origin.astype(np.float32) - - -if __name__ == "__main__": - costmap = Costmap.from_pickle("costmapMsg.pickle") - print(costmap) - - # Create a smudged version of the costmap for better planning - smudged_costmap = costmap.smudge( - kernel_size=10, iterations=10, threshold=80, preserve_unknown=False - ) - - print(costmap) diff --git a/dimos/types/path.py b/dimos/types/path.py deleted file mode 100644 index dd1bf3603e..0000000000 --- a/dimos/types/path.py +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from typing import Iterator, List, Tuple, TypeVar, Union - -import numpy as np -from dimos_lcm.nav_msgs import Path as LCMPPath - -from dimos.msgs.geometry_msgs import Vector3 -from dimos.types.vector import Vector - -T = TypeVar("T", bound="Path") - - -class Path: - """A class representing a path as a sequence of points.""" - - msg_name = "nav_msgs.Path" - - def lcm_encode(self) -> bytes: - """Encode the path to a bytes representation for LCM.""" - for point in self: - print(point) - - lcmpath = LCMPPath() - lcmpath.header.frame_id = "map" - lcmpath.header.stamp = 0 # Placeholder for timestamp - # Convert points to LCM format - - def __init__( - self, - points: Union[List[Vector3], List[np.ndarray], List[Tuple], np.ndarray, None] = None, - ): - """Initialize a path from a list of points. - - Args: - points: List of Vector3 objects, numpy arrays, tuples, or a 2D numpy array where each row is a point. - If None, creates an empty path. - - Examples: - Path([Vector3(1, 2), Vector(3, 4)]) # from Vector objects - Path([(1, 2), (3, 4)]) # from tuples - Path(np.array([[1, 2], [3, 4]])) # from 2D numpy array - """ - if points is None: - self._points = np.zeros((0, 0), dtype=float) - return - - if isinstance(points, np.ndarray) and points.ndim == 2: - # If already a 2D numpy array, use it directly - self._points = points.astype(float) - else: - # Convert various input types to numpy array - converted = [] - for p in points: - if isinstance(p, Vector3) or isinstance(p, Vector): - converted.append(p.data) - else: - converted.append(p) - self._points = np.array(converted, dtype=float) - - def serialize(self) -> dict: - """Serialize the path to a dictionary.""" - return { - "type": "path", - "points": self._points.tolist(), - } - - @property - def points(self) -> np.ndarray: - """Get the path points as a numpy array.""" - return self._points - - def as_vectors(self) -> List[Vector3]: - """Get the path points as Vector3 objects.""" - return [Vector3(p) for p in self._points] - - def append(self, point: Union[Vector3, np.ndarray, Tuple]) -> None: - """Append a point to the path. - - Args: - point: A Vector3, numpy array, or tuple representing a point - """ - if isinstance(point, Vector3): - point_data = point.data - else: - point_data = np.array(point, dtype=float) - - if len(self._points) == 0: - # If empty, create with correct dimensionality - self._points = np.array([point_data]) - else: - self._points = np.vstack((self._points, point_data)) - - def extend(self, points: Union[List[Vector3], List[np.ndarray], List[Tuple], "Path"]) -> None: - """Extend the path with more points. - - Args: - points: List of points or another Path object - """ - if isinstance(points, Path): - if len(self._points) == 0: - self._points = points.points.copy() - else: - self._points = np.vstack((self._points, points.points)) - else: - for point in points: - self.append(point) - - def insert(self, index: int, point: Union[Vector3, np.ndarray, Tuple]) -> None: - """Insert a point at a specific index. - - Args: - index: The index at which to insert the point - point: A Vector3, numpy array, or tuple representing a point - """ - if isinstance(point, Vector3): - point_data = point.data - else: - point_data = np.array(point, dtype=float) - - if len(self._points) == 0: - self._points = np.array([point_data]) - else: - self._points = np.insert(self._points, index, point_data, axis=0) - - def remove(self, index: int) -> np.ndarray: - """Remove and return the point at the given index. - - Args: - index: The index of the point to remove - - Returns: - The removed point as a numpy array - """ - point = self._points[index].copy() - self._points = np.delete(self._points, index, axis=0) - return point - - def clear(self) -> None: - """Remove all points from the path.""" - self._points = np.zeros( - (0, self._points.shape[1] if len(self._points) > 0 else 0), dtype=float - ) - - def length(self) -> float: - """Calculate the total length of the path. - - Returns: - The sum of the distances between consecutive points - """ - if len(self._points) < 2: - return 0.0 - - # Efficient vector calculation of consecutive point distances - diff = self._points[1:] - self._points[:-1] - segment_lengths = np.sqrt(np.sum(diff * diff, axis=1)) - return float(np.sum(segment_lengths)) - - def resample(self: T, point_spacing: float) -> T: - """Resample the path with approximately uniform spacing between points. - - Args: - point_spacing: The desired distance between consecutive points - - Returns: - A new Path object with resampled points - """ - if len(self._points) < 2 or point_spacing <= 0: - return self.__class__(self._points.copy()) - - resampled_points = [self._points[0].copy()] - accumulated_distance = 0.0 - - for i in range(1, len(self._points)): - current_point = self._points[i] - prev_point = self._points[i - 1] - segment_vector = current_point - prev_point - segment_length = np.linalg.norm(segment_vector) - - if segment_length < 1e-10: - continue - - direction = segment_vector / segment_length - - # Add points along this segment until we reach the end - while accumulated_distance + segment_length >= point_spacing: - # How far along this segment the next point should be - dist_along_segment = point_spacing - accumulated_distance - if dist_along_segment < 0: - break - - # Create the new point - new_point = prev_point + direction * dist_along_segment - resampled_points.append(new_point) - - # Update for next iteration - accumulated_distance = 0 - segment_length -= dist_along_segment - prev_point = new_point - - # Update the accumulated distance for the next segment - accumulated_distance += segment_length - - # Add the last point if it's not already there - if len(self._points) > 1: - last_point = self._points[-1] - if not np.array_equal(resampled_points[-1], last_point): - resampled_points.append(last_point.copy()) - - return self.__class__(np.array(resampled_points)) - - def simplify(self: T, tolerance: float) -> T: - """Simplify the path using the Ramer-Douglas-Peucker algorithm. - - Args: - tolerance: The maximum distance a point can deviate from the simplified path - - Returns: - A new simplified Path object - """ - if len(self._points) <= 2: - return self.__class__(self._points.copy()) - - # Implementation of Ramer-Douglas-Peucker algorithm - def rdp(points, epsilon, start, end): - if end <= start + 1: - return [start] - - # Find point with max distance from line - line_vec = points[end] - points[start] - line_length = np.linalg.norm(line_vec) - - if line_length < 1e-10: # If start and end points are the same - # Distance from next point to start point - max_dist = np.linalg.norm(points[start + 1] - points[start]) - max_idx = start + 1 - else: - max_dist = 0 - max_idx = start - - for i in range(start + 1, end): - # Distance from point to line - p_vec = points[i] - points[start] - - # Project p_vec onto line_vec - proj_scalar = np.dot(p_vec, line_vec) / (line_length * line_length) - proj = points[start] + proj_scalar * line_vec - - # Calculate perpendicular distance - dist = np.linalg.norm(points[i] - proj) - - if dist > max_dist: - max_dist = dist - max_idx = i - - # Recursive call - result = [] - if max_dist > epsilon: - result_left = rdp(points, epsilon, start, max_idx) - result_right = rdp(points, epsilon, max_idx, end) - result = result_left + result_right[1:] - else: - result = [start, end] - - return result - - indices = rdp(self._points, tolerance, 0, len(self._points) - 1) - indices.append(len(self._points) - 1) # Make sure the last point is included - indices = sorted(set(indices)) # Remove duplicates and sort - - return self.__class__(self._points[indices]) - - def smooth(self: T, weight: float = 0.5, iterations: int = 1) -> T: - """Smooth the path using a moving average filter. - - Args: - weight: How much to weight the neighboring points (0-1) - iterations: Number of smoothing passes to apply - - Returns: - A new smoothed Path object - """ - if len(self._points) <= 2 or weight <= 0 or iterations <= 0: - return self.__class__(self._points.copy()) - - smoothed_points = self._points.copy() - - for _ in range(iterations): - new_points = np.zeros_like(smoothed_points) - new_points[0] = smoothed_points[0] # Keep first point unchanged - - # Apply weighted average to middle points - for i in range(1, len(smoothed_points) - 1): - neighbor_avg = 0.5 * (smoothed_points[i - 1] + smoothed_points[i + 1]) - new_points[i] = (1 - weight) * smoothed_points[i] + weight * neighbor_avg - - new_points[-1] = smoothed_points[-1] # Keep last point unchanged - smoothed_points = new_points - - return self.__class__(smoothed_points) - - def nearest_point_index(self, point: Union[Vector3, np.ndarray, Tuple]) -> int: - """Find the index of the closest point on the path to the given point. - - Args: - point: The reference point - - Returns: - Index of the closest point on the path - """ - if len(self._points) == 0: - raise ValueError("Cannot find nearest point in an empty path") - - if isinstance(point, Vector3): - point_data = point.data - else: - point_data = np.array(point, dtype=float) - - # Calculate squared distances to all points - diff = self._points - point_data - sq_distances = np.sum(diff * diff, axis=1) - - # Return index of minimum distance - return int(np.argmin(sq_distances)) - - def reverse(self: T) -> T: - """Reverse the path direction. - - Returns: - A new Path object with points in reverse order - """ - return self.__class__(self._points[::-1].copy()) - - def __len__(self) -> int: - """Return the number of points in the path.""" - return len(self._points) - - def __getitem__(self, idx) -> Union[np.ndarray, Path]: - """Get a point or slice of points from the path.""" - if isinstance(idx, slice): - return self.__class__(self._points[idx]) - return self._points[idx].copy() - - def get_vector(self, idx: int) -> Vector3: - """Get a point at the given index as a Vector3 object.""" - return Vector3(self._points[idx]) - - def last(self) -> Vector3: - """Get the last point in the path as a Vector3 object.""" - if len(self._points) == 0: - return None - return Vector3(self._points[-1]) - - def head(self) -> Vector3: - """Get the first point in the path as a Vector3 object.""" - if len(self._points) == 0: - return None - return Vector3(self._points[0]) - - def tail(self) -> Path: - """Get all points except the first point as a new Path object.""" - if len(self._points) <= 1: - return None - return self.__class__(self._points[1:].copy()) - - def vectors(self) -> Iterator[Vector3]: - """Iterate over the points in the path.""" - for point in self._points: - yield Vector3(*point) - - def __iter__(self) -> Iterator[np.ndarray]: - """Iterate over the points in the path as numpy arrays.""" - for point in self._points: - yield point.copy() - - def __repr__(self) -> str: - """String representation of the path.""" - return f"↶ Path ({len(self._points)} Points)" - - def ipush(self, point: Union[Vector3, np.ndarray, Tuple]) -> "Path": - """Return a new Path with `point` appended.""" - if isinstance(point, Vector3): - p = point.data - else: - p = np.asarray(point, dtype=float) - - if len(self._points) == 0: - new_pts = p.reshape(1, -1) - else: - new_pts = np.vstack((self._points, p)) - return self.__class__(new_pts) - - def iclip_tail(self, max_len: int) -> "Path": - """Return a new Path containing only the last `max_len` points.""" - if max_len < 0: - raise ValueError("max_len must be ≥ 0") - return self.__class__(self._points[-max_len:]) - - def __add__(self, other): - """path + point -> path.ipush(point) or path + path -> path.extend(path)""" - if isinstance(other, Path): - new_path = Path(self._points.copy()) - new_path.extend(other) - return new_path - else: - return self.ipush(other) diff --git a/dimos/types/test_path.py b/dimos/types/test_path.py deleted file mode 100644 index 3f69002963..0000000000 --- a/dimos/types/test_path.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest - -from dimos.msgs.geometry_msgs import Vector3 -from dimos.types.path import Path - - -@pytest.fixture -def path(): - return Path([(1, 2, 3), (4, 5, 6), (7, 8, 9)]) - - -@pytest.fixture -def empty_path(): - return Path() - - -def test_init(path): - assert path.length() == 10.392304845413264 - assert len(path) == 3 - assert np.array_equal(path[1], [4.0, 5.0, 6.0]) - - -def test_init_empty(): - empty = Path() - assert len(empty) == 0 - assert empty.length() == 0.0 - - -def test_init_Vector3(): - points = map((lambda p: Vector3(p)), [[1, 2], [3, 4], [5, 6]]) - path = Path(points) - print(path) - - -def test_init_numpy_array(): - points = np.array([[1, 2], [3, 4], [5, 6]]) - path = Path(points) - assert len(path) == 3 - assert np.array_equal(path[0], [1.0, 2.0]) - - -def test_add_path(path): - path2 = Path([(10, 11, 12)]) - result = path + path2 - assert len(result) == 4 - assert np.array_equal(result[3], [10.0, 11.0, 12.0]) - - -def test_add_point(path): - result = path + (10, 11, 12) - assert len(result) == 4 - assert np.array_equal(result[3], [10.0, 11.0, 12.0]) - - -def test_append(path): - original_len = len(path) - path.append((10, 11, 12)) - assert len(path) == original_len + 1 - assert np.array_equal(path[-1], [10.0, 11.0, 12.0]) - - -def test_extend(path): - path2 = Path([(10, 11, 12), (13, 14, 15)]) - original_len = len(path) - path.extend(path2) - assert len(path) == original_len + 2 - assert np.array_equal(path[-1], [13.0, 14.0, 15.0]) - - -def test_insert(path): - path.insert(1, (10, 11, 12)) - assert len(path) == 4 - assert np.array_equal(path[1], [10.0, 11.0, 12.0]) - assert np.array_equal(path[2], [4.0, 5.0, 6.0]) # Original point shifted - - -def test_remove(path): - removed = path.remove(1) - assert len(path) == 2 - assert np.array_equal(removed, [4.0, 5.0, 6.0]) - assert np.array_equal(path[1], [7.0, 8.0, 9.0]) # Next pointhey ca shifted down - - -def test_clear(path): - path.clear() - assert len(path) == 0 - - -def test_resample(path): - resampled = path.resample(2.0) - assert len(resampled) >= 2 - # Resampling can create more points than original * 2 for complex paths - assert len(resampled) > 0 - - -def test_simplify(path): - simplified = path.simplify(0.1) - assert len(simplified) <= len(path) - assert len(simplified) >= 2 - - -def test_smooth(path): - smoothed = path.smooth(0.5, 1) - assert len(smoothed) == len(path) - assert np.array_equal(smoothed[0], path[0]) # First point unchanged - assert np.array_equal(smoothed[-1], path[-1]) # Last point unchanged - - -def test_nearest_point_index(path): - idx = path.nearest_point_index((4, 5, 6)) - assert idx == 1 - - idx = path.nearest_point_index((1, 2, 3)) - assert idx == 0 - - -def test_nearest_point_index_empty(): - empty = Path() - with pytest.raises(ValueError): - empty.nearest_point_index((1, 2, 3)) - - -def test_reverse(path): - reversed_path = path.reverse() - assert len(reversed_path) == len(path) - assert np.array_equal(reversed_path[0], path[-1]) - assert np.array_equal(reversed_path[-1], path[0]) - - -def test_getitem_slice(path): - slice_path = path[1:3] - assert isinstance(slice_path, Path) - assert len(slice_path) == 2 - assert np.array_equal(slice_path[0], [4.0, 5.0, 6.0]) - - -def test_get_vector(path): - vector = path.get_vector(1) - assert isinstance(vector, Vector3) - assert vector == Vector3([4.0, 5.0, 6.0]) - - -def test_head_tail_last(path): - head = path.head() - assert isinstance(head, Vector3) - assert head == Vector3([1.0, 2.0, 3.0]) - - last = path.last() - assert isinstance(last, Vector3) - assert last == Vector3([7.0, 8.0, 9.0]) - - tail = path.tail() - assert isinstance(tail, Path) - assert len(tail) == 2 - assert np.array_equal(tail[0], [4.0, 5.0, 6.0]) - - -def test_head_tail_last_empty(): - empty = Path() - assert empty.head() is None - assert empty.last() is None - assert empty.tail() is None - - -def test_iter(path): - arrays = list(path) - assert len(arrays) == 3 - assert all(isinstance(arr, np.ndarray) for arr in arrays) - assert np.array_equal(arrays[0], [1.0, 2.0, 3.0]) - - -def test_vectors(path): - vectors = list(path.vectors()) - assert len(vectors) == 3 - assert all(isinstance(v, Vector3) for v in vectors) - assert vectors[0] == Vector3([1.0, 2.0, 3.0]) - - -def test_repr(path): - repr_str = repr(path) - assert "Path" in repr_str - assert "3 Points" in repr_str - - -def test_ipush(path): - new_path = path.ipush((10, 11, 12)) - assert len(new_path) == 4 - assert len(path) == 3 # Original unchanged - assert np.array_equal(new_path[-1], [10.0, 11.0, 12.0]) - - -def test_iclip_tail(path): - clipped = path.iclip_tail(2) - assert len(clipped) == 2 - assert np.array_equal(clipped[0], [4.0, 5.0, 6.0]) - assert np.array_equal(clipped[1], [7.0, 8.0, 9.0]) - - -def test_iclip_tail_negative(): - path = Path([(1, 2, 3)]) - with pytest.raises(ValueError): - path.iclip_tail(-1) - - -def test_serialize(path): - serialized = path.serialize() - assert isinstance(serialized, dict) - assert serialized["type"] == "path" - assert serialized["points"] == [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] - - -def test_as_vectors(path): - vectors = path.as_vectors() - assert len(vectors) == 3 - assert all(isinstance(v, Vector3) for v in vectors) - assert vectors[0] == Vector3([1.0, 2.0, 3.0]) - - -def test_points_property(path): - points = path.points - assert isinstance(points, np.ndarray) - assert points.shape == (3, 3) - assert np.array_equal(points[0], [1.0, 2.0, 3.0]) - - -# def test_lcm_encode_decode(path): -# print(path.lcm_encode()) diff --git a/dimos/utils/test_transform_utils.py b/dimos/utils/test_transform_utils.py new file mode 100644 index 0000000000..94d15e9ee4 --- /dev/null +++ b/dimos/utils/test_transform_utils.py @@ -0,0 +1,672 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np +from scipy.spatial.transform import Rotation as R + +from dimos.utils import transform_utils +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion, Transform + + +class TestNormalizeAngle: + def test_normalize_angle_zero(self): + assert transform_utils.normalize_angle(0) == 0 + + def test_normalize_angle_pi(self): + assert np.isclose(transform_utils.normalize_angle(np.pi), np.pi) + + def test_normalize_angle_negative_pi(self): + assert np.isclose(transform_utils.normalize_angle(-np.pi), -np.pi) + + def test_normalize_angle_two_pi(self): + # 2*pi should normalize to 0 + assert np.isclose(transform_utils.normalize_angle(2 * np.pi), 0, atol=1e-10) + + def test_normalize_angle_large_positive(self): + # Large positive angle should wrap to [-pi, pi] + angle = 5 * np.pi + normalized = transform_utils.normalize_angle(angle) + assert -np.pi <= normalized <= np.pi + assert np.isclose(normalized, np.pi) + + def test_normalize_angle_large_negative(self): + # Large negative angle should wrap to [-pi, pi] + angle = -5 * np.pi + normalized = transform_utils.normalize_angle(angle) + assert -np.pi <= normalized <= np.pi + # -5*pi = -pi (odd multiple of pi wraps to -pi) + assert np.isclose(normalized, -np.pi) or np.isclose(normalized, np.pi) + + +# Tests for distance_angle_to_goal_xy removed as function doesn't exist in the module + + +class TestPoseToMatrix: + def test_identity_pose(self): + pose = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + T = transform_utils.pose_to_matrix(pose) + assert np.allclose(T, np.eye(4)) + + def test_translation_only(self): + pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) + T = transform_utils.pose_to_matrix(pose) + expected = np.eye(4) + expected[:3, 3] = [1, 2, 3] + assert np.allclose(T, expected) + + def test_rotation_only_90_degrees_z(self): + # 90 degree rotation around z-axis + quat = R.from_euler("z", np.pi / 2).as_quat() + pose = Pose(Vector3(0, 0, 0), Quaternion(quat[0], quat[1], quat[2], quat[3])) + T = transform_utils.pose_to_matrix(pose) + + # Check rotation part + expected_rot = R.from_euler("z", np.pi / 2).as_matrix() + assert np.allclose(T[:3, :3], expected_rot) + + # Check translation is zero + assert np.allclose(T[:3, 3], [0, 0, 0]) + + def test_translation_and_rotation(self): + quat = R.from_euler("xyz", [np.pi / 4, np.pi / 6, np.pi / 3]).as_quat() + pose = Pose(Vector3(5, -3, 2), Quaternion(quat[0], quat[1], quat[2], quat[3])) + T = transform_utils.pose_to_matrix(pose) + + # Check translation + assert np.allclose(T[:3, 3], [5, -3, 2]) + + # Check rotation + expected_rot = R.from_euler("xyz", [np.pi / 4, np.pi / 6, np.pi / 3]).as_matrix() + assert np.allclose(T[:3, :3], expected_rot) + + # Check bottom row + assert np.allclose(T[3, :], [0, 0, 0, 1]) + + def test_zero_norm_quaternion(self): + # Test handling of zero norm quaternion + pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 0)) + T = transform_utils.pose_to_matrix(pose) + + # Should use identity rotation + expected = np.eye(4) + expected[:3, 3] = [1, 2, 3] + assert np.allclose(T, expected) + + +class TestMatrixToPose: + def test_identity_matrix(self): + T = np.eye(4) + pose = transform_utils.matrix_to_pose(T) + assert pose.position.x == 0 + assert pose.position.y == 0 + assert pose.position.z == 0 + assert np.isclose(pose.orientation.w, 1) + assert np.isclose(pose.orientation.x, 0) + assert np.isclose(pose.orientation.y, 0) + assert np.isclose(pose.orientation.z, 0) + + def test_translation_only(self): + T = np.eye(4) + T[:3, 3] = [1, 2, 3] + pose = transform_utils.matrix_to_pose(T) + assert pose.position.x == 1 + assert pose.position.y == 2 + assert pose.position.z == 3 + assert np.isclose(pose.orientation.w, 1) + + def test_rotation_only(self): + T = np.eye(4) + T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() + pose = transform_utils.matrix_to_pose(T) + + # Check position is zero + assert pose.position.x == 0 + assert pose.position.y == 0 + assert pose.position.z == 0 + + # Check rotation + quat = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + recovered_rot = R.from_quat(quat).as_matrix() + assert np.allclose(recovered_rot, T[:3, :3]) + + def test_round_trip_conversion(self): + # Test that pose -> matrix -> pose gives same result + # Use a properly normalized quaternion + quat = R.from_euler("xyz", [0.1, 0.2, 0.3]).as_quat() + original_pose = Pose( + Vector3(1.5, -2.3, 0.7), Quaternion(quat[0], quat[1], quat[2], quat[3]) + ) + T = transform_utils.pose_to_matrix(original_pose) + recovered_pose = transform_utils.matrix_to_pose(T) + + assert np.isclose(recovered_pose.position.x, original_pose.position.x) + assert np.isclose(recovered_pose.position.y, original_pose.position.y) + assert np.isclose(recovered_pose.position.z, original_pose.position.z) + assert np.isclose(recovered_pose.orientation.x, original_pose.orientation.x, atol=1e-6) + assert np.isclose(recovered_pose.orientation.y, original_pose.orientation.y, atol=1e-6) + assert np.isclose(recovered_pose.orientation.z, original_pose.orientation.z, atol=1e-6) + assert np.isclose(recovered_pose.orientation.w, original_pose.orientation.w, atol=1e-6) + + +class TestApplyTransform: + def test_identity_transform(self): + pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) + T_identity = np.eye(4) + result = transform_utils.apply_transform(pose, T_identity) + + assert np.isclose(result.position.x, pose.position.x) + assert np.isclose(result.position.y, pose.position.y) + assert np.isclose(result.position.z, pose.position.z) + + def test_translation_transform(self): + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + T = np.eye(4) + T[:3, 3] = [2, 3, 4] + result = transform_utils.apply_transform(pose, T) + + assert np.isclose(result.position.x, 3) # 2 + 1 + assert np.isclose(result.position.y, 3) # 3 + 0 + assert np.isclose(result.position.z, 4) # 4 + 0 + + def test_rotation_transform(self): + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + T = np.eye(4) + T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() # 90 degree rotation + result = transform_utils.apply_transform(pose, T) + + # After 90 degree rotation around z, point (1,0,0) becomes (0,1,0) + assert np.isclose(result.position.x, 0, atol=1e-10) + assert np.isclose(result.position.y, 1) + assert np.isclose(result.position.z, 0) + + def test_transform_with_transform_object(self): + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + pose.frame_id = "base" + + transform = Transform() + transform.frame_id = "world" + transform.child_frame_id = "base" + transform.translation = Vector3(2, 3, 4) + transform.rotation = Quaternion(0, 0, 0, 1) + + result = transform_utils.apply_transform(pose, transform) + assert np.isclose(result.position.x, 3) + assert np.isclose(result.position.y, 3) + assert np.isclose(result.position.z, 4) + + def test_transform_frame_mismatch_raises(self): + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + pose.frame_id = "base" + + transform = Transform() + transform.frame_id = "world" + transform.child_frame_id = "different_frame" + transform.translation = Vector3(2, 3, 4) + transform.rotation = Quaternion(0, 0, 0, 1) + + with pytest.raises(ValueError, match="does not match"): + transform_utils.apply_transform(pose, transform) + + +class TestOpticalToRobotFrame: + def test_identity_at_origin(self): + pose = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.optical_to_robot_frame(pose) + assert result.position.x == 0 + assert result.position.y == 0 + assert result.position.z == 0 + + def test_position_transformation(self): + # Optical: X=right(1), Y=down(0), Z=forward(0) + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.optical_to_robot_frame(pose) + + # Robot: X=forward(0), Y=left(-1), Z=up(0) + assert np.isclose(result.position.x, 0) # Forward = Camera Z + assert np.isclose(result.position.y, -1) # Left = -Camera X + assert np.isclose(result.position.z, 0) # Up = -Camera Y + + def test_forward_position(self): + # Optical: X=right(0), Y=down(0), Z=forward(2) + pose = Pose(Vector3(0, 0, 2), Quaternion(0, 0, 0, 1)) + result = transform_utils.optical_to_robot_frame(pose) + + # Robot: X=forward(2), Y=left(0), Z=up(0) + assert np.isclose(result.position.x, 2) + assert np.isclose(result.position.y, 0) + assert np.isclose(result.position.z, 0) + + def test_down_position(self): + # Optical: X=right(0), Y=down(3), Z=forward(0) + pose = Pose(Vector3(0, 3, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.optical_to_robot_frame(pose) + + # Robot: X=forward(0), Y=left(0), Z=up(-3) + assert np.isclose(result.position.x, 0) + assert np.isclose(result.position.y, 0) + assert np.isclose(result.position.z, -3) + + def test_round_trip_optical_robot(self): + original_pose = Pose(Vector3(1, 2, 3), Quaternion(0.1, 0.2, 0.3, 0.9165151389911680)) + robot_pose = transform_utils.optical_to_robot_frame(original_pose) + recovered_pose = transform_utils.robot_to_optical_frame(robot_pose) + + assert np.isclose(recovered_pose.position.x, original_pose.position.x, atol=1e-10) + assert np.isclose(recovered_pose.position.y, original_pose.position.y, atol=1e-10) + assert np.isclose(recovered_pose.position.z, original_pose.position.z, atol=1e-10) + + +class TestRobotToOpticalFrame: + def test_position_transformation(self): + # Robot: X=forward(1), Y=left(0), Z=up(0) + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.robot_to_optical_frame(pose) + + # Optical: X=right(0), Y=down(0), Z=forward(1) + assert np.isclose(result.position.x, 0) + assert np.isclose(result.position.y, 0) + assert np.isclose(result.position.z, 1) + + def test_left_position(self): + # Robot: X=forward(0), Y=left(2), Z=up(0) + pose = Pose(Vector3(0, 2, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.robot_to_optical_frame(pose) + + # Optical: X=right(-2), Y=down(0), Z=forward(0) + assert np.isclose(result.position.x, -2) + assert np.isclose(result.position.y, 0) + assert np.isclose(result.position.z, 0) + + def test_up_position(self): + # Robot: X=forward(0), Y=left(0), Z=up(3) + pose = Pose(Vector3(0, 0, 3), Quaternion(0, 0, 0, 1)) + result = transform_utils.robot_to_optical_frame(pose) + + # Optical: X=right(0), Y=down(-3), Z=forward(0) + assert np.isclose(result.position.x, 0) + assert np.isclose(result.position.y, -3) + assert np.isclose(result.position.z, 0) + + +class TestYawTowardsPoint: + def test_yaw_from_origin(self): + # Point at (1, 0) from origin should have yaw = 0 + position = Vector3(1, 0, 0) + yaw = transform_utils.yaw_towards_point(position) + assert np.isclose(yaw, 0) + + def test_yaw_ninety_degrees(self): + # Point at (0, 1) from origin should have yaw = pi/2 + position = Vector3(0, 1, 0) + yaw = transform_utils.yaw_towards_point(position) + assert np.isclose(yaw, np.pi / 2) + + def test_yaw_negative_ninety_degrees(self): + # Point at (0, -1) from origin should have yaw = -pi/2 + position = Vector3(0, -1, 0) + yaw = transform_utils.yaw_towards_point(position) + assert np.isclose(yaw, -np.pi / 2) + + def test_yaw_forty_five_degrees(self): + # Point at (1, 1) from origin should have yaw = pi/4 + position = Vector3(1, 1, 0) + yaw = transform_utils.yaw_towards_point(position) + assert np.isclose(yaw, np.pi / 4) + + def test_yaw_with_custom_target(self): + # Point at (3, 2) from target (1, 1) + position = Vector3(3, 2, 0) + target = Vector3(1, 1, 0) + yaw = transform_utils.yaw_towards_point(position, target) + # Direction is (2, 1), so yaw = atan2(1, 2) + expected = np.arctan2(1, 2) + assert np.isclose(yaw, expected) + + +# Tests for transform_robot_to_map removed as function doesn't exist in the module + + +class TestCreateTransformFrom6DOF: + def test_identity_transform(self): + trans = Vector3(0, 0, 0) + euler = Vector3(0, 0, 0) + T = transform_utils.create_transform_from_6dof(trans, euler) + assert np.allclose(T, np.eye(4)) + + def test_translation_only(self): + trans = Vector3(1, 2, 3) + euler = Vector3(0, 0, 0) + T = transform_utils.create_transform_from_6dof(trans, euler) + + expected = np.eye(4) + expected[:3, 3] = [1, 2, 3] + assert np.allclose(T, expected) + + def test_rotation_only(self): + trans = Vector3(0, 0, 0) + euler = Vector3(np.pi / 4, np.pi / 6, np.pi / 3) + T = transform_utils.create_transform_from_6dof(trans, euler) + + expected_rot = R.from_euler("xyz", [np.pi / 4, np.pi / 6, np.pi / 3]).as_matrix() + assert np.allclose(T[:3, :3], expected_rot) + assert np.allclose(T[:3, 3], [0, 0, 0]) + assert np.allclose(T[3, :], [0, 0, 0, 1]) + + def test_translation_and_rotation(self): + trans = Vector3(5, -3, 2) + euler = Vector3(0.1, 0.2, 0.3) + T = transform_utils.create_transform_from_6dof(trans, euler) + + expected_rot = R.from_euler("xyz", [0.1, 0.2, 0.3]).as_matrix() + assert np.allclose(T[:3, :3], expected_rot) + assert np.allclose(T[:3, 3], [5, -3, 2]) + + def test_small_angles_threshold(self): + trans = Vector3(1, 2, 3) + euler = Vector3(1e-7, 1e-8, 1e-9) # Very small angles + T = transform_utils.create_transform_from_6dof(trans, euler) + + # Should be effectively identity rotation + expected = np.eye(4) + expected[:3, 3] = [1, 2, 3] + assert np.allclose(T, expected, atol=1e-6) + + +class TestInvertTransform: + def test_identity_inverse(self): + T = np.eye(4) + T_inv = transform_utils.invert_transform(T) + assert np.allclose(T_inv, np.eye(4)) + + def test_translation_inverse(self): + T = np.eye(4) + T[:3, 3] = [1, 2, 3] + T_inv = transform_utils.invert_transform(T) + + # Inverse should negate translation + expected = np.eye(4) + expected[:3, 3] = [-1, -2, -3] + assert np.allclose(T_inv, expected) + + def test_rotation_inverse(self): + T = np.eye(4) + T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() + T_inv = transform_utils.invert_transform(T) + + # Inverse rotation is transpose + expected = np.eye(4) + expected[:3, :3] = R.from_euler("z", -np.pi / 2).as_matrix() + assert np.allclose(T_inv, expected) + + def test_general_transform_inverse(self): + T = np.eye(4) + T[:3, :3] = R.from_euler("xyz", [0.1, 0.2, 0.3]).as_matrix() + T[:3, 3] = [1, 2, 3] + + T_inv = transform_utils.invert_transform(T) + + # T @ T_inv should be identity + result = T @ T_inv + assert np.allclose(result, np.eye(4)) + + # T_inv @ T should also be identity + result2 = T_inv @ T + assert np.allclose(result2, np.eye(4)) + + +class TestComposeTransforms: + def test_no_transforms(self): + result = transform_utils.compose_transforms() + assert np.allclose(result, np.eye(4)) + + def test_single_transform(self): + T = np.eye(4) + T[:3, 3] = [1, 2, 3] + result = transform_utils.compose_transforms(T) + assert np.allclose(result, T) + + def test_two_translations(self): + T1 = np.eye(4) + T1[:3, 3] = [1, 0, 0] + + T2 = np.eye(4) + T2[:3, 3] = [0, 2, 0] + + result = transform_utils.compose_transforms(T1, T2) + + expected = np.eye(4) + expected[:3, 3] = [1, 2, 0] + assert np.allclose(result, expected) + + def test_three_transforms(self): + T1 = np.eye(4) + T1[:3, 3] = [1, 0, 0] + + T2 = np.eye(4) + T2[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() + + T3 = np.eye(4) + T3[:3, 3] = [1, 0, 0] + + result = transform_utils.compose_transforms(T1, T2, T3) + expected = T1 @ T2 @ T3 + assert np.allclose(result, expected) + + +class TestEulerToQuaternion: + def test_zero_euler(self): + euler = Vector3(0, 0, 0) + quat = transform_utils.euler_to_quaternion(euler) + assert np.isclose(quat.w, 1) + assert np.isclose(quat.x, 0) + assert np.isclose(quat.y, 0) + assert np.isclose(quat.z, 0) + + def test_roll_only(self): + euler = Vector3(np.pi / 2, 0, 0) + quat = transform_utils.euler_to_quaternion(euler) + + # Verify by converting back + recovered = R.from_quat([quat.x, quat.y, quat.z, quat.w]).as_euler("xyz") + assert np.isclose(recovered[0], np.pi / 2) + assert np.isclose(recovered[1], 0) + assert np.isclose(recovered[2], 0) + + def test_pitch_only(self): + euler = Vector3(0, np.pi / 3, 0) + quat = transform_utils.euler_to_quaternion(euler) + + recovered = R.from_quat([quat.x, quat.y, quat.z, quat.w]).as_euler("xyz") + assert np.isclose(recovered[0], 0) + assert np.isclose(recovered[1], np.pi / 3) + assert np.isclose(recovered[2], 0) + + def test_yaw_only(self): + euler = Vector3(0, 0, np.pi / 4) + quat = transform_utils.euler_to_quaternion(euler) + + recovered = R.from_quat([quat.x, quat.y, quat.z, quat.w]).as_euler("xyz") + assert np.isclose(recovered[0], 0) + assert np.isclose(recovered[1], 0) + assert np.isclose(recovered[2], np.pi / 4) + + def test_degrees_mode(self): + euler = Vector3(45, 30, 60) # degrees + quat = transform_utils.euler_to_quaternion(euler, degrees=True) + + recovered = R.from_quat([quat.x, quat.y, quat.z, quat.w]).as_euler("xyz", degrees=True) + assert np.isclose(recovered[0], 45) + assert np.isclose(recovered[1], 30) + assert np.isclose(recovered[2], 60) + + +class TestQuaternionToEuler: + def test_identity_quaternion(self): + quat = Quaternion(0, 0, 0, 1) + euler = transform_utils.quaternion_to_euler(quat) + assert np.isclose(euler.x, 0) + assert np.isclose(euler.y, 0) + assert np.isclose(euler.z, 0) + + def test_90_degree_yaw(self): + # Create quaternion for 90 degree yaw rotation + r = R.from_euler("z", np.pi / 2) + q = r.as_quat() + quat = Quaternion(q[0], q[1], q[2], q[3]) + + euler = transform_utils.quaternion_to_euler(quat) + assert np.isclose(euler.x, 0) + assert np.isclose(euler.y, 0) + assert np.isclose(euler.z, np.pi / 2) + + def test_round_trip_euler_quaternion(self): + original_euler = Vector3(0.3, 0.5, 0.7) + quat = transform_utils.euler_to_quaternion(original_euler) + recovered_euler = transform_utils.quaternion_to_euler(quat) + + assert np.isclose(recovered_euler.x, original_euler.x, atol=1e-10) + assert np.isclose(recovered_euler.y, original_euler.y, atol=1e-10) + assert np.isclose(recovered_euler.z, original_euler.z, atol=1e-10) + + def test_degrees_mode(self): + # Create quaternion for 45 degree yaw rotation + r = R.from_euler("z", 45, degrees=True) + q = r.as_quat() + quat = Quaternion(q[0], q[1], q[2], q[3]) + + euler = transform_utils.quaternion_to_euler(quat, degrees=True) + assert np.isclose(euler.x, 0) + assert np.isclose(euler.y, 0) + assert np.isclose(euler.z, 45) + + def test_angle_normalization(self): + # Test that angles are normalized to [-pi, pi] + r = R.from_euler("xyz", [3 * np.pi, -3 * np.pi, 2 * np.pi]) + q = r.as_quat() + quat = Quaternion(q[0], q[1], q[2], q[3]) + + euler = transform_utils.quaternion_to_euler(quat) + assert -np.pi <= euler.x <= np.pi + assert -np.pi <= euler.y <= np.pi + assert -np.pi <= euler.z <= np.pi + + +class TestGetDistance: + def test_same_pose(self): + pose1 = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(1, 2, 3), Quaternion(0.1, 0.2, 0.3, 0.9)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 0) + + def test_distance_x_axis(self): + pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(5, 0, 0), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 5) + + def test_distance_y_axis(self): + pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(0, 3, 0), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 3) + + def test_distance_z_axis(self): + pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(0, 0, 4), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 4) + + def test_3d_distance(self): + pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(3, 4, 0), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 5) # 3-4-5 triangle + + def test_negative_coordinates(self): + pose1 = Pose(Vector3(-1, -2, -3), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + expected = np.sqrt(4 + 16 + 36) # sqrt(56) + assert np.isclose(distance, expected) + + +class TestRetractDistance: + def test_retract_along_negative_z(self): + # Default case: gripper approaches along -z axis + # Positive distance moves away from the surface (opposite to approach direction) + target_pose = Pose(Vector3(0, 0, 1), Quaternion(0, 0, 0, 1)) + retracted = transform_utils.retract_distance(target_pose, 0.5) + + # Moving along -z approach vector with positive distance = retracting upward + # Since approach is -z and we retract (positive distance), we move in +z + assert np.isclose(retracted.position.x, 0) + assert np.isclose(retracted.position.y, 0) + assert np.isclose(retracted.position.z, 0.5) # 1 + 0.5 * (-1) = 0.5 + + # Orientation should remain unchanged + assert retracted.orientation.x == target_pose.orientation.x + assert retracted.orientation.y == target_pose.orientation.y + assert retracted.orientation.z == target_pose.orientation.z + assert retracted.orientation.w == target_pose.orientation.w + + def test_retract_with_rotation(self): + # Test with a rotated pose (90 degrees around x-axis) + r = R.from_euler("x", np.pi / 2) + q = r.as_quat() + target_pose = Pose(Vector3(0, 0, 1), Quaternion(q[0], q[1], q[2], q[3])) + + retracted = transform_utils.retract_distance(target_pose, 0.5) + + # After 90 degree rotation around x, -z becomes +y + assert np.isclose(retracted.position.x, 0) + assert np.isclose(retracted.position.y, 0.5) # Move along +y + assert np.isclose(retracted.position.z, 1) + + def test_retract_negative_distance(self): + # Negative distance should move forward (toward the approach direction) + target_pose = Pose(Vector3(0, 0, 1), Quaternion(0, 0, 0, 1)) + retracted = transform_utils.retract_distance(target_pose, -0.3) + + # Moving along -z approach vector with negative distance = moving downward + assert np.isclose(retracted.position.x, 0) + assert np.isclose(retracted.position.y, 0) + assert np.isclose(retracted.position.z, 1.3) # 1 + (-0.3) * (-1) = 1.3 + + def test_retract_arbitrary_pose(self): + # Test with arbitrary position and rotation + r = R.from_euler("xyz", [0.1, 0.2, 0.3]) + q = r.as_quat() + target_pose = Pose(Vector3(5, 3, 2), Quaternion(q[0], q[1], q[2], q[3])) + + distance = 1.0 + retracted = transform_utils.retract_distance(target_pose, distance) + + # Verify the distance between original and retracted is as expected + # (approximately, due to the approach vector direction) + T_target = transform_utils.pose_to_matrix(target_pose) + rotation_matrix = T_target[:3, :3] + approach_vector = rotation_matrix @ np.array([0, 0, -1]) + + expected_x = target_pose.position.x + distance * approach_vector[0] + expected_y = target_pose.position.y + distance * approach_vector[1] + expected_z = target_pose.position.z + distance * approach_vector[2] + + assert np.isclose(retracted.position.x, expected_x) + assert np.isclose(retracted.position.y, expected_y) + assert np.isclose(retracted.position.z, expected_z) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index eaedbcecf3..0b93b9a0f3 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -14,12 +14,8 @@ import numpy as np from typing import Tuple -import logging from scipy.spatial.transform import Rotation as R - -from dimos_lcm.geometry_msgs import Pose, Point, Vector3, Quaternion - -logger = logging.getLogger(__name__) +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion, Transform def normalize_angle(angle: float) -> float: @@ -27,11 +23,6 @@ def normalize_angle(angle: float) -> float: return np.arctan2(np.sin(angle), np.cos(angle)) -def distance_angle_to_goal_xy(distance: float, angle: float) -> Tuple[float, float]: - """Convert distance and angle to goal x, y in robot frame""" - return distance * np.cos(angle), distance * np.sin(angle) - - def pose_to_matrix(pose: Pose) -> np.ndarray: """ Convert pose to 4x4 homogeneous transform matrix. @@ -76,7 +67,7 @@ def matrix_to_pose(T: np.ndarray) -> Pose: Pose object with position and orientation (quaternion) """ # Extract position - pos = Point(T[0, 3], T[1, 3], T[2, 3]) + pos = Vector3(T[0, 3], T[1, 3], T[2, 3]) # Extract rotation matrix and convert to quaternion Rot = T[:3, :3] @@ -88,7 +79,7 @@ def matrix_to_pose(T: np.ndarray) -> Pose: return Pose(pos, orientation) -def apply_transform(pose: Pose, transform_matrix: np.ndarray) -> Pose: +def apply_transform(pose: Pose, transform: np.ndarray | Transform) -> Pose: """ Apply a transformation matrix to a pose. @@ -99,11 +90,18 @@ def apply_transform(pose: Pose, transform_matrix: np.ndarray) -> Pose: Returns: Transformed pose """ + if isinstance(transform, Transform): + if transform.child_frame_id != pose.frame_id: + raise ValueError( + f"Transform frame_id {transform.frame_id} does not match pose frame_id {pose.frame_id}" + ) + transform = pose_to_matrix(transform.to_pose()) + # Convert pose to matrix T_pose = pose_to_matrix(pose) # Apply transform - T_result = transform_matrix @ T_pose + T_result = transform @ T_pose # Convert back to pose return matrix_to_pose(T_result) @@ -156,7 +154,7 @@ def optical_to_robot_frame(pose: Pose) -> Pose: quat_robot = R.from_matrix(R_robot).as_quat() # [x, y, z, w] return Pose( - Point(robot_x, robot_y, robot_z), + Vector3(robot_x, robot_y, robot_z), Quaternion(quat_robot[0], quat_robot[1], quat_robot[2], quat_robot[3]), ) @@ -198,12 +196,12 @@ def robot_to_optical_frame(pose: Pose) -> Pose: quat_optical = R.from_matrix(R_optical).as_quat() # [x, y, z, w] return Pose( - Point(optical_x, optical_y, optical_z), + Vector3(optical_x, optical_y, optical_z), Quaternion(quat_optical[0], quat_optical[1], quat_optical[2], quat_optical[3]), ) -def yaw_towards_point(position: Point, target_point: Point = None) -> float: +def yaw_towards_point(position: Vector3, target_point: Vector3 = None) -> float: """ Calculate yaw angle from target point to position (away from target). This is commonly used for object orientation in grasping applications. @@ -217,66 +215,12 @@ def yaw_towards_point(position: Point, target_point: Point = None) -> float: Yaw angle in radians pointing from target_point to position """ if target_point is None: - target_point = Point(0.0, 0.0, 0.0) + target_point = Vector3(0.0, 0.0, 0.0) direction_x = position.x - target_point.x direction_y = position.y - target_point.y return np.arctan2(direction_y, direction_x) -def transform_robot_to_map( - robot_position: Point, robot_rotation: Vector3, position: Point, rotation: Vector3 -) -> Tuple[Point, Vector3]: - """Transform position and rotation from robot frame to map frame. - - Args: - robot_position: Current robot position in map frame - robot_rotation: Current robot rotation in map frame - position: Position in robot frame as Point (x, y, z) - rotation: Rotation in robot frame as Vector3 (roll, pitch, yaw) in radians - - Returns: - Tuple of (transformed_position, transformed_rotation) where: - - transformed_position: Point (x, y, z) in map frame - - transformed_rotation: Vector3 (roll, pitch, yaw) in map frame - - Example: - obj_pos_robot = Point(1.0, 0.5, 0.0) # 1m forward, 0.5m left of robot - obj_rot_robot = Vector3(0.0, 0.0, 0.0) # No rotation relative to robot - - map_pos, map_rot = transform_robot_to_map(robot_position, robot_rotation, obj_pos_robot, obj_rot_robot) - """ - # Extract robot pose components - robot_pos = robot_position - robot_rot = robot_rotation - - # Robot position and orientation in map frame - robot_x, robot_y, robot_z = robot_pos.x, robot_pos.y, robot_pos.z - robot_yaw = robot_rot.z # yaw is rotation around z-axis - - # Position in robot frame - pos_x, pos_y, pos_z = position.x, position.y, position.z - - # Apply 2D transformation (rotation + translation) for x,y coordinates - cos_yaw = np.cos(robot_yaw) - sin_yaw = np.sin(robot_yaw) - - # Transform position from robot frame to map frame - map_x = robot_x + cos_yaw * pos_x - sin_yaw * pos_y - map_y = robot_y + sin_yaw * pos_x + cos_yaw * pos_y - map_z = robot_z + pos_z # Z translation (assume flat ground) - - # Transform rotation from robot frame to map frame - rot_roll, rot_pitch, rot_yaw = rotation.x, rotation.y, rotation.z - map_roll = robot_rot.x + rot_roll # Add robot's roll - map_pitch = robot_rot.y + rot_pitch # Add robot's pitch - map_yaw_rot = normalize_angle(robot_yaw + rot_yaw) # Add robot's yaw and normalize - - transformed_position = Point(map_x, map_y, map_z) - transformed_rotation = Vector3(map_roll, map_pitch, map_yaw_rot) - - return transformed_position, transformed_rotation - - def create_transform_from_6dof(translation: Vector3, euler_angles: Vector3) -> np.ndarray: """ Create a 4x4 transformation matrix from 6DOF parameters. @@ -396,3 +340,39 @@ def get_distance(pose1: Pose, pose2: Pose) -> float: dz = pose1.position.z - pose2.position.z return np.linalg.norm(np.array([dx, dy, dz])) + + +def retract_distance(target_pose: Pose, distance: float) -> Pose: + """ + Apply distance offset to target pose along its approach direction. + + This is commonly used in grasping to retract the gripper by a certain distance + along the approach vector before or after grasping. + + Args: + target_pose: Target pose (e.g., grasp pose) + distance: Distance to offset along the approach direction (meters) + + Returns: + Target pose offset by the specified distance along its approach direction + """ + # Convert pose to transformation matrix to extract rotation + T_target = pose_to_matrix(target_pose) + rotation_matrix = T_target[:3, :3] + + # Define the approach vector based on the target pose orientation + # Assuming the gripper approaches along its local -z axis (common for downward grasps) + # You can change this to [1, 0, 0] for x-axis or [0, 1, 0] for y-axis based on your gripper + approach_vector_local = np.array([0, 0, -1]) + + # Transform approach vector to world coordinates + approach_vector_world = rotation_matrix @ approach_vector_local + + # Apply offset along the approach direction + offset_position = Vector3( + target_pose.position.x + distance * approach_vector_world[0], + target_pose.position.y + distance * approach_vector_world[1], + target_pose.position.z + distance * approach_vector_world[2], + ) + + return Pose(position=offset_position, orientation=target_pose.orientation) diff --git a/dimos/web/websocket_vis/costmap_viz.py b/dimos/web/websocket_vis/costmap_viz.py new file mode 100644 index 0000000000..a1c6944d2b --- /dev/null +++ b/dimos/web/websocket_vis/costmap_viz.py @@ -0,0 +1,65 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple costmap wrapper for visualization purposes. +This is a minimal implementation to support websocket visualization. +""" + +import numpy as np +from typing import Optional +from dimos.msgs.nav_msgs import OccupancyGrid + + +class CostmapViz: + """A wrapper around OccupancyGrid for visualization compatibility.""" + + def __init__(self, occupancy_grid: Optional[OccupancyGrid] = None): + """Initialize from an OccupancyGrid.""" + self.occupancy_grid = occupancy_grid + + @property + def data(self) -> Optional[np.ndarray]: + """Get the costmap data as a numpy array.""" + if self.occupancy_grid: + return self.occupancy_grid.grid + return None + + @property + def width(self) -> int: + """Get the width of the costmap.""" + if self.occupancy_grid: + return self.occupancy_grid.width + return 0 + + @property + def height(self) -> int: + """Get the height of the costmap.""" + if self.occupancy_grid: + return self.occupancy_grid.height + return 0 + + @property + def resolution(self) -> float: + """Get the resolution of the costmap.""" + if self.occupancy_grid: + return self.occupancy_grid.resolution + return 1.0 + + @property + def origin(self): + """Get the origin pose of the costmap.""" + if self.occupancy_grid: + return self.occupancy_grid.origin + return None diff --git a/dimos/web/websocket_vis/helpers.py b/dimos/web/websocket_vis/helpers.py deleted file mode 100644 index 80601a2dbe..0000000000 --- a/dimos/web/websocket_vis/helpers.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC -from typing import Tuple, Callable -from dimos.types.path import Path -from dimos.types.vector import Vector - -import reactivex as rx -from reactivex import operators as ops -from reactivex.observable import Observable -from reactivex.subject import Subject -from dimos.web.websocket_vis.types import Drawable - - -class Visualizable(ABC): - """ - Base class for objects that can provide visualization data. - """ - - def vis_stream(self) -> Observable[Tuple[str, Drawable]]: - if not hasattr(self, "_vis_subject"): - self._vis_subject = Subject() - return self._vis_subject - - def vis(self, name: str, drawable: Drawable) -> None: - if not hasattr(self, "_vis_subject"): - return - self._vis_subject.on_next((name, drawable)) - - -def vector_stream( - name: str, pos: Callable[[], Vector], update_interval=0.1, precision=0.25, history=10 -) -> Observable[Tuple[str, Drawable]]: - return rx.interval(update_interval).pipe( - ops.map(lambda _: pos()), - ops.distinct_until_changed( - comparer=lambda a, b: (a - b).length() < precision, - ), - ops.scan( - lambda hist, cur: hist.ipush(cur).iclip_tail(history), - seed=Path(), - ), - ops.flat_map(lambda path: rx.from_([(f"{name}_hst", path), (name, path.last())])), - ) diff --git a/dimos/web/websocket_vis/path_history.py b/dimos/web/websocket_vis/path_history.py new file mode 100644 index 0000000000..2bfa66a956 --- /dev/null +++ b/dimos/web/websocket_vis/path_history.py @@ -0,0 +1,76 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple path history class for visualization purposes. +This is a minimal implementation to support websocket visualization. +""" + +from typing import List, Optional, Union +from dimos.msgs.geometry_msgs import Vector3 + + +class PathHistory: + """A simple container for storing a history of positions for visualization.""" + + def __init__(self, points: Optional[List[Union[Vector3, tuple, list]]] = None): + """Initialize with optional list of points.""" + self.points: List[Vector3] = [] + if points: + for p in points: + if isinstance(p, Vector3): + self.points.append(p) + else: + self.points.append(Vector3(*p)) + + def ipush(self, point: Union[Vector3, tuple, list]) -> "PathHistory": + """Add a point to the history (in-place) and return self.""" + if isinstance(point, Vector3): + self.points.append(point) + else: + self.points.append(Vector3(*point)) + return self + + def iclip_tail(self, max_length: int) -> "PathHistory": + """Keep only the last max_length points (in-place) and return self.""" + if max_length > 0 and len(self.points) > max_length: + self.points = self.points[-max_length:] + return self + + def last(self) -> Optional[Vector3]: + """Return the last point in the history, or None if empty.""" + return self.points[-1] if self.points else None + + def length(self) -> float: + """Calculate the total length of the path.""" + if len(self.points) < 2: + return 0.0 + + total = 0.0 + for i in range(1, len(self.points)): + p1 = self.points[i - 1] + p2 = self.points[i] + dx = p2.x - p1.x + dy = p2.y - p1.y + dz = p2.z - p1.z + total += (dx * dx + dy * dy + dz * dz) ** 0.5 + return total + + def __len__(self) -> int: + """Return the number of points in the history.""" + return len(self.points) + + def __getitem__(self, index: int) -> Vector3: + """Get a point by index.""" + return self.points[index] diff --git a/dimos/web/websocket_vis/server.py b/dimos/web/websocket_vis/server.py deleted file mode 100644 index a7aca2da2b..0000000000 --- a/dimos/web/websocket_vis/server.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import socketio -import uvicorn -import threading -import os -import sys -import asyncio -from typing import Tuple -from starlette.routing import Route -from starlette.responses import HTMLResponse -from starlette.applications import Starlette -from starlette.staticfiles import StaticFiles -from dimos.web.websocket_vis.types import Drawable -from reactivex import Observable -import concurrent.futures - - -async def serve_index(request): - # Read the index.html file directly - index_path = os.path.join(os.path.dirname(__file__), "static", "index.html") - with open(index_path, "r") as f: - content = f.read() - return HTMLResponse(content) - - -# Create global socketio server -sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") - -# Create Starlette app with route for root path -routes = [Route("/", serve_index)] -starlette_app = Starlette(routes=routes) - - -static_dir = os.path.join(os.path.dirname(__file__), "static") -starlette_app.mount("/", StaticFiles(directory=static_dir), name="static") - -# Create the ASGI app -app = socketio.ASGIApp(sio, starlette_app) - -main_state = { - "status": "idle", - "connected_clients": 0, -} - - -@sio.event -async def connect(sid, environ): - print(f"Client connected: {sid}") - await update_state({"connected_clients": main_state["connected_clients"] + 1}) - await sio.emit("full_state", main_state, room=sid) - - -@sio.event -async def disconnect(sid): - print(f"Client disconnected: {sid}") - await update_state({"connected_clients": main_state["connected_clients"] - 1}) - - -@sio.event -async def message(sid, data): - # print(f"Message received from {sid}: {data}") - # Call WebsocketVis.handle_message if there's an active instance - if hasattr(sio, "vis_instance") and sio.vis_instance: - msgtype = data.get("type", "unknown") - sio.vis_instance.handle_message(msgtype, data) - # await sio.emit("message", {"response": "Server received your message"}, room=sid) - - -# Deep merge function for nested dictionaries -def deep_merge(source, destination): - """ - Deep merge two dictionaries recursively. - Updates destination in-place with values from source. - Lists are replaced, not merged. - """ - for key, value in source.items(): - if key in destination and isinstance(destination[key], dict) and isinstance(value, dict): - # If both values are dictionaries, recursively deep merge them - deep_merge(value, destination[key]) - else: - # Otherwise, just update the value - destination[key] = value - return destination - - -# Utility function to update state and broadcast to all clients -async def update_state(new_data): - """Update main_state and broadcast only the new data to all connected clients""" - # Deep merge the new data into main_state - deep_merge(new_data, main_state) - # Broadcast only the new data to all connected clients - await sio.emit("state_update", new_data) - - -class WebsocketVis: - def __init__(self, port=7779, use_reload=False, msg_handler=None): - self.port = port - self.server = None - self.server_thread = None - self.sio = sio # Use the global sio instance - self.use_reload = use_reload - self.main_state = main_state # Reference to global main_state - self.msg_handler = msg_handler - self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - - # Store reference to this instance on the sio object for message handling - sio.vis_instance = self - - def handle_message(self, msgtype, msg): - """Handle incoming messages from the client""" - if self.msg_handler: - self.msg_handler(msgtype, msg) - else: - print("No message handler defined. Ignoring message.") - - def start(self): - # If reload is requested, run in main thread - if self.use_reload: - print("Starting server with hot reload in main thread") - uvicorn.run( - "server:app", # Use import string for reload to work - host="0.0.0.0", - port=self.port, - reload=True, - reload_dirs=[os.path.dirname(__file__)], - ) - return self - - # Otherwise, run in background thread - else: - print("Starting server in background thread") - self.server_thread = threading.Thread( - target=uvicorn.run, - kwargs={ - "app": app, # Use direct app object for thread mode - "host": "0.0.0.0", - "port": self.port, - }, - daemon=True, - ) - self.server_thread.start() - return self - - def process_drawable(self, drawable: Drawable): - """Process a drawable object and return a dictionary representation""" - if isinstance(drawable, tuple): - obj, config = drawable - return [obj.serialize(), config] - else: - return drawable.serialize() - - def connect(self, obs: Observable[Tuple[str, Drawable]], window_name: str = "main"): - """Connect to an Observable stream and update state on new data""" - - def new_update(data): - [name, drawable] = data - self.update_state({"draw": {name: self.process_drawable(drawable)}}) - - return obs.subscribe( - on_next=new_update, - on_error=lambda e: print(f"Error in stream: {e}"), - on_completed=lambda: print("Stream completed"), - ) - - def stop(self): - if self.server_thread and self.server_thread.is_alive(): - self.server_thread.join() - if hasattr(self, "_executor"): - self._executor.shutdown(wait=True) - if hasattr(self.sio, "disconnect"): - try: - asyncio.run(self.sio.disconnect()) - except: - pass - - async def update_state_async(self, new_data): - """Update main_state and broadcast to all connected clients""" - await update_state(new_data) - - def update_state(self, new_data): - """Thread-safe wrapper for update_state""" - try: - # Try to get the current event loop - loop = asyncio.get_running_loop() - # If we're in an event loop, schedule the coroutine - future = asyncio.run_coroutine_threadsafe(update_state(new_data), loop) - return future.result(timeout=5.0) # Add timeout to prevent blocking - except RuntimeError: - # No event loop running, this is likely called from a different thread - # Use the executor to run in a controlled manner - def run_async(): - try: - return asyncio.run(update_state(new_data)) - except Exception as e: - print(f"Error updating state: {e}") - return None - - future = self._executor.submit(run_async) - try: - return future.result(timeout=5.0) - except concurrent.futures.TimeoutError: - print("Warning: State update timed out") - return None - - -# Test timer function that updates state with current Unix time -async def start_time_counter(server): - """Start a background task that updates state with current Unix time every second""" - import time - - while True: - # Update state with current Unix timestamp - await server.update_state_async({"time": int(time.time())}) - # Wait for 1 second - await asyncio.sleep(1) - - -# For direct execution with uvicorn CLI -if __name__ == "__main__": - # Check if --reload flag is passed - use_reload = "--reload" in sys.argv - server = WebsocketVis(port=7778, use_reload=use_reload) - server_instance = server.start() diff --git a/dimos/web/websocket_vis/types.py b/dimos/web/websocket_vis/types.py deleted file mode 100644 index ea00099de5..0000000000 --- a/dimos/web/websocket_vis/types.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union, Iterable, Tuple, TypedDict - -from dimos.types.vector import Vector -from dimos.types.path import Path -from dimos.types.costmap import Costmap - - -class VectorDrawConfig(TypedDict, total=False): - color: str - width: float - style: str # "solid", "dashed", etc. - - -class PathDrawConfig(TypedDict, total=False): - color: str - width: float - style: str - fill: bool - - -class CostmapDrawConfig(TypedDict, total=False): - colormap: str - opacity: float - scale: float - - -Drawable = Union[ - Vector, - Path, - Costmap, - Tuple[Vector, VectorDrawConfig], - Tuple[Path, PathDrawConfig], - Tuple[Costmap, CostmapDrawConfig], -] - -Drawables = Iterable[Drawable] diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py new file mode 100644 index 0000000000..02f56b8460 --- /dev/null +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +WebSocket Visualization Module for Dimos navigation and mapping. +""" + +import asyncio +import concurrent.futures +import os +import threading +from typing import Any, Dict, Optional + +import socketio +import uvicorn +from starlette.applications import Starlette +from starlette.responses import HTMLResponse +from starlette.routing import Route +from starlette.staticfiles import StaticFiles + +from dimos.core import Module, In, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.web.websocket_vis") + + +class WebsocketVisModule(Module): + """ + WebSocket-based visualization module for real-time navigation data. + + This module provides a web interface for visualizing: + - Robot position and orientation + - Navigation paths + - Costmaps + - Interactive goal setting via mouse clicks + + Inputs: + - robot_pose: Current robot position + - path: Navigation path + - global_costmap: Global costmap for visualization + + Outputs: + - click_goal: Goal position from user clicks + """ + + # LCM inputs + robot_pose: In[PoseStamped] = None + path: In[Path] = None + global_costmap: In[OccupancyGrid] = None + + # LCM outputs + click_goal: Out[PoseStamped] = None + + def __init__(self, port: int = 7779, **kwargs): + """Initialize the WebSocket visualization module. + + Args: + port: Port to run the web server on + """ + super().__init__(**kwargs) + + self.port = port + self.server_thread: Optional[threading.Thread] = None + self.sio: Optional[socketio.AsyncServer] = None + self.app = None + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + # Visualization state + self.vis_state = { + "draw": {}, # Client expects visualization data under 'draw' key + "connected_clients": 0, + "status": "running", + } + self.state_lock = threading.Lock() + + logger.info(f"WebSocket visualization module initialized on port {port}") + + @rpc + def start(self): + """Start the WebSocket server and subscribe to inputs.""" + # Create the server + self._create_server() + + # Start the server in a background thread + self.server_thread = threading.Thread(target=self._run_server, daemon=True) + self.server_thread.start() + + # Subscribe to inputs + self.robot_pose.subscribe(self._on_robot_pose) + self.path.subscribe(self._on_path) + self.global_costmap.subscribe(self._on_global_costmap) + + logger.info(f"WebSocket server started on http://localhost:{self.port}") + + @rpc + def stop(self): + """Stop the WebSocket server.""" + if self._executor: + self._executor.shutdown(wait=True) + logger.info("WebSocket visualization module stopped") + + def _create_server(self): + """Create the SocketIO server and Starlette app.""" + # Create SocketIO server + self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") + + # Create Starlette app + async def serve_index(request): + index_path = os.path.join(os.path.dirname(__file__), "static", "index.html") + with open(index_path, "r") as f: + content = f.read() + return HTMLResponse(content) + + routes = [Route("/", serve_index)] + starlette_app = Starlette(routes=routes) + + # Mount static files + static_dir = os.path.join(os.path.dirname(__file__), "static") + starlette_app.mount("/", StaticFiles(directory=static_dir), name="static") + + # Create ASGI app + self.app = socketio.ASGIApp(self.sio, starlette_app) + + # Register SocketIO event handlers + @self.sio.event + async def connect(sid, environ): + logger.info(f"Client connected: {sid}") + with self.state_lock: + self.vis_state["connected_clients"] += 1 + current_state = dict(self.vis_state) + # Send current state to new client + await self.sio.emit("full_state", current_state, room=sid) + + @self.sio.event + async def disconnect(sid): + logger.info(f"Client disconnected: {sid}") + with self.state_lock: + self.vis_state["connected_clients"] -= 1 + + @self.sio.event + async def message(sid, data): + """Handle messages from the client.""" + msg_type = data.get("type") + + if msg_type == "click": + # Convert click to navigation goal + position = data.get("position", []) + if isinstance(position, list) and len(position) >= 2: + goal = PoseStamped( + position=(position[0], position[1], 0), + orientation=(0, 0, 0, 1), # Default orientation + frame_id="world", + ) + self.click_goal.publish(goal) + logger.info( + f"Click goal published: ({goal.position.x:.2f}, {goal.position.y:.2f})" + ) + + def _run_server(self): + """Run the uvicorn server.""" + uvicorn.run( + self.app, + host="0.0.0.0", + port=self.port, + log_level="error", # Reduce verbosity + ) + + def _on_robot_pose(self, msg: PoseStamped): + """Handle robot pose updates.""" + pose_data = {"type": "vector", "c": [msg.position.x, msg.position.y, msg.position.z]} + self._update_state({"draw": {"robot_pos": pose_data}}) + + def _on_path(self, msg: Path): + """Handle path updates.""" + points = [] + for pose in msg.poses: + points.append([pose.position.x, pose.position.y]) + path_data = {"type": "path", "points": points} + self._update_state({"draw": {"path": path_data}}) + + def _on_global_costmap(self, msg: OccupancyGrid): + """Handle global costmap updates.""" + costmap_data = self._process_costmap(msg) + self._update_state({"draw": {"costmap": costmap_data}}) + + def _process_costmap(self, costmap: OccupancyGrid) -> Dict[str, Any]: + """Convert OccupancyGrid to visualization format.""" + import base64 + import numpy as np + + # Convert grid data to base64 encoded string + grid_bytes = costmap.grid.astype(np.float32).tobytes() + grid_base64 = base64.b64encode(grid_bytes).decode("ascii") + + return { + "type": "costmap", + "grid": { + "type": "grid", + "shape": [costmap.height, costmap.width], + "dtype": "f32", + "compressed": False, + "data": grid_base64, + }, + "origin": { + "type": "vector", + "c": [costmap.origin.position.x, costmap.origin.position.y, 0], + }, + "resolution": costmap.resolution, + "origin_theta": 0, # Assuming no rotation for now + } + + def _update_state(self, new_data: Dict[str, Any]): + """Update visualization state and broadcast to clients.""" + with self.state_lock: + # If updating draw data, merge it properly + if "draw" in new_data: + if "draw" not in self.vis_state: + self.vis_state["draw"] = {} + self.vis_state["draw"].update(new_data["draw"]) + else: + self.vis_state.update(new_data) + + # Broadcast update asynchronously + def broadcast(): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.sio.emit("state_update", new_data)) + except Exception as e: + logger.error(f"Failed to broadcast state update: {e}") + + self._executor.submit(broadcast) diff --git a/tests/colmap_test.py b/tests/colmap_test.py deleted file mode 100644 index e1f451a7dc..0000000000 --- a/tests/colmap_test.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os -import sys - -# ----- - -# Now try to import -from dimos.environment.colmap_environment import COLMAPEnvironment - -env = COLMAPEnvironment() -env.initialize_from_video("data/IMG_1525.MOV", "data/frames") diff --git a/tests/test_spatial_memory.py b/tests/test_spatial_memory.py index b400749cb4..4a7b72701a 100644 --- a/tests/test_spatial_memory.py +++ b/tests/test_spatial_memory.py @@ -28,18 +28,29 @@ import tests.test_header -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +# from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 # Uncomment when properly configured from dimos.perception.spatial_perception import SpatialMemory +from dimos.types.vector import Vector +from dimos.msgs.geometry_msgs import Vector3, Quaternion -def extract_position(transform): - """Extract position coordinates from a transform message""" +def extract_pose_data(transform): + """Extract position and rotation from a transform message""" if transform is None: - return (0, 0, 0) + return None, None pos = transform.transform.translation - return (pos.x, pos.y, pos.z) + rot = transform.transform.rotation + + # Convert to Vector3 objects expected by SpatialMemory + position = Vector3(x=pos.x, y=pos.y, z=pos.z) + + # Convert quaternion to euler angles for rotation vector + quat = Quaternion(x=rot.x, y=rot.y, z=rot.z, w=rot.w) + euler = quat.to_euler() + rotation = Vector3(x=euler.x, y=euler.y, z=euler.z) + + return position, rotation def setup_persistent_chroma_db(db_path="chromadb_data"): @@ -65,26 +76,22 @@ def setup_persistent_chroma_db(db_path="chromadb_data"): def main(): print("Starting spatial memory test...") - # Initialize ROS control and robot - ros_control = UnitreeROSControl(node_name="spatial_memory_test", mock_connection=False) - - robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP")) - # Create counters for tracking frame_count = 0 transform_count = 0 stored_count = 0 - print("Setting up video stream...") - video_stream = robot.get_ros_video_stream() + print("Note: This test requires proper robot connection setup.") + print("Please ensure video_stream and transform_stream are properly configured.") - # Create transform stream at 1 Hz - print("Setting up transform stream...") - transform_stream = ros_control.get_transform_stream( - child_frame="map", - parent_frame="base_link", - rate_hz=1.0, # 1 transform per second - ) + # These need to be set up based on your specific robot configuration + video_stream = None # TODO: Set up video stream from robot + transform_stream = None # TODO: Set up transform stream from robot + + if video_stream is None or transform_stream is None: + print("\nWARNING: Video or transform streams not configured.") + print("Exiting test. Please configure streams properly.") + return # Setup output directory for visual memory visual_memory_dir = "/home/stash/dimensional/dimos/assets/test_spatial_memory" @@ -125,9 +132,11 @@ def main(): ops.map( lambda pair: { "frame": pair[0], # First element is the frame - "position": extract_position(pair[1]), # Second element is the transform + "position": extract_pose_data(pair[1])[0], # Position as Vector3 + "rotation": extract_pose_data(pair[1])[1], # Rotation as Vector3 } - ) + ), + ops.filter(lambda data: data["position"] is not None and data["rotation"] is not None), ) # Process with spatial memory @@ -140,7 +149,12 @@ def on_stored_frame(result): if not result.get("stored", True) == False: stored_count += 1 pos = result["position"] - print(f"\nStored frame #{stored_count} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})") + if isinstance(pos, tuple): + print( + f"\nStored frame #{stored_count} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})" + ) + else: + print(f"\nStored frame #{stored_count} at position {pos}") # Save the frame to the assets directory if "frame" in result: @@ -184,6 +198,12 @@ def on_stored_frame(result): saved_path = spatial_memory.vector_db.visual_memory.save("visual_memory.pkl") print(f"Saved {spatial_memory.vector_db.visual_memory.count()} images to disk at {saved_path}") + # Final cleanup + print("Performing final cleanup...") + spatial_memory.cleanup() + + print("Test completed successfully") + def visualize_spatial_memory_with_objects( spatial_memory, objects, output_filename="spatial_memory_map.png" @@ -206,11 +226,17 @@ def visualize_spatial_memory_with_objects( return # Extract coordinates from all stored locations - if len(locations[0]) >= 3: - x_coords = [loc[0] for loc in locations] - y_coords = [loc[1] for loc in locations] - else: - x_coords, y_coords = zip(*locations) + x_coords = [] + y_coords = [] + for loc in locations: + if isinstance(loc, dict): + x_coords.append(loc.get("pos_x", 0)) + y_coords.append(loc.get("pos_y", 0)) + elif isinstance(loc, (tuple, list)) and len(loc) >= 2: + x_coords.append(loc[0]) + y_coords.append(loc[1]) + else: + print(f"Unknown location format: {loc}") # Create figure plt.figure(figsize=(12, 10)) @@ -240,9 +266,10 @@ def visualize_spatial_memory_with_objects( if isinstance(metadata, list) and metadata: metadata = metadata[0] - if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: - x = metadata.get("x", 0) - y = metadata.get("y", 0) + if isinstance(metadata, dict): + # New metadata format uses pos_x, pos_y + x = metadata.get("pos_x", metadata.get("x", 0)) + y = metadata.get("pos_y", metadata.get("y", 0)) # Store coordinates for this object object_coords[obj] = (x, y) @@ -281,17 +308,6 @@ def visualize_spatial_memory_with_objects( return object_coords - # Final cleanup - print("Performing final cleanup...") - spatial_memory.cleanup() - - try: - robot.cleanup() - except Exception as e: - print(f"Error during robot cleanup: {e}") - - print("Test completed successfully") - if __name__ == "__main__": main() diff --git a/tests/test_zed_module.py b/tests/test_zed_module.py index fbc99a54a4..a8c5691b59 100644 --- a/tests/test_zed_module.py +++ b/tests/test_zed_module.py @@ -229,7 +229,6 @@ async def test_zed_module(): # Print module info logger.info("ZED Module configured:") - print(zed.io().result()) # Start ZED module logger.info("Starting ZED module...")