From 54b727f98735d84fc931f4d19f80165d12ca92d1 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 14 May 2025 13:24:16 -0700 Subject: [PATCH 01/89] initial implementation of pointcloud filtering and segmentation --- dimos-lcm | 1 + dimos/perception/pointcloud/pointcloud_seg.py | 338 ++++++++++++++++++ 2 files changed, 339 insertions(+) create mode 160000 dimos-lcm create mode 100644 dimos/perception/pointcloud/pointcloud_seg.py diff --git a/dimos-lcm b/dimos-lcm new file mode 160000 index 0000000000..403afa2fdb --- /dev/null +++ b/dimos-lcm @@ -0,0 +1 @@ +Subproject commit 403afa2fdba3232d98719f426fbd8d7d94e0e549 diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py new file mode 100644 index 0000000000..0968a0b338 --- /dev/null +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -0,0 +1,338 @@ +import numpy as np +import cv2 +import yaml +import os +import sys +from PIL import Image, ImageDraw +from dimos.perception.segmentation import Sam2DSegmenter +from dimos.perception.pointcloud.utils import ( + load_camera_matrix_from_yaml, + create_masked_point_cloud, + o3d_point_cloud_to_numpy, + rotation_to_o3d +) +from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit +import torch +import open3d as o3d + +class PointcloudSegmentation: + def __init__( + self, + model_path="FastSAM-s.pt", + device="cuda", + color_intrinsics=None, + depth_intrinsics=None, + enable_tracking=True, + enable_analysis=True, + ): + """ + Initialize processor to segment objects in RGB images and extract their point clouds. + + Args: + model_path: Path to the FastSAM model + device: Computation device ("cuda" or "cpu") + color_intrinsics: Path to YAML file or list with color camera intrinsics [fx, fy, cx, cy] + depth_intrinsics: Path to YAML file or list with depth camera intrinsics [fx, fy, cx, cy] + enable_tracking: Whether to enable object tracking + enable_analysis: Whether to enable object analysis (labels, etc.) + min_analysis_interval: Minimum interval between analysis runs in seconds + """ + # Initialize segmenter + self.segmenter = Sam2DSegmenter( + model_path=model_path, + device=device, + use_tracker=enable_tracking, + use_analyzer=enable_analysis, + ) + + # Store settings + self.enable_tracking = enable_tracking + self.enable_analysis = enable_analysis + + # Load camera matrices + self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) + self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) + + def generate_color_from_id(self, track_id): + """Generate a consistent color for a given tracking ID.""" + np.random.seed(track_id) + color = np.random.randint(0, 255, 3) + np.random.seed(None) + return color + + def process_images(self, color_img, depth_img, fit_3d_cuboids=True): + """ + Process color and depth images to segment objects and extract point clouds. + Uses Open3D for point cloud processing. + + Args: + color_img: RGB image as numpy array (H, W, 3) + depth_img: Depth image as numpy array (H, W) in meters + fit_3d_cuboids: Whether to fit 3D cuboids to each object + + Returns: + dict: Dictionary containing: + - viz_image: Visualization image with detections + - objects: List of dicts for each object with: + - mask: Segmentation mask (H, W, bool) + - bbox: Bounding box [x1, y1, x2, y2] + - target_id: Tracking ID + - confidence: Detection confidence + - name: Object name (if analyzer enabled) + - point_cloud: Open3D point cloud object + - point_cloud_numpy: Nx6 array of XYZRGB points (for compatibility) + - color: RGB color for visualization + - cuboid_params: Cuboid parameters (if fit_3d_cuboids=True) + """ + if self.depth_camera_matrix is None: + raise ValueError("Depth camera matrix must be provided to process images") + + # Run segmentation + masks, bboxes, target_ids, probs, names = self.segmenter.process_image(color_img) + print(f"Found {len(masks)} segmentation masks") + + # Run analysis if enabled + if self.enable_analysis: + self.segmenter.run_analysis(color_img, bboxes, target_ids) + names = self.segmenter.get_object_names(target_ids, names) + + # Create visualization image + viz_img = self.segmenter.visualize_results( + color_img.copy(), + masks, + bboxes, + target_ids, + probs, + names + ) + + # Process each object + objects = [] + for i, (mask, bbox, target_id, prob, name) in enumerate(zip(masks, bboxes, target_ids, probs, names)): + # Convert mask to numpy if it's a tensor + if hasattr(mask, 'cpu'): + mask = mask.cpu().numpy() + + # Ensure mask is proper boolean array with correct dimensions + mask = mask.astype(bool) + + # Ensure mask has the same shape as the depth image + if mask.shape != depth_img.shape[:2]: + print(f"Warning: Mask shape {mask.shape} doesn't match depth image shape {depth_img.shape[:2]}") + if len(mask.shape) > 2: + # If mask has extra dimensions, take the first channel + mask = mask[:,:,0] if mask.shape[2] > 0 else mask[:,:,0] + + # If shapes still don't match, try to resize the mask + if mask.shape != depth_img.shape[:2]: + mask = cv2.resize(mask.astype(np.uint8), + (depth_img.shape[1], depth_img.shape[0]), + interpolation=cv2.INTER_NEAREST).astype(bool) + + try: + # Create point cloud using Open3D + pcd = create_masked_point_cloud( + color_img, + depth_img, + mask, + self.depth_camera_matrix, + depth_scale=1.0 # Assuming depth is already in meters + ) + + # Skip if no points + if len(np.asarray(pcd.points)) == 0: + print(f"Skipping object {i+1}: No points in point cloud") + continue + + # Generate color for visualization + rgb_color = self.generate_color_from_id(target_id) + + # Create object data + obj_data = { + "mask": mask, + "bbox": bbox, + "target_id": target_id, + "confidence": float(prob), + "name": name if name else "", + "point_cloud": pcd, + "point_cloud_numpy": o3d_point_cloud_to_numpy(pcd), + "color": rgb_color + } + + # Fit 3D cuboid if requested + if fit_3d_cuboids: + points = np.asarray(pcd.points) + cuboid_params = fit_cuboid(points) + obj_data["cuboid_params"] = cuboid_params + + # Update visualization with cuboid if available + if cuboid_params is not None and self.color_camera_matrix is not None: + viz_img = visualize_fit(viz_img, cuboid_params, self.color_camera_matrix) + + objects.append(obj_data) + + except Exception as e: + print(f"Error processing object {i+1}: {e}") + continue + + # Clean up GPU memory if using CUDA + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return { + "viz_image": viz_img, + "objects": objects + } + + def cleanup(self): + """Clean up resources.""" + if hasattr(self, 'segmenter'): + self.segmenter.cleanup() + +def main(): + """ + Main function to test the PointcloudSegmentation class with data from rgbd_data folder. + """ + + def find_first_image(directory): + """Find the first image file in the given directory.""" + image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] + for filename in sorted(os.listdir(directory)): + if any(filename.lower().endswith(ext) for ext in image_extensions): + return os.path.join(directory, filename) + return None + + # Define paths + script_dir = os.path.dirname(os.path.abspath(__file__)) + dimos_dir = os.path.abspath(os.path.join(script_dir, "../../../")) + data_dir = os.path.join(dimos_dir, "assets/rgbd_data") + + color_info_path = os.path.join(data_dir, "color_camera_info.yaml") + depth_info_path = os.path.join(data_dir, "depth_camera_info.yaml") + + color_dir = os.path.join(data_dir, "color") + depth_dir = os.path.join(data_dir, "depth") + + # Find first color and depth images + color_img_path = find_first_image(color_dir) + depth_img_path = find_first_image(depth_dir) + + if not color_img_path or not depth_img_path: + print(f"Error: Could not find color or depth images in {data_dir}") + return + + print(f"Found color image: {color_img_path}") + print(f"Found depth image: {depth_img_path}") + + # Load images + color_img = cv2.imread(color_img_path) + if color_img is None: + print(f"Error: Could not load color image from {color_img_path}") + return + + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) # Convert to RGB + + depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) + if depth_img is None: + print(f"Error: Could not load depth image from {depth_img_path}") + return + + # Convert depth to meters if needed (adjust scale as needed for your data) + if depth_img.dtype == np.uint16: + # Convert from mm to meters for typical depth cameras + depth_img = depth_img.astype(np.float32) / 1000.0 + + # Verify image shapes for debugging + print(f"Color image shape: {color_img.shape}") + print(f"Depth image shape: {depth_img.shape}") + + # Initialize segmentation with direct camera matrices + seg = PointcloudSegmentation( + model_path="FastSAM-s.pt", # Adjust path as needed + device="cuda" if torch.cuda.is_available() else "cpu", + color_intrinsics=color_info_path, + depth_intrinsics=depth_info_path, + enable_tracking=False, + enable_analysis=True + ) + + # Process images + print("Processing images...") + try: + results = seg.process_images(color_img, depth_img, fit_3d_cuboids=True) + + # Show segmentation results using PIL instead of OpenCV + viz_img = results["viz_image"] + + # Convert OpenCV image (BGR) to PIL image (RGB) + pil_img = Image.fromarray(cv2.cvtColor(viz_img, cv2.COLOR_BGR2RGB)) + + # Display the image using PIL + pil_img.show(title="Segmentation Results") + + # Add a short pause to ensure the image has time to display + import time + time.sleep(0.5) + + print(f"Found {len(results['objects'])} objects with valid point clouds") + + # Visualize all point clouds in a single window + all_pcds = [] + for i, obj in enumerate(results['objects']): + pcd = obj['point_cloud'] + + # Optionally add axis-aligned bounding box visualization + if 'cuboid_params' in obj and obj['cuboid_params'] is not None: + cuboid = obj['cuboid_params'] + + # Create oriented bounding box using the rotation matrix instead of axis-aligned box + center = cuboid['center'] + dimensions = cuboid['dimensions'] + rotation = rotation_to_o3d(cuboid['rotation']) + + # Create oriented bounding box + obb = o3d.geometry.OrientedBoundingBox( + center=center, + R=rotation, + extent=dimensions + ) + obb.color = [1, 0, 0] # Red bounding box + all_pcds.append(obb) + + # Add a small coordinate frame at the center of each object to show orientation + coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=min(dimensions) * 0.5, + origin=center + ) + all_pcds.append(coord_frame) + + # Add the point cloud + all_pcds.append(pcd) + + # Add coordinate frame at origin + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + all_pcds.append(coordinate_frame) + + # Show point clouds + if all_pcds: + o3d.visualization.draw_geometries(all_pcds, + window_name="Segmented Objects", + width=1280, + height=720, + left=50, + top=50) + else: + print("No objects with valid point clouds found.") + + except Exception as e: + print(f"Error during processing: {str(e)}") + import traceback + traceback.print_exc() + + # Clean up resources + seg.cleanup() + print("Done!") + + +if __name__ == "__main__": + main() From 94487cbaa8b6761db131de4ab8606c6328832106 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 14 May 2025 13:39:32 -0700 Subject: [PATCH 02/89] small bug fix --- dimos/perception/pointcloud/pointcloud_seg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py index 0968a0b338..a915cbd69e 100644 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -9,7 +9,6 @@ load_camera_matrix_from_yaml, create_masked_point_cloud, o3d_point_cloud_to_numpy, - rotation_to_o3d ) from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit import torch @@ -288,7 +287,7 @@ def find_first_image(directory): # Create oriented bounding box using the rotation matrix instead of axis-aligned box center = cuboid['center'] dimensions = cuboid['dimensions'] - rotation = rotation_to_o3d(cuboid['rotation']) + rotation = cuboid['rotation'] # Create oriented bounding box obb = o3d.geometry.OrientedBoundingBox( From 07acf385eced7897f94ee4588f5990d66db45e92 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 20 May 2025 10:54:59 -0700 Subject: [PATCH 03/89] added basic RANSAC plane remove algorithm --- dimos/perception/pointcloud/pointcloud_seg.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py index a915cbd69e..75607d65be 100644 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -9,6 +9,8 @@ load_camera_matrix_from_yaml, create_masked_point_cloud, o3d_point_cloud_to_numpy, + create_o3d_point_cloud_from_rgbd, + segment_and_remove_plane ) from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit import torch @@ -82,6 +84,8 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): - point_cloud_numpy: Nx6 array of XYZRGB points (for compatibility) - color: RGB color for visualization - cuboid_params: Cuboid parameters (if fit_3d_cuboids=True) + - raw_point_cloud: Open3D point cloud object + - plane_removed_point_cloud: Open3D point cloud object with dominant plane removed """ if self.depth_camera_matrix is None: raise ValueError("Depth camera matrix must be provided to process images") @@ -173,6 +177,9 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): except Exception as e: print(f"Error processing object {i+1}: {e}") continue + + raw_point_cloud = create_o3d_point_cloud_from_rgbd(color_img, depth_img, self.depth_camera_matrix) + plane_removed_point_cloud = segment_and_remove_plane(raw_point_cloud) # Clean up GPU memory if using CUDA if torch.cuda.is_available(): @@ -180,7 +187,9 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): return { "viz_image": viz_img, - "objects": objects + "objects": objects, + "raw_point_cloud": raw_point_cloud, + "plane_removed_point_cloud": plane_removed_point_cloud } def cleanup(self): @@ -322,7 +331,22 @@ def find_first_image(directory): top=50) else: print("No objects with valid point clouds found.") - + + # Show raw point cloud + o3d.visualization.draw_geometries([results['raw_point_cloud']], + window_name="Raw Point Cloud", + width=1280, + height=720, + left=50, + top=50) + + # Show plane removed point cloud + o3d.visualization.draw_geometries([results['plane_removed_point_cloud']], + window_name="Plane Removed Point Cloud", + width=1280, + height=720, + left=50, + top=50) except Exception as e: print(f"Error during processing: {str(e)}") import traceback From 1ad42298f8ac52af289f5585fce1397bef3e8eac Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 2 Jun 2025 00:40:31 -0700 Subject: [PATCH 04/89] refactored and cleanup pointcloud filtering --- dimos/perception/pointcloud/pointcloud_seg.py | 280 +++++++++--------- 1 file changed, 135 insertions(+), 145 deletions(-) diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py index 75607d65be..de3297b7c9 100644 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -5,113 +5,99 @@ import sys from PIL import Image, ImageDraw from dimos.perception.segmentation import Sam2DSegmenter +from dimos.types.segmentation import SegmentationType from dimos.perception.pointcloud.utils import ( load_camera_matrix_from_yaml, create_masked_point_cloud, o3d_point_cloud_to_numpy, create_o3d_point_cloud_from_rgbd, - segment_and_remove_plane ) from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit import torch import open3d as o3d -class PointcloudSegmentation: +class PointcloudFiltering: def __init__( self, - model_path="FastSAM-s.pt", - device="cuda", color_intrinsics=None, depth_intrinsics=None, - enable_tracking=True, - enable_analysis=True, + enable_statistical_filtering=True, + enable_cuboid_fitting=True, + color_weight=0.3, + statistical_neighbors=20, + statistical_std_ratio=2.0, ): """ - Initialize processor to segment objects in RGB images and extract their point clouds. + Initialize processor to filter point clouds from segmented objects. Args: - model_path: Path to the FastSAM model - device: Computation device ("cuda" or "cpu") color_intrinsics: Path to YAML file or list with color camera intrinsics [fx, fy, cx, cy] depth_intrinsics: Path to YAML file or list with depth camera intrinsics [fx, fy, cx, cy] - enable_tracking: Whether to enable object tracking - enable_analysis: Whether to enable object analysis (labels, etc.) - min_analysis_interval: Minimum interval between analysis runs in seconds + enable_statistical_filtering: Whether to apply statistical outlier filtering + enable_cuboid_fitting: Whether to fit 3D cuboids to objects + color_weight: Weight for blending generated color with original color (0.0 = original, 1.0 = generated) + statistical_neighbors: Number of neighbors for statistical filtering + statistical_std_ratio: Standard deviation ratio for statistical filtering """ - # Initialize segmenter - self.segmenter = Sam2DSegmenter( - model_path=model_path, - device=device, - use_tracker=enable_tracking, - use_analyzer=enable_analysis, - ) - # Store settings - self.enable_tracking = enable_tracking - self.enable_analysis = enable_analysis + self.enable_statistical_filtering = enable_statistical_filtering + self.enable_cuboid_fitting = enable_cuboid_fitting + self.color_weight = color_weight + self.statistical_neighbors = statistical_neighbors + self.statistical_std_ratio = statistical_std_ratio # Load camera matrices self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) - def generate_color_from_id(self, track_id): - """Generate a consistent color for a given tracking ID.""" - np.random.seed(track_id) + def generate_color_from_id(self, object_id): + """Generate a consistent color for a given object ID.""" + np.random.seed(object_id) color = np.random.randint(0, 255, 3) np.random.seed(None) return color - def process_images(self, color_img, depth_img, fit_3d_cuboids=True): + def process_images(self, color_img, depth_img, segmentation_result): """ - Process color and depth images to segment objects and extract point clouds. - Uses Open3D for point cloud processing. + Process color and depth images with segmentation results to create filtered point clouds. Args: color_img: RGB image as numpy array (H, W, 3) depth_img: Depth image as numpy array (H, W) in meters - fit_3d_cuboids: Whether to fit 3D cuboids to each object + segmentation_result: SegmentationType object containing masks and metadata Returns: dict: Dictionary containing: - - viz_image: Visualization image with detections - objects: List of dicts for each object with: + - object_id: Object tracking ID - mask: Segmentation mask (H, W, bool) - bbox: Bounding box [x1, y1, x2, y2] - - target_id: Tracking ID - confidence: Detection confidence - - name: Object name (if analyzer enabled) - - point_cloud: Open3D point cloud object + - label: Object label/name + - point_cloud: Open3D point cloud object (filtered and colored) - point_cloud_numpy: Nx6 array of XYZRGB points (for compatibility) - color: RGB color for visualization - - cuboid_params: Cuboid parameters (if fit_3d_cuboids=True) - - raw_point_cloud: Open3D point cloud object - - plane_removed_point_cloud: Open3D point cloud object with dominant plane removed + - cuboid_params: Cuboid parameters (if enabled) + - filtering_stats: Filtering statistics (if filtering enabled) """ if self.depth_camera_matrix is None: raise ValueError("Depth camera matrix must be provided to process images") - # Run segmentation - masks, bboxes, target_ids, probs, names = self.segmenter.process_image(color_img) - print(f"Found {len(masks)} segmentation masks") - - # Run analysis if enabled - if self.enable_analysis: - self.segmenter.run_analysis(color_img, bboxes, target_ids) - names = self.segmenter.get_object_names(target_ids, names) - - # Create visualization image - viz_img = self.segmenter.visualize_results( - color_img.copy(), - masks, - bboxes, - target_ids, - probs, - names - ) + # Extract masks and metadata from segmentation result + masks = segmentation_result.masks + metadata = segmentation_result.metadata + objects_metadata = metadata.get('objects', []) # Process each object objects = [] - for i, (mask, bbox, target_id, prob, name) in enumerate(zip(masks, bboxes, target_ids, probs, names)): + for i, mask in enumerate(masks): + # Get object metadata if available + obj_meta = objects_metadata[i] if i < len(objects_metadata) else {} + object_id = obj_meta.get('object_id', i) + bbox = obj_meta.get('bbox', [0, 0, 0, 0]) + confidence = obj_meta.get('prob', 1.0) + label = obj_meta.get('label', '') + # Convert mask to numpy if it's a tensor if hasattr(mask, 'cpu'): mask = mask.cpu().numpy() @@ -121,12 +107,9 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): # Ensure mask has the same shape as the depth image if mask.shape != depth_img.shape[:2]: - print(f"Warning: Mask shape {mask.shape} doesn't match depth image shape {depth_img.shape[:2]}") if len(mask.shape) > 2: - # If mask has extra dimensions, take the first channel mask = mask[:,:,0] if mask.shape[2] > 0 else mask[:,:,0] - # If shapes still don't match, try to resize the mask if mask.shape != depth_img.shape[:2]: mask = cv2.resize(mask.astype(np.uint8), (depth_img.shape[1], depth_img.shape[0]), @@ -139,67 +122,88 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): depth_img, mask, self.depth_camera_matrix, - depth_scale=1.0 # Assuming depth is already in meters + depth_scale=1.0 ) # Skip if no points if len(np.asarray(pcd.points)) == 0: - print(f"Skipping object {i+1}: No points in point cloud") continue # Generate color for visualization - rgb_color = self.generate_color_from_id(target_id) + rgb_color = self.generate_color_from_id(object_id) + + # Apply weighted colored mask to the point cloud + if len(np.asarray(pcd.colors)) > 0: + original_colors = np.asarray(pcd.colors) + generated_color = np.array(rgb_color) / 255.0 + colored_mask = (1.0 - self.color_weight) * original_colors + self.color_weight * generated_color + colored_mask = np.clip(colored_mask, 0.0, 1.0) + pcd.colors = o3d.utility.Vector3dVector(colored_mask) + + # Apply statistical outlier filtering if enabled + filtering_stats = None + if self.enable_statistical_filtering: + num_points_before = len(np.asarray(pcd.points)) + pcd_filtered, outlier_indices = pcd.remove_statistical_outlier( + nb_neighbors=self.statistical_neighbors, + std_ratio=self.statistical_std_ratio + ) + num_points_after = len(np.asarray(pcd_filtered.points)) + num_outliers_removed = num_points_before - num_points_after + + pcd = pcd_filtered + + filtering_stats = { + "points_before": num_points_before, + "points_after": num_points_after, + "outliers_removed": num_outliers_removed, + "outlier_percentage": 100.0 * num_outliers_removed / num_points_before if num_points_before > 0 else 0 + } # Create object data obj_data = { + "object_id": object_id, "mask": mask, "bbox": bbox, - "target_id": target_id, - "confidence": float(prob), - "name": name if name else "", + "confidence": float(confidence), + "label": label, "point_cloud": pcd, "point_cloud_numpy": o3d_point_cloud_to_numpy(pcd), - "color": rgb_color + "color": rgb_color, } - # Fit 3D cuboid if requested - if fit_3d_cuboids: + # Add optional data if available + if filtering_stats is not None: + obj_data["filtering_stats"] = filtering_stats + + # Fit 3D cuboid if enabled + if self.enable_cuboid_fitting: points = np.asarray(pcd.points) cuboid_params = fit_cuboid(points) - obj_data["cuboid_params"] = cuboid_params - - # Update visualization with cuboid if available - if cuboid_params is not None and self.color_camera_matrix is not None: - viz_img = visualize_fit(viz_img, cuboid_params, self.color_camera_matrix) + if cuboid_params is not None: + obj_data["cuboid_params"] = cuboid_params objects.append(obj_data) except Exception as e: - print(f"Error processing object {i+1}: {e}") continue - raw_point_cloud = create_o3d_point_cloud_from_rgbd(color_img, depth_img, self.depth_camera_matrix) - plane_removed_point_cloud = segment_and_remove_plane(raw_point_cloud) - # Clean up GPU memory if using CUDA if torch.cuda.is_available(): torch.cuda.empty_cache() return { - "viz_image": viz_img, "objects": objects, - "raw_point_cloud": raw_point_cloud, - "plane_removed_point_cloud": plane_removed_point_cloud } def cleanup(self): """Clean up resources.""" - if hasattr(self, 'segmenter'): - self.segmenter.cleanup() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def main(): """ - Main function to test the PointcloudSegmentation class with data from rgbd_data folder. + Main function to test the PointcloudFiltering class with data from rgbd_data folder. """ def find_first_image(directory): @@ -229,132 +233,118 @@ def find_first_image(directory): print(f"Error: Could not find color or depth images in {data_dir}") return - print(f"Found color image: {color_img_path}") - print(f"Found depth image: {depth_img_path}") - # Load images color_img = cv2.imread(color_img_path) if color_img is None: print(f"Error: Could not load color image from {color_img_path}") return - color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) # Convert to RGB + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) if depth_img is None: print(f"Error: Could not load depth image from {depth_img_path}") return - # Convert depth to meters if needed (adjust scale as needed for your data) + # Convert depth to meters if needed if depth_img.dtype == np.uint16: - # Convert from mm to meters for typical depth cameras depth_img = depth_img.astype(np.float32) / 1000.0 - # Verify image shapes for debugging - print(f"Color image shape: {color_img.shape}") - print(f"Depth image shape: {depth_img.shape}") - - # Initialize segmentation with direct camera matrices - seg = PointcloudSegmentation( - model_path="FastSAM-s.pt", # Adjust path as needed + # Run segmentation + segmenter = Sam2DSegmenter( + model_path="FastSAM-s.pt", device="cuda" if torch.cuda.is_available() else "cpu", + use_tracker=False, + use_analyzer=True + ) + + masks, bboxes, target_ids, probs, names = segmenter.process_image(color_img) + segmenter.run_analysis(color_img, bboxes, target_ids) + names = segmenter.get_object_names(target_ids, names) + + # Create metadata + objects_metadata = [] + for i in range(len(bboxes)): + obj_data = { + "object_id": target_ids[i] if i < len(target_ids) else i, + "bbox": bboxes[i], + "prob": probs[i] if i < len(probs) else 1.0, + "label": names[i] if i < len(names) else "", + } + objects_metadata.append(obj_data) + + metadata = { + "frame": color_img, + "objects": objects_metadata + } + + numpy_masks = [mask.cpu().numpy() if hasattr(mask, 'cpu') else mask for mask in masks] + segmentation_result = SegmentationType(masks=numpy_masks, metadata=metadata) + + # Initialize filtering pipeline + filter_pipeline = PointcloudFiltering( color_intrinsics=color_info_path, depth_intrinsics=depth_info_path, - enable_tracking=False, - enable_analysis=True + enable_statistical_filtering=True, + enable_cuboid_fitting=True, + color_weight=0.3, + statistical_neighbors=20, + statistical_std_ratio=2.0, ) - # Process images - print("Processing images...") + # Process images through filtering pipeline try: - results = seg.process_images(color_img, depth_img, fit_3d_cuboids=True) - - # Show segmentation results using PIL instead of OpenCV - viz_img = results["viz_image"] + results = filter_pipeline.process_images(color_img, depth_img, segmentation_result) - # Convert OpenCV image (BGR) to PIL image (RGB) - pil_img = Image.fromarray(cv2.cvtColor(viz_img, cv2.COLOR_BGR2RGB)) - - # Display the image using PIL - pil_img.show(title="Segmentation Results") - - # Add a short pause to ensure the image has time to display - import time - time.sleep(0.5) - - print(f"Found {len(results['objects'])} objects with valid point clouds") - - # Visualize all point clouds in a single window + # Visualize filtered point clouds all_pcds = [] for i, obj in enumerate(results['objects']): pcd = obj['point_cloud'] - # Optionally add axis-aligned bounding box visualization + # Add cuboid visualization if available if 'cuboid_params' in obj and obj['cuboid_params'] is not None: cuboid = obj['cuboid_params'] - - # Create oriented bounding box using the rotation matrix instead of axis-aligned box center = cuboid['center'] dimensions = cuboid['dimensions'] rotation = cuboid['rotation'] - # Create oriented bounding box obb = o3d.geometry.OrientedBoundingBox( center=center, R=rotation, extent=dimensions ) - obb.color = [1, 0, 0] # Red bounding box + obb.color = [1, 0, 0] all_pcds.append(obb) - # Add a small coordinate frame at the center of each object to show orientation coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( size=min(dimensions) * 0.5, origin=center ) all_pcds.append(coord_frame) - # Add the point cloud all_pcds.append(pcd) # Add coordinate frame at origin coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) all_pcds.append(coordinate_frame) - # Show point clouds + # Show filtered point clouds if all_pcds: o3d.visualization.draw_geometries(all_pcds, - window_name="Segmented Objects", + window_name="Filtered Point Clouds", width=1280, height=720, left=50, top=50) - else: - print("No objects with valid point clouds found.") - - # Show raw point cloud - o3d.visualization.draw_geometries([results['raw_point_cloud']], - window_name="Raw Point Cloud", - width=1280, - height=720, - left=50, - top=50) - - # Show plane removed point cloud - o3d.visualization.draw_geometries([results['plane_removed_point_cloud']], - window_name="Plane Removed Point Cloud", - width=1280, - height=720, - left=50, - top=50) + except Exception as e: print(f"Error during processing: {str(e)}") import traceback traceback.print_exc() # Clean up resources - seg.cleanup() - print("Done!") + segmenter.cleanup() + filter_pipeline.cleanup() if __name__ == "__main__": From cad533021d67836f11b82f9ca94732b280ef3948 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Mon, 2 Jun 2025 07:45:16 +0000 Subject: [PATCH 05/89] CI code cleanup --- dimos/perception/pointcloud/pointcloud_seg.py | 189 +++++++++--------- 1 file changed, 94 insertions(+), 95 deletions(-) diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py index de3297b7c9..6c6c60c262 100644 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -16,6 +16,7 @@ import torch import open3d as o3d + class PointcloudFiltering: def __init__( self, @@ -29,7 +30,7 @@ def __init__( ): """ Initialize processor to filter point clouds from segmented objects. - + Args: color_intrinsics: Path to YAML file or list with color camera intrinsics [fx, fy, cx, cy] depth_intrinsics: Path to YAML file or list with depth camera intrinsics [fx, fy, cx, cy] @@ -45,27 +46,27 @@ def __init__( self.color_weight = color_weight self.statistical_neighbors = statistical_neighbors self.statistical_std_ratio = statistical_std_ratio - + # Load camera matrices self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) - + def generate_color_from_id(self, object_id): """Generate a consistent color for a given object ID.""" np.random.seed(object_id) color = np.random.randint(0, 255, 3) np.random.seed(None) return color - + def process_images(self, color_img, depth_img, segmentation_result): """ Process color and depth images with segmentation results to create filtered point clouds. - + Args: color_img: RGB image as numpy array (H, W, 3) depth_img: Depth image as numpy array (H, W) in meters segmentation_result: SegmentationType object containing masks and metadata - + Returns: dict: Dictionary containing: - objects: List of dicts for each object with: @@ -82,84 +83,86 @@ def process_images(self, color_img, depth_img, segmentation_result): """ if self.depth_camera_matrix is None: raise ValueError("Depth camera matrix must be provided to process images") - + # Extract masks and metadata from segmentation result masks = segmentation_result.masks metadata = segmentation_result.metadata - objects_metadata = metadata.get('objects', []) - + objects_metadata = metadata.get("objects", []) + # Process each object objects = [] for i, mask in enumerate(masks): # Get object metadata if available obj_meta = objects_metadata[i] if i < len(objects_metadata) else {} - object_id = obj_meta.get('object_id', i) - bbox = obj_meta.get('bbox', [0, 0, 0, 0]) - confidence = obj_meta.get('prob', 1.0) - label = obj_meta.get('label', '') - + object_id = obj_meta.get("object_id", i) + bbox = obj_meta.get("bbox", [0, 0, 0, 0]) + confidence = obj_meta.get("prob", 1.0) + label = obj_meta.get("label", "") + # Convert mask to numpy if it's a tensor - if hasattr(mask, 'cpu'): + if hasattr(mask, "cpu"): mask = mask.cpu().numpy() - + # Ensure mask is proper boolean array with correct dimensions mask = mask.astype(bool) - + # Ensure mask has the same shape as the depth image if mask.shape != depth_img.shape[:2]: if len(mask.shape) > 2: - mask = mask[:,:,0] if mask.shape[2] > 0 else mask[:,:,0] - + mask = mask[:, :, 0] if mask.shape[2] > 0 else mask[:, :, 0] + if mask.shape != depth_img.shape[:2]: - mask = cv2.resize(mask.astype(np.uint8), - (depth_img.shape[1], depth_img.shape[0]), - interpolation=cv2.INTER_NEAREST).astype(bool) - + mask = cv2.resize( + mask.astype(np.uint8), + (depth_img.shape[1], depth_img.shape[0]), + interpolation=cv2.INTER_NEAREST, + ).astype(bool) + try: # Create point cloud using Open3D pcd = create_masked_point_cloud( - color_img, - depth_img, - mask, - self.depth_camera_matrix, - depth_scale=1.0 + color_img, depth_img, mask, self.depth_camera_matrix, depth_scale=1.0 ) - + # Skip if no points if len(np.asarray(pcd.points)) == 0: continue - + # Generate color for visualization rgb_color = self.generate_color_from_id(object_id) - + # Apply weighted colored mask to the point cloud if len(np.asarray(pcd.colors)) > 0: original_colors = np.asarray(pcd.colors) generated_color = np.array(rgb_color) / 255.0 - colored_mask = (1.0 - self.color_weight) * original_colors + self.color_weight * generated_color + colored_mask = ( + 1.0 - self.color_weight + ) * original_colors + self.color_weight * generated_color colored_mask = np.clip(colored_mask, 0.0, 1.0) pcd.colors = o3d.utility.Vector3dVector(colored_mask) - + # Apply statistical outlier filtering if enabled filtering_stats = None if self.enable_statistical_filtering: num_points_before = len(np.asarray(pcd.points)) pcd_filtered, outlier_indices = pcd.remove_statistical_outlier( nb_neighbors=self.statistical_neighbors, - std_ratio=self.statistical_std_ratio + std_ratio=self.statistical_std_ratio, ) num_points_after = len(np.asarray(pcd_filtered.points)) num_outliers_removed = num_points_before - num_points_after - + pcd = pcd_filtered - + filtering_stats = { "points_before": num_points_before, "points_after": num_points_after, "outliers_removed": num_outliers_removed, - "outlier_percentage": 100.0 * num_outliers_removed / num_points_before if num_points_before > 0 else 0 + "outlier_percentage": 100.0 * num_outliers_removed / num_points_before + if num_points_before > 0 + else 0, } - + # Create object data obj_data = { "object_id": object_id, @@ -171,36 +174,37 @@ def process_images(self, color_img, depth_img, segmentation_result): "point_cloud_numpy": o3d_point_cloud_to_numpy(pcd), "color": rgb_color, } - + # Add optional data if available if filtering_stats is not None: obj_data["filtering_stats"] = filtering_stats - + # Fit 3D cuboid if enabled if self.enable_cuboid_fitting: points = np.asarray(pcd.points) cuboid_params = fit_cuboid(points) if cuboid_params is not None: obj_data["cuboid_params"] = cuboid_params - + objects.append(obj_data) - + except Exception as e: continue # Clean up GPU memory if using CUDA if torch.cuda.is_available(): torch.cuda.empty_cache() - + return { "objects": objects, } - + def cleanup(self): """Clean up resources.""" if torch.cuda.is_available(): torch.cuda.empty_cache() + def main(): """ Main function to test the PointcloudFiltering class with data from rgbd_data folder. @@ -208,7 +212,7 @@ def main(): def find_first_image(directory): """Find the first image file in the given directory.""" - image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] + image_extensions = [".jpg", ".jpeg", ".png", ".bmp"] for filename in sorted(os.listdir(directory)): if any(filename.lower().endswith(ext) for ext in image_extensions): return os.path.join(directory, filename) @@ -218,50 +222,50 @@ def find_first_image(directory): script_dir = os.path.dirname(os.path.abspath(__file__)) dimos_dir = os.path.abspath(os.path.join(script_dir, "../../../")) data_dir = os.path.join(dimos_dir, "assets/rgbd_data") - + color_info_path = os.path.join(data_dir, "color_camera_info.yaml") depth_info_path = os.path.join(data_dir, "depth_camera_info.yaml") - + color_dir = os.path.join(data_dir, "color") depth_dir = os.path.join(data_dir, "depth") - + # Find first color and depth images color_img_path = find_first_image(color_dir) depth_img_path = find_first_image(depth_dir) - + if not color_img_path or not depth_img_path: print(f"Error: Could not find color or depth images in {data_dir}") return - + # Load images color_img = cv2.imread(color_img_path) if color_img is None: print(f"Error: Could not load color image from {color_img_path}") return - + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) - + depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) if depth_img is None: print(f"Error: Could not load depth image from {depth_img_path}") return - + # Convert depth to meters if needed if depth_img.dtype == np.uint16: depth_img = depth_img.astype(np.float32) / 1000.0 - + # Run segmentation segmenter = Sam2DSegmenter( model_path="FastSAM-s.pt", device="cuda" if torch.cuda.is_available() else "cpu", use_tracker=False, - use_analyzer=True + use_analyzer=True, ) - + masks, bboxes, target_ids, probs, names = segmenter.process_image(color_img) segmenter.run_analysis(color_img, bboxes, target_ids) names = segmenter.get_object_names(target_ids, names) - + # Create metadata objects_metadata = [] for i in range(len(bboxes)): @@ -272,15 +276,12 @@ def find_first_image(directory): "label": names[i] if i < len(names) else "", } objects_metadata.append(obj_data) - - metadata = { - "frame": color_img, - "objects": objects_metadata - } - - numpy_masks = [mask.cpu().numpy() if hasattr(mask, 'cpu') else mask for mask in masks] + + metadata = {"frame": color_img, "objects": objects_metadata} + + numpy_masks = [mask.cpu().numpy() if hasattr(mask, "cpu") else mask for mask in masks] segmentation_result = SegmentationType(masks=numpy_masks, metadata=metadata) - + # Initialize filtering pipeline filter_pipeline = PointcloudFiltering( color_intrinsics=color_info_path, @@ -291,57 +292,55 @@ def find_first_image(directory): statistical_neighbors=20, statistical_std_ratio=2.0, ) - + # Process images through filtering pipeline try: results = filter_pipeline.process_images(color_img, depth_img, segmentation_result) - + # Visualize filtered point clouds all_pcds = [] - for i, obj in enumerate(results['objects']): - pcd = obj['point_cloud'] - + for i, obj in enumerate(results["objects"]): + pcd = obj["point_cloud"] + # Add cuboid visualization if available - if 'cuboid_params' in obj and obj['cuboid_params'] is not None: - cuboid = obj['cuboid_params'] - center = cuboid['center'] - dimensions = cuboid['dimensions'] - rotation = cuboid['rotation'] - - obb = o3d.geometry.OrientedBoundingBox( - center=center, - R=rotation, - extent=dimensions - ) + if "cuboid_params" in obj and obj["cuboid_params"] is not None: + cuboid = obj["cuboid_params"] + center = cuboid["center"] + dimensions = cuboid["dimensions"] + rotation = cuboid["rotation"] + + obb = o3d.geometry.OrientedBoundingBox(center=center, R=rotation, extent=dimensions) obb.color = [1, 0, 0] all_pcds.append(obb) - + coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( - size=min(dimensions) * 0.5, - origin=center + size=min(dimensions) * 0.5, origin=center ) all_pcds.append(coord_frame) - + all_pcds.append(pcd) - + # Add coordinate frame at origin coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) all_pcds.append(coordinate_frame) - + # Show filtered point clouds if all_pcds: - o3d.visualization.draw_geometries(all_pcds, - window_name="Filtered Point Clouds", - width=1280, - height=720, - left=50, - top=50) - + o3d.visualization.draw_geometries( + all_pcds, + window_name="Filtered Point Clouds", + width=1280, + height=720, + left=50, + top=50, + ) + except Exception as e: print(f"Error during processing: {str(e)}") import traceback + traceback.print_exc() - + # Clean up resources segmenter.cleanup() filter_pipeline.cleanup() From 54b6a6f78d9cb13f9ecc112229a00130606435aa Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 2 Jun 2025 17:22:17 -0700 Subject: [PATCH 06/89] cleanup pointcloud filtering --- dimos/perception/pointcloud/pointcloud_seg.py | 350 ------------------ 1 file changed, 350 deletions(-) delete mode 100644 dimos/perception/pointcloud/pointcloud_seg.py diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py deleted file mode 100644 index 6c6c60c262..0000000000 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ /dev/null @@ -1,350 +0,0 @@ -import numpy as np -import cv2 -import yaml -import os -import sys -from PIL import Image, ImageDraw -from dimos.perception.segmentation import Sam2DSegmenter -from dimos.types.segmentation import SegmentationType -from dimos.perception.pointcloud.utils import ( - load_camera_matrix_from_yaml, - create_masked_point_cloud, - o3d_point_cloud_to_numpy, - create_o3d_point_cloud_from_rgbd, -) -from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit -import torch -import open3d as o3d - - -class PointcloudFiltering: - def __init__( - self, - color_intrinsics=None, - depth_intrinsics=None, - enable_statistical_filtering=True, - enable_cuboid_fitting=True, - color_weight=0.3, - statistical_neighbors=20, - statistical_std_ratio=2.0, - ): - """ - Initialize processor to filter point clouds from segmented objects. - - Args: - color_intrinsics: Path to YAML file or list with color camera intrinsics [fx, fy, cx, cy] - depth_intrinsics: Path to YAML file or list with depth camera intrinsics [fx, fy, cx, cy] - enable_statistical_filtering: Whether to apply statistical outlier filtering - enable_cuboid_fitting: Whether to fit 3D cuboids to objects - color_weight: Weight for blending generated color with original color (0.0 = original, 1.0 = generated) - statistical_neighbors: Number of neighbors for statistical filtering - statistical_std_ratio: Standard deviation ratio for statistical filtering - """ - # Store settings - self.enable_statistical_filtering = enable_statistical_filtering - self.enable_cuboid_fitting = enable_cuboid_fitting - self.color_weight = color_weight - self.statistical_neighbors = statistical_neighbors - self.statistical_std_ratio = statistical_std_ratio - - # Load camera matrices - self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) - self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) - - def generate_color_from_id(self, object_id): - """Generate a consistent color for a given object ID.""" - np.random.seed(object_id) - color = np.random.randint(0, 255, 3) - np.random.seed(None) - return color - - def process_images(self, color_img, depth_img, segmentation_result): - """ - Process color and depth images with segmentation results to create filtered point clouds. - - Args: - color_img: RGB image as numpy array (H, W, 3) - depth_img: Depth image as numpy array (H, W) in meters - segmentation_result: SegmentationType object containing masks and metadata - - Returns: - dict: Dictionary containing: - - objects: List of dicts for each object with: - - object_id: Object tracking ID - - mask: Segmentation mask (H, W, bool) - - bbox: Bounding box [x1, y1, x2, y2] - - confidence: Detection confidence - - label: Object label/name - - point_cloud: Open3D point cloud object (filtered and colored) - - point_cloud_numpy: Nx6 array of XYZRGB points (for compatibility) - - color: RGB color for visualization - - cuboid_params: Cuboid parameters (if enabled) - - filtering_stats: Filtering statistics (if filtering enabled) - """ - if self.depth_camera_matrix is None: - raise ValueError("Depth camera matrix must be provided to process images") - - # Extract masks and metadata from segmentation result - masks = segmentation_result.masks - metadata = segmentation_result.metadata - objects_metadata = metadata.get("objects", []) - - # Process each object - objects = [] - for i, mask in enumerate(masks): - # Get object metadata if available - obj_meta = objects_metadata[i] if i < len(objects_metadata) else {} - object_id = obj_meta.get("object_id", i) - bbox = obj_meta.get("bbox", [0, 0, 0, 0]) - confidence = obj_meta.get("prob", 1.0) - label = obj_meta.get("label", "") - - # Convert mask to numpy if it's a tensor - if hasattr(mask, "cpu"): - mask = mask.cpu().numpy() - - # Ensure mask is proper boolean array with correct dimensions - mask = mask.astype(bool) - - # Ensure mask has the same shape as the depth image - if mask.shape != depth_img.shape[:2]: - if len(mask.shape) > 2: - mask = mask[:, :, 0] if mask.shape[2] > 0 else mask[:, :, 0] - - if mask.shape != depth_img.shape[:2]: - mask = cv2.resize( - mask.astype(np.uint8), - (depth_img.shape[1], depth_img.shape[0]), - interpolation=cv2.INTER_NEAREST, - ).astype(bool) - - try: - # Create point cloud using Open3D - pcd = create_masked_point_cloud( - color_img, depth_img, mask, self.depth_camera_matrix, depth_scale=1.0 - ) - - # Skip if no points - if len(np.asarray(pcd.points)) == 0: - continue - - # Generate color for visualization - rgb_color = self.generate_color_from_id(object_id) - - # Apply weighted colored mask to the point cloud - if len(np.asarray(pcd.colors)) > 0: - original_colors = np.asarray(pcd.colors) - generated_color = np.array(rgb_color) / 255.0 - colored_mask = ( - 1.0 - self.color_weight - ) * original_colors + self.color_weight * generated_color - colored_mask = np.clip(colored_mask, 0.0, 1.0) - pcd.colors = o3d.utility.Vector3dVector(colored_mask) - - # Apply statistical outlier filtering if enabled - filtering_stats = None - if self.enable_statistical_filtering: - num_points_before = len(np.asarray(pcd.points)) - pcd_filtered, outlier_indices = pcd.remove_statistical_outlier( - nb_neighbors=self.statistical_neighbors, - std_ratio=self.statistical_std_ratio, - ) - num_points_after = len(np.asarray(pcd_filtered.points)) - num_outliers_removed = num_points_before - num_points_after - - pcd = pcd_filtered - - filtering_stats = { - "points_before": num_points_before, - "points_after": num_points_after, - "outliers_removed": num_outliers_removed, - "outlier_percentage": 100.0 * num_outliers_removed / num_points_before - if num_points_before > 0 - else 0, - } - - # Create object data - obj_data = { - "object_id": object_id, - "mask": mask, - "bbox": bbox, - "confidence": float(confidence), - "label": label, - "point_cloud": pcd, - "point_cloud_numpy": o3d_point_cloud_to_numpy(pcd), - "color": rgb_color, - } - - # Add optional data if available - if filtering_stats is not None: - obj_data["filtering_stats"] = filtering_stats - - # Fit 3D cuboid if enabled - if self.enable_cuboid_fitting: - points = np.asarray(pcd.points) - cuboid_params = fit_cuboid(points) - if cuboid_params is not None: - obj_data["cuboid_params"] = cuboid_params - - objects.append(obj_data) - - except Exception as e: - continue - - # Clean up GPU memory if using CUDA - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return { - "objects": objects, - } - - def cleanup(self): - """Clean up resources.""" - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def main(): - """ - Main function to test the PointcloudFiltering class with data from rgbd_data folder. - """ - - def find_first_image(directory): - """Find the first image file in the given directory.""" - image_extensions = [".jpg", ".jpeg", ".png", ".bmp"] - for filename in sorted(os.listdir(directory)): - if any(filename.lower().endswith(ext) for ext in image_extensions): - return os.path.join(directory, filename) - return None - - # Define paths - script_dir = os.path.dirname(os.path.abspath(__file__)) - dimos_dir = os.path.abspath(os.path.join(script_dir, "../../../")) - data_dir = os.path.join(dimos_dir, "assets/rgbd_data") - - color_info_path = os.path.join(data_dir, "color_camera_info.yaml") - depth_info_path = os.path.join(data_dir, "depth_camera_info.yaml") - - color_dir = os.path.join(data_dir, "color") - depth_dir = os.path.join(data_dir, "depth") - - # Find first color and depth images - color_img_path = find_first_image(color_dir) - depth_img_path = find_first_image(depth_dir) - - if not color_img_path or not depth_img_path: - print(f"Error: Could not find color or depth images in {data_dir}") - return - - # Load images - color_img = cv2.imread(color_img_path) - if color_img is None: - print(f"Error: Could not load color image from {color_img_path}") - return - - color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) - - depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) - if depth_img is None: - print(f"Error: Could not load depth image from {depth_img_path}") - return - - # Convert depth to meters if needed - if depth_img.dtype == np.uint16: - depth_img = depth_img.astype(np.float32) / 1000.0 - - # Run segmentation - segmenter = Sam2DSegmenter( - model_path="FastSAM-s.pt", - device="cuda" if torch.cuda.is_available() else "cpu", - use_tracker=False, - use_analyzer=True, - ) - - masks, bboxes, target_ids, probs, names = segmenter.process_image(color_img) - segmenter.run_analysis(color_img, bboxes, target_ids) - names = segmenter.get_object_names(target_ids, names) - - # Create metadata - objects_metadata = [] - for i in range(len(bboxes)): - obj_data = { - "object_id": target_ids[i] if i < len(target_ids) else i, - "bbox": bboxes[i], - "prob": probs[i] if i < len(probs) else 1.0, - "label": names[i] if i < len(names) else "", - } - objects_metadata.append(obj_data) - - metadata = {"frame": color_img, "objects": objects_metadata} - - numpy_masks = [mask.cpu().numpy() if hasattr(mask, "cpu") else mask for mask in masks] - segmentation_result = SegmentationType(masks=numpy_masks, metadata=metadata) - - # Initialize filtering pipeline - filter_pipeline = PointcloudFiltering( - color_intrinsics=color_info_path, - depth_intrinsics=depth_info_path, - enable_statistical_filtering=True, - enable_cuboid_fitting=True, - color_weight=0.3, - statistical_neighbors=20, - statistical_std_ratio=2.0, - ) - - # Process images through filtering pipeline - try: - results = filter_pipeline.process_images(color_img, depth_img, segmentation_result) - - # Visualize filtered point clouds - all_pcds = [] - for i, obj in enumerate(results["objects"]): - pcd = obj["point_cloud"] - - # Add cuboid visualization if available - if "cuboid_params" in obj and obj["cuboid_params"] is not None: - cuboid = obj["cuboid_params"] - center = cuboid["center"] - dimensions = cuboid["dimensions"] - rotation = cuboid["rotation"] - - obb = o3d.geometry.OrientedBoundingBox(center=center, R=rotation, extent=dimensions) - obb.color = [1, 0, 0] - all_pcds.append(obb) - - coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( - size=min(dimensions) * 0.5, origin=center - ) - all_pcds.append(coord_frame) - - all_pcds.append(pcd) - - # Add coordinate frame at origin - coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) - all_pcds.append(coordinate_frame) - - # Show filtered point clouds - if all_pcds: - o3d.visualization.draw_geometries( - all_pcds, - window_name="Filtered Point Clouds", - width=1280, - height=720, - left=50, - top=50, - ) - - except Exception as e: - print(f"Error during processing: {str(e)}") - import traceback - - traceback.print_exc() - - # Clean up resources - segmenter.cleanup() - filter_pipeline.cleanup() - - -if __name__ == "__main__": - main() From b7575d4034d236169601ae224314a4fe7a7ba65e Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 5 Jun 2025 18:18:56 -0700 Subject: [PATCH 07/89] added grasp generation to pipeline --- ...est_manipulation_pipeline_visualization.py | 187 ++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 tests/test_manipulation_pipeline_visualization.py diff --git a/tests/test_manipulation_pipeline_visualization.py b/tests/test_manipulation_pipeline_visualization.py new file mode 100644 index 0000000000..a97ed473cd --- /dev/null +++ b/tests/test_manipulation_pipeline_visualization.py @@ -0,0 +1,187 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test manipulation pipeline with direct visualization and grasp data output.""" + +import os +import sys +import cv2 +import numpy as np +import time +import argparse +import matplotlib.pyplot as plt +import open3d as o3d +from typing import Dict, List +import threading +from reactivex import Observable, operators as ops +from reactivex.subject import Subject + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.perception.manip_aio_pipeline import ManipulationPipeline +from dimos.perception.grasp_generation.utils import visualize_grasps_3d +from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("test_pipeline_viz") + + +def load_first_frame(data_dir: str): + """Load first RGB-D frame and camera intrinsics.""" + # Load images + color_img = cv2.imread(os.path.join(data_dir, "color", "00000.png")) + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) + + depth_img = cv2.imread(os.path.join(data_dir, "depth", "00000.png"), cv2.IMREAD_ANYDEPTH) + if depth_img.dtype == np.uint16: + depth_img = depth_img.astype(np.float32) / 1000.0 + # Load intrinsics + camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, "color_camera_info.yaml")) + intrinsics = [ + camera_matrix[0, 0], + camera_matrix[1, 1], + camera_matrix[0, 2], + camera_matrix[1, 2], + ] + + return color_img, depth_img, intrinsics + + +def create_point_cloud(color_img, depth_img, intrinsics): + """Create Open3D point cloud.""" + fx, fy, cx, cy = intrinsics + height, width = depth_img.shape + + o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) + color_o3d = o3d.geometry.Image(color_img) + depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False + ) + + return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) + + +def run_pipeline(color_img, depth_img, intrinsics, wait_time=5.0): + """Run pipeline and collect results.""" + # Create pipeline + pipeline = ManipulationPipeline( + camera_intrinsics=intrinsics, + grasp_server_url="ws://10.0.0.125:8000/ws/grasp", + enable_grasp_generation=True, + ) + + # Create single-frame stream + subject = Subject() + streams = pipeline.create_streams(subject) + + # Debug: print available streams + print(f"Available streams: {list(streams.keys())}") + + # Collect results + results = {} + + def collect(key): + def on_next(value): + results[key] = value + logger.info(f"Received {key}") + + return on_next + + # Subscribe to streams + for key, stream in streams.items(): + if stream: + stream.pipe(ops.take(1)).subscribe(on_next=collect(key)) + + # Send frame + threading.Timer( + 0.5, + lambda: subject.on_next({"rgb": color_img, "depth": depth_img, "timestamp": time.time()}), + ).start() + + # Wait for results + time.sleep(wait_time) + + # If grasp generation is enabled, also check for latest grasps + if pipeline.latest_grasps: + results["grasps"] = pipeline.latest_grasps + logger.info(f"Retrieved latest grasps: {len(pipeline.latest_grasps)} grasps") + + pipeline.cleanup() + + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-dir", default="assets/rgbd_data") + parser.add_argument("--wait-time", type=float, default=5.0) + args = parser.parse_args() + + # Load data + color_img, depth_img, intrinsics = load_first_frame(args.data_dir) + logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") + + # Run pipeline + results = run_pipeline(color_img, depth_img, intrinsics, args.wait_time) + + # Debug: Print what we received + print(f"\n✅ Pipeline Results:") + print(f" Available streams: {list(results.keys())}") + + if "filtered_objects" in results and results["filtered_objects"]: + print(f" Objects detected: {len(results['filtered_objects'])}") + + # Print grasp summary + if "grasps" in results and results["grasps"]: + total_grasps = 0 + best_score = 0 + for grasp in results["grasps"]: + score = grasp.get("score", 0) + if score > best_score: + best_score = score + total_grasps += 1 + print(f" Grasps generated: {total_grasps} (best score: {best_score:.3f})") + else: + print(" Grasps: None generated") + + # Visualize 2D results + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + + if "detection_viz" in results and results["detection_viz"] is not None: + axes[0].imshow(results["detection_viz"]) + axes[0].set_title("Object Detection") + axes[0].axis("off") + + if "pointcloud_viz" in results and results["pointcloud_viz"] is not None: + axes[1].imshow(results["pointcloud_viz"]) + axes[1].set_title("Point Cloud Overlay") + axes[1].axis("off") + + plt.tight_layout() + plt.show() + + # 3D visualization with grasps + if "grasps" in results and results["grasps"]: + pcd = create_point_cloud(color_img, depth_img, intrinsics) + all_grasps = results["grasps"] + + if all_grasps: + logger.info(f"Visualizing {len(all_grasps)} grasps in 3D") + visualize_grasps_3d(pcd, all_grasps) + + +if __name__ == "__main__": + main() From 42840ddefe81eb2cebceb58521e9e127e59ad6f5 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 5 Jun 2025 20:53:41 -0700 Subject: [PATCH 08/89] I'm so f**ing tired after this --- ...est_manipulation_perception_pipeline.py.py | 167 ++++++++++++++++ ...est_manipulation_pipeline_visualization.py | 187 ------------------ 2 files changed, 167 insertions(+), 187 deletions(-) create mode 100644 tests/test_manipulation_perception_pipeline.py.py delete mode 100644 tests/test_manipulation_pipeline_visualization.py diff --git a/tests/test_manipulation_perception_pipeline.py.py b/tests/test_manipulation_perception_pipeline.py.py new file mode 100644 index 0000000000..9a2bc9d371 --- /dev/null +++ b/tests/test_manipulation_perception_pipeline.py.py @@ -0,0 +1,167 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +import threading +from reactivex import operators as ops + +import tests.test_header + +from pyzed import sl +from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.logging_config import logger +from dimos.perception.manip_aio_pipeline import ManipulationPipeline + + +def monitor_grasps(pipeline): + """Monitor and print grasp updates in a separate thread.""" + print(" Grasp monitor started...") + + last_grasp_count = 0 + last_update_time = time.time() + + while True: + try: + # Get latest grasps using the getter function + grasps = pipeline.get_latest_grasps(timeout=0.5) + current_time = time.time() + + if grasps is not None: + current_count = len(grasps) + if current_count != last_grasp_count: + print(f" Grasps received: {current_count} (at {time.strftime('%H:%M:%S')})") + if current_count > 0: + best_score = max(grasp.get("score", 0.0) for grasp in grasps) + print(f" Best grasp score: {best_score:.3f}") + last_grasp_count = current_count + last_update_time = current_time + else: + # Show periodic "still waiting" message + if current_time - last_update_time > 10.0: + print(f" Still waiting for grasps... ({time.strftime('%H:%M:%S')})") + last_update_time = current_time + + time.sleep(1.0) # Check every second + + except Exception as e: + print(f" Error in grasp monitor: {e}") + time.sleep(2.0) + + +def main(): + """Test point cloud filtering with grasp generation using ManipulationPipeline.""" + print(" Testing point cloud filtering + grasp generation with ManipulationPipeline...") + + # Configuration + min_confidence = 0.6 + web_port = 5555 + grasp_server_url = "ws://10.0.0.125:8000/ws/grasp" + + try: + # Initialize ZED camera stream + zed_stream = ZEDCameraStream(resolution=sl.RESOLUTION.HD1080, fps=10) + + # Get camera intrinsics + camera_intrinsics_dict = zed_stream.get_camera_info() + camera_intrinsics = [ + camera_intrinsics_dict["fx"], + camera_intrinsics_dict["fy"], + camera_intrinsics_dict["cx"], + camera_intrinsics_dict["cy"], + ] + + # Create the concurrent manipulation pipeline WITH grasp generation + pipeline = ManipulationPipeline( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + max_objects=10, + grasp_server_url=grasp_server_url, + enable_grasp_generation=True, # Enable grasp generation + ) + + # Create ZED stream + zed_frame_stream = zed_stream.create_stream().pipe(ops.share()) + + # Create concurrent processing streams + streams = pipeline.create_streams(zed_frame_stream) + detection_viz_stream = streams["detection_viz"] + pointcloud_viz_stream = streams["pointcloud_viz"] + grasps_stream = streams.get("grasps") # Get grasp stream if available + grasp_overlay_stream = streams.get("grasp_overlay") # Get grasp overlay stream if available + + except ImportError: + print("Error: ZED SDK not installed. Please install pyzed package.") + sys.exit(1) + except RuntimeError as e: + print(f"Error: Failed to open ZED camera: {e}") + sys.exit(1) + + try: + # Set up web interface with concurrent visualization streams + print("Initializing web interface...") + web_interface = RobotWebInterface( + port=web_port, + object_detection=detection_viz_stream, + pointcloud_stream=pointcloud_viz_stream, + grasp_overlay_stream=grasp_overlay_stream, + ) + + # Start grasp monitoring in background thread + grasp_monitor_thread = threading.Thread( + target=monitor_grasps, args=(pipeline,), daemon=True + ) + grasp_monitor_thread.start() + + print(f"\n Point Cloud + Grasp Generation Test Running:") + print(f" Web Interface: http://localhost:{web_port}") + print(f" Object Detection View: RGB with bounding boxes") + print(f" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print(f" Confidence threshold: {min_confidence}") + print(f" Grasp server: {grasp_server_url}") + print(f" Available streams: {list(streams.keys())}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Cleaning up resources...") + if "zed_stream" in locals(): + zed_stream.cleanup() + if "pipeline" in locals(): + pipeline.cleanup() + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/tests/test_manipulation_pipeline_visualization.py b/tests/test_manipulation_pipeline_visualization.py deleted file mode 100644 index a97ed473cd..0000000000 --- a/tests/test_manipulation_pipeline_visualization.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test manipulation pipeline with direct visualization and grasp data output.""" - -import os -import sys -import cv2 -import numpy as np -import time -import argparse -import matplotlib.pyplot as plt -import open3d as o3d -from typing import Dict, List -import threading -from reactivex import Observable, operators as ops -from reactivex.subject import Subject - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from dimos.perception.manip_aio_pipeline import ManipulationPipeline -from dimos.perception.grasp_generation.utils import visualize_grasps_3d -from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("test_pipeline_viz") - - -def load_first_frame(data_dir: str): - """Load first RGB-D frame and camera intrinsics.""" - # Load images - color_img = cv2.imread(os.path.join(data_dir, "color", "00000.png")) - color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) - - depth_img = cv2.imread(os.path.join(data_dir, "depth", "00000.png"), cv2.IMREAD_ANYDEPTH) - if depth_img.dtype == np.uint16: - depth_img = depth_img.astype(np.float32) / 1000.0 - # Load intrinsics - camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, "color_camera_info.yaml")) - intrinsics = [ - camera_matrix[0, 0], - camera_matrix[1, 1], - camera_matrix[0, 2], - camera_matrix[1, 2], - ] - - return color_img, depth_img, intrinsics - - -def create_point_cloud(color_img, depth_img, intrinsics): - """Create Open3D point cloud.""" - fx, fy, cx, cy = intrinsics - height, width = depth_img.shape - - o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) - color_o3d = o3d.geometry.Image(color_img) - depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) - - rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( - color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False - ) - - return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) - - -def run_pipeline(color_img, depth_img, intrinsics, wait_time=5.0): - """Run pipeline and collect results.""" - # Create pipeline - pipeline = ManipulationPipeline( - camera_intrinsics=intrinsics, - grasp_server_url="ws://10.0.0.125:8000/ws/grasp", - enable_grasp_generation=True, - ) - - # Create single-frame stream - subject = Subject() - streams = pipeline.create_streams(subject) - - # Debug: print available streams - print(f"Available streams: {list(streams.keys())}") - - # Collect results - results = {} - - def collect(key): - def on_next(value): - results[key] = value - logger.info(f"Received {key}") - - return on_next - - # Subscribe to streams - for key, stream in streams.items(): - if stream: - stream.pipe(ops.take(1)).subscribe(on_next=collect(key)) - - # Send frame - threading.Timer( - 0.5, - lambda: subject.on_next({"rgb": color_img, "depth": depth_img, "timestamp": time.time()}), - ).start() - - # Wait for results - time.sleep(wait_time) - - # If grasp generation is enabled, also check for latest grasps - if pipeline.latest_grasps: - results["grasps"] = pipeline.latest_grasps - logger.info(f"Retrieved latest grasps: {len(pipeline.latest_grasps)} grasps") - - pipeline.cleanup() - - return results - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--data-dir", default="assets/rgbd_data") - parser.add_argument("--wait-time", type=float, default=5.0) - args = parser.parse_args() - - # Load data - color_img, depth_img, intrinsics = load_first_frame(args.data_dir) - logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") - - # Run pipeline - results = run_pipeline(color_img, depth_img, intrinsics, args.wait_time) - - # Debug: Print what we received - print(f"\n✅ Pipeline Results:") - print(f" Available streams: {list(results.keys())}") - - if "filtered_objects" in results and results["filtered_objects"]: - print(f" Objects detected: {len(results['filtered_objects'])}") - - # Print grasp summary - if "grasps" in results and results["grasps"]: - total_grasps = 0 - best_score = 0 - for grasp in results["grasps"]: - score = grasp.get("score", 0) - if score > best_score: - best_score = score - total_grasps += 1 - print(f" Grasps generated: {total_grasps} (best score: {best_score:.3f})") - else: - print(" Grasps: None generated") - - # Visualize 2D results - fig, axes = plt.subplots(1, 2, figsize=(12, 6)) - - if "detection_viz" in results and results["detection_viz"] is not None: - axes[0].imshow(results["detection_viz"]) - axes[0].set_title("Object Detection") - axes[0].axis("off") - - if "pointcloud_viz" in results and results["pointcloud_viz"] is not None: - axes[1].imshow(results["pointcloud_viz"]) - axes[1].set_title("Point Cloud Overlay") - axes[1].axis("off") - - plt.tight_layout() - plt.show() - - # 3D visualization with grasps - if "grasps" in results and results["grasps"]: - pcd = create_point_cloud(color_img, depth_img, intrinsics) - all_grasps = results["grasps"] - - if all_grasps: - logger.info(f"Visualizing {len(all_grasps)} grasps in 3D") - visualize_grasps_3d(pcd, all_grasps) - - -if __name__ == "__main__": - main() From d65932afcb061831c9edbb7109191712eb25292f Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 17 Jun 2025 01:07:54 -0700 Subject: [PATCH 09/89] added SAM2 support for segmentation, added manipulation perception processer with any streaming --- dimos/perception/segmentation/sam_2d_seg.py | 51 +++++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index fcf27584e6..e1b2b9755f 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -47,6 +47,7 @@ def __init__( use_tracker=True, use_analyzer=True, use_rich_labeling=False, + model_type="auto", # "auto", "fastsam", "sam2" ): self.device = device if is_cuda_available(): @@ -62,6 +63,24 @@ def __init__( self.use_tracker = use_tracker self.use_analyzer = use_analyzer self.use_rich_labeling = use_rich_labeling + + # Determine model type automatically if needed + if model_type == "auto": + if "sam2" in model_path.lower() or "sam_2" in model_path.lower(): + self.model_type = "sam2" + elif "fastsam" in model_path.lower(): + self.model_type = "fastsam" + else: + # Default to FastSAM for backward compatibility + self.model_type = "fastsam" + else: + self.model_type = model_type + + # Initialize the appropriate model + if self.model_type == "sam2": + self.model = SAM(model_path) + else: + self.model = FastSAM(model_path) module_dir = os.path.dirname(__file__) self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") @@ -94,16 +113,28 @@ def __init__( def process_image(self, image): """Process an image and return segmentation results.""" - results = self.model.track( - source=image, - device=self.device, - retina_masks=True, - conf=0.6, - iou=0.9, - persist=True, - verbose=False, - tracker=self.tracker_config, - ) + if self.model_type == "sam2": + # For SAM 2, use segment everything mode + results = self.model( + source=image, + device=self.device, + save=False, + conf=0.4, + iou=0.9, + verbose=False, + ) + else: + # For FastSAM, use the original tracking approach + results = self.model.track( + source=image, + device=self.device, + retina_masks=True, + conf=0.6, + iou=0.9, + persist=True, + verbose=False, + tracker=self.tracker_config, + ) if len(results) > 0: # Get initial segmentation results From 6a16523a03ca8a7ed03adf431196bacf9fb7ee23 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Tue, 17 Jun 2025 08:08:48 +0000 Subject: [PATCH 10/89] CI code cleanup --- dimos/perception/segmentation/sam_2d_seg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index e1b2b9755f..91831d947e 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -63,7 +63,7 @@ def __init__( self.use_tracker = use_tracker self.use_analyzer = use_analyzer self.use_rich_labeling = use_rich_labeling - + # Determine model type automatically if needed if model_type == "auto": if "sam2" in model_path.lower() or "sam_2" in model_path.lower(): @@ -75,7 +75,7 @@ def __init__( self.model_type = "fastsam" else: self.model_type = model_type - + # Initialize the appropriate model if self.model_type == "sam2": self.model = SAM(model_path) From 1dbc232b536357a1fab9e37c6d2db8f94df60c75 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 25 Jun 2025 14:59:16 -0700 Subject: [PATCH 11/89] supports contact graspnet --- .../pointcloud/pointcloud_filtering.py | 2 +- tests/manipulation_pipeline_demo.ipynb | 839 ++++++++++++++++++ 2 files changed, 840 insertions(+), 1 deletion(-) create mode 100644 tests/manipulation_pipeline_demo.ipynb diff --git a/dimos/perception/pointcloud/pointcloud_filtering.py b/dimos/perception/pointcloud/pointcloud_filtering.py index 3de2f3ae6a..47d351bd14 100644 --- a/dimos/perception/pointcloud/pointcloud_filtering.py +++ b/dimos/perception/pointcloud/pointcloud_filtering.py @@ -292,7 +292,7 @@ def process_images( pcd = self._apply_color_mask(pcd, rgb_color) # Apply subsampling to control point cloud size - pcd = self._apply_subsampling(pcd) + # pcd = self._apply_subsampling(pcd) # Apply filtering (optional based on flags) pcd_filtered = self._apply_filtering(pcd) diff --git a/tests/manipulation_pipeline_demo.ipynb b/tests/manipulation_pipeline_demo.ipynb new file mode 100644 index 0000000000..df43a7c6ac --- /dev/null +++ b/tests/manipulation_pipeline_demo.ipynb @@ -0,0 +1,839 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Manipulation Pipeline Demo with ContactGraspNet\n", + "\n", + "This notebook demonstrates the complete manipulation pipeline including:\n", + "- Object detection (Detic)\n", + "- Semantic segmentation (SAM/FastSAM)\n", + "- Point cloud processing\n", + "- 6-DoF grasp generation (ContactGraspNet)\n", + "- 3D visualization\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Jupyter environment detected. Enabling Open3D WebVisualizer.\n", + "[Open3D INFO] WebRTC GUI backend enabled.\n", + "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n", + "\u2705 All imports successful!\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "import cv2\n", + "import numpy as np\n", + "import time\n", + "import matplotlib\n", + "\n", + "# Configure matplotlib backend\n", + "try:\n", + " matplotlib.use(\"TkAgg\")\n", + "except:\n", + " try:\n", + " matplotlib.use(\"Qt5Agg\")\n", + " except:\n", + " matplotlib.use(\"Agg\")\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import open3d as o3d\n", + "from typing import Dict, List\n", + "\n", + "# Add project root to path\n", + "sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(\"__file__\"))))\n", + "\n", + "# Import DIMOS modules\n", + "from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid\n", + "from dimos.perception.manip_aio_processer import ManipulationProcessor\n", + "from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml, visualize_pcd\n", + "from dimos.utils.logging_config import setup_logger\n", + "\n", + "# Import ContactGraspNet visualization\n", + "from dimos.models.manipulation.contact_graspnet_pytorch.contact_graspnet_pytorch.visualization_utils_o3d import (\n", + " visualize_grasps,\n", + ")\n", + "\n", + "logger = setup_logger(\"manipulation_pipeline_demo\")\n", + "print(\"\u2705 All imports successful!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Configuration:\n", + " data_dir: /home/alex-lin/dimos/assets/rgbd_data\n", + " enable_grasp_generation: True\n", + " enable_segmentation: True\n", + " segmentation_model: FastSAM-x.pt\n", + " min_confidence: 0.6\n", + " max_objects: 20\n", + " show_3d_visualizations: True\n", + " save_results: True\n" + ] + } + ], + "source": [ + "# Configuration parameters\n", + "CONFIG = {\n", + " \"data_dir\": \"/home/alex-lin/dimos/assets/rgbd_data\",\n", + " \"enable_grasp_generation\": True,\n", + " \"enable_segmentation\": True,\n", + " \"segmentation_model\": \"FastSAM-x.pt\", # or \"sam2_b.pt\"\n", + " \"min_confidence\": 0.6,\n", + " \"max_objects\": 20,\n", + " \"show_3d_visualizations\": True,\n", + " \"save_results\": True,\n", + "}\n", + "\n", + "print(f\"Configuration:\")\n", + "for key, value in CONFIG.items():\n", + " print(f\" {key}: {value}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Data Loading Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u2705 Data loading functions defined!\n" + ] + } + ], + "source": [ + "def load_first_frame(data_dir: str):\n", + " \"\"\"Load first RGB-D frame and camera intrinsics.\"\"\"\n", + " # Load images\n", + " color_img = cv2.imread(os.path.join(data_dir, \"color\", \"00000.png\"))\n", + " color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)\n", + "\n", + " depth_img = cv2.imread(os.path.join(data_dir, \"depth\", \"00000.png\"), cv2.IMREAD_ANYDEPTH)\n", + " if depth_img.dtype == np.uint16:\n", + " depth_img = depth_img.astype(np.float32) / 1000.0\n", + "\n", + " # Load intrinsics\n", + " camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, \"color_camera_info.yaml\"))\n", + " intrinsics = [\n", + " camera_matrix[0, 0], # fx\n", + " camera_matrix[1, 1], # fy\n", + " camera_matrix[0, 2], # cx\n", + " camera_matrix[1, 2], # cy\n", + " ]\n", + "\n", + " return color_img, depth_img, intrinsics\n", + "\n", + "\n", + "def create_point_cloud(color_img, depth_img, intrinsics):\n", + " \"\"\"Create Open3D point cloud for reference.\"\"\"\n", + " fx, fy, cx, cy = intrinsics\n", + " height, width = depth_img.shape\n", + "\n", + " o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy)\n", + " color_o3d = o3d.geometry.Image(color_img)\n", + " depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16))\n", + "\n", + " rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(\n", + " color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False\n", + " )\n", + "\n", + " return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics)\n", + "\n", + "\n", + "print(\"\u2705 Data loading functions defined!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Load RGB-D Data" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-06-25 13:29:47,127 - manipulation_pipeline_demo - INFO - Loaded images: color (720, 1280, 3), depth (720, 1280)\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Camera intrinsics: fx=749.3, fy=748.6, cx=639.4, cy=357.2\n" + ] + } + ], + "source": [ + "# Load data\n", + "color_img, depth_img, intrinsics = load_first_frame(CONFIG[\"data_dir\"])\n", + "logger.info(f\"Loaded images: color {color_img.shape}, depth {depth_img.shape}\")\n", + "\n", + "# Display input images\n", + "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", + "\n", + "axes[0].imshow(color_img)\n", + "axes[0].set_title(\"RGB Image\")\n", + "axes[0].axis(\"off\")\n", + "\n", + "# Colorize depth for visualization\n", + "depth_colorized = cv2.applyColorMap(\n", + " cv2.convertScaleAbs(depth_img, alpha=255.0 / depth_img.max()), cv2.COLORMAP_JET\n", + ")\n", + "depth_colorized = cv2.cvtColor(depth_colorized, cv2.COLOR_BGR2RGB)\n", + "axes[1].imshow(depth_colorized)\n", + "axes[1].set_title(\"Depth Image\")\n", + "axes[1].axis(\"off\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "print(\n", + " f\"Camera intrinsics: fx={intrinsics[0]:.1f}, fy={intrinsics[1]:.1f}, cx={intrinsics[2]:.1f}, cy={intrinsics[3]:.1f}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Initialize Manipulation Processor" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n", + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/helpers.py:7: FutureWarning: Importing from timm.models.helpers is deprecated, please import via timm.models\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/TensorShape.cpp:3526.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting zs_weight /home/alex-lin/dimos/dimos/perception/detection2d/../../models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy\n", + "Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/FastSAM-x.pt to 'FastSAM-x.pt'...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 138M/138M [00:03<00:00, 41.5MB/s] \n", + "\u001b[32m2025-06-25 13:30:01,134 - dimos.perception.grasp_generation - INFO - Initializing ContactGraspNet on device: cuda\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,141 - dimos.perception.grasp_generation - INFO - Loaded config from dimos/models/manipulation/contact_graspnet_pytorch/checkpoints\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,164 - dimos.perception.grasp_generation - INFO - ContactGraspNet model initialized\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator initialized successfully\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.manip_aio_processor - INFO - ContactGraspNet generator initialized successfully\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.manip_aio_processor - INFO - Initialized ManipulationProcessor with confidence=0.6, grasp_generation=True\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model func: \n", + "\u2705 ManipulationProcessor initialized successfully!\n" + ] + } + ], + "source": [ + "# Create processor with ContactGraspNet enabled\n", + "processor = ManipulationProcessor(\n", + " camera_intrinsics=intrinsics,\n", + " min_confidence=CONFIG[\"min_confidence\"],\n", + " max_objects=CONFIG[\"max_objects\"],\n", + " enable_grasp_generation=CONFIG[\"enable_grasp_generation\"],\n", + " enable_segmentation=CONFIG[\"enable_segmentation\"],\n", + " segmentation_model=CONFIG[\"segmentation_model\"],\n", + ")\n", + "\n", + "print(\"\u2705 ManipulationProcessor initialized successfully!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Run Processing Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udd04 Processing frame through pipeline...\n", + "DBSCAN clustering found 11 clusters from 28067 points\n", + "Created voxel grid with 2220 voxels (voxel_size=0.02)\n", + "using local regions\n", + "Extracted Region Cube Size: 0.311576783657074\n", + "Extracted Region Cube Size: 0.445679247379303\n", + "Extracted Region Cube Size: 0.24130240082740784\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.46059030294418335\n", + "Extracted Region Cube Size: 0.2357255220413208\n", + "Extracted Region Cube Size: 0.3680998980998993\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.24357137084007263\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.2409430295228958\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.23709678649902344\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.5130001306533813\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", + " return _methods._mean(a, axis=axis, dtype=dtype,\n", + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/numpy/core/_methods.py:121: RuntimeWarning: invalid value encountered in divide\n", + " ret = um.true_divide(\n", + "\u001b[32m2025-06-25 13:30:19,727 - dimos.perception.grasp_generation - INFO - Generated 3400 grasps across 17 objects in 12.91s\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u2705 Processing completed in 14.768s\n" + ] + } + ], + "source": [ + "# Process single frame\n", + "print(\"\ud83d\udd04 Processing frame through pipeline...\")\n", + "start_time = time.time()\n", + "\n", + "results = processor.process_frame(color_img, depth_img)\n", + "\n", + "processing_time = time.time() - start_time\n", + "print(f\"\u2705 Processing completed in {processing_time:.3f}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Results Summary" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\ud83d\udcca PROCESSING RESULTS SUMMARY\n", + "==================================================\n", + "Available results: ['detection2d_objects', 'detection_viz', 'segmentation2d_objects', 'segmentation_viz', 'detected_objects', 'all_objects', 'full_pointcloud', 'misc_clusters', 'misc_voxel_grid', 'pointcloud_viz', 'detected_pointcloud_viz', 'misc_pointcloud_viz', 'grasps', 'processing_time', 'timing_breakdown']\n", + "Total processing time: 14.768s\n", + "\n", + "\u23f1\ufe0f Timing breakdown:\n", + " Detection: 0.550s\n", + " Segmentation: 0.733s\n", + " Point cloud: 0.144s\n", + " Misc extraction: 0.371s\n", + "\n", + "\ud83c\udfaf Object Detection:\n", + " Detection objects: 13\n", + " All objects processed: 18\n", + "\n", + "\ud83e\udde9 Background Analysis:\n", + " Misc clusters: 11 clusters with 26,692 total points\n", + "\n", + "\ud83e\udd16 ContactGraspNet Results:\n", + " Total grasps: 3400\n", + " Best score: 0.911\n", + " Objects with grasps: 17\n", + "\n", + "==================================================\n" + ] + } + ], + "source": [ + "# Print comprehensive results summary\n", + "print(f\"\\n\ud83d\udcca PROCESSING RESULTS SUMMARY\")\n", + "print(f\"\" + \"=\" * 50)\n", + "print(f\"Available results: {list(results.keys())}\")\n", + "print(f\"Total processing time: {results.get('processing_time', 0):.3f}s\")\n", + "\n", + "# Show timing breakdown\n", + "if \"timing_breakdown\" in results:\n", + " breakdown = results[\"timing_breakdown\"]\n", + " print(f\"\\n\u23f1\ufe0f Timing breakdown:\")\n", + " print(f\" Detection: {breakdown.get('detection', 0):.3f}s\")\n", + " print(f\" Segmentation: {breakdown.get('segmentation', 0):.3f}s\")\n", + " print(f\" Point cloud: {breakdown.get('pointcloud', 0):.3f}s\")\n", + " print(f\" Misc extraction: {breakdown.get('misc_extraction', 0):.3f}s\")\n", + "\n", + "# Object counts\n", + "detected_count = len(results.get(\"detected_objects\", []))\n", + "all_count = len(results.get(\"all_objects\", []))\n", + "print(f\"\\n\ud83c\udfaf Object Detection:\")\n", + "print(f\" Detection objects: {detected_count}\")\n", + "print(f\" All objects processed: {all_count}\")\n", + "\n", + "# Misc clusters info\n", + "if \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", + " cluster_count = len(results[\"misc_clusters\"])\n", + " total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in results[\"misc_clusters\"])\n", + " print(f\"\\n\ud83e\udde9 Background Analysis:\")\n", + " print(f\" Misc clusters: {cluster_count} clusters with {total_misc_points:,} total points\")\n", + "else:\n", + " print(f\"\\n\ud83e\udde9 Background Analysis: No clusters found\")\n", + "\n", + "# ContactGraspNet grasp summary\n", + "if \"grasps\" in results and results[\"grasps\"]:\n", + " grasp_data = results[\"grasps\"]\n", + " if isinstance(grasp_data, dict):\n", + " pred_grasps = grasp_data.get(\"pred_grasps_cam\", {})\n", + " scores = grasp_data.get(\"scores\", {})\n", + "\n", + " total_grasps = 0\n", + " best_score = 0\n", + "\n", + " for obj_id, obj_grasps in pred_grasps.items():\n", + " num_grasps = len(obj_grasps) if hasattr(obj_grasps, \"__len__\") else 0\n", + " total_grasps += num_grasps\n", + "\n", + " if obj_id in scores and len(scores[obj_id]) > 0:\n", + " obj_best_score = max(scores[obj_id])\n", + " if obj_best_score > best_score:\n", + " best_score = obj_best_score\n", + "\n", + " print(f\"\\n\ud83e\udd16 ContactGraspNet Results:\")\n", + " print(f\" Total grasps: {total_grasps}\")\n", + " print(f\" Best score: {best_score:.3f}\")\n", + " print(f\" Objects with grasps: {len(pred_grasps)}\")\n", + " else:\n", + " print(f\"\\n\ud83e\udd16 ContactGraspNet Results: Invalid format\")\n", + "else:\n", + " print(f\"\\n\ud83e\udd16 ContactGraspNet Results: No grasps generated\")\n", + "\n", + "print(\"\\n\" + \"=\" * 50)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. 2D Visualization Results" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udcf8 Results saved to: manipulation_results.png\n" + ] + } + ], + "source": [ + "# Collect available visualizations\n", + "viz_configs = []\n", + "\n", + "if \"detection_viz\" in results and results[\"detection_viz\"] is not None:\n", + " viz_configs.append((\"detection_viz\", \"Object Detection\"))\n", + "\n", + "if \"segmentation_viz\" in results and results[\"segmentation_viz\"] is not None:\n", + " viz_configs.append((\"segmentation_viz\", \"Semantic Segmentation\"))\n", + "\n", + "if \"pointcloud_viz\" in results and results[\"pointcloud_viz\"] is not None:\n", + " viz_configs.append((\"pointcloud_viz\", \"All Objects Point Cloud\"))\n", + "\n", + "if \"detected_pointcloud_viz\" in results and results[\"detected_pointcloud_viz\"] is not None:\n", + " viz_configs.append((\"detected_pointcloud_viz\", \"Detection Objects Point Cloud\"))\n", + "\n", + "if \"misc_pointcloud_viz\" in results and results[\"misc_pointcloud_viz\"] is not None:\n", + " viz_configs.append((\"misc_pointcloud_viz\", \"Misc/Background Points\"))\n", + "\n", + "# Create visualization grid\n", + "if viz_configs:\n", + " num_plots = len(viz_configs)\n", + "\n", + " if num_plots <= 3:\n", + " fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5))\n", + " else:\n", + " rows = 2\n", + " cols = (num_plots + 1) // 2\n", + " fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows))\n", + "\n", + " # Ensure axes is always iterable\n", + " if num_plots == 1:\n", + " axes = [axes]\n", + " elif num_plots > 2:\n", + " axes = axes.flatten()\n", + "\n", + " # Plot each result\n", + " for i, (key, title) in enumerate(viz_configs):\n", + " axes[i].imshow(results[key])\n", + " axes[i].set_title(title, fontsize=12, fontweight=\"bold\")\n", + " axes[i].axis(\"off\")\n", + "\n", + " # Hide unused subplots\n", + " if num_plots > 3:\n", + " for i in range(num_plots, len(axes)):\n", + " axes[i].axis(\"off\")\n", + "\n", + " plt.tight_layout()\n", + "\n", + " if CONFIG[\"save_results\"]:\n", + " output_path = \"manipulation_results.png\"\n", + " plt.savefig(output_path, dpi=150, bbox_inches=\"tight\")\n", + " print(f\"\ud83d\udcf8 Results saved to: {output_path}\")\n", + "\n", + " plt.show()\n", + "else:\n", + " print(\"\u26a0\ufe0f No 2D visualization results to display\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. 3D ContactGraspNet Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83c\udfaf Launching 3D visualization with 3400 ContactGraspNet grasps...\n", + "\ud83d\udccc Note: Close the 3D window to continue with the notebook\n", + "Visualizing...\n", + "\u2705 3D grasp visualization completed!\n" + ] + } + ], + "source": [ + "# 3D ContactGraspNet visualization\n", + "if (\n", + " CONFIG[\"show_3d_visualizations\"]\n", + " and \"grasps\" in results\n", + " and results[\"grasps\"]\n", + " and \"full_pointcloud\" in results\n", + "):\n", + " grasp_data = results[\"grasps\"]\n", + " full_pcd = results[\"full_pointcloud\"]\n", + "\n", + " if isinstance(grasp_data, dict) and full_pcd is not None:\n", + " try:\n", + " # Extract ContactGraspNet data\n", + " pred_grasps_cam = grasp_data.get(\"pred_grasps_cam\", {})\n", + " scores = grasp_data.get(\"scores\", {})\n", + " contact_pts = grasp_data.get(\"contact_pts\", {})\n", + " gripper_openings = grasp_data.get(\"gripper_openings\", {})\n", + "\n", + " # Check if we have valid grasp data\n", + " total_grasps = (\n", + " sum(len(grasps) for grasps in pred_grasps_cam.values()) if pred_grasps_cam else 0\n", + " )\n", + "\n", + " if total_grasps > 0:\n", + " print(\n", + " f\"\ud83c\udfaf Launching 3D visualization with {total_grasps} ContactGraspNet grasps...\"\n", + " )\n", + " print(\"\ud83d\udccc Note: Close the 3D window to continue with the notebook\")\n", + "\n", + " # Use ContactGraspNet's native visualization - pass dictionaries directly\n", + " visualize_grasps(\n", + " full_pcd,\n", + " pred_grasps_cam, # Pass dictionary directly\n", + " scores, # Pass dictionary directly\n", + " gripper_openings=gripper_openings,\n", + " )\n", + "\n", + " print(\"\u2705 3D grasp visualization completed!\")\n", + " else:\n", + " print(\"\u26a0\ufe0f No valid grasps to visualize in 3D\")\n", + "\n", + " except Exception as e:\n", + " print(f\"\u274c Error in ContactGraspNet 3D visualization: {e}\")\n", + " print(\" Skipping 3D grasp visualization\")\n", + "else:\n", + " if not CONFIG[\"show_3d_visualizations\"]:\n", + " print(\"\u23ed\ufe0f 3D visualizations disabled in config\")\n", + " else:\n", + " print(\"\u26a0\ufe0f ContactGraspNet grasp generation disabled or no results\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 10. Additional 3D Visualizations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.1 Full Scene Point Cloud" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if (\n", + " CONFIG[\"show_3d_visualizations\"]\n", + " and \"full_pointcloud\" in results\n", + " and results[\"full_pointcloud\"] is not None\n", + "):\n", + " full_pcd = results[\"full_pointcloud\"]\n", + " num_points = len(np.asarray(full_pcd.points))\n", + " print(f\"\ud83c\udf0d Launching full scene point cloud visualization ({num_points:,} points)...\")\n", + " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", + "\n", + " try:\n", + " visualize_pcd(\n", + " full_pcd,\n", + " window_name=\"Full Scene Point Cloud\",\n", + " point_size=2.0,\n", + " show_coordinate_frame=True,\n", + " )\n", + " print(\"\u2705 Full point cloud visualization completed!\")\n", + " except (KeyboardInterrupt, EOFError):\n", + " print(\"\u23ed\ufe0f Full point cloud visualization skipped\")\n", + "else:\n", + " print(\"\u26a0\ufe0f No full point cloud available for visualization\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.2 Background/Misc Clusters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if CONFIG[\"show_3d_visualizations\"] and \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", + " misc_clusters = results[\"misc_clusters\"]\n", + " cluster_count = len(misc_clusters)\n", + " total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in misc_clusters)\n", + "\n", + " print(\n", + " f\"\ud83e\udde9 Launching misc/background clusters visualization ({cluster_count} clusters, {total_misc_points:,} points)...\"\n", + " )\n", + " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", + "\n", + " try:\n", + " visualize_clustered_point_clouds(\n", + " misc_clusters,\n", + " window_name=\"Misc/Background Clusters (DBSCAN)\",\n", + " point_size=3.0,\n", + " show_coordinate_frame=True,\n", + " )\n", + " print(\"\u2705 Misc clusters visualization completed!\")\n", + " except (KeyboardInterrupt, EOFError):\n", + " print(\"\u23ed\ufe0f Misc clusters visualization skipped\")\n", + "else:\n", + " print(\"\u26a0\ufe0f No misc clusters available for visualization\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.3 Voxel Grid Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if (\n", + " CONFIG[\"show_3d_visualizations\"]\n", + " and \"misc_voxel_grid\" in results\n", + " and results[\"misc_voxel_grid\"] is not None\n", + "):\n", + " misc_voxel_grid = results[\"misc_voxel_grid\"]\n", + " voxel_count = len(misc_voxel_grid.get_voxels())\n", + "\n", + " print(f\"\ud83d\udce6 Launching voxel grid visualization ({voxel_count} voxels)...\")\n", + " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", + "\n", + " try:\n", + " visualize_voxel_grid(\n", + " misc_voxel_grid,\n", + " window_name=\"Misc/Background Voxel Grid\",\n", + " show_coordinate_frame=True,\n", + " )\n", + " print(\"\u2705 Voxel grid visualization completed!\")\n", + " except (KeyboardInterrupt, EOFError):\n", + " print(\"\u23ed\ufe0f Voxel grid visualization skipped\")\n", + " except Exception as e:\n", + " print(f\"\u274c Error in voxel grid visualization: {e}\")\n", + "else:\n", + " print(\"\u26a0\ufe0f No voxel grid available for visualization\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 11. Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Clean up resources\n", + "processor.cleanup()\n", + "print(\"\u2705 Pipeline cleanup completed!\")\n", + "print(\"\\n\ud83c\udf89 Manipulation pipeline demo finished successfully!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Summary\n", + "\n", + "This notebook demonstrated the complete DIMOS manipulation pipeline:\n", + "\n", + "1. **Object Detection**: Using Detic for 2D object detection\n", + "2. **Semantic Segmentation**: Using SAM/FastSAM for detailed segmentation\n", + "3. **Point Cloud Processing**: Converting RGB-D to 3D point clouds with filtering\n", + "4. **Background Analysis**: DBSCAN clustering of miscellaneous points\n", + "5. **Grasp Generation**: ContactGraspNet for 6-DoF grasp pose estimation\n", + "6. **Visualization**: Comprehensive 2D and 3D visualizations\n", + "\n", + "The pipeline is designed for real-time robotic manipulation tasks and provides rich visual feedback for debugging and analysis.\n", + "\n", + "### Key Features:\n", + "- \u2705 Modular design with clean interfaces\n", + "- \u2705 Real-time performance optimizations\n", + "- \u2705 Comprehensive error handling\n", + "- \u2705 Rich visualization capabilities\n", + "- \u2705 ContactGraspNet integration for state-of-the-art grasp generation\n", + "\n", + "### Next Steps:\n", + "- Integrate with robotic control systems\n", + "- Add grasp execution and feedback\n", + "- Implement multi-frame tracking\n", + "- Add custom object vocabularies\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "contact-graspnet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 0b03fd99ddf3cb6e6bf649bbc23985a84aecb197 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 25 Jun 2025 16:33:12 -0700 Subject: [PATCH 12/89] added parsing of contact graspnet results into dict --- tests/manipulation_pipeline_demo.ipynb | 186 ++++++++++++++++--------- 1 file changed, 119 insertions(+), 67 deletions(-) diff --git a/tests/manipulation_pipeline_demo.ipynb b/tests/manipulation_pipeline_demo.ipynb index df43a7c6ac..01470c6355 100644 --- a/tests/manipulation_pipeline_demo.ipynb +++ b/tests/manipulation_pipeline_demo.ipynb @@ -87,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -133,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -195,14 +195,14 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2025-06-25 13:29:47,127 - manipulation_pipeline_demo - INFO - Loaded images: color (720, 1280, 3), depth (720, 1280)\u001b[0m\n" + "\u001b[32m2025-06-25 15:04:05,846 - manipulation_pipeline_demo - INFO - Loaded images: color (720, 1280, 3), depth (720, 1280)\u001b[0m\n" ] }, { @@ -251,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -265,7 +265,13 @@ "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n", " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/TensorShape.cpp:3526.)\n", - " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", + "\u001b[32m2025-06-25 15:04:11,530 - dimos.perception.grasp_generation - INFO - Initializing ContactGraspNet on device: cuda\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,541 - dimos.perception.grasp_generation - INFO - Loaded config from dimos/models/manipulation/contact_graspnet_pytorch/checkpoints\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,565 - dimos.perception.grasp_generation - INFO - ContactGraspNet model initialized\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,566 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator initialized successfully\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,566 - dimos.perception.manip_aio_processor - INFO - ContactGraspNet generator initialized successfully\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,567 - dimos.perception.manip_aio_processor - INFO - Initialized ManipulationProcessor with confidence=0.6, grasp_generation=True\u001b[0m\n" ] }, { @@ -273,26 +279,6 @@ "output_type": "stream", "text": [ "Resetting zs_weight /home/alex-lin/dimos/dimos/perception/detection2d/../../models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy\n", - "Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/FastSAM-x.pt to 'FastSAM-x.pt'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 138M/138M [00:03<00:00, 41.5MB/s] \n", - "\u001b[32m2025-06-25 13:30:01,134 - dimos.perception.grasp_generation - INFO - Initializing ContactGraspNet on device: cuda\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,141 - dimos.perception.grasp_generation - INFO - Loaded config from dimos/models/manipulation/contact_graspnet_pytorch/checkpoints\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,164 - dimos.perception.grasp_generation - INFO - ContactGraspNet model initialized\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator initialized successfully\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.manip_aio_processor - INFO - ContactGraspNet generator initialized successfully\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.manip_aio_processor - INFO - Initialized ManipulationProcessor with confidence=0.6, grasp_generation=True\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ "model func: \n", "\u2705 ManipulationProcessor initialized successfully!\n" ] @@ -321,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -329,44 +315,59 @@ "output_type": "stream", "text": [ "\ud83d\udd04 Processing frame through pipeline...\n", - "DBSCAN clustering found 11 clusters from 28067 points\n", - "Created voxel grid with 2220 voxels (voxel_size=0.02)\n", + "DBSCAN clustering found 13 clusters from 26536 points\n", + "Created voxel grid with 2074 voxels (voxel_size=0.02)\n", "using local regions\n", - "Extracted Region Cube Size: 0.311576783657074\n", - "Extracted Region Cube Size: 0.445679247379303\n", - "Extracted Region Cube Size: 0.24130240082740784\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.46059030294418335\n", - "Extracted Region Cube Size: 0.2357255220413208\n", - "Extracted Region Cube Size: 0.3680998980998993\n", + "Extracted Region Cube Size: 0.3148665130138397\n", + "Extracted Region Cube Size: 0.4740000367164612\n", + "Extracted Region Cube Size: 0.2676139771938324\n", "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.4960000514984131\n", + "Extracted Region Cube Size: 0.30400002002716064\n", + "Extracted Region Cube Size: 0.38946154713630676\n", + "Extracted Region Cube Size: 0.2087651789188385\n", "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.24357137084007263\n", "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.2409430295228958\n", + "Extracted Region Cube Size: 0.24777960777282715\n", "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.2502080202102661\n", "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.23709678649902344\n", + "Extracted Region Cube Size: 0.3400000333786011\n", + "Extracted Region Cube Size: 0.22946105897426605\n", "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.5130001306533813\n" + "Extracted Region Cube Size: 0.5360000133514404\n", + "Generated 18 grasps for object 2\n", + "Generated 44 grasps for object 3\n", + "Generated 14 grasps for object 4\n", + "Generated 6 grasps for object 7\n", + "Generated 9 grasps for object 8\n", + "Generated 15 grasps for object 9\n", + "Generated 25 grasps for object 10\n", + "Generated 25 grasps for object 11\n", + "Generated 16 grasps for object 14\n", + "Generated 3 grasps for object 15\n", + "Generated 13 grasps for object 16\n", + "Generated 15 grasps for object 19\n", + "Generated 12 grasps for object 27\n", + "Generated 17 grasps for object 29\n", + "Generated 19 grasps for object 31\n", + "Generated 19 grasps for object 32\n", + "Generated 3 grasps for object 33\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/numpy/core/_methods.py:121: RuntimeWarning: invalid value encountered in divide\n", - " ret = um.true_divide(\n", - "\u001b[32m2025-06-25 13:30:19,727 - dimos.perception.grasp_generation - INFO - Generated 3400 grasps across 17 objects in 12.91s\u001b[0m\n" + "\u001b[32m2025-06-25 15:04:30,107 - dimos.perception.grasp_generation - INFO - Generated 296 grasps across 18 objects in 14.69s\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "\u2705 Processing completed in 14.768s\n" + "Generated 23 grasps for object 37\n", + "\u2705 Processing completed in 18.517s\n" ] } ], @@ -390,7 +391,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -401,25 +402,25 @@ "\ud83d\udcca PROCESSING RESULTS SUMMARY\n", "==================================================\n", "Available results: ['detection2d_objects', 'detection_viz', 'segmentation2d_objects', 'segmentation_viz', 'detected_objects', 'all_objects', 'full_pointcloud', 'misc_clusters', 'misc_voxel_grid', 'pointcloud_viz', 'detected_pointcloud_viz', 'misc_pointcloud_viz', 'grasps', 'processing_time', 'timing_breakdown']\n", - "Total processing time: 14.768s\n", + "Total processing time: 18.517s\n", "\n", "\u23f1\ufe0f Timing breakdown:\n", - " Detection: 0.550s\n", - " Segmentation: 0.733s\n", - " Point cloud: 0.144s\n", - " Misc extraction: 0.371s\n", + " Detection: 0.529s\n", + " Segmentation: 0.720s\n", + " Point cloud: 1.837s\n", + " Misc extraction: 0.385s\n", "\n", "\ud83c\udfaf Object Detection:\n", " Detection objects: 13\n", " All objects processed: 18\n", "\n", "\ud83e\udde9 Background Analysis:\n", - " Misc clusters: 11 clusters with 26,692 total points\n", + " Misc clusters: 13 clusters with 25,628 total points\n", "\n", "\ud83e\udd16 ContactGraspNet Results:\n", - " Total grasps: 3400\n", - " Best score: 0.911\n", - " Objects with grasps: 17\n", + " Total grasps: 296\n", + " Best score: 0.798\n", + " Objects with grasps: 18\n", "\n", "==================================================\n" ] @@ -497,7 +498,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -576,14 +577,14 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\ud83c\udfaf Launching 3D visualization with 3400 ContactGraspNet grasps...\n", + "\ud83c\udfaf Launching 3D visualization with 296 ContactGraspNet grasps...\n", "\ud83d\udccc Note: Close the 3D window to continue with the notebook\n", "Visualizing...\n", "\u2705 3D grasp visualization completed!\n" @@ -658,9 +659,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83c\udf0d Launching full scene point cloud visualization (526,100 points)...\n", + "\ud83d\udccc Note: Close the 3D window to continue\n", + "Visualizing point cloud with 526100 points\n", + "\u2705 Full point cloud visualization completed!\n" + ] + } + ], "source": [ "if (\n", " CONFIG[\"show_3d_visualizations\"]\n", @@ -695,9 +707,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83e\udde9 Launching misc/background clusters visualization (13 clusters, 25,628 points)...\n", + "\ud83d\udccc Note: Close the 3D window to continue\n", + "Visualizing 13 clusters with 25628 total points\n", + "\u2705 Misc clusters visualization completed!\n" + ] + } + ], "source": [ "if CONFIG[\"show_3d_visualizations\"] and \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", " misc_clusters = results[\"misc_clusters\"]\n", @@ -732,9 +755,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce6 Launching voxel grid visualization (2074 voxels)...\n", + "\ud83d\udccc Note: Close the 3D window to continue\n", + "Visualizing voxel grid with 2074 voxels\n", + "\u2705 Voxel grid visualization completed!\n" + ] + } + ], "source": [ "if (\n", " CONFIG[\"show_3d_visualizations\"]\n", @@ -771,9 +805,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-06-25 15:05:01,624 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator cleaned up\u001b[0m\n", + "\u001b[32m2025-06-25 15:05:01,626 - dimos.perception.manip_aio_processor - INFO - ManipulationProcessor cleaned up\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u2705 Pipeline cleanup completed!\n", + "\n", + "\ud83c\udf89 Manipulation pipeline demo finished successfully!\n" + ] + } + ], "source": [ "# Clean up resources\n", "processor.cleanup()\n", From 4d29bc3b6ddb913335a3af20076510c2844bc719 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 25 Jun 2025 22:12:34 -0700 Subject: [PATCH 13/89] added anygrasp and contact graspnet support --- dimos/perception/pointcloud/pointcloud_filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/perception/pointcloud/pointcloud_filtering.py b/dimos/perception/pointcloud/pointcloud_filtering.py index 47d351bd14..3de2f3ae6a 100644 --- a/dimos/perception/pointcloud/pointcloud_filtering.py +++ b/dimos/perception/pointcloud/pointcloud_filtering.py @@ -292,7 +292,7 @@ def process_images( pcd = self._apply_color_mask(pcd, rgb_color) # Apply subsampling to control point cloud size - # pcd = self._apply_subsampling(pcd) + pcd = self._apply_subsampling(pcd) # Apply filtering (optional based on flags) pcd_filtered = self._apply_filtering(pcd) From 4ed8ab11a6bd437df2d89f54d09d5a6c9fe0f35e Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 26 Jun 2025 02:45:47 -0700 Subject: [PATCH 14/89] zed frames saving --- tests/manipulation_pipeline_demo.ipynb | 891 ------------------------- 1 file changed, 891 deletions(-) delete mode 100644 tests/manipulation_pipeline_demo.ipynb diff --git a/tests/manipulation_pipeline_demo.ipynb b/tests/manipulation_pipeline_demo.ipynb deleted file mode 100644 index 01470c6355..0000000000 --- a/tests/manipulation_pipeline_demo.ipynb +++ /dev/null @@ -1,891 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Manipulation Pipeline Demo with ContactGraspNet\n", - "\n", - "This notebook demonstrates the complete manipulation pipeline including:\n", - "- Object detection (Detic)\n", - "- Semantic segmentation (SAM/FastSAM)\n", - "- Point cloud processing\n", - "- 6-DoF grasp generation (ContactGraspNet)\n", - "- 3D visualization\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Setup and Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Jupyter environment detected. Enabling Open3D WebVisualizer.\n", - "[Open3D INFO] WebRTC GUI backend enabled.\n", - "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n", - "\u2705 All imports successful!\n" - ] - } - ], - "source": [ - "import os\n", - "import sys\n", - "import cv2\n", - "import numpy as np\n", - "import time\n", - "import matplotlib\n", - "\n", - "# Configure matplotlib backend\n", - "try:\n", - " matplotlib.use(\"TkAgg\")\n", - "except:\n", - " try:\n", - " matplotlib.use(\"Qt5Agg\")\n", - " except:\n", - " matplotlib.use(\"Agg\")\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import open3d as o3d\n", - "from typing import Dict, List\n", - "\n", - "# Add project root to path\n", - "sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(\"__file__\"))))\n", - "\n", - "# Import DIMOS modules\n", - "from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid\n", - "from dimos.perception.manip_aio_processer import ManipulationProcessor\n", - "from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml, visualize_pcd\n", - "from dimos.utils.logging_config import setup_logger\n", - "\n", - "# Import ContactGraspNet visualization\n", - "from dimos.models.manipulation.contact_graspnet_pytorch.contact_graspnet_pytorch.visualization_utils_o3d import (\n", - " visualize_grasps,\n", - ")\n", - "\n", - "logger = setup_logger(\"manipulation_pipeline_demo\")\n", - "print(\"\u2705 All imports successful!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Configuration" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Configuration:\n", - " data_dir: /home/alex-lin/dimos/assets/rgbd_data\n", - " enable_grasp_generation: True\n", - " enable_segmentation: True\n", - " segmentation_model: FastSAM-x.pt\n", - " min_confidence: 0.6\n", - " max_objects: 20\n", - " show_3d_visualizations: True\n", - " save_results: True\n" - ] - } - ], - "source": [ - "# Configuration parameters\n", - "CONFIG = {\n", - " \"data_dir\": \"/home/alex-lin/dimos/assets/rgbd_data\",\n", - " \"enable_grasp_generation\": True,\n", - " \"enable_segmentation\": True,\n", - " \"segmentation_model\": \"FastSAM-x.pt\", # or \"sam2_b.pt\"\n", - " \"min_confidence\": 0.6,\n", - " \"max_objects\": 20,\n", - " \"show_3d_visualizations\": True,\n", - " \"save_results\": True,\n", - "}\n", - "\n", - "print(f\"Configuration:\")\n", - "for key, value in CONFIG.items():\n", - " print(f\" {key}: {value}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Data Loading Functions" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u2705 Data loading functions defined!\n" - ] - } - ], - "source": [ - "def load_first_frame(data_dir: str):\n", - " \"\"\"Load first RGB-D frame and camera intrinsics.\"\"\"\n", - " # Load images\n", - " color_img = cv2.imread(os.path.join(data_dir, \"color\", \"00000.png\"))\n", - " color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)\n", - "\n", - " depth_img = cv2.imread(os.path.join(data_dir, \"depth\", \"00000.png\"), cv2.IMREAD_ANYDEPTH)\n", - " if depth_img.dtype == np.uint16:\n", - " depth_img = depth_img.astype(np.float32) / 1000.0\n", - "\n", - " # Load intrinsics\n", - " camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, \"color_camera_info.yaml\"))\n", - " intrinsics = [\n", - " camera_matrix[0, 0], # fx\n", - " camera_matrix[1, 1], # fy\n", - " camera_matrix[0, 2], # cx\n", - " camera_matrix[1, 2], # cy\n", - " ]\n", - "\n", - " return color_img, depth_img, intrinsics\n", - "\n", - "\n", - "def create_point_cloud(color_img, depth_img, intrinsics):\n", - " \"\"\"Create Open3D point cloud for reference.\"\"\"\n", - " fx, fy, cx, cy = intrinsics\n", - " height, width = depth_img.shape\n", - "\n", - " o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy)\n", - " color_o3d = o3d.geometry.Image(color_img)\n", - " depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16))\n", - "\n", - " rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(\n", - " color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False\n", - " )\n", - "\n", - " return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics)\n", - "\n", - "\n", - "print(\"\u2705 Data loading functions defined!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Load RGB-D Data" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2025-06-25 15:04:05,846 - manipulation_pipeline_demo - INFO - Loaded images: color (720, 1280, 3), depth (720, 1280)\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Camera intrinsics: fx=749.3, fy=748.6, cx=639.4, cy=357.2\n" - ] - } - ], - "source": [ - "# Load data\n", - "color_img, depth_img, intrinsics = load_first_frame(CONFIG[\"data_dir\"])\n", - "logger.info(f\"Loaded images: color {color_img.shape}, depth {depth_img.shape}\")\n", - "\n", - "# Display input images\n", - "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", - "\n", - "axes[0].imshow(color_img)\n", - "axes[0].set_title(\"RGB Image\")\n", - "axes[0].axis(\"off\")\n", - "\n", - "# Colorize depth for visualization\n", - "depth_colorized = cv2.applyColorMap(\n", - " cv2.convertScaleAbs(depth_img, alpha=255.0 / depth_img.max()), cv2.COLORMAP_JET\n", - ")\n", - "depth_colorized = cv2.cvtColor(depth_colorized, cv2.COLOR_BGR2RGB)\n", - "axes[1].imshow(depth_colorized)\n", - "axes[1].set_title(\"Depth Image\")\n", - "axes[1].axis(\"off\")\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(\n", - " f\"Camera intrinsics: fx={intrinsics[0]:.1f}, fy={intrinsics[1]:.1f}, cx={intrinsics[2]:.1f}, cy={intrinsics[3]:.1f}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Initialize Manipulation Processor" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", - " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n", - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/helpers.py:7: FutureWarning: Importing from timm.models.helpers is deprecated, please import via timm.models\n", - " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n", - " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/TensorShape.cpp:3526.)\n", - " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", - "\u001b[32m2025-06-25 15:04:11,530 - dimos.perception.grasp_generation - INFO - Initializing ContactGraspNet on device: cuda\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,541 - dimos.perception.grasp_generation - INFO - Loaded config from dimos/models/manipulation/contact_graspnet_pytorch/checkpoints\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,565 - dimos.perception.grasp_generation - INFO - ContactGraspNet model initialized\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,566 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator initialized successfully\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,566 - dimos.perception.manip_aio_processor - INFO - ContactGraspNet generator initialized successfully\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,567 - dimos.perception.manip_aio_processor - INFO - Initialized ManipulationProcessor with confidence=0.6, grasp_generation=True\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Resetting zs_weight /home/alex-lin/dimos/dimos/perception/detection2d/../../models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy\n", - "model func: \n", - "\u2705 ManipulationProcessor initialized successfully!\n" - ] - } - ], - "source": [ - "# Create processor with ContactGraspNet enabled\n", - "processor = ManipulationProcessor(\n", - " camera_intrinsics=intrinsics,\n", - " min_confidence=CONFIG[\"min_confidence\"],\n", - " max_objects=CONFIG[\"max_objects\"],\n", - " enable_grasp_generation=CONFIG[\"enable_grasp_generation\"],\n", - " enable_segmentation=CONFIG[\"enable_segmentation\"],\n", - " segmentation_model=CONFIG[\"segmentation_model\"],\n", - ")\n", - "\n", - "print(\"\u2705 ManipulationProcessor initialized successfully!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Run Processing Pipeline" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udd04 Processing frame through pipeline...\n", - "DBSCAN clustering found 13 clusters from 26536 points\n", - "Created voxel grid with 2074 voxels (voxel_size=0.02)\n", - "using local regions\n", - "Extracted Region Cube Size: 0.3148665130138397\n", - "Extracted Region Cube Size: 0.4740000367164612\n", - "Extracted Region Cube Size: 0.2676139771938324\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.4960000514984131\n", - "Extracted Region Cube Size: 0.30400002002716064\n", - "Extracted Region Cube Size: 0.38946154713630676\n", - "Extracted Region Cube Size: 0.2087651789188385\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.24777960777282715\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.2502080202102661\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.3400000333786011\n", - "Extracted Region Cube Size: 0.22946105897426605\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.5360000133514404\n", - "Generated 18 grasps for object 2\n", - "Generated 44 grasps for object 3\n", - "Generated 14 grasps for object 4\n", - "Generated 6 grasps for object 7\n", - "Generated 9 grasps for object 8\n", - "Generated 15 grasps for object 9\n", - "Generated 25 grasps for object 10\n", - "Generated 25 grasps for object 11\n", - "Generated 16 grasps for object 14\n", - "Generated 3 grasps for object 15\n", - "Generated 13 grasps for object 16\n", - "Generated 15 grasps for object 19\n", - "Generated 12 grasps for object 27\n", - "Generated 17 grasps for object 29\n", - "Generated 19 grasps for object 31\n", - "Generated 19 grasps for object 32\n", - "Generated 3 grasps for object 33\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2025-06-25 15:04:30,107 - dimos.perception.grasp_generation - INFO - Generated 296 grasps across 18 objects in 14.69s\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generated 23 grasps for object 37\n", - "\u2705 Processing completed in 18.517s\n" - ] - } - ], - "source": [ - "# Process single frame\n", - "print(\"\ud83d\udd04 Processing frame through pipeline...\")\n", - "start_time = time.time()\n", - "\n", - "results = processor.process_frame(color_img, depth_img)\n", - "\n", - "processing_time = time.time() - start_time\n", - "print(f\"\u2705 Processing completed in {processing_time:.3f}s\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Results Summary" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\ud83d\udcca PROCESSING RESULTS SUMMARY\n", - "==================================================\n", - "Available results: ['detection2d_objects', 'detection_viz', 'segmentation2d_objects', 'segmentation_viz', 'detected_objects', 'all_objects', 'full_pointcloud', 'misc_clusters', 'misc_voxel_grid', 'pointcloud_viz', 'detected_pointcloud_viz', 'misc_pointcloud_viz', 'grasps', 'processing_time', 'timing_breakdown']\n", - "Total processing time: 18.517s\n", - "\n", - "\u23f1\ufe0f Timing breakdown:\n", - " Detection: 0.529s\n", - " Segmentation: 0.720s\n", - " Point cloud: 1.837s\n", - " Misc extraction: 0.385s\n", - "\n", - "\ud83c\udfaf Object Detection:\n", - " Detection objects: 13\n", - " All objects processed: 18\n", - "\n", - "\ud83e\udde9 Background Analysis:\n", - " Misc clusters: 13 clusters with 25,628 total points\n", - "\n", - "\ud83e\udd16 ContactGraspNet Results:\n", - " Total grasps: 296\n", - " Best score: 0.798\n", - " Objects with grasps: 18\n", - "\n", - "==================================================\n" - ] - } - ], - "source": [ - "# Print comprehensive results summary\n", - "print(f\"\\n\ud83d\udcca PROCESSING RESULTS SUMMARY\")\n", - "print(f\"\" + \"=\" * 50)\n", - "print(f\"Available results: {list(results.keys())}\")\n", - "print(f\"Total processing time: {results.get('processing_time', 0):.3f}s\")\n", - "\n", - "# Show timing breakdown\n", - "if \"timing_breakdown\" in results:\n", - " breakdown = results[\"timing_breakdown\"]\n", - " print(f\"\\n\u23f1\ufe0f Timing breakdown:\")\n", - " print(f\" Detection: {breakdown.get('detection', 0):.3f}s\")\n", - " print(f\" Segmentation: {breakdown.get('segmentation', 0):.3f}s\")\n", - " print(f\" Point cloud: {breakdown.get('pointcloud', 0):.3f}s\")\n", - " print(f\" Misc extraction: {breakdown.get('misc_extraction', 0):.3f}s\")\n", - "\n", - "# Object counts\n", - "detected_count = len(results.get(\"detected_objects\", []))\n", - "all_count = len(results.get(\"all_objects\", []))\n", - "print(f\"\\n\ud83c\udfaf Object Detection:\")\n", - "print(f\" Detection objects: {detected_count}\")\n", - "print(f\" All objects processed: {all_count}\")\n", - "\n", - "# Misc clusters info\n", - "if \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", - " cluster_count = len(results[\"misc_clusters\"])\n", - " total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in results[\"misc_clusters\"])\n", - " print(f\"\\n\ud83e\udde9 Background Analysis:\")\n", - " print(f\" Misc clusters: {cluster_count} clusters with {total_misc_points:,} total points\")\n", - "else:\n", - " print(f\"\\n\ud83e\udde9 Background Analysis: No clusters found\")\n", - "\n", - "# ContactGraspNet grasp summary\n", - "if \"grasps\" in results and results[\"grasps\"]:\n", - " grasp_data = results[\"grasps\"]\n", - " if isinstance(grasp_data, dict):\n", - " pred_grasps = grasp_data.get(\"pred_grasps_cam\", {})\n", - " scores = grasp_data.get(\"scores\", {})\n", - "\n", - " total_grasps = 0\n", - " best_score = 0\n", - "\n", - " for obj_id, obj_grasps in pred_grasps.items():\n", - " num_grasps = len(obj_grasps) if hasattr(obj_grasps, \"__len__\") else 0\n", - " total_grasps += num_grasps\n", - "\n", - " if obj_id in scores and len(scores[obj_id]) > 0:\n", - " obj_best_score = max(scores[obj_id])\n", - " if obj_best_score > best_score:\n", - " best_score = obj_best_score\n", - "\n", - " print(f\"\\n\ud83e\udd16 ContactGraspNet Results:\")\n", - " print(f\" Total grasps: {total_grasps}\")\n", - " print(f\" Best score: {best_score:.3f}\")\n", - " print(f\" Objects with grasps: {len(pred_grasps)}\")\n", - " else:\n", - " print(f\"\\n\ud83e\udd16 ContactGraspNet Results: Invalid format\")\n", - "else:\n", - " print(f\"\\n\ud83e\udd16 ContactGraspNet Results: No grasps generated\")\n", - "\n", - "print(\"\\n\" + \"=\" * 50)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8. 2D Visualization Results" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udcf8 Results saved to: manipulation_results.png\n" - ] - } - ], - "source": [ - "# Collect available visualizations\n", - "viz_configs = []\n", - "\n", - "if \"detection_viz\" in results and results[\"detection_viz\"] is not None:\n", - " viz_configs.append((\"detection_viz\", \"Object Detection\"))\n", - "\n", - "if \"segmentation_viz\" in results and results[\"segmentation_viz\"] is not None:\n", - " viz_configs.append((\"segmentation_viz\", \"Semantic Segmentation\"))\n", - "\n", - "if \"pointcloud_viz\" in results and results[\"pointcloud_viz\"] is not None:\n", - " viz_configs.append((\"pointcloud_viz\", \"All Objects Point Cloud\"))\n", - "\n", - "if \"detected_pointcloud_viz\" in results and results[\"detected_pointcloud_viz\"] is not None:\n", - " viz_configs.append((\"detected_pointcloud_viz\", \"Detection Objects Point Cloud\"))\n", - "\n", - "if \"misc_pointcloud_viz\" in results and results[\"misc_pointcloud_viz\"] is not None:\n", - " viz_configs.append((\"misc_pointcloud_viz\", \"Misc/Background Points\"))\n", - "\n", - "# Create visualization grid\n", - "if viz_configs:\n", - " num_plots = len(viz_configs)\n", - "\n", - " if num_plots <= 3:\n", - " fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5))\n", - " else:\n", - " rows = 2\n", - " cols = (num_plots + 1) // 2\n", - " fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows))\n", - "\n", - " # Ensure axes is always iterable\n", - " if num_plots == 1:\n", - " axes = [axes]\n", - " elif num_plots > 2:\n", - " axes = axes.flatten()\n", - "\n", - " # Plot each result\n", - " for i, (key, title) in enumerate(viz_configs):\n", - " axes[i].imshow(results[key])\n", - " axes[i].set_title(title, fontsize=12, fontweight=\"bold\")\n", - " axes[i].axis(\"off\")\n", - "\n", - " # Hide unused subplots\n", - " if num_plots > 3:\n", - " for i in range(num_plots, len(axes)):\n", - " axes[i].axis(\"off\")\n", - "\n", - " plt.tight_layout()\n", - "\n", - " if CONFIG[\"save_results\"]:\n", - " output_path = \"manipulation_results.png\"\n", - " plt.savefig(output_path, dpi=150, bbox_inches=\"tight\")\n", - " print(f\"\ud83d\udcf8 Results saved to: {output_path}\")\n", - "\n", - " plt.show()\n", - "else:\n", - " print(\"\u26a0\ufe0f No 2D visualization results to display\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 9. 3D ContactGraspNet Visualization" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83c\udfaf Launching 3D visualization with 296 ContactGraspNet grasps...\n", - "\ud83d\udccc Note: Close the 3D window to continue with the notebook\n", - "Visualizing...\n", - "\u2705 3D grasp visualization completed!\n" - ] - } - ], - "source": [ - "# 3D ContactGraspNet visualization\n", - "if (\n", - " CONFIG[\"show_3d_visualizations\"]\n", - " and \"grasps\" in results\n", - " and results[\"grasps\"]\n", - " and \"full_pointcloud\" in results\n", - "):\n", - " grasp_data = results[\"grasps\"]\n", - " full_pcd = results[\"full_pointcloud\"]\n", - "\n", - " if isinstance(grasp_data, dict) and full_pcd is not None:\n", - " try:\n", - " # Extract ContactGraspNet data\n", - " pred_grasps_cam = grasp_data.get(\"pred_grasps_cam\", {})\n", - " scores = grasp_data.get(\"scores\", {})\n", - " contact_pts = grasp_data.get(\"contact_pts\", {})\n", - " gripper_openings = grasp_data.get(\"gripper_openings\", {})\n", - "\n", - " # Check if we have valid grasp data\n", - " total_grasps = (\n", - " sum(len(grasps) for grasps in pred_grasps_cam.values()) if pred_grasps_cam else 0\n", - " )\n", - "\n", - " if total_grasps > 0:\n", - " print(\n", - " f\"\ud83c\udfaf Launching 3D visualization with {total_grasps} ContactGraspNet grasps...\"\n", - " )\n", - " print(\"\ud83d\udccc Note: Close the 3D window to continue with the notebook\")\n", - "\n", - " # Use ContactGraspNet's native visualization - pass dictionaries directly\n", - " visualize_grasps(\n", - " full_pcd,\n", - " pred_grasps_cam, # Pass dictionary directly\n", - " scores, # Pass dictionary directly\n", - " gripper_openings=gripper_openings,\n", - " )\n", - "\n", - " print(\"\u2705 3D grasp visualization completed!\")\n", - " else:\n", - " print(\"\u26a0\ufe0f No valid grasps to visualize in 3D\")\n", - "\n", - " except Exception as e:\n", - " print(f\"\u274c Error in ContactGraspNet 3D visualization: {e}\")\n", - " print(\" Skipping 3D grasp visualization\")\n", - "else:\n", - " if not CONFIG[\"show_3d_visualizations\"]:\n", - " print(\"\u23ed\ufe0f 3D visualizations disabled in config\")\n", - " else:\n", - " print(\"\u26a0\ufe0f ContactGraspNet grasp generation disabled or no results\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 10. Additional 3D Visualizations" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 10.1 Full Scene Point Cloud" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83c\udf0d Launching full scene point cloud visualization (526,100 points)...\n", - "\ud83d\udccc Note: Close the 3D window to continue\n", - "Visualizing point cloud with 526100 points\n", - "\u2705 Full point cloud visualization completed!\n" - ] - } - ], - "source": [ - "if (\n", - " CONFIG[\"show_3d_visualizations\"]\n", - " and \"full_pointcloud\" in results\n", - " and results[\"full_pointcloud\"] is not None\n", - "):\n", - " full_pcd = results[\"full_pointcloud\"]\n", - " num_points = len(np.asarray(full_pcd.points))\n", - " print(f\"\ud83c\udf0d Launching full scene point cloud visualization ({num_points:,} points)...\")\n", - " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", - "\n", - " try:\n", - " visualize_pcd(\n", - " full_pcd,\n", - " window_name=\"Full Scene Point Cloud\",\n", - " point_size=2.0,\n", - " show_coordinate_frame=True,\n", - " )\n", - " print(\"\u2705 Full point cloud visualization completed!\")\n", - " except (KeyboardInterrupt, EOFError):\n", - " print(\"\u23ed\ufe0f Full point cloud visualization skipped\")\n", - "else:\n", - " print(\"\u26a0\ufe0f No full point cloud available for visualization\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 10.2 Background/Misc Clusters" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83e\udde9 Launching misc/background clusters visualization (13 clusters, 25,628 points)...\n", - "\ud83d\udccc Note: Close the 3D window to continue\n", - "Visualizing 13 clusters with 25628 total points\n", - "\u2705 Misc clusters visualization completed!\n" - ] - } - ], - "source": [ - "if CONFIG[\"show_3d_visualizations\"] and \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", - " misc_clusters = results[\"misc_clusters\"]\n", - " cluster_count = len(misc_clusters)\n", - " total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in misc_clusters)\n", - "\n", - " print(\n", - " f\"\ud83e\udde9 Launching misc/background clusters visualization ({cluster_count} clusters, {total_misc_points:,} points)...\"\n", - " )\n", - " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", - "\n", - " try:\n", - " visualize_clustered_point_clouds(\n", - " misc_clusters,\n", - " window_name=\"Misc/Background Clusters (DBSCAN)\",\n", - " point_size=3.0,\n", - " show_coordinate_frame=True,\n", - " )\n", - " print(\"\u2705 Misc clusters visualization completed!\")\n", - " except (KeyboardInterrupt, EOFError):\n", - " print(\"\u23ed\ufe0f Misc clusters visualization skipped\")\n", - "else:\n", - " print(\"\u26a0\ufe0f No misc clusters available for visualization\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 10.3 Voxel Grid Visualization" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udce6 Launching voxel grid visualization (2074 voxels)...\n", - "\ud83d\udccc Note: Close the 3D window to continue\n", - "Visualizing voxel grid with 2074 voxels\n", - "\u2705 Voxel grid visualization completed!\n" - ] - } - ], - "source": [ - "if (\n", - " CONFIG[\"show_3d_visualizations\"]\n", - " and \"misc_voxel_grid\" in results\n", - " and results[\"misc_voxel_grid\"] is not None\n", - "):\n", - " misc_voxel_grid = results[\"misc_voxel_grid\"]\n", - " voxel_count = len(misc_voxel_grid.get_voxels())\n", - "\n", - " print(f\"\ud83d\udce6 Launching voxel grid visualization ({voxel_count} voxels)...\")\n", - " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", - "\n", - " try:\n", - " visualize_voxel_grid(\n", - " misc_voxel_grid,\n", - " window_name=\"Misc/Background Voxel Grid\",\n", - " show_coordinate_frame=True,\n", - " )\n", - " print(\"\u2705 Voxel grid visualization completed!\")\n", - " except (KeyboardInterrupt, EOFError):\n", - " print(\"\u23ed\ufe0f Voxel grid visualization skipped\")\n", - " except Exception as e:\n", - " print(f\"\u274c Error in voxel grid visualization: {e}\")\n", - "else:\n", - " print(\"\u26a0\ufe0f No voxel grid available for visualization\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 11. Cleanup" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2025-06-25 15:05:01,624 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator cleaned up\u001b[0m\n", - "\u001b[32m2025-06-25 15:05:01,626 - dimos.perception.manip_aio_processor - INFO - ManipulationProcessor cleaned up\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u2705 Pipeline cleanup completed!\n", - "\n", - "\ud83c\udf89 Manipulation pipeline demo finished successfully!\n" - ] - } - ], - "source": [ - "# Clean up resources\n", - "processor.cleanup()\n", - "print(\"\u2705 Pipeline cleanup completed!\")\n", - "print(\"\\n\ud83c\udf89 Manipulation pipeline demo finished successfully!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "\n", - "## Summary\n", - "\n", - "This notebook demonstrated the complete DIMOS manipulation pipeline:\n", - "\n", - "1. **Object Detection**: Using Detic for 2D object detection\n", - "2. **Semantic Segmentation**: Using SAM/FastSAM for detailed segmentation\n", - "3. **Point Cloud Processing**: Converting RGB-D to 3D point clouds with filtering\n", - "4. **Background Analysis**: DBSCAN clustering of miscellaneous points\n", - "5. **Grasp Generation**: ContactGraspNet for 6-DoF grasp pose estimation\n", - "6. **Visualization**: Comprehensive 2D and 3D visualizations\n", - "\n", - "The pipeline is designed for real-time robotic manipulation tasks and provides rich visual feedback for debugging and analysis.\n", - "\n", - "### Key Features:\n", - "- \u2705 Modular design with clean interfaces\n", - "- \u2705 Real-time performance optimizations\n", - "- \u2705 Comprehensive error handling\n", - "- \u2705 Rich visualization capabilities\n", - "- \u2705 ContactGraspNet integration for state-of-the-art grasp generation\n", - "\n", - "### Next Steps:\n", - "- Integrate with robotic control systems\n", - "- Add grasp execution and feedback\n", - "- Implement multi-frame tracking\n", - "- Add custom object vocabularies\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "contact-graspnet", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.18" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 6665b28d7e793ad917be1c72a78c3fd7e0f6e517 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 26 Jun 2025 16:03:52 -0700 Subject: [PATCH 15/89] zed driver changes, saving 3d pointcloud --- tests/test_manipulation_perception_pipeline.py.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_manipulation_perception_pipeline.py.py b/tests/test_manipulation_perception_pipeline.py.py index 9a2bc9d371..8b333ec310 100644 --- a/tests/test_manipulation_perception_pipeline.py.py +++ b/tests/test_manipulation_perception_pipeline.py.py @@ -81,7 +81,7 @@ def main(): # Configuration min_confidence = 0.6 web_port = 5555 - grasp_server_url = "ws://10.0.0.125:8000/ws/grasp" + grasp_server_url = "ws://18.224.39.74:8000/ws/grasp" try: # Initialize ZED camera stream From e554fca7c69a357356474a679dc1ecc0c34a9634 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 9 Jul 2025 14:07:58 -0700 Subject: [PATCH 16/89] fixes --- dimos/perception/segmentation/sam_2d_seg.py | 51 ++++----------------- 1 file changed, 10 insertions(+), 41 deletions(-) diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index 91831d947e..fcf27584e6 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -47,7 +47,6 @@ def __init__( use_tracker=True, use_analyzer=True, use_rich_labeling=False, - model_type="auto", # "auto", "fastsam", "sam2" ): self.device = device if is_cuda_available(): @@ -64,24 +63,6 @@ def __init__( self.use_analyzer = use_analyzer self.use_rich_labeling = use_rich_labeling - # Determine model type automatically if needed - if model_type == "auto": - if "sam2" in model_path.lower() or "sam_2" in model_path.lower(): - self.model_type = "sam2" - elif "fastsam" in model_path.lower(): - self.model_type = "fastsam" - else: - # Default to FastSAM for backward compatibility - self.model_type = "fastsam" - else: - self.model_type = model_type - - # Initialize the appropriate model - if self.model_type == "sam2": - self.model = SAM(model_path) - else: - self.model = FastSAM(model_path) - module_dir = os.path.dirname(__file__) self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") @@ -113,28 +94,16 @@ def __init__( def process_image(self, image): """Process an image and return segmentation results.""" - if self.model_type == "sam2": - # For SAM 2, use segment everything mode - results = self.model( - source=image, - device=self.device, - save=False, - conf=0.4, - iou=0.9, - verbose=False, - ) - else: - # For FastSAM, use the original tracking approach - results = self.model.track( - source=image, - device=self.device, - retina_masks=True, - conf=0.6, - iou=0.9, - persist=True, - verbose=False, - tracker=self.tracker_config, - ) + results = self.model.track( + source=image, + device=self.device, + retina_masks=True, + conf=0.6, + iou=0.9, + persist=True, + verbose=False, + tracker=self.tracker_config, + ) if len(results) > 0: # Get initial segmentation results From b7c7a57769e97e6dfe2d3e57bea613efebd3d57b Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 9 Jul 2025 17:08:10 -0700 Subject: [PATCH 17/89] first commit --- dimos/manipulation/ibvs/detection3d.py | 226 ++++++++++++++++++ .../manip_aio_pipeline.py | 0 .../manip_aio_processer.py | 0 dimos/perception/detection2d/detic_2d_det.py | 2 +- dimos/perception/pointcloud/utils.py | 86 ++++++- tests/test_ibvs.py | 166 +++++++++++++ ...est_manipulation_perception_pipeline.py.py | 2 +- ...test_manipulation_pipeline_single_frame.py | 2 +- ..._manipulation_pipeline_single_frame_lcm.py | 2 +- tests/test_pointcloud_filtering.py | 2 +- 10 files changed, 482 insertions(+), 6 deletions(-) create mode 100644 dimos/manipulation/ibvs/detection3d.py rename dimos/{perception => manipulation}/manip_aio_pipeline.py (100%) rename dimos/{perception => manipulation}/manip_aio_processer.py (100%) create mode 100644 tests/test_ibvs.py diff --git a/dimos/manipulation/ibvs/detection3d.py b/dimos/manipulation/ibvs/detection3d.py new file mode 100644 index 0000000000..5cb5352e90 --- /dev/null +++ b/dimos/manipulation/ibvs/detection3d.py @@ -0,0 +1,226 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Real-time 3D object detection processor that extracts object poses from RGB-D data. +""" + +import time +from typing import Dict, List, Optional, Any +import numpy as np +import cv2 + +from dimos.utils.logging_config import setup_logger +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.perception.pointcloud.utils import extract_centroids_from_masks +from dimos.perception.detection2d.utils import plot_results + +logger = setup_logger("dimos.perception.detection3d") + + +class Detection3DProcessor: + """ + Real-time 3D detection processor optimized for speed. + + Uses Sam (FastSAM) for segmentation and mask generation, then extracts + 3D centroids and orientations from depth data. + """ + + def __init__( + self, + camera_intrinsics: List[float], # [fx, fy, cx, cy] + min_confidence: float = 0.6, + min_points: int = 30, # Reduced for speed + max_depth: float = 5.0, # Reduced for typical manipulation scenarios + ): + """ + Initialize the real-time 3D detection processor. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + min_points: Minimum 3D points required for valid detection + max_depth: Maximum valid depth in meters + """ + self.camera_intrinsics = camera_intrinsics + self.min_points = min_points + self.max_depth = max_depth + + # Initialize Sam segmenter with tracking enabled but analysis disabled + self.detector = Sam2DSegmenter( + use_tracker=False, + use_analyzer=False, + device="cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu", + ) + + # Store confidence threshold for filtering + self.min_confidence = min_confidence + + logger.info( + f"Initialized Detection3DProcessor with Sam segmenter, confidence={min_confidence}, " + f"min_points={min_points}, max_depth={max_depth}m" + ) + + def process_frame(self, rgb_image: np.ndarray, depth_image: np.ndarray) -> Dict[str, Any]: + """ + Process a single RGB-D frame to extract 3D object detections. + Optimized for real-time performance. + + Args: + rgb_image: RGB image (H, W, 3) + depth_image: Depth image (H, W) in meters + + Returns: + Dictionary containing: + - detections: List of detection dictionaries with: + - bbox: 2D bounding box [x1, y1, x2, y2] + - class_name: Object class name + - confidence: Detection confidence + - centroid: 3D centroid [x, y, z] in camera frame + - orientation: Unit vector from camera to object + - num_points: Number of valid 3D points + - track_id: Tracking ID + - processing_time: Total processing time in seconds + """ + start_time = time.time() + + # Convert RGB to BGR for Sam (OpenCV format) + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Run Sam segmentation with tracking + masks, bboxes, track_ids, probs, names = self.detector.process_image(bgr_image) + + # Early exit if no detections + if not masks or len(masks) == 0: + return {"detections": [], "processing_time": time.time() - start_time} + + # Convert CUDA tensors to numpy arrays if needed + numpy_masks = [] + for mask in masks: + if hasattr(mask, "cpu"): # PyTorch tensor + numpy_masks.append(mask.cpu().numpy()) + else: # Already numpy array + numpy_masks.append(mask) + + # Extract 3D centroids from masks + poses = extract_centroids_from_masks( + rgb_image=rgb_image, + depth_image=depth_image, + masks=numpy_masks, + camera_intrinsics=self.camera_intrinsics, + min_points=self.min_points, + max_depth=self.max_depth, + ) + + # Build detection results + detections = [] + pose_dict = {p["mask_idx"]: p for p in poses} + + for i, (bbox, name, prob, track_id) in enumerate(zip(bboxes, names, probs, track_ids)): + detection = { + "bbox": bbox.tolist() if isinstance(bbox, np.ndarray) else bbox, + "class_name": name, + "confidence": float(prob), + "track_id": track_id, + } + + # Add 3D pose if available + if i in pose_dict: + pose = pose_dict[i] + detection["centroid"] = pose["centroid"].tolist() + detection["orientation"] = pose["orientation"].tolist() + detection["num_points"] = pose["num_points"] + detection["has_3d"] = True + else: + detection["has_3d"] = False + + detections.append(detection) + + return {"detections": detections, "processing_time": time.time() - start_time} + + def visualize_detections( + self, rgb_image: np.ndarray, detections: List[Dict[str, Any]], show_3d: bool = True + ) -> np.ndarray: + """ + Fast visualization of detections with optional 3D info using plot_results. + + Args: + rgb_image: Original RGB image + detections: List of detection dictionaries + show_3d: Whether to show 3D centroids and orientations + + Returns: + Visualization image + """ + if not detections: + return rgb_image.copy() + + # Extract data for plot_results function + bboxes = [det["bbox"] for det in detections] + track_ids = [det.get("track_id", i) for i, det in enumerate(detections)] + class_ids = [i for i in range(len(detections))] # Use indices as class IDs + confidences = [det["confidence"] for det in detections] + names = [det["class_name"] for det in detections] + + # Use plot_results for basic visualization (bboxes and labels) + viz = plot_results(rgb_image, bboxes, track_ids, class_ids, confidences, names) + + # Add 3D centroids if requested + if show_3d: + for det in detections: + if det.get("has_3d", False): + # Project and draw centroid + centroid = np.array(det["centroid"]) + fx, fy, cx, cy = self.camera_intrinsics + + if centroid[2] > 0: + u = int(centroid[0] * fx / centroid[2] + cx) + v = int(centroid[1] * fy / centroid[2] + cy) + + # Draw centroid circle + cv2.circle(viz, (u, v), 6, (255, 0, 0), -1) + cv2.circle(viz, (u, v), 8, (255, 255, 255), 2) + + return viz + + def get_closest_detection( + self, detections: List[Dict[str, Any]], class_filter: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + Get the closest detection with valid 3D data. + + Args: + detections: List of detections + class_filter: Optional class name to filter by + + Returns: + Closest detection or None + """ + valid_detections = [ + d + for d in detections + if d.get("has_3d", False) and (class_filter is None or d["class_name"] == class_filter) + ] + + if not valid_detections: + return None + + # Sort by depth (Z coordinate) + return min(valid_detections, key=lambda d: d["centroid"][2]) + + def cleanup(self): + """Clean up resources.""" + if hasattr(self.detector, "cleanup"): + self.detector.cleanup() + logger.info("Detection3DProcessor cleaned up") diff --git a/dimos/perception/manip_aio_pipeline.py b/dimos/manipulation/manip_aio_pipeline.py similarity index 100% rename from dimos/perception/manip_aio_pipeline.py rename to dimos/manipulation/manip_aio_pipeline.py diff --git a/dimos/perception/manip_aio_processer.py b/dimos/manipulation/manip_aio_processer.py similarity index 100% rename from dimos/perception/manip_aio_processer.py rename to dimos/manipulation/manip_aio_processer.py diff --git a/dimos/perception/detection2d/detic_2d_det.py b/dimos/perception/detection2d/detic_2d_det.py index fc81526ad2..8bc4f9c4b0 100644 --- a/dimos/perception/detection2d/detic_2d_det.py +++ b/dimos/perception/detection2d/detic_2d_det.py @@ -15,6 +15,7 @@ import numpy as np import os import sys +from dimos.perception.detection2d.utils import plot_results # Add Detic to Python path detic_path = os.path.join(os.path.dirname(__file__), "..", "..", "models", "Detic") @@ -404,7 +405,6 @@ def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, na Returns: Image with visualized detections """ - from dimos.perception.detection2d.utils import plot_results return plot_results(image, bboxes, track_ids, class_ids, confidences, names) diff --git a/dimos/perception/pointcloud/utils.py b/dimos/perception/pointcloud/utils.py index 0813c2ca0e..3ee1ea3923 100644 --- a/dimos/perception/pointcloud/utils.py +++ b/dimos/perception/pointcloud/utils.py @@ -24,7 +24,7 @@ import os import cv2 import open3d as o3d -from typing import List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union, Dict, Any from scipy.spatial import cKDTree @@ -1080,3 +1080,87 @@ def combine_object_pointclouds( combined_pcd.colors = o3d.utility.Vector3dVector(np.vstack(all_colors)) return combined_pcd + + +def extract_centroids_from_masks( + rgb_image: np.ndarray, + depth_image: np.ndarray, + masks: List[np.ndarray], + camera_intrinsics: Union[List[float], np.ndarray], + min_points: int = 10, + max_depth: float = 10.0, +) -> List[Dict[str, Any]]: + """ + Extract 3D centroids and orientations from segmentation masks. + + Args: + rgb_image: RGB image (H, W, 3) + depth_image: Depth image (H, W) in meters + masks: List of boolean masks (H, W) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] or 3x3 matrix + min_points: Minimum number of valid 3D points required for a detection + max_depth: Maximum valid depth in meters + + Returns: + List of dictionaries containing: + - centroid: 3D centroid position [x, y, z] in camera frame + - orientation: Normalized direction vector from camera to centroid + - num_points: Number of valid 3D points + - mask_idx: Index of the mask in the input list + """ + # Extract camera parameters + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + else: + fx = camera_intrinsics[0, 0] + fy = camera_intrinsics[1, 1] + cx = camera_intrinsics[0, 2] + cy = camera_intrinsics[1, 2] + + results = [] + + for mask_idx, mask in enumerate(masks): + if mask is None or mask.sum() == 0: + continue + + # Get pixel coordinates where mask is True + y_coords, x_coords = np.where(mask) + + # Get depth values at mask locations + depths = depth_image[y_coords, x_coords] + + # Filter valid depths + valid_mask = (depths > 0) & (depths < max_depth) & np.isfinite(depths) + if valid_mask.sum() < min_points: + continue + + # Get valid coordinates and depths + valid_x = x_coords[valid_mask] + valid_y = y_coords[valid_mask] + valid_z = depths[valid_mask] + + # Convert to 3D points in camera frame + X = (valid_x - cx) * valid_z / fx + Y = (valid_y - cy) * valid_z / fy + Z = valid_z + + # Calculate centroid + centroid_x = np.mean(X) + centroid_y = np.mean(Y) + centroid_z = np.mean(Z) + centroid = np.array([centroid_x, centroid_y, centroid_z]) + + # Calculate orientation as normalized direction from camera origin to centroid + # Camera origin is at (0, 0, 0) + orientation = centroid / np.linalg.norm(centroid) + + results.append( + { + "centroid": centroid, + "orientation": orientation, + "num_points": int(valid_mask.sum()), + "mask_idx": mask_idx, + } + ) + + return results diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py new file mode 100644 index 0000000000..ea17275af2 --- /dev/null +++ b/tests/test_ibvs.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple test script for Detection3D processor with ZED camera. +Press 'q' to quit, 's' to save current frame. +""" + +import cv2 +import numpy as np +import sys +import os + +# Add parent directory to path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.hardware.zed_camera import ZEDCamera +from dimos.manipulation.ibvs.detection3d import Detection3DProcessor + +try: + import pyzed.sl as sl +except ImportError: + print("Error: ZED SDK not installed. Please install pyzed package.") + sys.exit(1) + + +def main(): + """Main test function.""" + print("Starting Detection3D test with ZED camera...") + + # Initialize ZED camera + print("Initializing ZED camera...") + zed_camera = ZEDCamera( + camera_id=0, + resolution=sl.RESOLUTION.HD720, # 1280x720 for good performance + depth_mode=sl.DEPTH_MODE.NEURAL, # Best quality depth + fps=30, + ) + + # Open camera + if not zed_camera.open(): + print("Failed to open ZED camera!") + return + + # Get camera intrinsics + camera_info = zed_camera.get_camera_info() + left_cam = camera_info.get("left_cam", {}) + + # Extract intrinsics [fx, fy, cx, cy] + intrinsics = [ + left_cam.get("fx", 700), + left_cam.get("fy", 700), + left_cam.get("cx", 640), + left_cam.get("cy", 360), + ] + + print( + f"Camera intrinsics: fx={intrinsics[0]:.1f}, fy={intrinsics[1]:.1f}, " + f"cx={intrinsics[2]:.1f}, cy={intrinsics[3]:.1f}" + ) + + # Initialize Detection3D processor + print("Initializing Detection3D processor...") + detector = Detection3DProcessor( + camera_intrinsics=intrinsics, + min_confidence=0.5, # Lower threshold for more detections + min_points=20, # Lower for better real-time performance + max_depth=3.0, # Limit to 3 meters + ) + + print("\nStarting detection loop...") + print("Press 'q' to quit, 's' to save current frame") + + frame_count = 0 + + try: + while True: + # Capture frame + left_img, right_img, depth = zed_camera.capture_frame() + + if left_img is None or depth is None: + print("Failed to capture frame") + continue + + # Convert BGR to RGB for detection + rgb_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB) + + # Process frame + results = detector.process_frame(rgb_img, depth) + + # Create visualization + viz = detector.visualize_detections(rgb_img, results["detections"], show_3d=True) + + # Convert back to BGR for OpenCV display + viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) + + # Add info text + info_text = [ + f"Frame: {frame_count}", + f"Detections: {len(results['detections'])}", + f"3D Valid: {sum(1 for d in results['detections'] if d.get('has_3d', False))}", + f"Time: {results['processing_time'] * 1000:.1f}ms", + ] + + y_offset = 20 + for text in info_text: + cv2.putText( + viz_bgr, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2 + ) + y_offset += 25 + + # Find closest detection + closest = detector.get_closest_detection(results["detections"]) + if closest: + text = f"Closest: {closest['class_name']} @ {closest['centroid'][2]:.2f}m" + cv2.putText( + viz_bgr, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2 + ) + + # Display + cv2.imshow("Detection3D Test", viz_bgr) + + # Handle key press + key = cv2.waitKey(1) & 0xFF + if key == ord("q"): + break + elif key == ord("s"): + # Save current frame + cv2.imwrite(f"detection3d_frame_{frame_count:04d}.png", viz_bgr) + print(f"Saved frame {frame_count}") + + frame_count += 1 + + # Print detections every 30 frames + if frame_count % 30 == 0: + print(f"\nFrame {frame_count}:") + for det in results["detections"]: + if det.get("has_3d", False): + print(f" - {det['class_name']}: {det['centroid'][2]:.2f}m away") + + except KeyboardInterrupt: + print("\nInterrupted by user") + + finally: + # Cleanup + print("\nCleaning up...") + cv2.destroyAllWindows() + detector.cleanup() + zed_camera.close() + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/tests/test_manipulation_perception_pipeline.py.py b/tests/test_manipulation_perception_pipeline.py.py index 8b333ec310..227f991650 100644 --- a/tests/test_manipulation_perception_pipeline.py.py +++ b/tests/test_manipulation_perception_pipeline.py.py @@ -36,7 +36,7 @@ from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.logging_config import logger -from dimos.perception.manip_aio_pipeline import ManipulationPipeline +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline def monitor_grasps(pipeline): diff --git a/tests/test_manipulation_pipeline_single_frame.py b/tests/test_manipulation_pipeline_single_frame.py index 061eb9035e..629ba4dbee 100644 --- a/tests/test_manipulation_pipeline_single_frame.py +++ b/tests/test_manipulation_pipeline_single_frame.py @@ -34,7 +34,7 @@ import open3d as o3d from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid -from dimos.perception.manip_aio_processer import ManipulationProcessor +from dimos.manipulation.manip_aio_processer import ManipulationProcessor from dimos.perception.pointcloud.utils import ( load_camera_matrix_from_yaml, visualize_pcd, diff --git a/tests/test_manipulation_pipeline_single_frame_lcm.py b/tests/test_manipulation_pipeline_single_frame_lcm.py index 635f82c9c9..7b57887ddc 100644 --- a/tests/test_manipulation_pipeline_single_frame_lcm.py +++ b/tests/test_manipulation_pipeline_single_frame_lcm.py @@ -42,7 +42,7 @@ from lcm_msgs.sensor_msgs import CameraInfo as LCMCameraInfo from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid -from dimos.perception.manip_aio_processer import ManipulationProcessor +from dimos.manipulation.manip_aio_processer import ManipulationProcessor from dimos.perception.grasp_generation.utils import visualize_grasps_3d from dimos.perception.pointcloud.utils import visualize_pcd from dimos.utils.logging_config import setup_logger diff --git a/tests/test_pointcloud_filtering.py b/tests/test_pointcloud_filtering.py index 308b4fc6ac..57a1cb5b00 100644 --- a/tests/test_pointcloud_filtering.py +++ b/tests/test_pointcloud_filtering.py @@ -23,7 +23,7 @@ from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.logging_config import logger -from dimos.perception.manip_aio_pipeline import ManipulationPipeline +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline def main(): From 5c86643d12419412b959d1cde224197db8f64ac2 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Sat, 12 Jul 2025 00:28:37 -0700 Subject: [PATCH 18/89] PBVS fully working in the correct coordinate frame --- dimos/hardware/zed_camera.py | 212 ++++++++- dimos/manipulation/ibvs/detection3d.py | 234 ++++++--- dimos/manipulation/ibvs/pbvs.py | 499 ++++++++++++++++++++ dimos/manipulation/ibvs/utils.py | 332 +++++++++++++ dimos/perception/common/utils.py | 35 +- dimos/perception/segmentation/sam_2d_seg.py | 2 +- tests/test_ibvs.py | 260 ++++++---- 7 files changed, 1417 insertions(+), 157 deletions(-) create mode 100644 dimos/manipulation/ibvs/pbvs.py create mode 100644 dimos/manipulation/ibvs/utils.py diff --git a/dimos/hardware/zed_camera.py b/dimos/hardware/zed_camera.py index b93d2577e6..ba936cec3a 100644 --- a/dimos/hardware/zed_camera.py +++ b/dimos/hardware/zed_camera.py @@ -84,6 +84,12 @@ def __init__( self.point_cloud = sl.Mat() self.confidence_map = sl.Mat() + # Positional tracking + self.tracking_enabled = False + self.tracking_params = sl.PositionalTrackingParameters() + self.camera_pose = sl.Pose() + self.sensors_data = sl.SensorsData() + self.is_opened = False def open(self) -> bool: @@ -109,12 +115,160 @@ def open(self) -> bool: logger.error(f"Error opening ZED camera: {e}") return False - def close(self): - """Close the ZED camera.""" - if self.is_opened: - self.zed.close() - self.is_opened = False - logger.info("ZED camera closed") + def enable_positional_tracking( + self, + enable_area_memory: bool = False, + enable_pose_smoothing: bool = True, + enable_imu_fusion: bool = True, + set_floor_as_origin: bool = False, + initial_world_transform: Optional[sl.Transform] = None, + ) -> bool: + """ + Enable positional tracking on the ZED camera. + + Args: + enable_area_memory: Enable area learning to correct tracking drift + enable_pose_smoothing: Enable pose smoothing + enable_imu_fusion: Enable IMU fusion if available + set_floor_as_origin: Set the floor as origin (useful for robotics) + initial_world_transform: Initial world transform + + Returns: + True if tracking enabled successfully + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return False + + try: + # Configure tracking parameters + self.tracking_params.enable_area_memory = enable_area_memory + self.tracking_params.enable_pose_smoothing = enable_pose_smoothing + self.tracking_params.enable_imu_fusion = enable_imu_fusion + self.tracking_params.set_floor_as_origin = set_floor_as_origin + + if initial_world_transform is not None: + self.tracking_params.initial_world_transform = initial_world_transform + + # Enable tracking + err = self.zed.enable_positional_tracking(self.tracking_params) + if err != sl.ERROR_CODE.SUCCESS: + logger.error(f"Failed to enable positional tracking: {err}") + return False + + self.tracking_enabled = True + logger.info("Positional tracking enabled successfully") + return True + + except Exception as e: + logger.error(f"Error enabling positional tracking: {e}") + return False + + def disable_positional_tracking(self): + """Disable positional tracking.""" + if self.tracking_enabled: + self.zed.disable_positional_tracking() + self.tracking_enabled = False + logger.info("Positional tracking disabled") + + def get_pose( + self, reference_frame: sl.REFERENCE_FRAME = sl.REFERENCE_FRAME.WORLD + ) -> Optional[Dict[str, Any]]: + """ + Get the current camera pose. + + Args: + reference_frame: Reference frame (WORLD or CAMERA) + + Returns: + Dictionary containing: + - position: [x, y, z] in meters + - rotation: [x, y, z, w] quaternion + - euler_angles: [roll, pitch, yaw] in radians + - timestamp: Pose timestamp in nanoseconds + - confidence: Tracking confidence (0-100) + - valid: Whether pose is valid + """ + if not self.tracking_enabled: + logger.error("Positional tracking not enabled") + return None + + try: + # Get current pose + tracking_state = self.zed.get_position(self.camera_pose, reference_frame) + + if tracking_state == sl.POSITIONAL_TRACKING_STATE.OK: + # Extract translation + translation = self.camera_pose.get_translation().get() + + # Extract rotation (quaternion) + rotation = self.camera_pose.get_orientation().get() + + # Get Euler angles + euler = self.camera_pose.get_euler_angles() + + return { + "position": translation.tolist(), + "rotation": rotation.tolist(), # [x, y, z, w] + "euler_angles": euler.tolist(), # [roll, pitch, yaw] + "timestamp": self.camera_pose.timestamp.get_nanoseconds(), + "confidence": self.camera_pose.pose_confidence, + "valid": True, + "tracking_state": str(tracking_state), + } + else: + logger.warning(f"Tracking state: {tracking_state}") + return {"valid": False, "tracking_state": str(tracking_state)} + + except Exception as e: + logger.error(f"Error getting pose: {e}") + return None + + def get_imu_data(self) -> Optional[Dict[str, Any]]: + """ + Get IMU sensor data if available. + + Returns: + Dictionary containing: + - orientation: IMU orientation quaternion [x, y, z, w] + - angular_velocity: [x, y, z] in rad/s + - linear_acceleration: [x, y, z] in m/s² + - timestamp: IMU data timestamp + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None + + try: + # Get sensors data synchronized with images + if ( + self.zed.get_sensors_data(self.sensors_data, sl.TIME_REFERENCE.IMAGE) + == sl.ERROR_CODE.SUCCESS + ): + imu = self.sensors_data.get_imu_data() + + # Get IMU orientation + imu_orientation = imu.get_pose().get_orientation().get() + + # Get angular velocity + angular_vel = imu.get_angular_velocity() + + # Get linear acceleration + linear_accel = imu.get_linear_acceleration() + + return { + "orientation": imu_orientation.tolist(), + "angular_velocity": angular_vel.tolist(), + "linear_acceleration": linear_accel.tolist(), + "timestamp": self.sensors_data.timestamp.get_nanoseconds(), + "temperature": self.sensors_data.temperature.get(sl.SENSOR_LOCATION.IMU), + } + else: + return None + + except Exception as e: + logger.error(f"Error getting IMU data: {e}") + return None def capture_frame( self, @@ -211,6 +365,52 @@ def capture_pointcloud(self) -> Optional[o3d.geometry.PointCloud]: logger.error(f"Error capturing point cloud: {e}") return None + def capture_frame_with_pose( + self, + ) -> Tuple[ + Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[Dict[str, Any]] + ]: + """ + Capture a frame with synchronized pose data. + + Returns: + Tuple of (left_image, right_image, depth_map, pose_data) + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None, None, None, None + + try: + # Grab frame + if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: + # Get images and depth + left_img, right_img, depth = self.capture_frame() + + # Get synchronized pose if tracking is enabled + pose_data = None + if self.tracking_enabled: + pose_data = self.get_pose() + + return left_img, right_img, depth, pose_data + else: + logger.warning("Failed to grab frame from ZED camera") + return None, None, None, None + + except Exception as e: + logger.error(f"Error capturing frame with pose: {e}") + return None, None, None, None + + def close(self): + """Close the ZED camera.""" + if self.is_opened: + # Disable tracking if enabled + if self.tracking_enabled: + self.disable_positional_tracking() + + self.zed.close() + self.is_opened = False + logger.info("ZED camera closed") + def get_camera_info(self) -> Dict[str, Any]: """Get ZED camera information and calibration parameters.""" if not self.is_opened: diff --git a/dimos/manipulation/ibvs/detection3d.py b/dimos/manipulation/ibvs/detection3d.py index 5cb5352e90..aca0169bf6 100644 --- a/dimos/manipulation/ibvs/detection3d.py +++ b/dimos/manipulation/ibvs/detection3d.py @@ -24,7 +24,12 @@ from dimos.utils.logging_config import setup_logger from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter from dimos.perception.pointcloud.utils import extract_centroids_from_masks -from dimos.perception.detection2d.utils import plot_results +from dimos.perception.detection2d.utils import plot_results, calculate_object_size_from_bbox + +from dimos.types.pose import Pose +from dimos.types.vector import Vector +from dimos.types.manipulation import ObjectData +from dimos.manipulation.ibvs.utils import estimate_object_depth logger = setup_logger("dimos.perception.detection3d") @@ -34,15 +39,15 @@ class Detection3DProcessor: Real-time 3D detection processor optimized for speed. Uses Sam (FastSAM) for segmentation and mask generation, then extracts - 3D centroids and orientations from depth data. + 3D centroids from depth data. """ def __init__( self, camera_intrinsics: List[float], # [fx, fy, cx, cy] min_confidence: float = 0.6, - min_points: int = 30, # Reduced for speed - max_depth: float = 5.0, # Reduced for typical manipulation scenarios + min_points: int = 30, + max_depth: float = 5.0, ): """ Initialize the real-time 3D detection processor. @@ -72,25 +77,20 @@ def __init__( f"min_points={min_points}, max_depth={max_depth}m" ) - def process_frame(self, rgb_image: np.ndarray, depth_image: np.ndarray) -> Dict[str, Any]: + def process_frame( + self, rgb_image: np.ndarray, depth_image: np.ndarray, camera_pose: Optional[Any] = None + ) -> Dict[str, Any]: """ Process a single RGB-D frame to extract 3D object detections. - Optimized for real-time performance. Args: rgb_image: RGB image (H, W, 3) depth_image: Depth image (H, W) in meters + camera_pose: Optional camera pose in world frame (Pose object in ZED coordinates) Returns: Dictionary containing: - - detections: List of detection dictionaries with: - - bbox: 2D bounding box [x1, y1, x2, y2] - - class_name: Object class name - - confidence: Detection confidence - - centroid: 3D centroid [x, y, z] in camera frame - - orientation: Unit vector from camera to object - - num_points: Number of valid 3D points - - track_id: Tracking ID + - detections: List of ObjectData objects with 3D pose information - processing_time: Total processing time in seconds """ start_time = time.time() @@ -128,37 +128,117 @@ def process_frame(self, rgb_image: np.ndarray, depth_image: np.ndarray) -> Dict[ pose_dict = {p["mask_idx"]: p for p in poses} for i, (bbox, name, prob, track_id) in enumerate(zip(bboxes, names, probs, track_ids)): - detection = { + # Create ObjectData object + obj_data: ObjectData = { + "object_id": track_id, "bbox": bbox.tolist() if isinstance(bbox, np.ndarray) else bbox, - "class_name": name, "confidence": float(prob), - "track_id": track_id, + "label": name, + "movement_tolerance": 1.0, # Default to freely movable + "segmentation_mask": numpy_masks[i] if i < len(numpy_masks) else np.array([]), } # Add 3D pose if available if i in pose_dict: pose = pose_dict[i] - detection["centroid"] = pose["centroid"].tolist() - detection["orientation"] = pose["orientation"].tolist() - detection["num_points"] = pose["num_points"] - detection["has_3d"] = True - else: - detection["has_3d"] = False - - detections.append(detection) + obj_cam_pos = pose["centroid"] + + # Set depth and position in camera frame + obj_data["depth"] = float(obj_cam_pos[2]) + + obj_data["rotation"] = None + + # Calculate object size from bbox and depth + width_m, height_m = calculate_object_size_from_bbox( + bbox, obj_cam_pos[2], self.camera_intrinsics + ) + + # Calculate depth dimension using segmentation mask + depth_m = estimate_object_depth( + depth_image, numpy_masks[i] if i < len(numpy_masks) else None, bbox + ) + + obj_data["size"] = { + "width": max(width_m, 0.01), # Minimum 1cm width + "height": max(height_m, 0.01), # Minimum 1cm height + "depth": max(depth_m, 0.01), # Minimum 1cm depth + } + + # Extract average color from the region + x1, y1, x2, y2 = map(int, bbox) + roi = rgb_image[y1:y2, x1:x2] + if roi.size > 0: + avg_color = np.mean(roi.reshape(-1, 3), axis=0) + obj_data["color"] = avg_color.astype(np.uint8) + else: + obj_data["color"] = np.array([128, 128, 128], dtype=np.uint8) + + # Transform to world frame if camera pose is available + if camera_pose is not None: + world_pos = self._transform_to_world(obj_cam_pos, camera_pose) + obj_data["world_position"] = world_pos + obj_data["position"] = world_pos # Use world position + else: + # If no camera pose, use camera coordinates + obj_data["position"] = Vector(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]) + + detections.append(obj_data) return {"detections": detections, "processing_time": time.time() - start_time} + def _transform_to_world(self, obj_pos: np.ndarray, camera_pose: Pose) -> Vector: + """ + Transform object position from camera frame to world frame (ZED coordinates). + + Args: + obj_pos: Object position in camera frame [x, y, z] + camera_pose: Camera pose in world frame + + Returns: + Object position in world frame as Vector + """ + # Simple transformation: rotate and translate + roll = camera_pose.rot.x + pitch = camera_pose.rot.y + yaw = camera_pose.rot.z + + # Create rotation matrices + cos_roll = np.cos(roll) + sin_roll = np.sin(roll) + R_x = np.array([[1, 0, 0], [0, cos_roll, -sin_roll], [0, sin_roll, cos_roll]]) + + cos_pitch = np.cos(pitch) + sin_pitch = np.sin(pitch) + R_y = np.array([[cos_pitch, 0, sin_pitch], [0, 1, 0], [-sin_pitch, 0, cos_pitch]]) + + cos_yaw = np.cos(yaw) + sin_yaw = np.sin(yaw) + R_z = np.array([[cos_yaw, -sin_yaw, 0], [sin_yaw, cos_yaw, 0], [0, 0, 1]]) + + # Combined rotation (ZYX convention) + rot_matrix = R_z @ R_y @ R_x + + # Rotate object position + rotated_pos = rot_matrix @ obj_pos + + # Translate by camera position + world_pos = camera_pose.pos + Vector(rotated_pos[0], rotated_pos[1], rotated_pos[2]) + + return world_pos + def visualize_detections( - self, rgb_image: np.ndarray, detections: List[Dict[str, Any]], show_3d: bool = True + self, + rgb_image: np.ndarray, + detections: List[ObjectData], + pbvs_controller: Optional[Any] = None, ) -> np.ndarray: """ - Fast visualization of detections with optional 3D info using plot_results. + Visualize detections with 3D position overlay next to bounding boxes. Args: rgb_image: Original RGB image - detections: List of detection dictionaries - show_3d: Whether to show 3D centroids and orientations + detections: List of ObjectData objects + pbvs_controller: Optional PBVS controller to get robot frame coordinates Returns: Visualization image @@ -168,56 +248,102 @@ def visualize_detections( # Extract data for plot_results function bboxes = [det["bbox"] for det in detections] - track_ids = [det.get("track_id", i) for i, det in enumerate(detections)] - class_ids = [i for i in range(len(detections))] # Use indices as class IDs + track_ids = [det.get("object_id", i) for i, det in enumerate(detections)] + class_ids = [i for i in range(len(detections))] confidences = [det["confidence"] for det in detections] - names = [det["class_name"] for det in detections] + names = [det["label"] for det in detections] - # Use plot_results for basic visualization (bboxes and labels) + # Use plot_results for basic visualization viz = plot_results(rgb_image, bboxes, track_ids, class_ids, confidences, names) - # Add 3D centroids if requested - if show_3d: - for det in detections: - if det.get("has_3d", False): - # Project and draw centroid - centroid = np.array(det["centroid"]) - fx, fy, cx, cy = self.camera_intrinsics - - if centroid[2] > 0: - u = int(centroid[0] * fx / centroid[2] + cx) - v = int(centroid[1] * fy / centroid[2] + cy) - - # Draw centroid circle - cv2.circle(viz, (u, v), 6, (255, 0, 0), -1) - cv2.circle(viz, (u, v), 8, (255, 255, 255), 2) + # Add 3D position overlay next to bounding boxes + fx, fy, cx, cy = self.camera_intrinsics + + for det in detections: + if "position" in det and "bbox" in det: + # Get position to display (robot frame if available, otherwise world frame) + world_position = det["position"] + display_position = world_position + frame_label = "" + + # Check if we should display robot frame coordinates + if pbvs_controller and pbvs_controller.manipulator_origin is not None: + robot_frame_data = pbvs_controller.get_object_pose_robot_frame(world_position) + if robot_frame_data: + display_position, _ = robot_frame_data + frame_label = "[R]" # Robot frame indicator + + bbox = det["bbox"] + + if isinstance(display_position, Vector): + display_xyz = np.array( + [display_position.x, display_position.y, display_position.z] + ) + else: + display_xyz = np.array( + [display_position["x"], display_position["y"], display_position["z"]] + ) + + # Get bounding box coordinates + x1, y1, x2, y2 = map(int, bbox) + + # Add position text next to bounding box (top-right corner) + pos_text = f"{frame_label}({display_xyz[0]:.2f}, {display_xyz[1]:.2f}, {display_xyz[2]:.2f})" + text_x = x2 + 5 # Right edge of bbox + small offset + text_y = y1 + 15 # Top edge of bbox + small offset + + # Add background rectangle for better readability + text_size = cv2.getTextSize(pos_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0] + cv2.rectangle( + viz, + (text_x - 2, text_y - text_size[1] - 2), + (text_x + text_size[0] + 2, text_y + 2), + (0, 0, 0), + -1, + ) + + cv2.putText( + viz, + pos_text, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) return viz def get_closest_detection( - self, detections: List[Dict[str, Any]], class_filter: Optional[str] = None - ) -> Optional[Dict[str, Any]]: + self, detections: List[ObjectData], class_filter: Optional[str] = None + ) -> Optional[ObjectData]: """ Get the closest detection with valid 3D data. Args: - detections: List of detections + detections: List of ObjectData objects class_filter: Optional class name to filter by Returns: - Closest detection or None + Closest ObjectData or None """ valid_detections = [ d for d in detections - if d.get("has_3d", False) and (class_filter is None or d["class_name"] == class_filter) + if "position" in d and (class_filter is None or d["label"] == class_filter) ] if not valid_detections: return None # Sort by depth (Z coordinate) - return min(valid_detections, key=lambda d: d["centroid"][2]) + def get_z_coord(d): + pos = d["position"] + if isinstance(pos, Vector): + return abs(pos.z) + return abs(pos["z"]) + + return min(valid_detections, key=get_z_coord) def cleanup(self): """Clean up resources.""" diff --git a/dimos/manipulation/ibvs/pbvs.py b/dimos/manipulation/ibvs/pbvs.py new file mode 100644 index 0000000000..7fbf828535 --- /dev/null +++ b/dimos/manipulation/ibvs/pbvs.py @@ -0,0 +1,499 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Position-Based Visual Servoing (PBVS) controller for eye-in-hand configuration. +Works with manipulator frame origin and proper robot arm conventions. +""" + +import numpy as np +from typing import Optional, Tuple, Dict, Any, List +import cv2 + +from dimos.types.pose import Pose +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger +from dimos.manipulation.ibvs.utils import ( + pose_to_transform_matrix, + apply_transform, + zed_to_robot_convention, + calculate_yaw_to_origin, +) + +logger = setup_logger("dimos.manipulation.pbvs") + + +class PBVSController: + """ + Position-Based Visual Servoing controller for eye-in-hand cameras. + Supports manipulator frame origin and robot arm conventions. + + Handles: + - Position and orientation error computation + - Velocity command generation with gain control + - Automatic target tracking across frames + - Frame transformations from ZED to robot conventions + """ + + def __init__( + self, + position_gain: float = 0.5, + rotation_gain: float = 0.3, + max_velocity: float = 0.1, # m/s + max_angular_velocity: float = 0.5, # rad/s + target_tolerance: float = 0.05, # 5cm + tracking_distance_threshold: float = 0.1, # 10cm for target tracking + ): + """ + Initialize PBVS controller. + + Args: + position_gain: Proportional gain for position control + rotation_gain: Proportional gain for rotation control + max_velocity: Maximum linear velocity command magnitude (m/s) + max_angular_velocity: Maximum angular velocity command magnitude (rad/s) + target_tolerance: Distance threshold for considering target reached (m) + tracking_distance_threshold: Max distance for target association (m) + """ + self.position_gain = position_gain + self.rotation_gain = rotation_gain + self.max_velocity = max_velocity + self.max_angular_velocity = max_angular_velocity + self.target_tolerance = target_tolerance + self.tracking_distance_threshold = tracking_distance_threshold + + # State variables + self.current_target = None + self.last_position_error = None + self.last_rotation_error = None + self.last_velocity_cmd = None + self.last_angular_velocity_cmd = None + self.last_target_reached = False + + # Manipulator frame origin + self.manipulator_origin = None # Transform matrix from world to manipulator frame + self.manipulator_origin_pose = None # Original pose for reference + + logger.info( + f"Initialized PBVS controller: pos_gain={position_gain}, rot_gain={rotation_gain}, " + f"max_vel={max_velocity}m/s, max_ang_vel={max_angular_velocity}rad/s, " + f"target_tolerance={target_tolerance}m" + ) + + def set_manipulator_origin(self, camera_pose: Pose): + """ + Set the manipulator frame origin based on current camera pose. + This establishes the robot arm coordinate frame. + + Args: + camera_pose: Current camera pose in ZED world frame + """ + self.manipulator_origin_pose = camera_pose + + # Create transform matrix from ZED world to manipulator origin + # This is the inverse of the camera pose at origin + T_world_to_origin = pose_to_transform_matrix(camera_pose) + self.manipulator_origin = np.linalg.inv(T_world_to_origin) + + logger.info( + f"Set manipulator origin at pose: pos=({camera_pose.pos.x:.3f}, " + f"{camera_pose.pos.y:.3f}, {camera_pose.pos.z:.3f})" + ) + + # Update current target if exists + if self.current_target and "position" in self.current_target: + self._update_target_robot_frame() + + def _update_target_robot_frame(self): + """Update current target with robot frame coordinates.""" + if not self.current_target or "position" not in self.current_target: + return + + # Get target position in ZED world frame + target_pos = self.current_target["position"] + target_pose_zed = Pose(target_pos, Vector(0, 0, 0)) + + # Transform to manipulator frame + target_pose_manip = apply_transform(target_pose_zed, self.manipulator_origin) + + # Convert to robot convention + target_pose_robot = zed_to_robot_convention(target_pose_manip) + + # Calculate orientation pointing at origin (in robot frame) + yaw_to_origin = calculate_yaw_to_origin(target_pose_robot.pos) + + # Update target with robot frame pose + self.current_target["robot_position"] = target_pose_robot.pos + self.current_target["robot_rotation"] = Vector(0.0, 0.0, yaw_to_origin) # Level grasp + + def set_target(self, target_object: Dict[str, Any]) -> bool: + """ + Set a new target object for servoing. + Requires manipulator origin to be set. + + Args: + target_object: Object dict with at least 'position' field + + Returns: + True if target was set successfully, False if no origin set + """ + # Require origin to be set + if self.manipulator_origin is None: + logger.warning("Cannot set target: No manipulator origin set") + return False + + if target_object and "position" in target_object: + self.current_target = target_object + + # Update to robot frame + self._update_target_robot_frame() + + logger.info(f"New target set: ID {target_object.get('object_id', 'unknown')}") + return True + return False + + def clear_target(self): + """Clear the current target.""" + self.current_target = None + self.last_position_error = None + self.last_rotation_error = None + self.last_velocity_cmd = None + self.last_angular_velocity_cmd = None + self.last_target_reached = False + logger.info("Target cleared") + + def update_target_tracking(self, new_detections: List[Dict[str, Any]]) -> bool: + """ + Update target by matching to closest object in new detections. + + Args: + new_detections: List of newly detected objects + + Returns: + True if target was successfully tracked, False if lost + """ + if not self.current_target or "position" not in self.current_target: + return False + + if not new_detections: + logger.debug("No detections for target tracking") + return False + + # Get current target position (in ZED world frame for matching) + target_pos = self.current_target["position"] + if isinstance(target_pos, Vector): + target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) + else: + target_xyz = np.array([target_pos["x"], target_pos["y"], target_pos["z"]]) + + # Find closest match + min_distance = float("inf") + best_match = None + + for detection in new_detections: + if "position" not in detection: + continue + + det_pos = detection["position"] + if isinstance(det_pos, Vector): + det_xyz = np.array([det_pos.x, det_pos.y, det_pos.z]) + else: + det_xyz = np.array([det_pos["x"], det_pos["y"], det_pos["z"]]) + + distance = np.linalg.norm(target_xyz - det_xyz) + + if distance < min_distance and distance < self.tracking_distance_threshold: + min_distance = distance + best_match = detection + + if best_match: + self.current_target = best_match + # Update to robot frame + self._update_target_robot_frame() + return True + return False + + def compute_control( + self, camera_pose: Pose, new_detections: Optional[List[Dict[str, Any]]] = None + ) -> Tuple[Optional[Vector], Optional[Vector], bool, bool]: + """ + Compute PBVS control with position and orientation servoing. + + Args: + camera_pose: Current camera pose in ZED world frame + new_detections: Optional new detections for target tracking + + Returns: + Tuple of (velocity_command, angular_velocity_command, target_reached, has_target) + - velocity_command: Linear velocity vector or None if no target + - angular_velocity_command: Angular velocity vector or None if no target + - target_reached: True if within target tolerance + - has_target: True if currently tracking a target + """ + # Check if we have a target and origin + if not self.current_target or "position" not in self.current_target: + return None, None, False, False + + if self.manipulator_origin is None: + logger.warning("Cannot compute control: No manipulator origin set") + return None, None, False, False + + # Try to update target tracking if new detections provided + if new_detections is not None: + self.update_target_tracking(new_detections) + + # Transform camera pose to robot frame + camera_pose_manip = apply_transform(camera_pose, self.manipulator_origin) + camera_pose_robot = zed_to_robot_convention(camera_pose_manip) + + # Get target in robot frame + target_pos = self.current_target.get("robot_position") + target_rot = self.current_target.get("robot_rotation", Vector(0, 0, 0)) + + if target_pos is None: + # Shouldn't happen but handle gracefully + self._update_target_robot_frame() + target_pos = self.current_target.get("robot_position", Vector(0, 0, 0)) + + # Calculate position error (target - camera) + error = target_pos - camera_pose_robot.pos + self.last_position_error = error + + # Compute velocity command with proportional control + velocity_cmd = error * self.position_gain + + # Limit velocity magnitude + vel_magnitude = np.linalg.norm([velocity_cmd.x, velocity_cmd.y, velocity_cmd.z]) + if vel_magnitude > self.max_velocity: + scale = self.max_velocity / vel_magnitude + velocity_cmd = velocity_cmd * scale + + self.last_velocity_cmd = velocity_cmd + + # Compute angular velocity for orientation control + angular_velocity_cmd = self._compute_angular_velocity(target_rot, camera_pose_robot) + + # Check if target reached + error_magnitude = np.linalg.norm([error.x, error.y, error.z]) + target_reached = error_magnitude < self.target_tolerance + self.last_target_reached = target_reached + + # Clear target only if it's reached + if target_reached: + logger.info( + f"Target reached! Clearing target ID {self.current_target.get('object_id', 'unknown')}" + ) + self.clear_target() + + return velocity_cmd, angular_velocity_cmd, target_reached, True + + def _compute_angular_velocity(self, target_rot: Vector, current_pose: Pose) -> Vector: + """ + Compute angular velocity commands for orientation control. + Aims for level grasping with appropriate yaw. + + Args: + target_rot: Target orientation (roll, pitch, yaw) + current_pose: Current camera/EE pose + + Returns: + Angular velocity command as Vector + """ + # Calculate rotation errors + roll_error = target_rot.x - current_pose.rot.x + pitch_error = target_rot.y - current_pose.rot.y + yaw_error = target_rot.z - current_pose.rot.z + + # Normalize yaw error to [-pi, pi] + while yaw_error > np.pi: + yaw_error -= 2 * np.pi + while yaw_error < -np.pi: + yaw_error += 2 * np.pi + + self.last_rotation_error = Vector(roll_error, pitch_error, yaw_error) + + # Apply proportional control + angular_velocity = Vector( + roll_error * self.rotation_gain, + pitch_error * self.rotation_gain, + yaw_error * self.rotation_gain, + ) + + # Limit angular velocity magnitude + ang_vel_magnitude = np.sqrt( + angular_velocity.x**2 + angular_velocity.y**2 + angular_velocity.z**2 + ) + if ang_vel_magnitude > self.max_angular_velocity: + scale = self.max_angular_velocity / ang_vel_magnitude + angular_velocity = angular_velocity * scale + + self.last_angular_velocity_cmd = angular_velocity + + return angular_velocity + + def get_camera_pose_robot_frame(self, camera_pose_zed: Pose) -> Optional[Pose]: + """ + Get camera pose in robot frame coordinates. + + Args: + camera_pose_zed: Camera pose in ZED world frame + + Returns: + Camera pose in robot frame or None if no origin set + """ + if self.manipulator_origin is None: + return None + + camera_pose_manip = apply_transform(camera_pose_zed, self.manipulator_origin) + return zed_to_robot_convention(camera_pose_manip) + + def get_object_pose_robot_frame( + self, object_pos_zed: Vector + ) -> Optional[Tuple[Vector, Vector]]: + """ + Get object pose in robot frame coordinates with orientation. + + Args: + object_pos_zed: Object position in ZED world frame + + Returns: + Tuple of (position, rotation) in robot frame or None if no origin set + """ + if self.manipulator_origin is None: + return None + + # Transform position + obj_pose_zed = Pose(object_pos_zed, Vector(0, 0, 0)) + obj_pose_manip = apply_transform(obj_pose_zed, self.manipulator_origin) + obj_pose_robot = zed_to_robot_convention(obj_pose_manip) + + # Calculate orientation pointing at origin + yaw_to_origin = calculate_yaw_to_origin(obj_pose_robot.pos) + orientation = Vector(0.0, 0.0, yaw_to_origin) # Level grasp + + return obj_pose_robot.pos, orientation + + def create_status_overlay( + self, image: np.ndarray, camera_intrinsics: Optional[list] = None + ) -> np.ndarray: + """ + Create PBVS status overlay on image. + + Args: + image: Input image + camera_intrinsics: Optional [fx, fy, cx, cy] (not used) + + Returns: + Image with PBVS status overlay + """ + viz_img = image.copy() + height, width = image.shape[:2] + + # Status panel + if self.current_target: + panel_height = 140 # Increased for rotation display + panel_y = height - panel_height + overlay = viz_img.copy() + cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) + viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) + + # Status text + y = panel_y + 20 + cv2.putText( + viz_img, "PBVS Status", (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2 + ) + + # Add frame info + frame_text = ( + "Frame: Robot" if self.manipulator_origin is not None else "Frame: ZED World" + ) + cv2.putText( + viz_img, frame_text, (200, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 + ) + + if self.last_position_error: + error_mag = np.linalg.norm( + [ + self.last_position_error.x, + self.last_position_error.y, + self.last_position_error.z, + ] + ) + color = (0, 255, 0) if self.last_target_reached else (0, 255, 255) + + cv2.putText( + viz_img, + f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", + (10, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color, + 1, + ) + + cv2.putText( + viz_img, + f"XYZ: ({self.last_position_error.x:.3f}, {self.last_position_error.y:.3f}, {self.last_position_error.z:.3f})", + (10, y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + if self.last_velocity_cmd: + cv2.putText( + viz_img, + f"Lin Vel: ({self.last_velocity_cmd.x:.2f}, {self.last_velocity_cmd.y:.2f}, {self.last_velocity_cmd.z:.2f})m/s", + (10, y + 65), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 200, 0), + 1, + ) + + if self.last_rotation_error: + cv2.putText( + viz_img, + f"Rot Error: ({self.last_rotation_error.x:.2f}, {self.last_rotation_error.y:.2f}, {self.last_rotation_error.z:.2f})rad", + (10, y + 85), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + if self.last_angular_velocity_cmd: + cv2.putText( + viz_img, + f"Ang Vel: ({self.last_angular_velocity_cmd.x:.2f}, {self.last_angular_velocity_cmd.y:.2f}, {self.last_angular_velocity_cmd.z:.2f})rad/s", + (10, y + 105), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 200, 0), + 1, + ) + + if self.last_target_reached: + cv2.putText( + viz_img, + "TARGET REACHED", + (width - 150, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + return viz_img diff --git a/dimos/manipulation/ibvs/utils.py b/dimos/manipulation/ibvs/utils.py new file mode 100644 index 0000000000..cca1acbab5 --- /dev/null +++ b/dimos/manipulation/ibvs/utils.py @@ -0,0 +1,332 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import Dict, Any, Optional, Tuple, List +from dimos.types.pose import Pose +from dimos.types.vector import Vector +import cv2 + + +def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: + """ + Parse ZED pose data dictionary into a Pose object. + + Args: + zed_pose_data: Dictionary from ZEDCamera.get_pose() containing: + - position: [x, y, z] in meters + - rotation: [x, y, z, w] quaternion + - euler_angles: [roll, pitch, yaw] in radians + - valid: Whether pose is valid + + Returns: + Pose object with position and rotation, or None if invalid + """ + if not zed_pose_data or not zed_pose_data.get("valid", False): + return None + + # Extract position + position = zed_pose_data.get("position", [0, 0, 0]) + pos_vector = Vector(position[0], position[1], position[2]) + + # Extract euler angles (roll, pitch, yaw) + euler = zed_pose_data.get("euler_angles", [0, 0, 0]) + rot_vector = Vector(euler[0], euler[1], euler[2]) # roll, pitch, yaw + + return Pose(pos_vector, rot_vector) + + +def pose_to_transform_matrix(pose: Pose) -> np.ndarray: + """ + Convert pose to 4x4 homogeneous transform matrix. + + Args: + pose: Pose object with position and rotation (euler angles) + + Returns: + 4x4 transformation matrix + """ + # Extract position + tx, ty, tz = pose.pos.x, pose.pos.y, pose.pos.z + + # Extract euler angles + roll, pitch, yaw = pose.rot.x, pose.rot.y, pose.rot.z + + # Create rotation matrices + cos_roll, sin_roll = np.cos(roll), np.sin(roll) + cos_pitch, sin_pitch = np.cos(pitch), np.sin(pitch) + cos_yaw, sin_yaw = np.cos(yaw), np.sin(yaw) + + # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention + R_x = np.array([[1, 0, 0], [0, cos_roll, -sin_roll], [0, sin_roll, cos_roll]]) + + R_y = np.array([[cos_pitch, 0, sin_pitch], [0, 1, 0], [-sin_pitch, 0, cos_pitch]]) + + R_z = np.array([[cos_yaw, -sin_yaw, 0], [sin_yaw, cos_yaw, 0], [0, 0, 1]]) + + R = R_z @ R_y @ R_x + + # Create 4x4 transform + T = np.eye(4) + T[:3, :3] = R + T[:3, 3] = [tx, ty, tz] + + return T + + +def transform_matrix_to_pose(T: np.ndarray) -> Pose: + """ + Convert 4x4 transformation matrix to Pose object. + + Args: + T: 4x4 transformation matrix + + Returns: + Pose object with position and rotation (euler angles) + """ + # Extract position + pos = Vector(T[0, 3], T[1, 3], T[2, 3]) + + # Extract rotation (euler angles from rotation matrix) + R = T[:3, :3] + roll = np.arctan2(R[2, 1], R[2, 2]) + pitch = np.arctan2(-R[2, 0], np.sqrt(R[2, 1] ** 2 + R[2, 2] ** 2)) + yaw = np.arctan2(R[1, 0], R[0, 0]) + + rot = Vector(roll, pitch, yaw) + + return Pose(pos, rot) + + +def apply_transform(pose: Pose, transform_matrix: np.ndarray) -> Pose: + """ + Apply a transformation matrix to a pose. + + Args: + pose: Input pose + transform_matrix: 4x4 transformation matrix to apply + + Returns: + Transformed pose + """ + # Convert pose to matrix + T_pose = pose_to_transform_matrix(pose) + + # Apply transform + T_result = transform_matrix @ T_pose + + # Convert back to pose + return transform_matrix_to_pose(T_result) + + +def zed_to_robot_convention(pose: Pose) -> Pose: + """ + Convert pose from ZED camera convention to robot arm convention. + + ZED Camera Coordinates: + - X: Right + - Y: Down + - Z: Forward (away from camera) + + Robot/ROS Coordinates: + - X: Forward + - Y: Left + - Z: Up + + Args: + pose: Pose in ZED camera convention + + Returns: + Pose in robot arm convention + """ + # Position transformation + robot_x = pose.pos.z # Forward = ZED Z + robot_y = -pose.pos.x # Left = -ZED X + robot_z = -pose.pos.y # Up = -ZED Y + + # Rotation transformation using rotation matrices + # First, create rotation matrix from ZED Euler angles + roll_zed, pitch_zed, yaw_zed = pose.rot.x, pose.rot.y, pose.rot.z + + # Create rotation matrix for ZED frame (ZYX convention) + cr, sr = np.cos(roll_zed), np.sin(roll_zed) + cp, sp = np.cos(pitch_zed), np.sin(pitch_zed) + cy, sy = np.cos(yaw_zed), np.sin(yaw_zed) + + # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention + R_x = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]]) + + R_y = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]]) + + R_z = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]]) + + R_zed = R_z @ R_y @ R_x + + # Coordinate frame transformation matrix from ZED to Robot + # X_robot = Z_zed, Y_robot = -X_zed, Z_robot = -Y_zed + T_frame = np.array( + [ + [0, 0, 1], # X_robot = Z_zed + [-1, 0, 0], # Y_robot = -X_zed + [0, -1, 0], + ] + ) # Z_robot = -Y_zed + + # Transform the rotation matrix + R_robot = T_frame @ R_zed @ T_frame.T + + # Extract Euler angles from robot rotation matrix + # Using ZYX convention for robot frame as well + robot_roll = np.arctan2(R_robot[2, 1], R_robot[2, 2]) + robot_pitch = np.arctan2(-R_robot[2, 0], np.sqrt(R_robot[2, 1] ** 2 + R_robot[2, 2] ** 2)) + robot_yaw = np.arctan2(R_robot[1, 0], R_robot[0, 0]) + + # Normalize angles to [-π, π] + robot_roll = np.arctan2(np.sin(robot_roll), np.cos(robot_roll)) + robot_pitch = np.arctan2(np.sin(robot_pitch), np.cos(robot_pitch)) + robot_yaw = np.arctan2(np.sin(robot_yaw), np.cos(robot_yaw)) + + return Pose(Vector(robot_x, robot_y, robot_z), Vector(robot_roll, robot_pitch, robot_yaw)) + + +def robot_to_zed_convention(pose: Pose) -> Pose: + """ + Convert pose from robot arm convention to ZED camera convention. + This is the inverse of zed_to_robot_convention. + + Args: + pose: Pose in robot arm convention + + Returns: + Pose in ZED camera convention + """ + # Position transformation (inverse) + zed_x = -pose.pos.y # Right = -Left + zed_y = -pose.pos.z # Down = -Up + zed_z = pose.pos.x # Forward = Forward + + # Rotation transformation using rotation matrices + # First, create rotation matrix from Robot Euler angles + roll_robot, pitch_robot, yaw_robot = pose.rot.x, pose.rot.y, pose.rot.z + + # Create rotation matrix for Robot frame (ZYX convention) + cr, sr = np.cos(roll_robot), np.sin(roll_robot) + cp, sp = np.cos(pitch_robot), np.sin(pitch_robot) + cy, sy = np.cos(yaw_robot), np.sin(yaw_robot) + + # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention + R_x = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]]) + + R_y = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]]) + + R_z = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]]) + + R_robot = R_z @ R_y @ R_x + + # Coordinate frame transformation matrix from Robot to ZED (inverse of ZED to Robot) + # This is the transpose of the forward transformation + T_frame_inv = np.array( + [ + [0, -1, 0], # X_zed = -Y_robot + [0, 0, -1], # Y_zed = -Z_robot + [1, 0, 0], + ] + ) # Z_zed = X_robot + + # Transform the rotation matrix + R_zed = T_frame_inv @ R_robot @ T_frame_inv.T + + # Extract Euler angles from ZED rotation matrix + # Using ZYX convention for ZED frame as well + zed_roll = np.arctan2(R_zed[2, 1], R_zed[2, 2]) + zed_pitch = np.arctan2(-R_zed[2, 0], np.sqrt(R_zed[2, 1] ** 2 + R_zed[2, 2] ** 2)) + zed_yaw = np.arctan2(R_zed[1, 0], R_zed[0, 0]) + + # Normalize angles + zed_roll = np.arctan2(np.sin(zed_roll), np.cos(zed_roll)) + zed_pitch = np.arctan2(np.sin(zed_pitch), np.cos(zed_pitch)) + zed_yaw = np.arctan2(np.sin(zed_yaw), np.cos(zed_yaw)) + + return Pose(Vector(zed_x, zed_y, zed_z), Vector(zed_roll, zed_pitch, zed_yaw)) + + +def calculate_yaw_to_origin(position: Vector) -> float: + """ + Calculate yaw angle to point away from origin (0,0,0) + Assumes robot frame where X is forward and Y is left. + + Args: + position: Current position in robot frame + + Returns: + Yaw angle in radians to point away from origin + """ + return np.arctan2(position.y, position.x) + + +def estimate_object_depth( + depth_image: np.ndarray, segmentation_mask: Optional[np.ndarray], bbox: List[float] +) -> float: + """ + Estimate object depth dimension using segmentation mask and depth data. + Optimized for real-time performance. + + Args: + depth_image: Depth image in meters + segmentation_mask: Binary segmentation mask for the object + bbox: Bounding box [x1, y1, x2, y2] + + Returns: + Estimated object depth in meters + """ + x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + + # Quick bounds check + if x2 <= x1 or y2 <= y1: + return 0.05 + + # Extract depth ROI once + roi_depth = depth_image[y1:y2, x1:x2] + + if segmentation_mask is not None and segmentation_mask.size > 0: + # Extract mask ROI efficiently + mask_roi = ( + segmentation_mask[y1:y2, x1:x2] + if segmentation_mask.shape != roi_depth.shape + else segmentation_mask + ) + + # Fast mask application using boolean indexing + valid_mask = mask_roi > 0 + if np.sum(valid_mask) > 10: # Early exit if not enough points + masked_depths = roi_depth[valid_mask] + + # Fast percentile calculation using numpy's optimized functions + depth_90 = np.percentile(masked_depths, 90) + depth_10 = np.percentile(masked_depths, 10) + depth_range = depth_90 - depth_10 + + # Clamp to reasonable bounds with single operation + return np.clip(depth_range, 0.02, 0.5) + + # Fast fallback using area calculation + bbox_area = (x2 - x1) * (y2 - y1) + + # Vectorized area-based estimation + if bbox_area > 10000: + return 0.15 + elif bbox_area > 5000: + return 0.10 + else: + return 0.05 diff --git a/dimos/perception/common/utils.py b/dimos/perception/common/utils.py index da9cc58fe0..fc50e042ad 100644 --- a/dimos/perception/common/utils.py +++ b/dimos/perception/common/utils.py @@ -14,7 +14,7 @@ import cv2 import numpy as np -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Any from dimos.types.manipulation import ObjectData from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger @@ -329,3 +329,36 @@ def combine_object_data( combined.append(obj_copy) return combined + + +def point_in_bbox(point: Tuple[int, int], bbox: List[float]) -> bool: + """ + Check if a point is inside a bounding box. + + Args: + point: (x, y) coordinates + bbox: Bounding box [x1, y1, x2, y2] + + Returns: + True if point is inside bbox + """ + x, y = point + x1, y1, x2, y2 = bbox + return x1 <= x <= x2 and y1 <= y <= y2 + + +def find_clicked_object(click_point: Tuple[int, int], objects: List[Any]) -> Optional[Any]: + """ + Find which object was clicked based on bounding boxes. + + Args: + click_point: (x, y) coordinates of mouse click + objects: List of objects with 'bbox' field + + Returns: + Clicked object or None + """ + for obj in objects: + if "bbox" in obj and point_in_bbox(click_point, obj["bbox"]): + return obj + return None diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index fcf27584e6..1b81dce07b 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -98,7 +98,7 @@ def process_image(self, image): source=image, device=self.device, retina_masks=True, - conf=0.6, + conf=0.5, iou=0.9, persist=True, verbose=False, diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index ea17275af2..03eb80f6ae 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -13,153 +13,223 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Copyright 2025 Dimensional Inc. + """ -Simple test script for Detection3D processor with ZED camera. -Press 'q' to quit, 's' to save current frame. +Test script for PBVS with ZED camera supporting robot arm frame. +Click on objects to select targets (requires origin to be set first). +Press 'o' to set manipulator origin at current camera pose. """ import cv2 import numpy as np import sys import os +import time -# Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dimos.hardware.zed_camera import ZEDCamera from dimos.manipulation.ibvs.detection3d import Detection3DProcessor +from dimos.manipulation.ibvs.utils import parse_zed_pose +from dimos.perception.common.utils import find_clicked_object +from dimos.manipulation.ibvs.pbvs import PBVSController try: import pyzed.sl as sl except ImportError: - print("Error: ZED SDK not installed. Please install pyzed package.") + print("Error: ZED SDK not installed.") sys.exit(1) +# Global for mouse events +mouse_click = None +warning_message = None +warning_time = None + + +def mouse_callback(event, x, y, flags, param): + global mouse_click + if event == cv2.EVENT_LBUTTONDOWN: + mouse_click = (x, y) + + def main(): - """Main test function.""" - print("Starting Detection3D test with ZED camera...") - - # Initialize ZED camera - print("Initializing ZED camera...") - zed_camera = ZEDCamera( - camera_id=0, - resolution=sl.RESOLUTION.HD720, # 1280x720 for good performance - depth_mode=sl.DEPTH_MODE.NEURAL, # Best quality depth - fps=30, - ) - - # Open camera - if not zed_camera.open(): - print("Failed to open ZED camera!") - return + global mouse_click, warning_message, warning_time - # Get camera intrinsics - camera_info = zed_camera.get_camera_info() - left_cam = camera_info.get("left_cam", {}) + print("=== PBVS Test with Robot Frame Support ===") + print("IMPORTANT: Press 'o' to set manipulator origin FIRST") + print("Then click objects to select targets | 'r' - reset | 'q' - quit") - # Extract intrinsics [fx, fy, cx, cy] + # Initialize camera + zed = ZEDCamera(resolution=sl.RESOLUTION.HD720, depth_mode=sl.DEPTH_MODE.NEURAL) + if not zed.open() or not zed.enable_positional_tracking(): + print("Camera initialization failed!") + return + + # Get intrinsics + cam_info = zed.get_camera_info() intrinsics = [ - left_cam.get("fx", 700), - left_cam.get("fy", 700), - left_cam.get("cx", 640), - left_cam.get("cy", 360), + cam_info["left_cam"]["fx"], + cam_info["left_cam"]["fy"], + cam_info["left_cam"]["cx"], + cam_info["left_cam"]["cy"], ] - print( - f"Camera intrinsics: fx={intrinsics[0]:.1f}, fy={intrinsics[1]:.1f}, " - f"cx={intrinsics[2]:.1f}, cy={intrinsics[3]:.1f}" - ) - - # Initialize Detection3D processor - print("Initializing Detection3D processor...") - detector = Detection3DProcessor( - camera_intrinsics=intrinsics, - min_confidence=0.5, # Lower threshold for more detections - min_points=20, # Lower for better real-time performance - max_depth=3.0, # Limit to 3 meters - ) + # Initialize processors + detector = Detection3DProcessor(intrinsics) + pbvs = PBVSController(position_gain=0.3, rotation_gain=0.2, target_tolerance=0.1) - print("\nStarting detection loop...") - print("Press 'q' to quit, 's' to save current frame") - - frame_count = 0 + # Setup window + cv2.namedWindow("PBVS") + cv2.setMouseCallback("PBVS", mouse_callback) try: while True: - # Capture frame - left_img, right_img, depth = zed_camera.capture_frame() - - if left_img is None or depth is None: - print("Failed to capture frame") + # Capture + bgr, _, depth, pose_data = zed.capture_frame_with_pose() + if bgr is None or depth is None: continue - # Convert BGR to RGB for detection - rgb_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB) + # Process + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + camera_pose = parse_zed_pose(pose_data) if pose_data else None + results = detector.process_frame(rgb, depth, camera_pose) + detections = results["detections"] + + # Handle click + if mouse_click: + clicked = find_clicked_object(mouse_click, detections) + if clicked: + # Try to set target (will fail if no origin) + if not pbvs.set_target(clicked): + warning_message = "SET ORIGIN FIRST! Press 'o'" + warning_time = time.time() + mouse_click = None + + # Create visualization with position overlays (robot frame if available) + viz = detector.visualize_detections(rgb, detections, pbvs_controller=pbvs) + + # PBVS control + if camera_pose: + vel_cmd, ang_vel_cmd, reached, has_target = pbvs.compute_control( + camera_pose, detections + ) - # Process frame - results = detector.process_frame(rgb_img, depth) + # Apply PBVS overlay + viz = pbvs.create_status_overlay(viz, intrinsics) - # Create visualization - viz = detector.visualize_detections(rgb_img, results["detections"], show_3d=True) + # Highlight target + if has_target and pbvs.current_target and "bbox" in pbvs.current_target: + x1, y1, x2, y2 = map(int, pbvs.current_target["bbox"]) + cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) + cv2.putText( + viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 + ) + + # Print velocity commands for debugging (only if origin set) + if vel_cmd and ang_vel_cmd: + print(f"Linear vel: ({vel_cmd.x:.3f}, {vel_cmd.y:.3f}, {vel_cmd.z:.3f}) m/s") + print( + f"Angular vel: ({ang_vel_cmd.x:.3f}, {ang_vel_cmd.y:.3f}, {ang_vel_cmd.z:.3f}) rad/s" + ) # Convert back to BGR for OpenCV display viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) - # Add info text - info_text = [ - f"Frame: {frame_count}", - f"Detections: {len(results['detections'])}", - f"3D Valid: {sum(1 for d in results['detections'] if d.get('has_3d', False))}", - f"Time: {results['processing_time'] * 1000:.1f}ms", - ] + # Add camera pose info + if camera_pose: + # Show camera pose in appropriate frame + if pbvs.manipulator_origin is not None: + cam_robot = pbvs.get_camera_pose_robot_frame(camera_pose) + if cam_robot: + pose_text = f"Camera [Robot]: ({cam_robot.pos.x:.2f}, {cam_robot.pos.y:.2f}, {cam_robot.pos.z:.2f})m" + else: + pose_text = f"Camera [ZED]: ({camera_pose.pos.x:.2f}, {camera_pose.pos.y:.2f}, {camera_pose.pos.z:.2f})m" + else: + pose_text = f"Camera [ZED]: ({camera_pose.pos.x:.2f}, {camera_pose.pos.y:.2f}, {camera_pose.pos.z:.2f})m" - y_offset = 20 - for text in info_text: cv2.putText( - viz_bgr, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2 + viz_bgr, pose_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 ) - y_offset += 25 - # Find closest detection - closest = detector.get_closest_detection(results["detections"]) - if closest: - text = f"Closest: {closest['class_name']} @ {closest['centroid'][2]:.2f}m" - cv2.putText( - viz_bgr, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2 - ) + # Show origin status + if pbvs.manipulator_origin is not None: + cv2.putText( + viz_bgr, + "Manipulator Origin SET", + (10, 50), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 0), + 1, + ) + else: + cv2.putText( + viz_bgr, + "Press 'o' to set manipulator origin", + (10, 50), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 0, 0), + 1, + ) + + # Display warning message if active + if warning_message and warning_time: + # Show warning for 3 seconds + if time.time() - warning_time < 3.0: + # Draw warning box + height, width = viz_bgr.shape[:2] + box_height = 80 + box_y = height // 2 - box_height // 2 + + # Semi-transparent red background + overlay = viz_bgr.copy() + cv2.rectangle( + overlay, (50, box_y), (width - 50, box_y + box_height), (0, 0, 255), -1 + ) + viz_bgr = cv2.addWeighted(viz_bgr, 0.7, overlay, 0.3, 0) + + # Warning text + text_size = cv2.getTextSize(warning_message, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0] + text_x = (width - text_size[0]) // 2 + text_y = box_y + box_height // 2 + text_size[1] // 2 + + cv2.putText( + viz_bgr, + warning_message, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (255, 255, 255), + 2, + ) + else: + warning_message = None + warning_time = None # Display - cv2.imshow("Detection3D Test", viz_bgr) + cv2.imshow("PBVS", viz_bgr) - # Handle key press + # Keyboard key = cv2.waitKey(1) & 0xFF if key == ord("q"): break - elif key == ord("s"): - # Save current frame - cv2.imwrite(f"detection3d_frame_{frame_count:04d}.png", viz_bgr) - print(f"Saved frame {frame_count}") - - frame_count += 1 - - # Print detections every 30 frames - if frame_count % 30 == 0: - print(f"\nFrame {frame_count}:") - for det in results["detections"]: - if det.get("has_3d", False): - print(f" - {det['class_name']}: {det['centroid'][2]:.2f}m away") + elif key == ord("r"): + pbvs.clear_target() + elif key == ord("o") and camera_pose: + pbvs.set_manipulator_origin(camera_pose) + print( + f"Set manipulator origin at: ({camera_pose.pos.x:.3f}, {camera_pose.pos.y:.3f}, {camera_pose.pos.z:.3f})" + ) except KeyboardInterrupt: - print("\nInterrupted by user") - + pass finally: - # Cleanup - print("\nCleaning up...") cv2.destroyAllWindows() detector.cleanup() - zed_camera.close() - print("Done!") + zed.close() if __name__ == "__main__": From 856acad181364e1f431a6c2f7cf00233eef7d26a Mon Sep 17 00:00:00 2001 From: rapmusta Date: Sun, 13 Jul 2025 02:29:03 +0000 Subject: [PATCH 19/89] added piper arm wrapper --- dimos/hardware/piper_arm.py | 72 +++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 dimos/hardware/piper_arm.py diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py new file mode 100644 index 0000000000..9c8a258da7 --- /dev/null +++ b/dimos/hardware/piper_arm.py @@ -0,0 +1,72 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# dimos/hardware/piper_arm.py + +from typing import ( + Optional, +) +from piper_sdk import * # from the official Piper SDK +import numpy as np +import time + +class PiperArm: + def __init__(self, arm_name: str = "arm"): + self.arm = C_PiperInterface_V2() + self.arm.ConnectPort() + time.sleep(0.1) + while( not self.arm.EnablePiper()): + pass + time.sleep(0.01) + self.arm.MotionCtrl_1(0x02,0,0) + self.arm.MotionCtrl_2(0, 0, 0, 0x00) + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + print(f"[PiperArm] Connected to {arm_name}") + + def softStop(self): + self.arm.MotionCtrl_1(0x01,0,0) + time.sleep(0.01) + + def cmd_EE_pose(self, x, y, z, r, p, y_): + """Command end-effector to target pose in space (position + Euler angles)""" + factor = 1000 + pose = [x*factor, y*factor, z*factor, r*factor, p*factor, y_*factor] + self.arm.EndPoseCtrl(pose) + print(f"[PiperArm] Moving to pose: {pose}") + + def get_EE_pose(self): + """Return the current end-effector pose as (x, y, z, r, p, y)""" + pose = self.arm.getArmEndPoseMsgs() + print(f"[PiperArm] Current pose: {pose}") + return tuple(pose) + + def cmd_gripper_ctrl(self, position): + """Command end-effector gripper""" + position = position * 1000 + + self.arm.GripperCtrl(abs(round(position)), 1000, 0x01,0) + print(f"[PiperArm] Commanding gripper position: {position}") + + def resetArm(self): + self.arm.MotionCtrl_1(0x02,0,0) + self.arm.MotionCtrl_2(0, 0, 0, 0x00) + print(f"[PiperArm] Resetting arm") + +if __name__ == "__main__": + arm = PiperArm() + arm.cmd_EE_pose(0, 0, 0, 0, 0, 0) + time.sleep(1) + arm.get_EE_pose() + time.sleep(1) + From 8774e0e68ce571736f0676de41bbce594712840e Mon Sep 17 00:00:00 2001 From: mustafab0 <39084056+mustafab0@users.noreply.github.com> Date: Sun, 13 Jul 2025 02:35:11 +0000 Subject: [PATCH 20/89] CI code cleanup --- dimos/hardware/piper_arm.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 9c8a258da7..8abef04e4b 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -21,27 +21,28 @@ import numpy as np import time + class PiperArm: def __init__(self, arm_name: str = "arm"): self.arm = C_PiperInterface_V2() self.arm.ConnectPort() time.sleep(0.1) - while( not self.arm.EnablePiper()): + while not self.arm.EnablePiper(): pass time.sleep(0.01) - self.arm.MotionCtrl_1(0x02,0,0) + self.arm.MotionCtrl_1(0x02, 0, 0) self.arm.MotionCtrl_2(0, 0, 0, 0x00) self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) print(f"[PiperArm] Connected to {arm_name}") def softStop(self): - self.arm.MotionCtrl_1(0x01,0,0) + self.arm.MotionCtrl_1(0x01, 0, 0) time.sleep(0.01) def cmd_EE_pose(self, x, y, z, r, p, y_): """Command end-effector to target pose in space (position + Euler angles)""" factor = 1000 - pose = [x*factor, y*factor, z*factor, r*factor, p*factor, y_*factor] + pose = [x * factor, y * factor, z * factor, r * factor, p * factor, y_ * factor] self.arm.EndPoseCtrl(pose) print(f"[PiperArm] Moving to pose: {pose}") @@ -55,18 +56,18 @@ def cmd_gripper_ctrl(self, position): """Command end-effector gripper""" position = position * 1000 - self.arm.GripperCtrl(abs(round(position)), 1000, 0x01,0) + self.arm.GripperCtrl(abs(round(position)), 1000, 0x01, 0) print(f"[PiperArm] Commanding gripper position: {position}") def resetArm(self): - self.arm.MotionCtrl_1(0x02,0,0) + self.arm.MotionCtrl_1(0x02, 0, 0) self.arm.MotionCtrl_2(0, 0, 0, 0x00) print(f"[PiperArm] Resetting arm") + if __name__ == "__main__": arm = PiperArm() arm.cmd_EE_pose(0, 0, 0, 0, 0, 0) time.sleep(1) arm.get_EE_pose() time.sleep(1) - From 62a032045c88cdc8c9a75a28b8fe585bc10e7940 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Mon, 14 Jul 2025 22:08:50 +0000 Subject: [PATCH 21/89] added can init support and soft stopping to avoid crashes when quiting --- dimos/hardware/can_activate.sh | 138 +++++++++++++++++++++++++++++++++ dimos/hardware/piper_arm.py | 80 ++++++++++++++++--- 2 files changed, 207 insertions(+), 11 deletions(-) create mode 100644 dimos/hardware/can_activate.sh diff --git a/dimos/hardware/can_activate.sh b/dimos/hardware/can_activate.sh new file mode 100644 index 0000000000..60cc95e7ea --- /dev/null +++ b/dimos/hardware/can_activate.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +# The default CAN name can be set by the user via command-line parameters. +DEFAULT_CAN_NAME="${1:-can0}" + +# The default bitrate for a single CAN module can be set by the user via command-line parameters. +DEFAULT_BITRATE="${2:-1000000}" + +# USB hardware address (optional parameter) +USB_ADDRESS="${3}" +echo "-------------------START-----------------------" +# Check if ethtool is installed. +if ! dpkg -l | grep -q "ethtool"; then + echo "\e[31mError: ethtool not detected in the system.\e[0m" + echo "Please use the following command to install ethtool:" + echo "sudo apt update && sudo apt install ethtool" + exit 1 +fi + +# Check if can-utils is installed. +if ! dpkg -l | grep -q "can-utils"; then + echo "\e[31mError: can-utils not detected in the system.\e[0m" + echo "Please use the following command to install ethtool:" + echo "sudo apt update && sudo apt install can-utils" + exit 1 +fi + +echo "Both ethtool and can-utils are installed." + +# Retrieve the number of CAN modules in the current system. +CURRENT_CAN_COUNT=$(ip link show type can | grep -c "link/can") + +# Verify if the number of CAN modules in the current system matches the expected value. +if [ "$CURRENT_CAN_COUNT" -ne "1" ]; then + if [ -z "$USB_ADDRESS" ]; then + # Iterate through all CAN interfaces. + for iface in $(ip -br link show type can | awk '{print $1}'); do + # Use ethtool to retrieve bus-info. + BUS_INFO=$(sudo ethtool -i "$iface" | grep "bus-info" | awk '{print $2}') + + if [ -z "$BUS_INFO" ];then + echo "Error: Unable to retrieve bus-info for interface $iface." + continue + fi + + echo "Interface $iface is inserted into USB port $BUS_INFO" + done + echo -e " \e[31m Error: The number of CAN modules detected by the system ($CURRENT_CAN_COUNT) does not match the expected number (1). \e[0m" + echo -e " \e[31m Please add the USB hardware address parameter, such as: \e[0m" + echo -e " bash can_activate.sh can0 1000000 1-2:1.0" + echo "-------------------ERROR-----------------------" + exit 1 + fi +fi + +# Load the gs_usb module. +# sudo modprobe gs_usb +# if [ $? -ne 0 ]; then +# echo "Error: Unable to load the gs_usb module." +# exit 1 +# fi + +if [ -n "$USB_ADDRESS" ]; then + echo "Detected USB hardware address parameter: $USB_ADDRESS" + + # Use ethtool to find the CAN interface corresponding to the USB hardware address. + INTERFACE_NAME="" + for iface in $(ip -br link show type can | awk '{print $1}'); do + BUS_INFO=$(sudo ethtool -i "$iface" | grep "bus-info" | awk '{print $2}') + if [ "$BUS_INFO" = "$USB_ADDRESS" ]; then + INTERFACE_NAME="$iface" + break + fi + done + + if [ -z "$INTERFACE_NAME" ]; then + echo "Error: Unable to find CAN interface corresponding to USB hardware address $USB_ADDRESS." + exit 1 + else + echo "Found the interface corresponding to USB hardware address $USB_ADDRESS: $INTERFACE_NAME." + fi +else + # Retrieve the unique CAN interface. + INTERFACE_NAME=$(ip -br link show type can | awk '{print $1}') + + # Check if the interface name has been retrieved. + if [ -z "$INTERFACE_NAME" ]; then + echo "Error: Unable to detect CAN interface." + exit 1 + fi + BUS_INFO=$(sudo ethtool -i "$INTERFACE_NAME" | grep "bus-info" | awk '{print $2}') + echo "Expected to configure a single CAN module, detected interface $INTERFACE_NAME with corresponding USB address $BUS_INFO." +fi + +# Check if the current interface is already activated. +IS_LINK_UP=$(ip link show "$INTERFACE_NAME" | grep -q "UP" && echo "yes" || echo "no") + +# Retrieve the bitrate of the current interface. +CURRENT_BITRATE=$(ip -details link show "$INTERFACE_NAME" | grep -oP 'bitrate \K\d+') + +if [ "$IS_LINK_UP" = "yes" ] && [ "$CURRENT_BITRATE" -eq "$DEFAULT_BITRATE" ]; then + echo "Interface $INTERFACE_NAME is already activated with a bitrate of $DEFAULT_BITRATE." + + # Check if the interface name matches the default name. + if [ "$INTERFACE_NAME" != "$DEFAULT_CAN_NAME" ]; then + echo "Rename interface $INTERFACE_NAME to $DEFAULT_CAN_NAME." + sudo ip link set "$INTERFACE_NAME" down + sudo ip link set "$INTERFACE_NAME" name "$DEFAULT_CAN_NAME" + sudo ip link set "$DEFAULT_CAN_NAME" up + echo "The interface has been renamed to $DEFAULT_CAN_NAME and reactivated." + else + echo "The interface name is already $DEFAULT_CAN_NAME." + fi +else + # If the interface is not activated or the bitrate is different, configure it. + if [ "$IS_LINK_UP" = "yes" ]; then + echo "Interface $INTERFACE_NAME is already activated, but the bitrate is $CURRENT_BITRATE, which does not match the set value of $DEFAULT_BITRATE." + else + echo "Interface $INTERFACE_NAME is not activated or bitrate is not set." + fi + + # Set the interface bitrate and activate it. + sudo ip link set "$INTERFACE_NAME" down + sudo ip link set "$INTERFACE_NAME" type can bitrate $DEFAULT_BITRATE + sudo ip link set "$INTERFACE_NAME" up + echo "Interface $INTERFACE_NAME has been reset to bitrate $DEFAULT_BITRATE and activated." + + # Rename the interface to the default name. + if [ "$INTERFACE_NAME" != "$DEFAULT_CAN_NAME" ]; then + echo "Rename interface $INTERFACE_NAME to $DEFAULT_CAN_NAME." + sudo ip link set "$INTERFACE_NAME" down + sudo ip link set "$INTERFACE_NAME" name "$DEFAULT_CAN_NAME" + sudo ip link set "$DEFAULT_CAN_NAME" up + echo "The interface has been renamed to $DEFAULT_CAN_NAME and reactivated." + fi +fi + +echo "-------------------OVER------------------------" diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 8abef04e4b..800b22cd6f 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -20,37 +20,76 @@ from piper_sdk import * # from the official Piper SDK import numpy as np import time +import subprocess class PiperArm: def __init__(self, arm_name: str = "arm"): + self.init_can() self.arm = C_PiperInterface_V2() self.arm.ConnectPort() time.sleep(0.1) + self.resetArm() + time.sleep(0.1) + self.enable() + self.gotoZero() + time.sleep(1) + + def init_can(self): + result = subprocess.run( + ["bash", "dimos/dimos/hardware/can_activate.sh"], # pass the script path directly if it has a shebang and execute perms + stdout=subprocess.PIPE, # capture stdout + stderr=subprocess.PIPE, # capture stderr + text=True # return strings instead of bytes + ) + + def enable(self): while not self.arm.EnablePiper(): pass time.sleep(0.01) - self.arm.MotionCtrl_1(0x02, 0, 0) - self.arm.MotionCtrl_2(0, 0, 0, 0x00) - self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) - print(f"[PiperArm] Connected to {arm_name}") + print(f"[PiperArm] Enabled") + self.arm.MotionCtrl_2(0x01, 0x01, 80, 0x00) + + + + def gotoZero(self): + factor = 57295.7795 #1000*180/3.1415926 + position = [0,0,0,0,0,0,0] + + joint_0 = round(position[0]*factor) + joint_1 = round(position[1]*factor) + joint_2 = round(position[2]*factor) + joint_3 = round(position[3]*factor) + joint_4 = round(position[4]*factor) + joint_5 = round(position[5]*factor) + joint_6 = round(position[6]*1000*1000) + self.arm.ModeCtrl(0x01, 0x01, 30, 0x00) + self.arm.JointCtrl(joint_0, joint_1, joint_2, joint_3, joint_4, joint_5) + self.arm.GripperCtrl(abs(joint_6), 1000, 0x01, 0) + pass + def softStop(self): + self.gotoZero() + time.sleep(1) + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) self.arm.MotionCtrl_1(0x01, 0, 0) - time.sleep(0.01) + time.sleep(5) + def cmd_EE_pose(self, x, y, z, r, p, y_): """Command end-effector to target pose in space (position + Euler angles)""" factor = 1000 pose = [x * factor, y * factor, z * factor, r * factor, p * factor, y_ * factor] - self.arm.EndPoseCtrl(pose) + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + self.arm.EndPoseCtrl(int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5])) print(f"[PiperArm] Moving to pose: {pose}") def get_EE_pose(self): """Return the current end-effector pose as (x, y, z, r, p, y)""" - pose = self.arm.getArmEndPoseMsgs() + pose = self.arm.GetArmEndPoseMsgs() print(f"[PiperArm] Current pose: {pose}") - return tuple(pose) + return pose def cmd_gripper_ctrl(self, position): """Command end-effector gripper""" @@ -64,10 +103,29 @@ def resetArm(self): self.arm.MotionCtrl_2(0, 0, 0, 0x00) print(f"[PiperArm] Resetting arm") + def disable(self): + self.softStop() + + while(self.arm.DisablePiper()): + pass + time.sleep(0.01) + self.arm.DisconnectPort() + if __name__ == "__main__": arm = PiperArm() - arm.cmd_EE_pose(0, 0, 0, 0, 0, 0) - time.sleep(1) + + print("get_EE_pose") arm.get_EE_pose() - time.sleep(1) + + while True: + arm.cmd_EE_pose(60, 0, 300, 0, 85, 0) + time.sleep(1) + arm.cmd_EE_pose(60, 0, 260, 0, 85, 0) + time.sleep(1) + + user_input = input("Press Enter to repeat, or type 'q' to quit: ") + if user_input.strip().lower() == 'q': + arm.disable() + break + From 4ca7cbd6ca13bee33837c78ac7e21bf0a3830110 Mon Sep 17 00:00:00 2001 From: mustafab0 <39084056+mustafab0@users.noreply.github.com> Date: Mon, 14 Jul 2025 22:09:36 +0000 Subject: [PATCH 22/89] CI code cleanup --- dimos/hardware/piper_arm.py | 46 ++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 800b22cd6f..3bd2a2cf15 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -37,10 +37,13 @@ def __init__(self, arm_name: str = "arm"): def init_can(self): result = subprocess.run( - ["bash", "dimos/dimos/hardware/can_activate.sh"], # pass the script path directly if it has a shebang and execute perms - stdout=subprocess.PIPE, # capture stdout - stderr=subprocess.PIPE, # capture stderr - text=True # return strings instead of bytes + [ + "bash", + "dimos/dimos/hardware/can_activate.sh", + ], # pass the script path directly if it has a shebang and execute perms + stdout=subprocess.PIPE, # capture stdout + stderr=subprocess.PIPE, # capture stderr + text=True, # return strings instead of bytes ) def enable(self): @@ -50,25 +53,22 @@ def enable(self): print(f"[PiperArm] Enabled") self.arm.MotionCtrl_2(0x01, 0x01, 80, 0x00) - - def gotoZero(self): - factor = 57295.7795 #1000*180/3.1415926 - position = [0,0,0,0,0,0,0] - - joint_0 = round(position[0]*factor) - joint_1 = round(position[1]*factor) - joint_2 = round(position[2]*factor) - joint_3 = round(position[3]*factor) - joint_4 = round(position[4]*factor) - joint_5 = round(position[5]*factor) - joint_6 = round(position[6]*1000*1000) + factor = 57295.7795 # 1000*180/3.1415926 + position = [0, 0, 0, 0, 0, 0, 0] + + joint_0 = round(position[0] * factor) + joint_1 = round(position[1] * factor) + joint_2 = round(position[2] * factor) + joint_3 = round(position[3] * factor) + joint_4 = round(position[4] * factor) + joint_5 = round(position[5] * factor) + joint_6 = round(position[6] * 1000 * 1000) self.arm.ModeCtrl(0x01, 0x01, 30, 0x00) self.arm.JointCtrl(joint_0, joint_1, joint_2, joint_3, joint_4, joint_5) self.arm.GripperCtrl(abs(joint_6), 1000, 0x01, 0) pass - def softStop(self): self.gotoZero() time.sleep(1) @@ -76,13 +76,14 @@ def softStop(self): self.arm.MotionCtrl_1(0x01, 0, 0) time.sleep(5) - def cmd_EE_pose(self, x, y, z, r, p, y_): """Command end-effector to target pose in space (position + Euler angles)""" factor = 1000 pose = [x * factor, y * factor, z * factor, r * factor, p * factor, y_ * factor] self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) - self.arm.EndPoseCtrl(int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5])) + self.arm.EndPoseCtrl( + int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5]) + ) print(f"[PiperArm] Moving to pose: {pose}") def get_EE_pose(self): @@ -105,8 +106,8 @@ def resetArm(self): def disable(self): self.softStop() - - while(self.arm.DisablePiper()): + + while self.arm.DisablePiper(): pass time.sleep(0.01) self.arm.DisconnectPort() @@ -125,7 +126,6 @@ def disable(self): time.sleep(1) user_input = input("Press Enter to repeat, or type 'q' to quit: ") - if user_input.strip().lower() == 'q': + if user_input.strip().lower() == "q": arm.disable() break - From a8a47c4ae3f019378ece1b59dcee28b7c2cf29dc Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 15 Jul 2025 01:27:09 +0000 Subject: [PATCH 23/89] addded velocity controller and urdf file --- dimos/hardware/piper_arm.py | 101 +++++- dimos/hardware/piper_description.urdf | 497 ++++++++++++++++++++++++++ 2 files changed, 588 insertions(+), 10 deletions(-) create mode 100755 dimos/hardware/piper_description.urdf diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 3bd2a2cf15..6b38d64bce 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -21,7 +21,11 @@ import numpy as np import time import subprocess - +import kinpy as kp +import sys +import termios +import tty +import select class PiperArm: def __init__(self, arm_name: str = "arm"): @@ -34,6 +38,7 @@ def __init__(self, arm_name: str = "arm"): self.enable() self.gotoZero() time.sleep(1) + self.init_vel_controller() def init_can(self): result = subprocess.run( @@ -104,6 +109,38 @@ def resetArm(self): self.arm.MotionCtrl_2(0, 0, 0, 0x00) print(f"[PiperArm] Resetting arm") + def init_vel_controller(self): + self.chain = kp.build_serial_chain_from_urdf(open("dimos/dimos/hardware/piper_description.urdf"), "gripper_base") + self.J = self.chain.jacobian(np.zeros(6)) + self.J_pinv = np.linalg.pinv(self.J) + self.dt = 0.01 + + def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): + x_dot = x_dot * 1000 + y_dot = y_dot * 1000 + z_dot = z_dot * 1000 + R_dot = R_dot * 1000 + P_dot = P_dot * 1000 + Y_dot = Y_dot * 1000 + + joint_state = self.arm.GetArmJointMsgs().joint_state + # print(f"[PiperArm] Current Joints: {joint_state}", type(joint_state)) + joint_angles = np.array([joint_state.joint_1, joint_state.joint_2, joint_state.joint_3, joint_state.joint_4, joint_state.joint_5, joint_state.joint_6]) + # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) + factor = 57295.7795 #1000*180/3.1415926 + joint_angles = joint_angles * factor # convert to radians + # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) + + q = np.array([joint_angles[0], joint_angles[1], joint_angles[2], joint_angles[3], joint_angles[4], joint_angles[5]]) + # print(f"[PiperArm] Current Joints: {q}") + time.sleep(0.005) + dq = self.J_pinv@np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot])*self.dt + newq = q + dq + + self.arm.MotionCtrl_2(0x01, 0x01, 100, 0x00) + self.arm.JointCtrl(int(round(newq[0])), int(round(newq[1])), int(round(newq[2])), int(round(newq[3])), int(round(newq[4])), int(round(newq[5]))) + # print(f"[PiperArm] Moving to Joints to : {newq}") + def disable(self): self.softStop() @@ -119,13 +156,57 @@ def disable(self): print("get_EE_pose") arm.get_EE_pose() - while True: - arm.cmd_EE_pose(60, 0, 300, 0, 85, 0) - time.sleep(1) - arm.cmd_EE_pose(60, 0, 260, 0, 85, 0) - time.sleep(1) - user_input = input("Press Enter to repeat, or type 'q' to quit: ") - if user_input.strip().lower() == "q": - arm.disable() - break + def get_key(timeout=0.1): + """Non-blocking key reader for arrow keys.""" + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + rlist, _, _ = select.select([fd], [], [], timeout) + if rlist: + ch1 = sys.stdin.read(1) + if ch1 == '\x1b': # Arrow keys start with ESC + ch2 = sys.stdin.read(1) + if ch2 == '[': + ch3 = sys.stdin.read(1) + return ch1 + ch2 + ch3 + else: + return ch1 + return None + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + def teleop_linear_vel(arm): + print("Use arrow keys to control linear velocity (x/y/z). Press 'q' to quit.") + print("Up/Down: +x/-x, Left/Right: +y/-y, 'w'/'s': +z/-z") + x_dot, y_dot, z_dot = 0.0, 0.0, 0.0 + while True: + key = get_key(timeout=0.1) + if key == '\x1b[A': # Up arrow + x_dot += 0.01 + elif key == '\x1b[B': # Down arrow + x_dot -= 0.01 + elif key == '\x1b[C': # Right arrow + y_dot += 0.01 + elif key == '\x1b[D': # Left arrow + y_dot -= 0.01 + elif key == 'w': + z_dot += 0.01 + elif key == 's': + z_dot -= 0.01 + elif key == 'q': + print("Exiting teleop.") + arm.disable() + break + + # Optionally, clamp velocities to reasonable limits + x_dot = max(min(x_dot, 0.2), -0.2) + y_dot = max(min(y_dot, 0.2), -0.2) + z_dot = max(min(z_dot, 0.2), -0.2) + + # Only linear velocities, angular set to zero + arm.cmd_vel(x_dot, y_dot, z_dot, 0, 0, 0) + print(f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s") + + teleop_linear_vel(arm) diff --git a/dimos/hardware/piper_description.urdf b/dimos/hardware/piper_description.urdf new file mode 100755 index 0000000000..21209b6dbb --- /dev/null +++ b/dimos/hardware/piper_description.urdf @@ -0,0 +1,497 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 82a7a1132b401c5fb74226d7af8e84db2ebf8977 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 15 Jul 2025 02:29:14 +0000 Subject: [PATCH 24/89] added velocity controller with end pose ctrl --- dimos/hardware/piper_arm.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 6b38d64bce..3fdd427958 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -141,6 +141,22 @@ def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): self.arm.JointCtrl(int(round(newq[0])), int(round(newq[1])), int(round(newq[2])), int(round(newq[3])), int(round(newq[4])), int(round(newq[5]))) # print(f"[PiperArm] Moving to Joints to : {newq}") + def cmd_vel_ee(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): + factor = 1000 + x_dot = x_dot * factor + y_dot = y_dot * factor + z_dot = z_dot * factor + R_dot = R_dot * factor + P_dot = P_dot * factor + Y_dot = Y_dot * factor + + current_pose = self.get_EE_pose().end_pose + current_pose = np.array([current_pose.X_axis, current_pose.Y_axis, current_pose.Z_axis, current_pose.RX_axis, current_pose.RY_axis, current_pose.RZ_axis]) + current_pose = current_pose * factor + current_pose = current_pose + np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot])*self.dt + current_pose = current_pose / factor + self.cmd_EE_pose(current_pose[0], current_pose[1], current_pose[2], current_pose[3], current_pose[4], current_pose[5]) + def disable(self): self.softStop() @@ -206,7 +222,8 @@ def teleop_linear_vel(arm): z_dot = max(min(z_dot, 0.2), -0.2) # Only linear velocities, angular set to zero - arm.cmd_vel(x_dot, y_dot, z_dot, 0, 0, 0) + arm.cmd_vel_ee(x_dot, y_dot, z_dot, 0, 0, 0) print(f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s") teleop_linear_vel(arm) + From 3b08d13cce8d3c251d3a6da3675f172b9d463745 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 15 Jul 2025 02:34:08 +0000 Subject: [PATCH 25/89] added dt to the loop --- dimos/hardware/piper_arm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 3fdd427958..1ac841ad2b 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -156,6 +156,7 @@ def cmd_vel_ee(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): current_pose = current_pose + np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot])*self.dt current_pose = current_pose / factor self.cmd_EE_pose(current_pose[0], current_pose[1], current_pose[2], current_pose[3], current_pose[4], current_pose[5]) + time.sleep(self.dt) def disable(self): self.softStop() From 157ed2f7e477e60964b6ca455fa0bc8f60f89540 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 14 May 2025 13:24:16 -0700 Subject: [PATCH 26/89] initial implementation of pointcloud filtering and segmentation --- dimos/perception/pointcloud/pointcloud_seg.py | 338 ++++++++++++++++++ 1 file changed, 338 insertions(+) create mode 100644 dimos/perception/pointcloud/pointcloud_seg.py diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py new file mode 100644 index 0000000000..0968a0b338 --- /dev/null +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -0,0 +1,338 @@ +import numpy as np +import cv2 +import yaml +import os +import sys +from PIL import Image, ImageDraw +from dimos.perception.segmentation import Sam2DSegmenter +from dimos.perception.pointcloud.utils import ( + load_camera_matrix_from_yaml, + create_masked_point_cloud, + o3d_point_cloud_to_numpy, + rotation_to_o3d +) +from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit +import torch +import open3d as o3d + +class PointcloudSegmentation: + def __init__( + self, + model_path="FastSAM-s.pt", + device="cuda", + color_intrinsics=None, + depth_intrinsics=None, + enable_tracking=True, + enable_analysis=True, + ): + """ + Initialize processor to segment objects in RGB images and extract their point clouds. + + Args: + model_path: Path to the FastSAM model + device: Computation device ("cuda" or "cpu") + color_intrinsics: Path to YAML file or list with color camera intrinsics [fx, fy, cx, cy] + depth_intrinsics: Path to YAML file or list with depth camera intrinsics [fx, fy, cx, cy] + enable_tracking: Whether to enable object tracking + enable_analysis: Whether to enable object analysis (labels, etc.) + min_analysis_interval: Minimum interval between analysis runs in seconds + """ + # Initialize segmenter + self.segmenter = Sam2DSegmenter( + model_path=model_path, + device=device, + use_tracker=enable_tracking, + use_analyzer=enable_analysis, + ) + + # Store settings + self.enable_tracking = enable_tracking + self.enable_analysis = enable_analysis + + # Load camera matrices + self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) + self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) + + def generate_color_from_id(self, track_id): + """Generate a consistent color for a given tracking ID.""" + np.random.seed(track_id) + color = np.random.randint(0, 255, 3) + np.random.seed(None) + return color + + def process_images(self, color_img, depth_img, fit_3d_cuboids=True): + """ + Process color and depth images to segment objects and extract point clouds. + Uses Open3D for point cloud processing. + + Args: + color_img: RGB image as numpy array (H, W, 3) + depth_img: Depth image as numpy array (H, W) in meters + fit_3d_cuboids: Whether to fit 3D cuboids to each object + + Returns: + dict: Dictionary containing: + - viz_image: Visualization image with detections + - objects: List of dicts for each object with: + - mask: Segmentation mask (H, W, bool) + - bbox: Bounding box [x1, y1, x2, y2] + - target_id: Tracking ID + - confidence: Detection confidence + - name: Object name (if analyzer enabled) + - point_cloud: Open3D point cloud object + - point_cloud_numpy: Nx6 array of XYZRGB points (for compatibility) + - color: RGB color for visualization + - cuboid_params: Cuboid parameters (if fit_3d_cuboids=True) + """ + if self.depth_camera_matrix is None: + raise ValueError("Depth camera matrix must be provided to process images") + + # Run segmentation + masks, bboxes, target_ids, probs, names = self.segmenter.process_image(color_img) + print(f"Found {len(masks)} segmentation masks") + + # Run analysis if enabled + if self.enable_analysis: + self.segmenter.run_analysis(color_img, bboxes, target_ids) + names = self.segmenter.get_object_names(target_ids, names) + + # Create visualization image + viz_img = self.segmenter.visualize_results( + color_img.copy(), + masks, + bboxes, + target_ids, + probs, + names + ) + + # Process each object + objects = [] + for i, (mask, bbox, target_id, prob, name) in enumerate(zip(masks, bboxes, target_ids, probs, names)): + # Convert mask to numpy if it's a tensor + if hasattr(mask, 'cpu'): + mask = mask.cpu().numpy() + + # Ensure mask is proper boolean array with correct dimensions + mask = mask.astype(bool) + + # Ensure mask has the same shape as the depth image + if mask.shape != depth_img.shape[:2]: + print(f"Warning: Mask shape {mask.shape} doesn't match depth image shape {depth_img.shape[:2]}") + if len(mask.shape) > 2: + # If mask has extra dimensions, take the first channel + mask = mask[:,:,0] if mask.shape[2] > 0 else mask[:,:,0] + + # If shapes still don't match, try to resize the mask + if mask.shape != depth_img.shape[:2]: + mask = cv2.resize(mask.astype(np.uint8), + (depth_img.shape[1], depth_img.shape[0]), + interpolation=cv2.INTER_NEAREST).astype(bool) + + try: + # Create point cloud using Open3D + pcd = create_masked_point_cloud( + color_img, + depth_img, + mask, + self.depth_camera_matrix, + depth_scale=1.0 # Assuming depth is already in meters + ) + + # Skip if no points + if len(np.asarray(pcd.points)) == 0: + print(f"Skipping object {i+1}: No points in point cloud") + continue + + # Generate color for visualization + rgb_color = self.generate_color_from_id(target_id) + + # Create object data + obj_data = { + "mask": mask, + "bbox": bbox, + "target_id": target_id, + "confidence": float(prob), + "name": name if name else "", + "point_cloud": pcd, + "point_cloud_numpy": o3d_point_cloud_to_numpy(pcd), + "color": rgb_color + } + + # Fit 3D cuboid if requested + if fit_3d_cuboids: + points = np.asarray(pcd.points) + cuboid_params = fit_cuboid(points) + obj_data["cuboid_params"] = cuboid_params + + # Update visualization with cuboid if available + if cuboid_params is not None and self.color_camera_matrix is not None: + viz_img = visualize_fit(viz_img, cuboid_params, self.color_camera_matrix) + + objects.append(obj_data) + + except Exception as e: + print(f"Error processing object {i+1}: {e}") + continue + + # Clean up GPU memory if using CUDA + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return { + "viz_image": viz_img, + "objects": objects + } + + def cleanup(self): + """Clean up resources.""" + if hasattr(self, 'segmenter'): + self.segmenter.cleanup() + +def main(): + """ + Main function to test the PointcloudSegmentation class with data from rgbd_data folder. + """ + + def find_first_image(directory): + """Find the first image file in the given directory.""" + image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] + for filename in sorted(os.listdir(directory)): + if any(filename.lower().endswith(ext) for ext in image_extensions): + return os.path.join(directory, filename) + return None + + # Define paths + script_dir = os.path.dirname(os.path.abspath(__file__)) + dimos_dir = os.path.abspath(os.path.join(script_dir, "../../../")) + data_dir = os.path.join(dimos_dir, "assets/rgbd_data") + + color_info_path = os.path.join(data_dir, "color_camera_info.yaml") + depth_info_path = os.path.join(data_dir, "depth_camera_info.yaml") + + color_dir = os.path.join(data_dir, "color") + depth_dir = os.path.join(data_dir, "depth") + + # Find first color and depth images + color_img_path = find_first_image(color_dir) + depth_img_path = find_first_image(depth_dir) + + if not color_img_path or not depth_img_path: + print(f"Error: Could not find color or depth images in {data_dir}") + return + + print(f"Found color image: {color_img_path}") + print(f"Found depth image: {depth_img_path}") + + # Load images + color_img = cv2.imread(color_img_path) + if color_img is None: + print(f"Error: Could not load color image from {color_img_path}") + return + + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) # Convert to RGB + + depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) + if depth_img is None: + print(f"Error: Could not load depth image from {depth_img_path}") + return + + # Convert depth to meters if needed (adjust scale as needed for your data) + if depth_img.dtype == np.uint16: + # Convert from mm to meters for typical depth cameras + depth_img = depth_img.astype(np.float32) / 1000.0 + + # Verify image shapes for debugging + print(f"Color image shape: {color_img.shape}") + print(f"Depth image shape: {depth_img.shape}") + + # Initialize segmentation with direct camera matrices + seg = PointcloudSegmentation( + model_path="FastSAM-s.pt", # Adjust path as needed + device="cuda" if torch.cuda.is_available() else "cpu", + color_intrinsics=color_info_path, + depth_intrinsics=depth_info_path, + enable_tracking=False, + enable_analysis=True + ) + + # Process images + print("Processing images...") + try: + results = seg.process_images(color_img, depth_img, fit_3d_cuboids=True) + + # Show segmentation results using PIL instead of OpenCV + viz_img = results["viz_image"] + + # Convert OpenCV image (BGR) to PIL image (RGB) + pil_img = Image.fromarray(cv2.cvtColor(viz_img, cv2.COLOR_BGR2RGB)) + + # Display the image using PIL + pil_img.show(title="Segmentation Results") + + # Add a short pause to ensure the image has time to display + import time + time.sleep(0.5) + + print(f"Found {len(results['objects'])} objects with valid point clouds") + + # Visualize all point clouds in a single window + all_pcds = [] + for i, obj in enumerate(results['objects']): + pcd = obj['point_cloud'] + + # Optionally add axis-aligned bounding box visualization + if 'cuboid_params' in obj and obj['cuboid_params'] is not None: + cuboid = obj['cuboid_params'] + + # Create oriented bounding box using the rotation matrix instead of axis-aligned box + center = cuboid['center'] + dimensions = cuboid['dimensions'] + rotation = rotation_to_o3d(cuboid['rotation']) + + # Create oriented bounding box + obb = o3d.geometry.OrientedBoundingBox( + center=center, + R=rotation, + extent=dimensions + ) + obb.color = [1, 0, 0] # Red bounding box + all_pcds.append(obb) + + # Add a small coordinate frame at the center of each object to show orientation + coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=min(dimensions) * 0.5, + origin=center + ) + all_pcds.append(coord_frame) + + # Add the point cloud + all_pcds.append(pcd) + + # Add coordinate frame at origin + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + all_pcds.append(coordinate_frame) + + # Show point clouds + if all_pcds: + o3d.visualization.draw_geometries(all_pcds, + window_name="Segmented Objects", + width=1280, + height=720, + left=50, + top=50) + else: + print("No objects with valid point clouds found.") + + except Exception as e: + print(f"Error during processing: {str(e)}") + import traceback + traceback.print_exc() + + # Clean up resources + seg.cleanup() + print("Done!") + + +if __name__ == "__main__": + main() From d89efb835281a2762ec0a2ae9f69a076c1557605 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 14 May 2025 13:39:32 -0700 Subject: [PATCH 27/89] small bug fix --- dimos/perception/pointcloud/pointcloud_seg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py index 0968a0b338..a915cbd69e 100644 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -9,7 +9,6 @@ load_camera_matrix_from_yaml, create_masked_point_cloud, o3d_point_cloud_to_numpy, - rotation_to_o3d ) from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit import torch @@ -288,7 +287,7 @@ def find_first_image(directory): # Create oriented bounding box using the rotation matrix instead of axis-aligned box center = cuboid['center'] dimensions = cuboid['dimensions'] - rotation = rotation_to_o3d(cuboid['rotation']) + rotation = cuboid['rotation'] # Create oriented bounding box obb = o3d.geometry.OrientedBoundingBox( From 1b77341926d8b31db48401f608256e08c5d450f0 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 20 May 2025 10:54:59 -0700 Subject: [PATCH 28/89] added basic RANSAC plane remove algorithm --- dimos/perception/pointcloud/pointcloud_seg.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py index a915cbd69e..75607d65be 100644 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -9,6 +9,8 @@ load_camera_matrix_from_yaml, create_masked_point_cloud, o3d_point_cloud_to_numpy, + create_o3d_point_cloud_from_rgbd, + segment_and_remove_plane ) from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit import torch @@ -82,6 +84,8 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): - point_cloud_numpy: Nx6 array of XYZRGB points (for compatibility) - color: RGB color for visualization - cuboid_params: Cuboid parameters (if fit_3d_cuboids=True) + - raw_point_cloud: Open3D point cloud object + - plane_removed_point_cloud: Open3D point cloud object with dominant plane removed """ if self.depth_camera_matrix is None: raise ValueError("Depth camera matrix must be provided to process images") @@ -173,6 +177,9 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): except Exception as e: print(f"Error processing object {i+1}: {e}") continue + + raw_point_cloud = create_o3d_point_cloud_from_rgbd(color_img, depth_img, self.depth_camera_matrix) + plane_removed_point_cloud = segment_and_remove_plane(raw_point_cloud) # Clean up GPU memory if using CUDA if torch.cuda.is_available(): @@ -180,7 +187,9 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): return { "viz_image": viz_img, - "objects": objects + "objects": objects, + "raw_point_cloud": raw_point_cloud, + "plane_removed_point_cloud": plane_removed_point_cloud } def cleanup(self): @@ -322,7 +331,22 @@ def find_first_image(directory): top=50) else: print("No objects with valid point clouds found.") - + + # Show raw point cloud + o3d.visualization.draw_geometries([results['raw_point_cloud']], + window_name="Raw Point Cloud", + width=1280, + height=720, + left=50, + top=50) + + # Show plane removed point cloud + o3d.visualization.draw_geometries([results['plane_removed_point_cloud']], + window_name="Plane Removed Point Cloud", + width=1280, + height=720, + left=50, + top=50) except Exception as e: print(f"Error during processing: {str(e)}") import traceback From dae60e8dee3105a710d9d6bc5aad9d47d38bd9a7 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 2 Jun 2025 00:40:31 -0700 Subject: [PATCH 29/89] refactored and cleanup pointcloud filtering --- dimos/perception/pointcloud/pointcloud_seg.py | 280 +++++++++--------- 1 file changed, 135 insertions(+), 145 deletions(-) diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py index 75607d65be..de3297b7c9 100644 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -5,113 +5,99 @@ import sys from PIL import Image, ImageDraw from dimos.perception.segmentation import Sam2DSegmenter +from dimos.types.segmentation import SegmentationType from dimos.perception.pointcloud.utils import ( load_camera_matrix_from_yaml, create_masked_point_cloud, o3d_point_cloud_to_numpy, create_o3d_point_cloud_from_rgbd, - segment_and_remove_plane ) from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit import torch import open3d as o3d -class PointcloudSegmentation: +class PointcloudFiltering: def __init__( self, - model_path="FastSAM-s.pt", - device="cuda", color_intrinsics=None, depth_intrinsics=None, - enable_tracking=True, - enable_analysis=True, + enable_statistical_filtering=True, + enable_cuboid_fitting=True, + color_weight=0.3, + statistical_neighbors=20, + statistical_std_ratio=2.0, ): """ - Initialize processor to segment objects in RGB images and extract their point clouds. + Initialize processor to filter point clouds from segmented objects. Args: - model_path: Path to the FastSAM model - device: Computation device ("cuda" or "cpu") color_intrinsics: Path to YAML file or list with color camera intrinsics [fx, fy, cx, cy] depth_intrinsics: Path to YAML file or list with depth camera intrinsics [fx, fy, cx, cy] - enable_tracking: Whether to enable object tracking - enable_analysis: Whether to enable object analysis (labels, etc.) - min_analysis_interval: Minimum interval between analysis runs in seconds + enable_statistical_filtering: Whether to apply statistical outlier filtering + enable_cuboid_fitting: Whether to fit 3D cuboids to objects + color_weight: Weight for blending generated color with original color (0.0 = original, 1.0 = generated) + statistical_neighbors: Number of neighbors for statistical filtering + statistical_std_ratio: Standard deviation ratio for statistical filtering """ - # Initialize segmenter - self.segmenter = Sam2DSegmenter( - model_path=model_path, - device=device, - use_tracker=enable_tracking, - use_analyzer=enable_analysis, - ) - # Store settings - self.enable_tracking = enable_tracking - self.enable_analysis = enable_analysis + self.enable_statistical_filtering = enable_statistical_filtering + self.enable_cuboid_fitting = enable_cuboid_fitting + self.color_weight = color_weight + self.statistical_neighbors = statistical_neighbors + self.statistical_std_ratio = statistical_std_ratio # Load camera matrices self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) - def generate_color_from_id(self, track_id): - """Generate a consistent color for a given tracking ID.""" - np.random.seed(track_id) + def generate_color_from_id(self, object_id): + """Generate a consistent color for a given object ID.""" + np.random.seed(object_id) color = np.random.randint(0, 255, 3) np.random.seed(None) return color - def process_images(self, color_img, depth_img, fit_3d_cuboids=True): + def process_images(self, color_img, depth_img, segmentation_result): """ - Process color and depth images to segment objects and extract point clouds. - Uses Open3D for point cloud processing. + Process color and depth images with segmentation results to create filtered point clouds. Args: color_img: RGB image as numpy array (H, W, 3) depth_img: Depth image as numpy array (H, W) in meters - fit_3d_cuboids: Whether to fit 3D cuboids to each object + segmentation_result: SegmentationType object containing masks and metadata Returns: dict: Dictionary containing: - - viz_image: Visualization image with detections - objects: List of dicts for each object with: + - object_id: Object tracking ID - mask: Segmentation mask (H, W, bool) - bbox: Bounding box [x1, y1, x2, y2] - - target_id: Tracking ID - confidence: Detection confidence - - name: Object name (if analyzer enabled) - - point_cloud: Open3D point cloud object + - label: Object label/name + - point_cloud: Open3D point cloud object (filtered and colored) - point_cloud_numpy: Nx6 array of XYZRGB points (for compatibility) - color: RGB color for visualization - - cuboid_params: Cuboid parameters (if fit_3d_cuboids=True) - - raw_point_cloud: Open3D point cloud object - - plane_removed_point_cloud: Open3D point cloud object with dominant plane removed + - cuboid_params: Cuboid parameters (if enabled) + - filtering_stats: Filtering statistics (if filtering enabled) """ if self.depth_camera_matrix is None: raise ValueError("Depth camera matrix must be provided to process images") - # Run segmentation - masks, bboxes, target_ids, probs, names = self.segmenter.process_image(color_img) - print(f"Found {len(masks)} segmentation masks") - - # Run analysis if enabled - if self.enable_analysis: - self.segmenter.run_analysis(color_img, bboxes, target_ids) - names = self.segmenter.get_object_names(target_ids, names) - - # Create visualization image - viz_img = self.segmenter.visualize_results( - color_img.copy(), - masks, - bboxes, - target_ids, - probs, - names - ) + # Extract masks and metadata from segmentation result + masks = segmentation_result.masks + metadata = segmentation_result.metadata + objects_metadata = metadata.get('objects', []) # Process each object objects = [] - for i, (mask, bbox, target_id, prob, name) in enumerate(zip(masks, bboxes, target_ids, probs, names)): + for i, mask in enumerate(masks): + # Get object metadata if available + obj_meta = objects_metadata[i] if i < len(objects_metadata) else {} + object_id = obj_meta.get('object_id', i) + bbox = obj_meta.get('bbox', [0, 0, 0, 0]) + confidence = obj_meta.get('prob', 1.0) + label = obj_meta.get('label', '') + # Convert mask to numpy if it's a tensor if hasattr(mask, 'cpu'): mask = mask.cpu().numpy() @@ -121,12 +107,9 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): # Ensure mask has the same shape as the depth image if mask.shape != depth_img.shape[:2]: - print(f"Warning: Mask shape {mask.shape} doesn't match depth image shape {depth_img.shape[:2]}") if len(mask.shape) > 2: - # If mask has extra dimensions, take the first channel mask = mask[:,:,0] if mask.shape[2] > 0 else mask[:,:,0] - # If shapes still don't match, try to resize the mask if mask.shape != depth_img.shape[:2]: mask = cv2.resize(mask.astype(np.uint8), (depth_img.shape[1], depth_img.shape[0]), @@ -139,67 +122,88 @@ def process_images(self, color_img, depth_img, fit_3d_cuboids=True): depth_img, mask, self.depth_camera_matrix, - depth_scale=1.0 # Assuming depth is already in meters + depth_scale=1.0 ) # Skip if no points if len(np.asarray(pcd.points)) == 0: - print(f"Skipping object {i+1}: No points in point cloud") continue # Generate color for visualization - rgb_color = self.generate_color_from_id(target_id) + rgb_color = self.generate_color_from_id(object_id) + + # Apply weighted colored mask to the point cloud + if len(np.asarray(pcd.colors)) > 0: + original_colors = np.asarray(pcd.colors) + generated_color = np.array(rgb_color) / 255.0 + colored_mask = (1.0 - self.color_weight) * original_colors + self.color_weight * generated_color + colored_mask = np.clip(colored_mask, 0.0, 1.0) + pcd.colors = o3d.utility.Vector3dVector(colored_mask) + + # Apply statistical outlier filtering if enabled + filtering_stats = None + if self.enable_statistical_filtering: + num_points_before = len(np.asarray(pcd.points)) + pcd_filtered, outlier_indices = pcd.remove_statistical_outlier( + nb_neighbors=self.statistical_neighbors, + std_ratio=self.statistical_std_ratio + ) + num_points_after = len(np.asarray(pcd_filtered.points)) + num_outliers_removed = num_points_before - num_points_after + + pcd = pcd_filtered + + filtering_stats = { + "points_before": num_points_before, + "points_after": num_points_after, + "outliers_removed": num_outliers_removed, + "outlier_percentage": 100.0 * num_outliers_removed / num_points_before if num_points_before > 0 else 0 + } # Create object data obj_data = { + "object_id": object_id, "mask": mask, "bbox": bbox, - "target_id": target_id, - "confidence": float(prob), - "name": name if name else "", + "confidence": float(confidence), + "label": label, "point_cloud": pcd, "point_cloud_numpy": o3d_point_cloud_to_numpy(pcd), - "color": rgb_color + "color": rgb_color, } - # Fit 3D cuboid if requested - if fit_3d_cuboids: + # Add optional data if available + if filtering_stats is not None: + obj_data["filtering_stats"] = filtering_stats + + # Fit 3D cuboid if enabled + if self.enable_cuboid_fitting: points = np.asarray(pcd.points) cuboid_params = fit_cuboid(points) - obj_data["cuboid_params"] = cuboid_params - - # Update visualization with cuboid if available - if cuboid_params is not None and self.color_camera_matrix is not None: - viz_img = visualize_fit(viz_img, cuboid_params, self.color_camera_matrix) + if cuboid_params is not None: + obj_data["cuboid_params"] = cuboid_params objects.append(obj_data) except Exception as e: - print(f"Error processing object {i+1}: {e}") continue - raw_point_cloud = create_o3d_point_cloud_from_rgbd(color_img, depth_img, self.depth_camera_matrix) - plane_removed_point_cloud = segment_and_remove_plane(raw_point_cloud) - # Clean up GPU memory if using CUDA if torch.cuda.is_available(): torch.cuda.empty_cache() return { - "viz_image": viz_img, "objects": objects, - "raw_point_cloud": raw_point_cloud, - "plane_removed_point_cloud": plane_removed_point_cloud } def cleanup(self): """Clean up resources.""" - if hasattr(self, 'segmenter'): - self.segmenter.cleanup() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def main(): """ - Main function to test the PointcloudSegmentation class with data from rgbd_data folder. + Main function to test the PointcloudFiltering class with data from rgbd_data folder. """ def find_first_image(directory): @@ -229,132 +233,118 @@ def find_first_image(directory): print(f"Error: Could not find color or depth images in {data_dir}") return - print(f"Found color image: {color_img_path}") - print(f"Found depth image: {depth_img_path}") - # Load images color_img = cv2.imread(color_img_path) if color_img is None: print(f"Error: Could not load color image from {color_img_path}") return - color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) # Convert to RGB + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) if depth_img is None: print(f"Error: Could not load depth image from {depth_img_path}") return - # Convert depth to meters if needed (adjust scale as needed for your data) + # Convert depth to meters if needed if depth_img.dtype == np.uint16: - # Convert from mm to meters for typical depth cameras depth_img = depth_img.astype(np.float32) / 1000.0 - # Verify image shapes for debugging - print(f"Color image shape: {color_img.shape}") - print(f"Depth image shape: {depth_img.shape}") - - # Initialize segmentation with direct camera matrices - seg = PointcloudSegmentation( - model_path="FastSAM-s.pt", # Adjust path as needed + # Run segmentation + segmenter = Sam2DSegmenter( + model_path="FastSAM-s.pt", device="cuda" if torch.cuda.is_available() else "cpu", + use_tracker=False, + use_analyzer=True + ) + + masks, bboxes, target_ids, probs, names = segmenter.process_image(color_img) + segmenter.run_analysis(color_img, bboxes, target_ids) + names = segmenter.get_object_names(target_ids, names) + + # Create metadata + objects_metadata = [] + for i in range(len(bboxes)): + obj_data = { + "object_id": target_ids[i] if i < len(target_ids) else i, + "bbox": bboxes[i], + "prob": probs[i] if i < len(probs) else 1.0, + "label": names[i] if i < len(names) else "", + } + objects_metadata.append(obj_data) + + metadata = { + "frame": color_img, + "objects": objects_metadata + } + + numpy_masks = [mask.cpu().numpy() if hasattr(mask, 'cpu') else mask for mask in masks] + segmentation_result = SegmentationType(masks=numpy_masks, metadata=metadata) + + # Initialize filtering pipeline + filter_pipeline = PointcloudFiltering( color_intrinsics=color_info_path, depth_intrinsics=depth_info_path, - enable_tracking=False, - enable_analysis=True + enable_statistical_filtering=True, + enable_cuboid_fitting=True, + color_weight=0.3, + statistical_neighbors=20, + statistical_std_ratio=2.0, ) - # Process images - print("Processing images...") + # Process images through filtering pipeline try: - results = seg.process_images(color_img, depth_img, fit_3d_cuboids=True) - - # Show segmentation results using PIL instead of OpenCV - viz_img = results["viz_image"] + results = filter_pipeline.process_images(color_img, depth_img, segmentation_result) - # Convert OpenCV image (BGR) to PIL image (RGB) - pil_img = Image.fromarray(cv2.cvtColor(viz_img, cv2.COLOR_BGR2RGB)) - - # Display the image using PIL - pil_img.show(title="Segmentation Results") - - # Add a short pause to ensure the image has time to display - import time - time.sleep(0.5) - - print(f"Found {len(results['objects'])} objects with valid point clouds") - - # Visualize all point clouds in a single window + # Visualize filtered point clouds all_pcds = [] for i, obj in enumerate(results['objects']): pcd = obj['point_cloud'] - # Optionally add axis-aligned bounding box visualization + # Add cuboid visualization if available if 'cuboid_params' in obj and obj['cuboid_params'] is not None: cuboid = obj['cuboid_params'] - - # Create oriented bounding box using the rotation matrix instead of axis-aligned box center = cuboid['center'] dimensions = cuboid['dimensions'] rotation = cuboid['rotation'] - # Create oriented bounding box obb = o3d.geometry.OrientedBoundingBox( center=center, R=rotation, extent=dimensions ) - obb.color = [1, 0, 0] # Red bounding box + obb.color = [1, 0, 0] all_pcds.append(obb) - # Add a small coordinate frame at the center of each object to show orientation coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( size=min(dimensions) * 0.5, origin=center ) all_pcds.append(coord_frame) - # Add the point cloud all_pcds.append(pcd) # Add coordinate frame at origin coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) all_pcds.append(coordinate_frame) - # Show point clouds + # Show filtered point clouds if all_pcds: o3d.visualization.draw_geometries(all_pcds, - window_name="Segmented Objects", + window_name="Filtered Point Clouds", width=1280, height=720, left=50, top=50) - else: - print("No objects with valid point clouds found.") - - # Show raw point cloud - o3d.visualization.draw_geometries([results['raw_point_cloud']], - window_name="Raw Point Cloud", - width=1280, - height=720, - left=50, - top=50) - - # Show plane removed point cloud - o3d.visualization.draw_geometries([results['plane_removed_point_cloud']], - window_name="Plane Removed Point Cloud", - width=1280, - height=720, - left=50, - top=50) + except Exception as e: print(f"Error during processing: {str(e)}") import traceback traceback.print_exc() # Clean up resources - seg.cleanup() - print("Done!") + segmenter.cleanup() + filter_pipeline.cleanup() if __name__ == "__main__": From da1bc671f4c9ae811adb80d62b02ff171099974d Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Mon, 2 Jun 2025 07:45:16 +0000 Subject: [PATCH 30/89] CI code cleanup --- dimos/perception/pointcloud/pointcloud_seg.py | 189 +++++++++--------- 1 file changed, 94 insertions(+), 95 deletions(-) diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py index de3297b7c9..6c6c60c262 100644 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ b/dimos/perception/pointcloud/pointcloud_seg.py @@ -16,6 +16,7 @@ import torch import open3d as o3d + class PointcloudFiltering: def __init__( self, @@ -29,7 +30,7 @@ def __init__( ): """ Initialize processor to filter point clouds from segmented objects. - + Args: color_intrinsics: Path to YAML file or list with color camera intrinsics [fx, fy, cx, cy] depth_intrinsics: Path to YAML file or list with depth camera intrinsics [fx, fy, cx, cy] @@ -45,27 +46,27 @@ def __init__( self.color_weight = color_weight self.statistical_neighbors = statistical_neighbors self.statistical_std_ratio = statistical_std_ratio - + # Load camera matrices self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) - + def generate_color_from_id(self, object_id): """Generate a consistent color for a given object ID.""" np.random.seed(object_id) color = np.random.randint(0, 255, 3) np.random.seed(None) return color - + def process_images(self, color_img, depth_img, segmentation_result): """ Process color and depth images with segmentation results to create filtered point clouds. - + Args: color_img: RGB image as numpy array (H, W, 3) depth_img: Depth image as numpy array (H, W) in meters segmentation_result: SegmentationType object containing masks and metadata - + Returns: dict: Dictionary containing: - objects: List of dicts for each object with: @@ -82,84 +83,86 @@ def process_images(self, color_img, depth_img, segmentation_result): """ if self.depth_camera_matrix is None: raise ValueError("Depth camera matrix must be provided to process images") - + # Extract masks and metadata from segmentation result masks = segmentation_result.masks metadata = segmentation_result.metadata - objects_metadata = metadata.get('objects', []) - + objects_metadata = metadata.get("objects", []) + # Process each object objects = [] for i, mask in enumerate(masks): # Get object metadata if available obj_meta = objects_metadata[i] if i < len(objects_metadata) else {} - object_id = obj_meta.get('object_id', i) - bbox = obj_meta.get('bbox', [0, 0, 0, 0]) - confidence = obj_meta.get('prob', 1.0) - label = obj_meta.get('label', '') - + object_id = obj_meta.get("object_id", i) + bbox = obj_meta.get("bbox", [0, 0, 0, 0]) + confidence = obj_meta.get("prob", 1.0) + label = obj_meta.get("label", "") + # Convert mask to numpy if it's a tensor - if hasattr(mask, 'cpu'): + if hasattr(mask, "cpu"): mask = mask.cpu().numpy() - + # Ensure mask is proper boolean array with correct dimensions mask = mask.astype(bool) - + # Ensure mask has the same shape as the depth image if mask.shape != depth_img.shape[:2]: if len(mask.shape) > 2: - mask = mask[:,:,0] if mask.shape[2] > 0 else mask[:,:,0] - + mask = mask[:, :, 0] if mask.shape[2] > 0 else mask[:, :, 0] + if mask.shape != depth_img.shape[:2]: - mask = cv2.resize(mask.astype(np.uint8), - (depth_img.shape[1], depth_img.shape[0]), - interpolation=cv2.INTER_NEAREST).astype(bool) - + mask = cv2.resize( + mask.astype(np.uint8), + (depth_img.shape[1], depth_img.shape[0]), + interpolation=cv2.INTER_NEAREST, + ).astype(bool) + try: # Create point cloud using Open3D pcd = create_masked_point_cloud( - color_img, - depth_img, - mask, - self.depth_camera_matrix, - depth_scale=1.0 + color_img, depth_img, mask, self.depth_camera_matrix, depth_scale=1.0 ) - + # Skip if no points if len(np.asarray(pcd.points)) == 0: continue - + # Generate color for visualization rgb_color = self.generate_color_from_id(object_id) - + # Apply weighted colored mask to the point cloud if len(np.asarray(pcd.colors)) > 0: original_colors = np.asarray(pcd.colors) generated_color = np.array(rgb_color) / 255.0 - colored_mask = (1.0 - self.color_weight) * original_colors + self.color_weight * generated_color + colored_mask = ( + 1.0 - self.color_weight + ) * original_colors + self.color_weight * generated_color colored_mask = np.clip(colored_mask, 0.0, 1.0) pcd.colors = o3d.utility.Vector3dVector(colored_mask) - + # Apply statistical outlier filtering if enabled filtering_stats = None if self.enable_statistical_filtering: num_points_before = len(np.asarray(pcd.points)) pcd_filtered, outlier_indices = pcd.remove_statistical_outlier( nb_neighbors=self.statistical_neighbors, - std_ratio=self.statistical_std_ratio + std_ratio=self.statistical_std_ratio, ) num_points_after = len(np.asarray(pcd_filtered.points)) num_outliers_removed = num_points_before - num_points_after - + pcd = pcd_filtered - + filtering_stats = { "points_before": num_points_before, "points_after": num_points_after, "outliers_removed": num_outliers_removed, - "outlier_percentage": 100.0 * num_outliers_removed / num_points_before if num_points_before > 0 else 0 + "outlier_percentage": 100.0 * num_outliers_removed / num_points_before + if num_points_before > 0 + else 0, } - + # Create object data obj_data = { "object_id": object_id, @@ -171,36 +174,37 @@ def process_images(self, color_img, depth_img, segmentation_result): "point_cloud_numpy": o3d_point_cloud_to_numpy(pcd), "color": rgb_color, } - + # Add optional data if available if filtering_stats is not None: obj_data["filtering_stats"] = filtering_stats - + # Fit 3D cuboid if enabled if self.enable_cuboid_fitting: points = np.asarray(pcd.points) cuboid_params = fit_cuboid(points) if cuboid_params is not None: obj_data["cuboid_params"] = cuboid_params - + objects.append(obj_data) - + except Exception as e: continue # Clean up GPU memory if using CUDA if torch.cuda.is_available(): torch.cuda.empty_cache() - + return { "objects": objects, } - + def cleanup(self): """Clean up resources.""" if torch.cuda.is_available(): torch.cuda.empty_cache() + def main(): """ Main function to test the PointcloudFiltering class with data from rgbd_data folder. @@ -208,7 +212,7 @@ def main(): def find_first_image(directory): """Find the first image file in the given directory.""" - image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] + image_extensions = [".jpg", ".jpeg", ".png", ".bmp"] for filename in sorted(os.listdir(directory)): if any(filename.lower().endswith(ext) for ext in image_extensions): return os.path.join(directory, filename) @@ -218,50 +222,50 @@ def find_first_image(directory): script_dir = os.path.dirname(os.path.abspath(__file__)) dimos_dir = os.path.abspath(os.path.join(script_dir, "../../../")) data_dir = os.path.join(dimos_dir, "assets/rgbd_data") - + color_info_path = os.path.join(data_dir, "color_camera_info.yaml") depth_info_path = os.path.join(data_dir, "depth_camera_info.yaml") - + color_dir = os.path.join(data_dir, "color") depth_dir = os.path.join(data_dir, "depth") - + # Find first color and depth images color_img_path = find_first_image(color_dir) depth_img_path = find_first_image(depth_dir) - + if not color_img_path or not depth_img_path: print(f"Error: Could not find color or depth images in {data_dir}") return - + # Load images color_img = cv2.imread(color_img_path) if color_img is None: print(f"Error: Could not load color image from {color_img_path}") return - + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) - + depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) if depth_img is None: print(f"Error: Could not load depth image from {depth_img_path}") return - + # Convert depth to meters if needed if depth_img.dtype == np.uint16: depth_img = depth_img.astype(np.float32) / 1000.0 - + # Run segmentation segmenter = Sam2DSegmenter( model_path="FastSAM-s.pt", device="cuda" if torch.cuda.is_available() else "cpu", use_tracker=False, - use_analyzer=True + use_analyzer=True, ) - + masks, bboxes, target_ids, probs, names = segmenter.process_image(color_img) segmenter.run_analysis(color_img, bboxes, target_ids) names = segmenter.get_object_names(target_ids, names) - + # Create metadata objects_metadata = [] for i in range(len(bboxes)): @@ -272,15 +276,12 @@ def find_first_image(directory): "label": names[i] if i < len(names) else "", } objects_metadata.append(obj_data) - - metadata = { - "frame": color_img, - "objects": objects_metadata - } - - numpy_masks = [mask.cpu().numpy() if hasattr(mask, 'cpu') else mask for mask in masks] + + metadata = {"frame": color_img, "objects": objects_metadata} + + numpy_masks = [mask.cpu().numpy() if hasattr(mask, "cpu") else mask for mask in masks] segmentation_result = SegmentationType(masks=numpy_masks, metadata=metadata) - + # Initialize filtering pipeline filter_pipeline = PointcloudFiltering( color_intrinsics=color_info_path, @@ -291,57 +292,55 @@ def find_first_image(directory): statistical_neighbors=20, statistical_std_ratio=2.0, ) - + # Process images through filtering pipeline try: results = filter_pipeline.process_images(color_img, depth_img, segmentation_result) - + # Visualize filtered point clouds all_pcds = [] - for i, obj in enumerate(results['objects']): - pcd = obj['point_cloud'] - + for i, obj in enumerate(results["objects"]): + pcd = obj["point_cloud"] + # Add cuboid visualization if available - if 'cuboid_params' in obj and obj['cuboid_params'] is not None: - cuboid = obj['cuboid_params'] - center = cuboid['center'] - dimensions = cuboid['dimensions'] - rotation = cuboid['rotation'] - - obb = o3d.geometry.OrientedBoundingBox( - center=center, - R=rotation, - extent=dimensions - ) + if "cuboid_params" in obj and obj["cuboid_params"] is not None: + cuboid = obj["cuboid_params"] + center = cuboid["center"] + dimensions = cuboid["dimensions"] + rotation = cuboid["rotation"] + + obb = o3d.geometry.OrientedBoundingBox(center=center, R=rotation, extent=dimensions) obb.color = [1, 0, 0] all_pcds.append(obb) - + coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( - size=min(dimensions) * 0.5, - origin=center + size=min(dimensions) * 0.5, origin=center ) all_pcds.append(coord_frame) - + all_pcds.append(pcd) - + # Add coordinate frame at origin coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) all_pcds.append(coordinate_frame) - + # Show filtered point clouds if all_pcds: - o3d.visualization.draw_geometries(all_pcds, - window_name="Filtered Point Clouds", - width=1280, - height=720, - left=50, - top=50) - + o3d.visualization.draw_geometries( + all_pcds, + window_name="Filtered Point Clouds", + width=1280, + height=720, + left=50, + top=50, + ) + except Exception as e: print(f"Error during processing: {str(e)}") import traceback + traceback.print_exc() - + # Clean up resources segmenter.cleanup() filter_pipeline.cleanup() From 96e5e09c042a76a0bccc74a3f89dbf6780db0ae6 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 2 Jun 2025 17:22:17 -0700 Subject: [PATCH 31/89] cleanup pointcloud filtering --- dimos/perception/pointcloud/pointcloud_seg.py | 350 ------------------ 1 file changed, 350 deletions(-) delete mode 100644 dimos/perception/pointcloud/pointcloud_seg.py diff --git a/dimos/perception/pointcloud/pointcloud_seg.py b/dimos/perception/pointcloud/pointcloud_seg.py deleted file mode 100644 index 6c6c60c262..0000000000 --- a/dimos/perception/pointcloud/pointcloud_seg.py +++ /dev/null @@ -1,350 +0,0 @@ -import numpy as np -import cv2 -import yaml -import os -import sys -from PIL import Image, ImageDraw -from dimos.perception.segmentation import Sam2DSegmenter -from dimos.types.segmentation import SegmentationType -from dimos.perception.pointcloud.utils import ( - load_camera_matrix_from_yaml, - create_masked_point_cloud, - o3d_point_cloud_to_numpy, - create_o3d_point_cloud_from_rgbd, -) -from dimos.perception.pointcloud.cuboid_fit import fit_cuboid, visualize_fit -import torch -import open3d as o3d - - -class PointcloudFiltering: - def __init__( - self, - color_intrinsics=None, - depth_intrinsics=None, - enable_statistical_filtering=True, - enable_cuboid_fitting=True, - color_weight=0.3, - statistical_neighbors=20, - statistical_std_ratio=2.0, - ): - """ - Initialize processor to filter point clouds from segmented objects. - - Args: - color_intrinsics: Path to YAML file or list with color camera intrinsics [fx, fy, cx, cy] - depth_intrinsics: Path to YAML file or list with depth camera intrinsics [fx, fy, cx, cy] - enable_statistical_filtering: Whether to apply statistical outlier filtering - enable_cuboid_fitting: Whether to fit 3D cuboids to objects - color_weight: Weight for blending generated color with original color (0.0 = original, 1.0 = generated) - statistical_neighbors: Number of neighbors for statistical filtering - statistical_std_ratio: Standard deviation ratio for statistical filtering - """ - # Store settings - self.enable_statistical_filtering = enable_statistical_filtering - self.enable_cuboid_fitting = enable_cuboid_fitting - self.color_weight = color_weight - self.statistical_neighbors = statistical_neighbors - self.statistical_std_ratio = statistical_std_ratio - - # Load camera matrices - self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) - self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) - - def generate_color_from_id(self, object_id): - """Generate a consistent color for a given object ID.""" - np.random.seed(object_id) - color = np.random.randint(0, 255, 3) - np.random.seed(None) - return color - - def process_images(self, color_img, depth_img, segmentation_result): - """ - Process color and depth images with segmentation results to create filtered point clouds. - - Args: - color_img: RGB image as numpy array (H, W, 3) - depth_img: Depth image as numpy array (H, W) in meters - segmentation_result: SegmentationType object containing masks and metadata - - Returns: - dict: Dictionary containing: - - objects: List of dicts for each object with: - - object_id: Object tracking ID - - mask: Segmentation mask (H, W, bool) - - bbox: Bounding box [x1, y1, x2, y2] - - confidence: Detection confidence - - label: Object label/name - - point_cloud: Open3D point cloud object (filtered and colored) - - point_cloud_numpy: Nx6 array of XYZRGB points (for compatibility) - - color: RGB color for visualization - - cuboid_params: Cuboid parameters (if enabled) - - filtering_stats: Filtering statistics (if filtering enabled) - """ - if self.depth_camera_matrix is None: - raise ValueError("Depth camera matrix must be provided to process images") - - # Extract masks and metadata from segmentation result - masks = segmentation_result.masks - metadata = segmentation_result.metadata - objects_metadata = metadata.get("objects", []) - - # Process each object - objects = [] - for i, mask in enumerate(masks): - # Get object metadata if available - obj_meta = objects_metadata[i] if i < len(objects_metadata) else {} - object_id = obj_meta.get("object_id", i) - bbox = obj_meta.get("bbox", [0, 0, 0, 0]) - confidence = obj_meta.get("prob", 1.0) - label = obj_meta.get("label", "") - - # Convert mask to numpy if it's a tensor - if hasattr(mask, "cpu"): - mask = mask.cpu().numpy() - - # Ensure mask is proper boolean array with correct dimensions - mask = mask.astype(bool) - - # Ensure mask has the same shape as the depth image - if mask.shape != depth_img.shape[:2]: - if len(mask.shape) > 2: - mask = mask[:, :, 0] if mask.shape[2] > 0 else mask[:, :, 0] - - if mask.shape != depth_img.shape[:2]: - mask = cv2.resize( - mask.astype(np.uint8), - (depth_img.shape[1], depth_img.shape[0]), - interpolation=cv2.INTER_NEAREST, - ).astype(bool) - - try: - # Create point cloud using Open3D - pcd = create_masked_point_cloud( - color_img, depth_img, mask, self.depth_camera_matrix, depth_scale=1.0 - ) - - # Skip if no points - if len(np.asarray(pcd.points)) == 0: - continue - - # Generate color for visualization - rgb_color = self.generate_color_from_id(object_id) - - # Apply weighted colored mask to the point cloud - if len(np.asarray(pcd.colors)) > 0: - original_colors = np.asarray(pcd.colors) - generated_color = np.array(rgb_color) / 255.0 - colored_mask = ( - 1.0 - self.color_weight - ) * original_colors + self.color_weight * generated_color - colored_mask = np.clip(colored_mask, 0.0, 1.0) - pcd.colors = o3d.utility.Vector3dVector(colored_mask) - - # Apply statistical outlier filtering if enabled - filtering_stats = None - if self.enable_statistical_filtering: - num_points_before = len(np.asarray(pcd.points)) - pcd_filtered, outlier_indices = pcd.remove_statistical_outlier( - nb_neighbors=self.statistical_neighbors, - std_ratio=self.statistical_std_ratio, - ) - num_points_after = len(np.asarray(pcd_filtered.points)) - num_outliers_removed = num_points_before - num_points_after - - pcd = pcd_filtered - - filtering_stats = { - "points_before": num_points_before, - "points_after": num_points_after, - "outliers_removed": num_outliers_removed, - "outlier_percentage": 100.0 * num_outliers_removed / num_points_before - if num_points_before > 0 - else 0, - } - - # Create object data - obj_data = { - "object_id": object_id, - "mask": mask, - "bbox": bbox, - "confidence": float(confidence), - "label": label, - "point_cloud": pcd, - "point_cloud_numpy": o3d_point_cloud_to_numpy(pcd), - "color": rgb_color, - } - - # Add optional data if available - if filtering_stats is not None: - obj_data["filtering_stats"] = filtering_stats - - # Fit 3D cuboid if enabled - if self.enable_cuboid_fitting: - points = np.asarray(pcd.points) - cuboid_params = fit_cuboid(points) - if cuboid_params is not None: - obj_data["cuboid_params"] = cuboid_params - - objects.append(obj_data) - - except Exception as e: - continue - - # Clean up GPU memory if using CUDA - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return { - "objects": objects, - } - - def cleanup(self): - """Clean up resources.""" - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def main(): - """ - Main function to test the PointcloudFiltering class with data from rgbd_data folder. - """ - - def find_first_image(directory): - """Find the first image file in the given directory.""" - image_extensions = [".jpg", ".jpeg", ".png", ".bmp"] - for filename in sorted(os.listdir(directory)): - if any(filename.lower().endswith(ext) for ext in image_extensions): - return os.path.join(directory, filename) - return None - - # Define paths - script_dir = os.path.dirname(os.path.abspath(__file__)) - dimos_dir = os.path.abspath(os.path.join(script_dir, "../../../")) - data_dir = os.path.join(dimos_dir, "assets/rgbd_data") - - color_info_path = os.path.join(data_dir, "color_camera_info.yaml") - depth_info_path = os.path.join(data_dir, "depth_camera_info.yaml") - - color_dir = os.path.join(data_dir, "color") - depth_dir = os.path.join(data_dir, "depth") - - # Find first color and depth images - color_img_path = find_first_image(color_dir) - depth_img_path = find_first_image(depth_dir) - - if not color_img_path or not depth_img_path: - print(f"Error: Could not find color or depth images in {data_dir}") - return - - # Load images - color_img = cv2.imread(color_img_path) - if color_img is None: - print(f"Error: Could not load color image from {color_img_path}") - return - - color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) - - depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) - if depth_img is None: - print(f"Error: Could not load depth image from {depth_img_path}") - return - - # Convert depth to meters if needed - if depth_img.dtype == np.uint16: - depth_img = depth_img.astype(np.float32) / 1000.0 - - # Run segmentation - segmenter = Sam2DSegmenter( - model_path="FastSAM-s.pt", - device="cuda" if torch.cuda.is_available() else "cpu", - use_tracker=False, - use_analyzer=True, - ) - - masks, bboxes, target_ids, probs, names = segmenter.process_image(color_img) - segmenter.run_analysis(color_img, bboxes, target_ids) - names = segmenter.get_object_names(target_ids, names) - - # Create metadata - objects_metadata = [] - for i in range(len(bboxes)): - obj_data = { - "object_id": target_ids[i] if i < len(target_ids) else i, - "bbox": bboxes[i], - "prob": probs[i] if i < len(probs) else 1.0, - "label": names[i] if i < len(names) else "", - } - objects_metadata.append(obj_data) - - metadata = {"frame": color_img, "objects": objects_metadata} - - numpy_masks = [mask.cpu().numpy() if hasattr(mask, "cpu") else mask for mask in masks] - segmentation_result = SegmentationType(masks=numpy_masks, metadata=metadata) - - # Initialize filtering pipeline - filter_pipeline = PointcloudFiltering( - color_intrinsics=color_info_path, - depth_intrinsics=depth_info_path, - enable_statistical_filtering=True, - enable_cuboid_fitting=True, - color_weight=0.3, - statistical_neighbors=20, - statistical_std_ratio=2.0, - ) - - # Process images through filtering pipeline - try: - results = filter_pipeline.process_images(color_img, depth_img, segmentation_result) - - # Visualize filtered point clouds - all_pcds = [] - for i, obj in enumerate(results["objects"]): - pcd = obj["point_cloud"] - - # Add cuboid visualization if available - if "cuboid_params" in obj and obj["cuboid_params"] is not None: - cuboid = obj["cuboid_params"] - center = cuboid["center"] - dimensions = cuboid["dimensions"] - rotation = cuboid["rotation"] - - obb = o3d.geometry.OrientedBoundingBox(center=center, R=rotation, extent=dimensions) - obb.color = [1, 0, 0] - all_pcds.append(obb) - - coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( - size=min(dimensions) * 0.5, origin=center - ) - all_pcds.append(coord_frame) - - all_pcds.append(pcd) - - # Add coordinate frame at origin - coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) - all_pcds.append(coordinate_frame) - - # Show filtered point clouds - if all_pcds: - o3d.visualization.draw_geometries( - all_pcds, - window_name="Filtered Point Clouds", - width=1280, - height=720, - left=50, - top=50, - ) - - except Exception as e: - print(f"Error during processing: {str(e)}") - import traceback - - traceback.print_exc() - - # Clean up resources - segmenter.cleanup() - filter_pipeline.cleanup() - - -if __name__ == "__main__": - main() From 32581e759321edc21df4daa0eb48263e1e8e51ac Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 4 Jun 2025 17:45:31 -0700 Subject: [PATCH 32/89] added an all-in-one manipulation pipeline that goes from stereo -> labels -> pointcloud --- dimos/perception/manip_aio_pipeline.py | 235 +++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 dimos/perception/manip_aio_pipeline.py diff --git a/dimos/perception/manip_aio_pipeline.py b/dimos/perception/manip_aio_pipeline.py new file mode 100644 index 0000000000..5594cc11df --- /dev/null +++ b/dimos/perception/manip_aio_pipeline.py @@ -0,0 +1,235 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 +import threading +from collections import deque +from reactivex import Observable +from reactivex import operators as ops +from typing import List, Optional, Dict +import time + +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization +from dimos.perception.common.utils import colorize_depth +from dimos.types.manipulation import ObjectData +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.perception.manip_aio_pipeline") + + +class ManipulationPipeline: + """ + Clean separated stream pipeline with frame buffering. + + - Object detection runs independently on RGB stream + - Point cloud processing subscribes to both detection and ZED streams separately + - Simple frame buffering to match RGB+depth+objects + """ + + def __init__( + self, + camera_intrinsics: List[float], # [fx, fy, cx, cy] + min_confidence: float = 0.6, + max_objects: int = 10, + vocabulary: Optional[str] = None, + ): + """ + Initialize the manipulation pipeline. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + max_objects: Maximum number of objects to process + max_bbox_size_percent: Maximum bbox size as percentage of image + vocabulary: Optional vocabulary for Detic detector + """ + self.camera_intrinsics = camera_intrinsics + self.min_confidence = min_confidence + + # Initialize object detector + self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) + + # Initialize point cloud processor + self.pointcloud_filter = PointcloudFiltering( + color_intrinsics=camera_intrinsics, + depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics + max_num_objects=max_objects, + ) + + logger.info(f"Initialized ManipulationPipeline with confidence={min_confidence}") + + def create_streams(self, zed_stream: Observable) -> Dict[str, Observable]: + """ + Create streams using exact old main logic. + """ + # Create ZED streams (from old main) + zed_frame_stream = zed_stream.pipe(ops.share()) + + # RGB stream for object detection (from old main) + video_stream = zed_frame_stream.pipe( + ops.map(lambda x: x.get("rgb") if x is not None else None), + ops.filter(lambda x: x is not None), + ops.share(), + ) + object_detector = ObjectDetectionStream( + camera_intrinsics=self.camera_intrinsics, + min_confidence=self.min_confidence, + class_filter=None, + detector=self.detector, + video_stream=video_stream, + disable_depth=True, + ) + + # Store latest frames for point cloud processing (from old main) + latest_rgb = None + latest_depth = None + latest_point_cloud_overlay = None + frame_lock = threading.Lock() + + # Subscribe to combined ZED frames (from old main) + def on_zed_frame(zed_data): + nonlocal latest_rgb, latest_depth + if zed_data is not None: + with frame_lock: + latest_rgb = zed_data.get("rgb") + latest_depth = zed_data.get("depth") + + # Depth stream for point cloud filtering (from old main) + def get_depth_or_overlay(zed_data): + if zed_data is None: + return None + + # Check if we have a point cloud overlay available + with frame_lock: + overlay = latest_point_cloud_overlay + + if overlay is not None: + return overlay + else: + # Return regular colorized depth + return colorize_depth(zed_data.get("depth"), max_depth=10.0) + + depth_stream = zed_frame_stream.pipe( + ops.map(get_depth_or_overlay), ops.filter(lambda x: x is not None), ops.share() + ) + + # Process object detection results with point cloud filtering (from old main) + def on_detection_next(result): + nonlocal latest_point_cloud_overlay + if "objects" in result and result["objects"]: + # Get latest RGB and depth frames + with frame_lock: + rgb = latest_rgb + depth = latest_depth + + if rgb is not None and depth is not None: + try: + filtered_objects = self.pointcloud_filter.process_images( + rgb, depth, result["objects"] + ) + + if filtered_objects: + # Create base image (colorized depth) + base_image = colorize_depth(depth, max_depth=10.0) + + # Create point cloud overlay visualization + overlay_viz = create_point_cloud_overlay_visualization( + base_image=base_image, + filtered_objects=filtered_objects, + camera_matrix=self.camera_intrinsics, + ) + + # Store the overlay for the stream + with frame_lock: + latest_point_cloud_overlay = overlay_viz + else: + # No filtered objects, clear overlay + with frame_lock: + latest_point_cloud_overlay = None + + except Exception as e: + logger.error(f"Error in point cloud filtering: {e}") + with frame_lock: + latest_point_cloud_overlay = None + + def on_error(error): + logger.error(f"Error in stream: {error}") + + def on_completed(): + logger.info("Stream completed") + + def start_subscriptions(): + """Start subscriptions in background thread (from old main)""" + # Subscribe to combined ZED frames + zed_frame_stream.subscribe(on_next=on_zed_frame) + + # Start subscriptions in background thread (from old main) + subscription_thread = threading.Thread(target=start_subscriptions, daemon=True) + subscription_thread.start() + time.sleep(2) # Give subscriptions time to start + + # Subscribe to object detection stream (from old main) + object_detector.get_stream().subscribe( + on_next=on_detection_next, on_error=on_error, on_completed=on_completed + ) + + # Create visualization stream for web interface (from old main) + viz_stream = object_detector.get_stream().pipe( + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), + ) + + return { + "detection_viz": viz_stream, + "pointcloud_viz": depth_stream, + "objects": object_detector.get_stream().pipe(ops.map(lambda x: x.get("objects", []))), + } + + def cleanup(self): + """Clean up resources.""" + if hasattr(self.detector, "cleanup"): + self.detector.cleanup() + if hasattr(self.pointcloud_filter, "cleanup"): + self.pointcloud_filter.cleanup() + logger.info("ManipulationPipeline cleaned up") + + +def create_manipulation_pipeline( + camera_intrinsics: List[float], + min_confidence: float = 0.6, + max_objects: int = 10, + vocabulary: Optional[str] = None, +) -> ManipulationPipeline: + """ + Factory function to create a ManipulationPipeline with sensible defaults. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + max_objects: Maximum number of objects to process + vocabulary: Optional vocabulary for Detic detector + + Returns: + Configured ManipulationPipeline instance + """ + return ManipulationPipeline( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + max_objects=max_objects, + vocabulary=vocabulary, + ) From 996629679e3185059ed43e93120efc974f8f0dd5 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 5 Jun 2025 18:18:56 -0700 Subject: [PATCH 33/89] added grasp generation to pipeline --- dimos/perception/manip_aio_pipeline.py | 304 +++++++++++++++--- ...est_manipulation_pipeline_visualization.py | 187 +++++++++++ 2 files changed, 453 insertions(+), 38 deletions(-) create mode 100644 tests/test_manipulation_pipeline_visualization.py diff --git a/dimos/perception/manip_aio_pipeline.py b/dimos/perception/manip_aio_pipeline.py index 5594cc11df..64f978bc78 100644 --- a/dimos/perception/manip_aio_pipeline.py +++ b/dimos/perception/manip_aio_pipeline.py @@ -12,21 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import cv2 +""" +Asynchronous, reactive manipulation pipeline for realtime detection, filtering, and grasp generation. +""" + +import asyncio +import json import threading -from collections import deque -from reactivex import Observable -from reactivex import operators as ops -from typing import List, Optional, Dict import time - +from typing import Dict, List, Optional +import numpy as np +import reactivex as rx +import reactivex.operators as ops +import websockets +from dimos.utils.logging_config import setup_logger from dimos.perception.detection2d.detic_2d_det import Detic2DDetector from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.object_detection_stream import ObjectDetectionStream from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization from dimos.perception.common.utils import colorize_depth -from dimos.types.manipulation import ObjectData from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.perception.manip_aio_pipeline") @@ -47,6 +51,8 @@ def __init__( min_confidence: float = 0.6, max_objects: int = 10, vocabulary: Optional[str] = None, + grasp_server_url: Optional[str] = None, + enable_grasp_generation: bool = False, ): """ Initialize the manipulation pipeline. @@ -55,12 +61,37 @@ def __init__( camera_intrinsics: [fx, fy, cx, cy] camera parameters min_confidence: Minimum detection confidence threshold max_objects: Maximum number of objects to process - max_bbox_size_percent: Maximum bbox size as percentage of image vocabulary: Optional vocabulary for Detic detector + grasp_server_url: Optional WebSocket URL for AnyGrasp server + enable_grasp_generation: Whether to enable async grasp generation """ self.camera_intrinsics = camera_intrinsics self.min_confidence = min_confidence + # Grasp generation settings + self.grasp_server_url = grasp_server_url + self.enable_grasp_generation = enable_grasp_generation + + # Asyncio event loop for WebSocket communication + self.grasp_loop = None + self.grasp_loop_thread = None + + # Storage for grasp results and filtered objects + self.latest_grasps: List[dict] = [] # Simplified: just a list of grasps + self.latest_filtered_objects = [] + self.grasp_lock = threading.Lock() + + # Track pending requests - simplified to single task + self.grasp_task: Optional[asyncio.Task] = None + + # Reactive subjects for streaming filtered objects and grasps + self.filtered_objects_subject = rx.subject.Subject() + self.grasps_subject = rx.subject.Subject() + + # Initialize grasp client if enabled + if self.enable_grasp_generation and self.grasp_server_url: + self._start_grasp_loop() + # Initialize object detector self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) @@ -73,7 +104,7 @@ def __init__( logger.info(f"Initialized ManipulationPipeline with confidence={min_confidence}") - def create_streams(self, zed_stream: Observable) -> Dict[str, Observable]: + def create_streams(self, zed_stream: rx.Observable) -> Dict[str, rx.Observable]: """ Create streams using exact old main logic. """ @@ -144,14 +175,49 @@ def on_detection_next(result): ) if filtered_objects: + # Store filtered objects + with self.grasp_lock: + self.latest_filtered_objects = filtered_objects + self.filtered_objects_subject.on_next(filtered_objects) + + # Request grasps if enabled + if self.enable_grasp_generation and filtered_objects: + logger.debug( + f"Requesting grasps for {len(filtered_objects)} filtered objects" + ) + task = self.request_scene_grasps(filtered_objects) + if task: + logger.debug( + "Grasp request task created, waiting for results..." + ) + + # Check for results after a delay + def check_grasps_later(): + logger.debug("Starting delayed grasp check...") + time.sleep(2.0) # Wait for grasp processing + grasps = self.get_latest_grasps() + if grasps: + logger.debug( + f"Found {len(grasps)} grasps in delayed check" + ) + self.grasps_subject.on_next(grasps) + logger.info(f"Received {len(grasps)} grasps for scene") + logger.debug(f"Grasps for scene: {grasps}") + else: + logger.debug("No grasps found in delayed check") + + threading.Thread(target=check_grasps_later, daemon=True).start() + else: + logger.debug("Failed to create grasp request task") + # Create base image (colorized depth) base_image = colorize_depth(depth, max_depth=10.0) # Create point cloud overlay visualization overlay_viz = create_point_cloud_overlay_visualization( base_image=base_image, - filtered_objects=filtered_objects, - camera_matrix=self.camera_intrinsics, + objects=filtered_objects, + intrinsics=self.camera_intrinsics, ) # Store the overlay for the stream @@ -194,42 +260,204 @@ def start_subscriptions(): ops.filter(lambda x: x is not None), ) + # Create filtered objects stream + filtered_objects_stream = self.filtered_objects_subject + + # Create grasps stream + grasps_stream = self.grasps_subject + return { "detection_viz": viz_stream, "pointcloud_viz": depth_stream, "objects": object_detector.get_stream().pipe(ops.map(lambda x: x.get("objects", []))), + "filtered_objects": filtered_objects_stream, + "grasps": grasps_stream, } + def _start_grasp_loop(self): + """Start asyncio event loop in a background thread for WebSocket communication.""" + + def run_loop(): + self.grasp_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.grasp_loop) + self.grasp_loop.run_forever() + + self.grasp_loop_thread = threading.Thread(target=run_loop, daemon=True) + self.grasp_loop_thread.start() + + # Wait for loop to start + while self.grasp_loop is None: + time.sleep(0.01) + + async def _send_grasp_request( + self, points: np.ndarray, colors: Optional[np.ndarray] + ) -> Optional[List[dict]]: + """Send grasp request to AnyGrasp server.""" + logger.debug(f"_send_grasp_request called with {len(points)} points") + + try: + logger.debug(f"Connecting to WebSocket: {self.grasp_server_url}") + async with websockets.connect(self.grasp_server_url) as websocket: + logger.debug("WebSocket connected successfully") + + # Use the correct format expected by AnyGrasp server + request = { + "points": points.tolist(), + "colors": colors.tolist() if colors is not None else None, + "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], # Default workspace limits + } + + logger.debug(f"Sending grasp request with {len(points)} points") + await websocket.send(json.dumps(request)) + + logger.debug("Waiting for response...") + response = await websocket.recv() + logger.debug(f"Received response: {len(response)} characters") + + # Parse response - server returns list of grasps directly + grasps = json.loads(response) + logger.debug(f"Received {len(grasps) if grasps else 0} grasps from server") + + if grasps and len(grasps) > 0: + # Convert to our format and store + converted_grasps = self._convert_grasp_format(grasps) + logger.debug(f"Converted to {len(converted_grasps)} grasps") + + with self.grasp_lock: + self.latest_grasps = converted_grasps + logger.debug(f"Stored {len(converted_grasps)} grasps") + return converted_grasps + else: + logger.warning("No grasps returned from server") + + except Exception as e: + logger.error(f"Error requesting grasps: {e}") + logger.debug(f"Error details: {e}") + + return None + + def request_scene_grasps(self, objects: List[dict]) -> Optional[asyncio.Task]: + """Request grasps for entire scene by combining all object point clouds.""" + logger.debug(f"request_scene_grasps called with {len(objects)} objects") + + if not self.grasp_loop or not objects: + logger.debug( + f"Cannot request grasps: grasp_loop={self.grasp_loop is not None}, objects={len(objects) if objects else 0}" + ) + return None + + # Combine all object point clouds + all_points = [] + all_colors = [] + + for obj in objects: + if "point_cloud_numpy" in obj and len(obj["point_cloud_numpy"]) > 0: + all_points.append(obj["point_cloud_numpy"]) + if "colors_numpy" in obj and obj["colors_numpy"] is not None: + all_colors.append(obj["colors_numpy"]) + logger.debug(f"Added object with {len(obj['point_cloud_numpy'])} points") + + if not all_points: + logger.debug("No points found in objects, cannot request grasps") + return None + + # Concatenate all points and colors + combined_points = np.vstack(all_points) + combined_colors = np.vstack(all_colors) if all_colors else None + + logger.debug( + f"Requesting scene grasps for combined point cloud with {len(combined_points)} points" + ) + logger.debug(f"Grasp server URL: {self.grasp_server_url}") + + # Create and schedule the task + try: + task = asyncio.run_coroutine_threadsafe( + self._send_grasp_request(combined_points, combined_colors), self.grasp_loop + ) + + self.grasp_task = task + logger.debug("Successfully created grasp request task") + return task + except Exception as e: + logger.error(f"Failed to create grasp request task: {e}") + return None + + def get_latest_grasps(self) -> Optional[List[dict]]: + """Get latest grasp results.""" + with self.grasp_lock: + return self.latest_grasps + + def clear_grasps(self) -> None: + """Clear all stored grasp results.""" + with self.grasp_lock: + self.latest_grasps = [] + + def _prepare_colors(self, colors: Optional[np.ndarray]) -> Optional[np.ndarray]: + """Prepare colors array, converting from various formats if needed.""" + if colors is None: + return None + + # Convert from 0-255 to 0-1 range if needed + if colors.max() > 1.0: + colors = colors / 255.0 + + return colors + + def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: + """Convert AnyGrasp format to our visualization format.""" + converted = [] + + for i, grasp in enumerate(anygrasp_grasps): + # Extract rotation matrix and convert to Euler angles + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) + euler_angles = self._rotation_matrix_to_euler(rotation_matrix) + + converted_grasp = { + "id": f"grasp_{i}", + "score": grasp.get("score", 0.0), + "width": grasp.get("width", 0.0), + "height": grasp.get("height", 0.0), + "depth": grasp.get("depth", 0.0), + "translation": grasp.get("translation", [0, 0, 0]), + "rotation_matrix": rotation_matrix.tolist(), + "euler_angles": euler_angles, + } + converted.append(converted_grasp) + + # Sort by score descending + converted.sort(key=lambda x: x["score"], reverse=True) + + return converted + + def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: + """Convert rotation matrix to Euler angles (in radians).""" + # Check for gimbal lock + sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) + + singular = sy < 1e-6 + + if not singular: + x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + else: + x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = 0 + + return {"roll": x, "pitch": y, "yaw": z} + def cleanup(self): """Clean up resources.""" if hasattr(self.detector, "cleanup"): self.detector.cleanup() + + # Stop the grasp event loop + if self.grasp_loop and self.grasp_loop_thread: + self.grasp_loop.call_soon_threadsafe(self.grasp_loop.stop) + self.grasp_loop_thread.join(timeout=1.0) + if hasattr(self.pointcloud_filter, "cleanup"): self.pointcloud_filter.cleanup() logger.info("ManipulationPipeline cleaned up") - - -def create_manipulation_pipeline( - camera_intrinsics: List[float], - min_confidence: float = 0.6, - max_objects: int = 10, - vocabulary: Optional[str] = None, -) -> ManipulationPipeline: - """ - Factory function to create a ManipulationPipeline with sensible defaults. - - Args: - camera_intrinsics: [fx, fy, cx, cy] camera parameters - min_confidence: Minimum detection confidence threshold - max_objects: Maximum number of objects to process - vocabulary: Optional vocabulary for Detic detector - - Returns: - Configured ManipulationPipeline instance - """ - return ManipulationPipeline( - camera_intrinsics=camera_intrinsics, - min_confidence=min_confidence, - max_objects=max_objects, - vocabulary=vocabulary, - ) diff --git a/tests/test_manipulation_pipeline_visualization.py b/tests/test_manipulation_pipeline_visualization.py new file mode 100644 index 0000000000..a97ed473cd --- /dev/null +++ b/tests/test_manipulation_pipeline_visualization.py @@ -0,0 +1,187 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test manipulation pipeline with direct visualization and grasp data output.""" + +import os +import sys +import cv2 +import numpy as np +import time +import argparse +import matplotlib.pyplot as plt +import open3d as o3d +from typing import Dict, List +import threading +from reactivex import Observable, operators as ops +from reactivex.subject import Subject + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.perception.manip_aio_pipeline import ManipulationPipeline +from dimos.perception.grasp_generation.utils import visualize_grasps_3d +from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("test_pipeline_viz") + + +def load_first_frame(data_dir: str): + """Load first RGB-D frame and camera intrinsics.""" + # Load images + color_img = cv2.imread(os.path.join(data_dir, "color", "00000.png")) + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) + + depth_img = cv2.imread(os.path.join(data_dir, "depth", "00000.png"), cv2.IMREAD_ANYDEPTH) + if depth_img.dtype == np.uint16: + depth_img = depth_img.astype(np.float32) / 1000.0 + # Load intrinsics + camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, "color_camera_info.yaml")) + intrinsics = [ + camera_matrix[0, 0], + camera_matrix[1, 1], + camera_matrix[0, 2], + camera_matrix[1, 2], + ] + + return color_img, depth_img, intrinsics + + +def create_point_cloud(color_img, depth_img, intrinsics): + """Create Open3D point cloud.""" + fx, fy, cx, cy = intrinsics + height, width = depth_img.shape + + o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) + color_o3d = o3d.geometry.Image(color_img) + depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False + ) + + return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) + + +def run_pipeline(color_img, depth_img, intrinsics, wait_time=5.0): + """Run pipeline and collect results.""" + # Create pipeline + pipeline = ManipulationPipeline( + camera_intrinsics=intrinsics, + grasp_server_url="ws://10.0.0.125:8000/ws/grasp", + enable_grasp_generation=True, + ) + + # Create single-frame stream + subject = Subject() + streams = pipeline.create_streams(subject) + + # Debug: print available streams + print(f"Available streams: {list(streams.keys())}") + + # Collect results + results = {} + + def collect(key): + def on_next(value): + results[key] = value + logger.info(f"Received {key}") + + return on_next + + # Subscribe to streams + for key, stream in streams.items(): + if stream: + stream.pipe(ops.take(1)).subscribe(on_next=collect(key)) + + # Send frame + threading.Timer( + 0.5, + lambda: subject.on_next({"rgb": color_img, "depth": depth_img, "timestamp": time.time()}), + ).start() + + # Wait for results + time.sleep(wait_time) + + # If grasp generation is enabled, also check for latest grasps + if pipeline.latest_grasps: + results["grasps"] = pipeline.latest_grasps + logger.info(f"Retrieved latest grasps: {len(pipeline.latest_grasps)} grasps") + + pipeline.cleanup() + + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-dir", default="assets/rgbd_data") + parser.add_argument("--wait-time", type=float, default=5.0) + args = parser.parse_args() + + # Load data + color_img, depth_img, intrinsics = load_first_frame(args.data_dir) + logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") + + # Run pipeline + results = run_pipeline(color_img, depth_img, intrinsics, args.wait_time) + + # Debug: Print what we received + print(f"\n✅ Pipeline Results:") + print(f" Available streams: {list(results.keys())}") + + if "filtered_objects" in results and results["filtered_objects"]: + print(f" Objects detected: {len(results['filtered_objects'])}") + + # Print grasp summary + if "grasps" in results and results["grasps"]: + total_grasps = 0 + best_score = 0 + for grasp in results["grasps"]: + score = grasp.get("score", 0) + if score > best_score: + best_score = score + total_grasps += 1 + print(f" Grasps generated: {total_grasps} (best score: {best_score:.3f})") + else: + print(" Grasps: None generated") + + # Visualize 2D results + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + + if "detection_viz" in results and results["detection_viz"] is not None: + axes[0].imshow(results["detection_viz"]) + axes[0].set_title("Object Detection") + axes[0].axis("off") + + if "pointcloud_viz" in results and results["pointcloud_viz"] is not None: + axes[1].imshow(results["pointcloud_viz"]) + axes[1].set_title("Point Cloud Overlay") + axes[1].axis("off") + + plt.tight_layout() + plt.show() + + # 3D visualization with grasps + if "grasps" in results and results["grasps"]: + pcd = create_point_cloud(color_img, depth_img, intrinsics) + all_grasps = results["grasps"] + + if all_grasps: + logger.info(f"Visualizing {len(all_grasps)} grasps in 3D") + visualize_grasps_3d(pcd, all_grasps) + + +if __name__ == "__main__": + main() From c140b073968020226ca850409d44d853f513a4ce Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 5 Jun 2025 20:53:41 -0700 Subject: [PATCH 34/89] I'm so f**ing tired after this --- dimos/perception/manip_aio_pipeline.py | 315 ++++++++++++------ ...est_manipulation_pipeline_visualization.py | 187 ----------- 2 files changed, 221 insertions(+), 281 deletions(-) delete mode 100644 tests/test_manipulation_pipeline_visualization.py diff --git a/dimos/perception/manip_aio_pipeline.py b/dimos/perception/manip_aio_pipeline.py index 64f978bc78..22e3f5d49e 100644 --- a/dimos/perception/manip_aio_pipeline.py +++ b/dimos/perception/manip_aio_pipeline.py @@ -18,20 +18,24 @@ import asyncio import json +import logging import threading import time -from typing import Dict, List, Optional +import traceback +import websockets +from typing import Dict, List, Optional, Any import numpy as np import reactivex as rx import reactivex.operators as ops -import websockets from dimos.utils.logging_config import setup_logger from dimos.perception.detection2d.detic_2d_det import Detic2DDetector from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.grasp_generation.utils import draw_grasps_on_image from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization from dimos.perception.common.utils import colorize_depth from dimos.utils.logging_config import setup_logger +import cv2 logger = setup_logger("dimos.perception.manip_aio_pipeline") @@ -78,7 +82,9 @@ def __init__( # Storage for grasp results and filtered objects self.latest_grasps: List[dict] = [] # Simplified: just a list of grasps + self.grasps_consumed = False self.latest_filtered_objects = [] + self.latest_rgb_for_grasps = None # Store RGB image for grasp overlay self.grasp_lock = threading.Lock() # Track pending requests - simplified to single task @@ -87,6 +93,7 @@ def __init__( # Reactive subjects for streaming filtered objects and grasps self.filtered_objects_subject = rx.subject.Subject() self.grasps_subject = rx.subject.Subject() + self.grasp_overlay_subject = rx.subject.Subject() # Add grasp overlay subject # Initialize grasp client if enabled if self.enable_grasp_generation and self.grasp_server_url: @@ -180,36 +187,6 @@ def on_detection_next(result): self.latest_filtered_objects = filtered_objects self.filtered_objects_subject.on_next(filtered_objects) - # Request grasps if enabled - if self.enable_grasp_generation and filtered_objects: - logger.debug( - f"Requesting grasps for {len(filtered_objects)} filtered objects" - ) - task = self.request_scene_grasps(filtered_objects) - if task: - logger.debug( - "Grasp request task created, waiting for results..." - ) - - # Check for results after a delay - def check_grasps_later(): - logger.debug("Starting delayed grasp check...") - time.sleep(2.0) # Wait for grasp processing - grasps = self.get_latest_grasps() - if grasps: - logger.debug( - f"Found {len(grasps)} grasps in delayed check" - ) - self.grasps_subject.on_next(grasps) - logger.info(f"Received {len(grasps)} grasps for scene") - logger.debug(f"Grasps for scene: {grasps}") - else: - logger.debug("No grasps found in delayed check") - - threading.Thread(target=check_grasps_later, daemon=True).start() - else: - logger.debug("Failed to create grasp request task") - # Create base image (colorized depth) base_image = colorize_depth(depth, max_depth=10.0) @@ -223,11 +200,59 @@ def check_grasps_later(): # Store the overlay for the stream with frame_lock: latest_point_cloud_overlay = overlay_viz - else: - # No filtered objects, clear overlay - with frame_lock: - latest_point_cloud_overlay = None + # Request grasps if enabled + if self.enable_grasp_generation and len(filtered_objects) > 0: + # Save RGB image for later grasp overlay + with frame_lock: + self.latest_rgb_for_grasps = rgb.copy() + + task = self.request_scene_grasps(filtered_objects) + if task: + # Check for results after a delay + def check_grasps_later(): + time.sleep(2.0) # Wait for grasp processing + # Wait for task to complete + if hasattr(self, "grasp_task") and self.grasp_task: + try: + result = self.grasp_task.result( + timeout=3.0 + ) # Get result with timeout + except Exception as e: + logger.warning(f"Grasp task failed or timeout: {e}") + + # Try to get latest grasps and create overlay + with self.grasp_lock: + grasps = self.latest_grasps + + if grasps and hasattr(self, "latest_rgb_for_grasps"): + # Create grasp overlay on the saved RGB image + try: + bgr_image = cv2.cvtColor( + self.latest_rgb_for_grasps, cv2.COLOR_RGB2BGR + ) + result_bgr = draw_grasps_on_image( + bgr_image, + grasps, + self.camera_intrinsics, + max_grasps=-1, # Show all grasps + ) + result_rgb = cv2.cvtColor( + result_bgr, cv2.COLOR_BGR2RGB + ) + + # Emit grasp overlay immediately + self.grasp_overlay_subject.on_next(result_rgb) + + except Exception as e: + logger.error(f"Error creating grasp overlay: {e}") + + # Emit grasps to stream + self.grasps_subject.on_next(grasps) + + threading.Thread(target=check_grasps_later, daemon=True).start() + else: + logger.warning("Failed to create grasp task") except Exception as e: logger.error(f"Error in point cloud filtering: {e}") with frame_lock: @@ -266,12 +291,16 @@ def start_subscriptions(): # Create grasps stream grasps_stream = self.grasps_subject + # Create grasp overlay subject for immediate emission + grasp_overlay_stream = self.grasp_overlay_subject + return { "detection_viz": viz_stream, "pointcloud_viz": depth_stream, "objects": object_detector.get_stream().pipe(ops.map(lambda x: x.get("objects", []))), "filtered_objects": filtered_objects_stream, "grasps": grasps_stream, + "grasp_overlay": grasp_overlay_stream, } def _start_grasp_loop(self): @@ -293,100 +322,203 @@ async def _send_grasp_request( self, points: np.ndarray, colors: Optional[np.ndarray] ) -> Optional[List[dict]]: """Send grasp request to AnyGrasp server.""" - logger.debug(f"_send_grasp_request called with {len(points)} points") - try: - logger.debug(f"Connecting to WebSocket: {self.grasp_server_url}") - async with websockets.connect(self.grasp_server_url) as websocket: - logger.debug("WebSocket connected successfully") + # Comprehensive client-side validation to prevent server errors - # Use the correct format expected by AnyGrasp server + # Validate points array + if points is None: + logger.error("Points array is None") + return None + if not isinstance(points, np.ndarray): + logger.error(f"Points is not numpy array: {type(points)}") + return None + if points.size == 0: + logger.error("Points array is empty") + return None + if len(points.shape) != 2 or points.shape[1] != 3: + logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)") + return None + if points.shape[0] < 100: # Minimum points for stable grasp detection + logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100") + return None + + # Validate and prepare colors + if colors is not None: + if not isinstance(colors, np.ndarray): + colors = None + elif colors.size == 0: + colors = None + elif len(colors.shape) != 2 or colors.shape[1] != 3: + colors = None + elif colors.shape[0] != points.shape[0]: + colors = None + + # If no valid colors, create default colors (required by server) + if colors is None: + # Create default white colors for all points + colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 + + # Ensure data types are correct (server expects float32) + points = points.astype(np.float32) + colors = colors.astype(np.float32) + + # Validate ranges (basic sanity checks) + if np.any(np.isnan(points)) or np.any(np.isinf(points)): + logger.error("Points contain NaN or Inf values") + return None + if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): + logger.error("Colors contain NaN or Inf values") + return None + + # Clamp color values to valid range [0, 1] + colors = np.clip(colors, 0.0, 1.0) + + async with websockets.connect(self.grasp_server_url) as websocket: request = { "points": points.tolist(), - "colors": colors.tolist() if colors is not None else None, + "colors": colors.tolist(), # Always send colors array "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], # Default workspace limits } - logger.debug(f"Sending grasp request with {len(points)} points") await websocket.send(json.dumps(request)) - logger.debug("Waiting for response...") response = await websocket.recv() - logger.debug(f"Received response: {len(response)} characters") - - # Parse response - server returns list of grasps directly grasps = json.loads(response) - logger.debug(f"Received {len(grasps) if grasps else 0} grasps from server") - - if grasps and len(grasps) > 0: - # Convert to our format and store - converted_grasps = self._convert_grasp_format(grasps) - logger.debug(f"Converted to {len(converted_grasps)} grasps") - - with self.grasp_lock: - self.latest_grasps = converted_grasps - logger.debug(f"Stored {len(converted_grasps)} grasps") - return converted_grasps - else: - logger.warning("No grasps returned from server") + # Handle server response validation + if isinstance(grasps, dict) and "error" in grasps: + logger.error(f"Server returned error: {grasps['error']}") + return None + elif isinstance(grasps, (int, float)) and grasps == 0: + return None + elif not isinstance(grasps, list): + logger.error( + f"Server returned unexpected response type: {type(grasps)}, value: {grasps}" + ) + return None + elif len(grasps) == 0: + return None + + converted_grasps = self._convert_grasp_format(grasps) + with self.grasp_lock: + self.latest_grasps = converted_grasps + self.grasps_consumed = False # Reset consumed flag + + # Emit to reactive stream + self.grasps_subject.on_next(self.latest_grasps) + + return converted_grasps + except websockets.exceptions.ConnectionClosed as e: + logger.error(f"WebSocket connection closed: {e}") + except websockets.exceptions.WebSocketException as e: + logger.error(f"WebSocket error: {e}") + except json.JSONDecodeError as e: + logger.error(f"Failed to parse server response as JSON: {e}") except Exception as e: logger.error(f"Error requesting grasps: {e}") - logger.debug(f"Error details: {e}") return None def request_scene_grasps(self, objects: List[dict]) -> Optional[asyncio.Task]: """Request grasps for entire scene by combining all object point clouds.""" - logger.debug(f"request_scene_grasps called with {len(objects)} objects") - if not self.grasp_loop or not objects: - logger.debug( - f"Cannot request grasps: grasp_loop={self.grasp_loop is not None}, objects={len(objects) if objects else 0}" - ) return None - # Combine all object point clouds all_points = [] all_colors = [] - - for obj in objects: - if "point_cloud_numpy" in obj and len(obj["point_cloud_numpy"]) > 0: - all_points.append(obj["point_cloud_numpy"]) - if "colors_numpy" in obj and obj["colors_numpy"] is not None: - all_colors.append(obj["colors_numpy"]) - logger.debug(f"Added object with {len(obj['point_cloud_numpy'])} points") + valid_objects = 0 + + for i, obj in enumerate(objects): + # Validate point cloud data + if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: + continue + + points = obj["point_cloud_numpy"] + if not isinstance(points, np.ndarray) or points.size == 0: + continue + + # Ensure points have correct shape (N, 3) + if len(points.shape) != 2 or points.shape[1] != 3: + continue + + # Validate colors if present + colors = None + if "colors_numpy" in obj and obj["colors_numpy"] is not None: + colors = obj["colors_numpy"] + if isinstance(colors, np.ndarray) and colors.size > 0: + # Ensure colors match points count and have correct shape + if colors.shape[0] != points.shape[0]: + colors = None # Ignore colors for this object + elif len(colors.shape) != 2 or colors.shape[1] != 3: + colors = None # Ignore colors for this object + + all_points.append(points) + if colors is not None: + all_colors.append(colors) + valid_objects += 1 if not all_points: - logger.debug("No points found in objects, cannot request grasps") return None - # Concatenate all points and colors - combined_points = np.vstack(all_points) - combined_colors = np.vstack(all_colors) if all_colors else None + try: + combined_points = np.vstack(all_points) - logger.debug( - f"Requesting scene grasps for combined point cloud with {len(combined_points)} points" - ) - logger.debug(f"Grasp server URL: {self.grasp_server_url}") + # Only combine colors if ALL objects have valid colors + combined_colors = None + if len(all_colors) == valid_objects and len(all_colors) > 0: + combined_colors = np.vstack(all_colors) + + # Validate final combined data + if combined_points.size == 0: + logger.warning("Combined point cloud is empty") + return None + + if combined_colors is not None and combined_colors.shape[0] != combined_points.shape[0]: + logger.warning( + f"Color/point count mismatch: {combined_colors.shape[0]} colors vs {combined_points.shape[0]} points, dropping colors" + ) + combined_colors = None + + except Exception as e: + logger.error(f"Failed to combine point clouds: {e}") + return None - # Create and schedule the task try: + # Check if there's already a grasp task running + if hasattr(self, "grasp_task") and self.grasp_task and not self.grasp_task.done(): + return self.grasp_task + task = asyncio.run_coroutine_threadsafe( self._send_grasp_request(combined_points, combined_colors), self.grasp_loop ) self.grasp_task = task - logger.debug("Successfully created grasp request task") return task except Exception as e: - logger.error(f"Failed to create grasp request task: {e}") + logger.warning("Failed to create grasp task") return None - def get_latest_grasps(self) -> Optional[List[dict]]: - """Get latest grasp results.""" + def get_latest_grasps(self, timeout: float = 5.0) -> Optional[List[dict]]: + """Get latest grasp results, waiting for new ones if current ones have been consumed.""" + # Mark current grasps as consumed and get a reference with self.grasp_lock: - return self.latest_grasps + current_grasps = self.latest_grasps + self.grasps_consumed = True + + # If we already have grasps and they haven't been consumed, return them + if current_grasps is not None and not getattr(self, "grasps_consumed", False): + return current_grasps + + # Wait for new grasps + start_time = time.time() + while time.time() - start_time < timeout: + with self.grasp_lock: + # Check if we have new grasps (different from what we marked as consumed) + if self.latest_grasps is not None and not getattr(self, "grasps_consumed", False): + return self.latest_grasps + time.sleep(0.1) # Check every 100ms + + return None # Timeout reached def clear_grasps(self) -> None: """Clear all stored grasp results.""" @@ -398,7 +530,6 @@ def _prepare_colors(self, colors: Optional[np.ndarray]) -> Optional[np.ndarray]: if colors is None: return None - # Convert from 0-255 to 0-1 range if needed if colors.max() > 1.0: colors = colors / 255.0 @@ -409,7 +540,6 @@ def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: converted = [] for i, grasp in enumerate(anygrasp_grasps): - # Extract rotation matrix and convert to Euler angles rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) euler_angles = self._rotation_matrix_to_euler(rotation_matrix) @@ -425,14 +555,12 @@ def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: } converted.append(converted_grasp) - # Sort by score descending converted.sort(key=lambda x: x["score"], reverse=True) return converted def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: """Convert rotation matrix to Euler angles (in radians).""" - # Check for gimbal lock sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) singular = sy < 1e-6 @@ -453,7 +581,6 @@ def cleanup(self): if hasattr(self.detector, "cleanup"): self.detector.cleanup() - # Stop the grasp event loop if self.grasp_loop and self.grasp_loop_thread: self.grasp_loop.call_soon_threadsafe(self.grasp_loop.stop) self.grasp_loop_thread.join(timeout=1.0) diff --git a/tests/test_manipulation_pipeline_visualization.py b/tests/test_manipulation_pipeline_visualization.py deleted file mode 100644 index a97ed473cd..0000000000 --- a/tests/test_manipulation_pipeline_visualization.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test manipulation pipeline with direct visualization and grasp data output.""" - -import os -import sys -import cv2 -import numpy as np -import time -import argparse -import matplotlib.pyplot as plt -import open3d as o3d -from typing import Dict, List -import threading -from reactivex import Observable, operators as ops -from reactivex.subject import Subject - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from dimos.perception.manip_aio_pipeline import ManipulationPipeline -from dimos.perception.grasp_generation.utils import visualize_grasps_3d -from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("test_pipeline_viz") - - -def load_first_frame(data_dir: str): - """Load first RGB-D frame and camera intrinsics.""" - # Load images - color_img = cv2.imread(os.path.join(data_dir, "color", "00000.png")) - color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) - - depth_img = cv2.imread(os.path.join(data_dir, "depth", "00000.png"), cv2.IMREAD_ANYDEPTH) - if depth_img.dtype == np.uint16: - depth_img = depth_img.astype(np.float32) / 1000.0 - # Load intrinsics - camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, "color_camera_info.yaml")) - intrinsics = [ - camera_matrix[0, 0], - camera_matrix[1, 1], - camera_matrix[0, 2], - camera_matrix[1, 2], - ] - - return color_img, depth_img, intrinsics - - -def create_point_cloud(color_img, depth_img, intrinsics): - """Create Open3D point cloud.""" - fx, fy, cx, cy = intrinsics - height, width = depth_img.shape - - o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) - color_o3d = o3d.geometry.Image(color_img) - depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) - - rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( - color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False - ) - - return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) - - -def run_pipeline(color_img, depth_img, intrinsics, wait_time=5.0): - """Run pipeline and collect results.""" - # Create pipeline - pipeline = ManipulationPipeline( - camera_intrinsics=intrinsics, - grasp_server_url="ws://10.0.0.125:8000/ws/grasp", - enable_grasp_generation=True, - ) - - # Create single-frame stream - subject = Subject() - streams = pipeline.create_streams(subject) - - # Debug: print available streams - print(f"Available streams: {list(streams.keys())}") - - # Collect results - results = {} - - def collect(key): - def on_next(value): - results[key] = value - logger.info(f"Received {key}") - - return on_next - - # Subscribe to streams - for key, stream in streams.items(): - if stream: - stream.pipe(ops.take(1)).subscribe(on_next=collect(key)) - - # Send frame - threading.Timer( - 0.5, - lambda: subject.on_next({"rgb": color_img, "depth": depth_img, "timestamp": time.time()}), - ).start() - - # Wait for results - time.sleep(wait_time) - - # If grasp generation is enabled, also check for latest grasps - if pipeline.latest_grasps: - results["grasps"] = pipeline.latest_grasps - logger.info(f"Retrieved latest grasps: {len(pipeline.latest_grasps)} grasps") - - pipeline.cleanup() - - return results - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--data-dir", default="assets/rgbd_data") - parser.add_argument("--wait-time", type=float, default=5.0) - args = parser.parse_args() - - # Load data - color_img, depth_img, intrinsics = load_first_frame(args.data_dir) - logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") - - # Run pipeline - results = run_pipeline(color_img, depth_img, intrinsics, args.wait_time) - - # Debug: Print what we received - print(f"\n✅ Pipeline Results:") - print(f" Available streams: {list(results.keys())}") - - if "filtered_objects" in results and results["filtered_objects"]: - print(f" Objects detected: {len(results['filtered_objects'])}") - - # Print grasp summary - if "grasps" in results and results["grasps"]: - total_grasps = 0 - best_score = 0 - for grasp in results["grasps"]: - score = grasp.get("score", 0) - if score > best_score: - best_score = score - total_grasps += 1 - print(f" Grasps generated: {total_grasps} (best score: {best_score:.3f})") - else: - print(" Grasps: None generated") - - # Visualize 2D results - fig, axes = plt.subplots(1, 2, figsize=(12, 6)) - - if "detection_viz" in results and results["detection_viz"] is not None: - axes[0].imshow(results["detection_viz"]) - axes[0].set_title("Object Detection") - axes[0].axis("off") - - if "pointcloud_viz" in results and results["pointcloud_viz"] is not None: - axes[1].imshow(results["pointcloud_viz"]) - axes[1].set_title("Point Cloud Overlay") - axes[1].axis("off") - - plt.tight_layout() - plt.show() - - # 3D visualization with grasps - if "grasps" in results and results["grasps"]: - pcd = create_point_cloud(color_img, depth_img, intrinsics) - all_grasps = results["grasps"] - - if all_grasps: - logger.info(f"Visualizing {len(all_grasps)} grasps in 3D") - visualize_grasps_3d(pcd, all_grasps) - - -if __name__ == "__main__": - main() From 474e442e8b049b256edebef05312ee70f692e5e2 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 17 Jun 2025 01:07:54 -0700 Subject: [PATCH 35/89] added SAM2 support for segmentation, added manipulation perception processer with any streaming --- dimos/perception/manip_aio_processer.py | 553 ++++++++++++++++++++++++ 1 file changed, 553 insertions(+) create mode 100644 dimos/perception/manip_aio_processer.py diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py new file mode 100644 index 0000000000..0d39cb39e1 --- /dev/null +++ b/dimos/perception/manip_aio_processer.py @@ -0,0 +1,553 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sequential manipulation processor for single-frame processing without reactive streams. +""" + +import json +import logging +import time +import asyncio +import websockets +from typing import Dict, List, Optional, Any, Tuple +import numpy as np +import cv2 + +from dimos.utils.logging_config import setup_logger +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.perception.grasp_generation.utils import draw_grasps_on_image +from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization +from dimos.perception.common.utils import colorize_depth, detection_results_to_object_data + +logger = setup_logger("dimos.perception.manip_aio_processor") + + +class ManipulationProcessor: + """ + Sequential manipulation processor for single-frame processing. + + Processes RGB-D frames through object detection, point cloud filtering, + and optional grasp generation in a single thread without reactive streams. + """ + + def __init__( + self, + camera_intrinsics: List[float], # [fx, fy, cx, cy] + min_confidence: float = 0.6, + max_objects: int = 20, + vocabulary: Optional[str] = None, + grasp_server_url: Optional[str] = None, + enable_grasp_generation: bool = False, + enable_segmentation: bool = True, + segmentation_model: str = "sam2_b.pt", + ): + """ + Initialize the manipulation processor. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + max_objects: Maximum number of objects to process + vocabulary: Optional vocabulary for Detic detector + grasp_server_url: Optional WebSocket URL for AnyGrasp server + enable_grasp_generation: Whether to enable grasp generation + enable_segmentation: Whether to enable semantic segmentation + segmentation_model: Segmentation model to use (SAM 2 or FastSAM) + """ + self.camera_intrinsics = camera_intrinsics + self.min_confidence = min_confidence + self.max_objects = max_objects + self.grasp_server_url = grasp_server_url + self.enable_grasp_generation = enable_grasp_generation + self.enable_segmentation = enable_segmentation + + # Initialize object detector + self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) + + # Initialize point cloud processor + self.pointcloud_filter = PointcloudFiltering( + color_intrinsics=camera_intrinsics, + depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics + max_num_objects=max_objects, + ) + + # Initialize semantic segmentation + self.segmenter = None + if self.enable_segmentation: + self.segmenter = Sam2DSegmenter( + model_path=segmentation_model, + device="cuda", + use_tracker=False, # Disable tracker for simple segmentation + use_analyzer=False, # Disable analyzer for simple segmentation + model_type="auto", # Auto-detect model type + ) + + logger.info(f"Initialized ManipulationProcessor with confidence={min_confidence}") + + def process_frame( + self, + rgb_image: np.ndarray, + depth_image: np.ndarray, + generate_grasps: bool = None + ) -> Dict[str, Any]: + """ + Process a single RGB-D frame through the complete pipeline. + + Args: + rgb_image: RGB image (H, W, 3) + depth_image: Depth image (H, W) in meters + generate_grasps: Override grasp generation setting for this frame + + Returns: + Dictionary containing: + - detection_viz: Visualization of object detection + - pointcloud_viz: Visualization of point cloud overlay + - segmentation_viz: Visualization of semantic segmentation (if enabled) + - detection2d_objects: Raw detection results as ObjectData + - segmentation2d_objects: Raw segmentation results as ObjectData (if enabled) + - detected_objects: Detection (Object Detection) objects with point clouds filtered + - all_objects: All objects (including misc objects) (SAM segmentation) with point clouds filtered + - full_pointcloud: Complete scene point cloud (if point cloud processing enabled) + - grasps: Grasp results (if enabled) + - grasp_overlay: Grasp visualization (if enabled) + - processing_time: Total processing time + """ + start_time = time.time() + results = {} + + try: + # Step 1: Object Detection + step_start = time.time() + logger.debug("Running object detection...") + detection_results = self._run_object_detection(rgb_image) + + results['detection2d_objects'] = detection_results.get('objects', []) + results['detection_viz'] = detection_results.get('viz_frame') + detection_time = time.time() - step_start + + # Step 2: Semantic Segmentation (if enabled) + segmentation_time = 0 + segmentation_results = None + if self.enable_segmentation: + step_start = time.time() + logger.debug("Running semantic segmentation...") + segmentation_results = self._run_segmentation(rgb_image) + results['segmentation2d_objects'] = segmentation_results.get('objects', []) + results['segmentation_viz'] = segmentation_results.get('viz_frame') + segmentation_time = time.time() - step_start + + # Step 3: Point Cloud Processing + pointcloud_time = 0 + detection2d_objects = results.get('detection2d_objects', []) + segmentation2d_objects = results.get('segmentation2d_objects', []) + + # Process detection objects if available + detected_objects = [] + if detection2d_objects: + step_start = time.time() + logger.debug(f"Processing {len(detection2d_objects)} detection2d_objects...") + detected_objects = self._run_pointcloud_filtering( + rgb_image, depth_image, detection2d_objects + ) + pointcloud_time += time.time() - step_start + + # Process segmentation objects if available + segmentation_filtered_objects = [] + if segmentation2d_objects: + step_start = time.time() + logger.debug(f"Processing {len(segmentation2d_objects)} segmentation objects...") + segmentation_filtered_objects = self._run_pointcloud_filtering( + rgb_image, depth_image, segmentation2d_objects + ) + pointcloud_time += time.time() - step_start + + # Combine all objects + all_objects = segmentation_filtered_objects + + # Get full point cloud + full_pcd = self.pointcloud_filter.get_full_point_cloud() + + results['detected_objects'] = detected_objects + results['all_objects'] = all_objects + results['full_pointcloud'] = full_pcd + + # Create point cloud visualizations + base_image = colorize_depth(depth_image, max_depth=10.0) + + # Main pointcloud visualization (all objects) + if all_objects: + results['pointcloud_viz'] = create_point_cloud_overlay_visualization( + base_image=base_image, + objects=all_objects, + intrinsics=self.camera_intrinsics, + ) + else: + results['pointcloud_viz'] = base_image + + # Detection objects pointcloud visualization + if detected_objects: + results['detected_pointcloud_viz'] = create_point_cloud_overlay_visualization( + base_image=base_image, + objects=detected_objects, + intrinsics=self.camera_intrinsics, + ) + else: + results['detected_pointcloud_viz'] = base_image + + # Step 4: Grasp Generation (if enabled) + should_generate_grasps = ( + generate_grasps if generate_grasps is not None + else self.enable_grasp_generation + ) + + if should_generate_grasps and all_objects: + logger.debug("Generating grasps...") + grasps = self._run_grasp_generation(all_objects) + results['grasps'] = grasps + + # Create grasp overlay + if grasps: + results['grasp_overlay'] = self._create_grasp_overlay(rgb_image, grasps) + + # Ensure segmentation runs even if no objects detected + if self.enable_segmentation and 'segmentation_viz' not in results: + logger.debug("Running semantic segmentation (no objects detected)...") + segmentation_results = self._run_segmentation(rgb_image) + results['segmentation2d_objects'] = segmentation_results.get('objects', []) + results['segmentation_viz'] = segmentation_results.get('viz_frame') + + except Exception as e: + logger.error(f"Error processing frame: {e}") + results['error'] = str(e) + + # Add timing information + total_time = time.time() - start_time + results['processing_time'] = total_time + results['timing_breakdown'] = { + 'detection': detection_time if 'detection_time' in locals() else 0, + 'segmentation': segmentation_time if 'segmentation_time' in locals() else 0, + 'pointcloud': pointcloud_time if 'pointcloud_time' in locals() else 0, + 'total': total_time + } + logger.debug(f"Frame processing completed in {total_time:.3f}s") + logger.debug(f"Timing breakdown: detection={detection_time:.3f}s, segmentation={segmentation_time:.3f}s, pointcloud={pointcloud_time:.3f}s") + + return results + + def _run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: + """Run object detection on RGB image.""" + try: + # Convert RGB to BGR for Detic detector + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Use process_image method from Detic detector + bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image(bgr_image) + + # Convert to ObjectData format using utility function + objects = detection_results_to_object_data( + bboxes=bboxes, + track_ids=track_ids, + class_ids=class_ids, + confidences=confidences, + names=names, + masks=masks, + source="detection" + ) + + # Create visualization using detector's built-in method + viz_frame = self.detector.visualize_results( + rgb_image, bboxes, track_ids, class_ids, confidences, names + ) + + return { + 'objects': objects, + 'viz_frame': viz_frame + } + + except Exception as e: + logger.error(f"Object detection failed: {e}") + return {'objects': [], 'viz_frame': rgb_image.copy()} + + def _run_pointcloud_filtering( + self, + rgb_image: np.ndarray, + depth_image: np.ndarray, + objects: List[Dict] + ) -> List[Dict]: + """Run point cloud filtering on detected objects.""" + try: + filtered_objects = self.pointcloud_filter.process_images( + rgb_image, depth_image, objects + ) + return filtered_objects if filtered_objects else [] + except Exception as e: + logger.error(f"Point cloud filtering failed: {e}") + return [] + + def _run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: + """Run semantic segmentation on RGB image.""" + if not self.segmenter: + return {'objects': [], 'viz_frame': rgb_image.copy()} + + try: + # Convert RGB to BGR for segmenter + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Get segmentation results + masks, bboxes, track_ids, probs, names = self.segmenter.process_image(bgr_image) + + # Convert to ObjectData format using utility function + objects = detection_results_to_object_data( + bboxes=bboxes, + track_ids=track_ids, + class_ids=list(range(len(bboxes))), # Use indices as class IDs for segmentation + confidences=probs, + names=names, + masks=masks, + source="segmentation" + ) + + # Create visualization + if masks: + viz_bgr = self.segmenter.visualize_results( + bgr_image, masks, bboxes, track_ids, probs, names + ) + # Convert back to RGB + viz_frame = cv2.cvtColor(viz_bgr, cv2.COLOR_BGR2RGB) + else: + viz_frame = rgb_image.copy() + + return { + 'objects': objects, + 'viz_frame': viz_frame + } + + except Exception as e: + logger.error(f"Segmentation failed: {e}") + return {'objects': [], 'viz_frame': rgb_image.copy()} + + def _run_grasp_generation(self, filtered_objects: List[Dict]) -> Optional[List[Dict]]: + """Run grasp generation on filtered objects.""" + if not self.grasp_server_url: + logger.warning("Grasp generation requested but no server URL provided") + return None + + try: + # Combine all point clouds + all_points = [] + all_colors = [] + valid_objects = 0 + + for obj in filtered_objects: + if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: + continue + + points = obj["point_cloud_numpy"] + if not isinstance(points, np.ndarray) or points.size == 0: + continue + + if len(points.shape) != 2 or points.shape[1] != 3: + continue + + colors = None + if "colors_numpy" in obj and obj["colors_numpy"] is not None: + colors = obj["colors_numpy"] + if isinstance(colors, np.ndarray) and colors.size > 0: + if colors.shape[0] != points.shape[0] or len(colors.shape) != 2 or colors.shape[1] != 3: + colors = None + + all_points.append(points) + if colors is not None: + all_colors.append(colors) + valid_objects += 1 + + if not all_points: + return None + + # Combine point clouds + combined_points = np.vstack(all_points) + combined_colors = None + if len(all_colors) == valid_objects and len(all_colors) > 0: + combined_colors = np.vstack(all_colors) + + # Send grasp request synchronously + return self._send_grasp_request_sync(combined_points, combined_colors) + + except Exception as e: + logger.error(f"Grasp generation failed: {e}") + return None + + def _send_grasp_request_sync( + self, + points: np.ndarray, + colors: Optional[np.ndarray] + ) -> Optional[List[Dict]]: + """Send synchronous grasp request to AnyGrasp server.""" + try: + # Validation (same as async version) + if points is None or not isinstance(points, np.ndarray) or points.size == 0: + logger.error("Invalid points array") + return None + + if len(points.shape) != 2 or points.shape[1] != 3: + logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)") + return None + + if points.shape[0] < 100: + logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100") + return None + + # Prepare colors + if colors is not None: + if not isinstance(colors, np.ndarray) or colors.size == 0: + colors = None + elif len(colors.shape) != 2 or colors.shape[1] != 3: + colors = None + elif colors.shape[0] != points.shape[0]: + colors = None + + if colors is None: + colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 + + # Ensure correct data types + points = points.astype(np.float32) + colors = colors.astype(np.float32) + + # Validate ranges + if np.any(np.isnan(points)) or np.any(np.isinf(points)): + logger.error("Points contain NaN or Inf values") + return None + if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): + logger.error("Colors contain NaN or Inf values") + return None + + colors = np.clip(colors, 0.0, 1.0) + + # Run async request in sync context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete( + self._async_grasp_request(points, colors) + ) + return result + finally: + loop.close() + + except Exception as e: + logger.error(f"Error in synchronous grasp request: {e}") + return None + + async def _async_grasp_request( + self, + points: np.ndarray, + colors: np.ndarray + ) -> Optional[List[Dict]]: + """Async grasp request helper.""" + try: + async with websockets.connect(self.grasp_server_url) as websocket: + request = { + "points": points.tolist(), + "colors": colors.tolist(), + "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], + } + + await websocket.send(json.dumps(request)) + response = await websocket.recv() + grasps = json.loads(response) + + if isinstance(grasps, dict) and "error" in grasps: + logger.error(f"Server returned error: {grasps['error']}") + return None + elif isinstance(grasps, (int, float)) and grasps == 0: + return None + elif not isinstance(grasps, list): + logger.error(f"Server returned unexpected response type: {type(grasps)}") + return None + elif len(grasps) == 0: + return None + + return self._convert_grasp_format(grasps) + + except Exception as e: + logger.error(f"Async grasp request failed: {e}") + return None + + def _create_grasp_overlay(self, rgb_image: np.ndarray, grasps: List[Dict]) -> np.ndarray: + """Create grasp visualization overlay on RGB image.""" + try: + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + result_bgr = draw_grasps_on_image( + bgr_image, + grasps, + self.camera_intrinsics, + max_grasps=-1, # Show all grasps + ) + return cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) + except Exception as e: + logger.error(f"Error creating grasp overlay: {e}") + return rgb_image.copy() + + def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: + """Convert AnyGrasp format to visualization format.""" + converted = [] + + for i, grasp in enumerate(anygrasp_grasps): + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) + euler_angles = self._rotation_matrix_to_euler(rotation_matrix) + + converted_grasp = { + "id": f"grasp_{i}", + "score": grasp.get("score", 0.0), + "width": grasp.get("width", 0.0), + "height": grasp.get("height", 0.0), + "depth": grasp.get("depth", 0.0), + "translation": grasp.get("translation", [0, 0, 0]), + "rotation_matrix": rotation_matrix.tolist(), + "euler_angles": euler_angles, + } + converted.append(converted_grasp) + + converted.sort(key=lambda x: x["score"], reverse=True) + return converted + + def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: + """Convert rotation matrix to Euler angles (in radians).""" + sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) + + singular = sy < 1e-6 + + if not singular: + x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + else: + x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = 0 + + return {"roll": x, "pitch": y, "yaw": z} + + def cleanup(self): + """Clean up resources.""" + if hasattr(self.detector, "cleanup"): + self.detector.cleanup() + if hasattr(self.pointcloud_filter, "cleanup"): + self.pointcloud_filter.cleanup() + if self.segmenter and hasattr(self.segmenter, "cleanup"): + self.segmenter.cleanup() + logger.info("ManipulationProcessor cleaned up") From 298192684c4c88a7c7cadaa3d841ddc9ad3986f1 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Tue, 17 Jun 2025 08:08:48 +0000 Subject: [PATCH 36/89] CI code cleanup --- dimos/perception/manip_aio_processer.py | 155 +++++++++++------------- 1 file changed, 72 insertions(+), 83 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index 0d39cb39e1..eb51643255 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -39,7 +39,7 @@ class ManipulationProcessor: """ Sequential manipulation processor for single-frame processing. - + Processes RGB-D frames through object detection, point cloud filtering, and optional grasp generation in a single thread without reactive streams. """ @@ -99,10 +99,7 @@ def __init__( logger.info(f"Initialized ManipulationProcessor with confidence={min_confidence}") def process_frame( - self, - rgb_image: np.ndarray, - depth_image: np.ndarray, - generate_grasps: bool = None + self, rgb_image: np.ndarray, depth_image: np.ndarray, generate_grasps: bool = None ) -> Dict[str, Any]: """ Process a single RGB-D frame through the complete pipeline. @@ -134,9 +131,9 @@ def process_frame( step_start = time.time() logger.debug("Running object detection...") detection_results = self._run_object_detection(rgb_image) - - results['detection2d_objects'] = detection_results.get('objects', []) - results['detection_viz'] = detection_results.get('viz_frame') + + results["detection2d_objects"] = detection_results.get("objects", []) + results["detection_viz"] = detection_results.get("viz_frame") detection_time = time.time() - step_start # Step 2: Semantic Segmentation (if enabled) @@ -146,15 +143,15 @@ def process_frame( step_start = time.time() logger.debug("Running semantic segmentation...") segmentation_results = self._run_segmentation(rgb_image) - results['segmentation2d_objects'] = segmentation_results.get('objects', []) - results['segmentation_viz'] = segmentation_results.get('viz_frame') + results["segmentation2d_objects"] = segmentation_results.get("objects", []) + results["segmentation_viz"] = segmentation_results.get("viz_frame") segmentation_time = time.time() - step_start # Step 3: Point Cloud Processing pointcloud_time = 0 - detection2d_objects = results.get('detection2d_objects', []) - segmentation2d_objects = results.get('segmentation2d_objects', []) - + detection2d_objects = results.get("detection2d_objects", []) + segmentation2d_objects = results.get("segmentation2d_objects", []) + # Process detection objects if available detected_objects = [] if detection2d_objects: @@ -164,7 +161,7 @@ def process_frame( rgb_image, depth_image, detection2d_objects ) pointcloud_time += time.time() - step_start - + # Process segmentation objects if available segmentation_filtered_objects = [] if segmentation2d_objects: @@ -174,77 +171,78 @@ def process_frame( rgb_image, depth_image, segmentation2d_objects ) pointcloud_time += time.time() - step_start - + # Combine all objects all_objects = segmentation_filtered_objects - + # Get full point cloud full_pcd = self.pointcloud_filter.get_full_point_cloud() - - results['detected_objects'] = detected_objects - results['all_objects'] = all_objects - results['full_pointcloud'] = full_pcd + + results["detected_objects"] = detected_objects + results["all_objects"] = all_objects + results["full_pointcloud"] = full_pcd # Create point cloud visualizations base_image = colorize_depth(depth_image, max_depth=10.0) - + # Main pointcloud visualization (all objects) if all_objects: - results['pointcloud_viz'] = create_point_cloud_overlay_visualization( + results["pointcloud_viz"] = create_point_cloud_overlay_visualization( base_image=base_image, objects=all_objects, intrinsics=self.camera_intrinsics, ) else: - results['pointcloud_viz'] = base_image - + results["pointcloud_viz"] = base_image + # Detection objects pointcloud visualization if detected_objects: - results['detected_pointcloud_viz'] = create_point_cloud_overlay_visualization( + results["detected_pointcloud_viz"] = create_point_cloud_overlay_visualization( base_image=base_image, objects=detected_objects, intrinsics=self.camera_intrinsics, ) else: - results['detected_pointcloud_viz'] = base_image + results["detected_pointcloud_viz"] = base_image # Step 4: Grasp Generation (if enabled) should_generate_grasps = ( - generate_grasps if generate_grasps is not None - else self.enable_grasp_generation + generate_grasps if generate_grasps is not None else self.enable_grasp_generation ) - + if should_generate_grasps and all_objects: logger.debug("Generating grasps...") grasps = self._run_grasp_generation(all_objects) - results['grasps'] = grasps + results["grasps"] = grasps # Create grasp overlay if grasps: - results['grasp_overlay'] = self._create_grasp_overlay(rgb_image, grasps) + results["grasp_overlay"] = self._create_grasp_overlay(rgb_image, grasps) # Ensure segmentation runs even if no objects detected - if self.enable_segmentation and 'segmentation_viz' not in results: + if self.enable_segmentation and "segmentation_viz" not in results: logger.debug("Running semantic segmentation (no objects detected)...") segmentation_results = self._run_segmentation(rgb_image) - results['segmentation2d_objects'] = segmentation_results.get('objects', []) - results['segmentation_viz'] = segmentation_results.get('viz_frame') + results["segmentation2d_objects"] = segmentation_results.get("objects", []) + results["segmentation_viz"] = segmentation_results.get("viz_frame") except Exception as e: logger.error(f"Error processing frame: {e}") - results['error'] = str(e) + results["error"] = str(e) # Add timing information total_time = time.time() - start_time - results['processing_time'] = total_time - results['timing_breakdown'] = { - 'detection': detection_time if 'detection_time' in locals() else 0, - 'segmentation': segmentation_time if 'segmentation_time' in locals() else 0, - 'pointcloud': pointcloud_time if 'pointcloud_time' in locals() else 0, - 'total': total_time + results["processing_time"] = total_time + results["timing_breakdown"] = { + "detection": detection_time if "detection_time" in locals() else 0, + "segmentation": segmentation_time if "segmentation_time" in locals() else 0, + "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, + "total": total_time, } logger.debug(f"Frame processing completed in {total_time:.3f}s") - logger.debug(f"Timing breakdown: detection={detection_time:.3f}s, segmentation={segmentation_time:.3f}s, pointcloud={pointcloud_time:.3f}s") + logger.debug( + f"Timing breakdown: detection={detection_time:.3f}s, segmentation={segmentation_time:.3f}s, pointcloud={pointcloud_time:.3f}s" + ) return results @@ -253,10 +251,12 @@ def _run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: try: # Convert RGB to BGR for Detic detector bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) - + # Use process_image method from Detic detector - bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image(bgr_image) - + bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image( + bgr_image + ) + # Convert to ObjectData format using utility function objects = detection_results_to_object_data( bboxes=bboxes, @@ -265,28 +265,22 @@ def _run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: confidences=confidences, names=names, masks=masks, - source="detection" + source="detection", ) - + # Create visualization using detector's built-in method viz_frame = self.detector.visualize_results( rgb_image, bboxes, track_ids, class_ids, confidences, names ) - - return { - 'objects': objects, - 'viz_frame': viz_frame - } - + + return {"objects": objects, "viz_frame": viz_frame} + except Exception as e: logger.error(f"Object detection failed: {e}") - return {'objects': [], 'viz_frame': rgb_image.copy()} + return {"objects": [], "viz_frame": rgb_image.copy()} def _run_pointcloud_filtering( - self, - rgb_image: np.ndarray, - depth_image: np.ndarray, - objects: List[Dict] + self, rgb_image: np.ndarray, depth_image: np.ndarray, objects: List[Dict] ) -> List[Dict]: """Run point cloud filtering on detected objects.""" try: @@ -301,15 +295,15 @@ def _run_pointcloud_filtering( def _run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: """Run semantic segmentation on RGB image.""" if not self.segmenter: - return {'objects': [], 'viz_frame': rgb_image.copy()} - + return {"objects": [], "viz_frame": rgb_image.copy()} + try: # Convert RGB to BGR for segmenter bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) - + # Get segmentation results masks, bboxes, track_ids, probs, names = self.segmenter.process_image(bgr_image) - + # Convert to ObjectData format using utility function objects = detection_results_to_object_data( bboxes=bboxes, @@ -318,9 +312,9 @@ def _run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: confidences=probs, names=names, masks=masks, - source="segmentation" + source="segmentation", ) - + # Create visualization if masks: viz_bgr = self.segmenter.visualize_results( @@ -330,15 +324,12 @@ def _run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: viz_frame = cv2.cvtColor(viz_bgr, cv2.COLOR_BGR2RGB) else: viz_frame = rgb_image.copy() - - return { - 'objects': objects, - 'viz_frame': viz_frame - } - + + return {"objects": objects, "viz_frame": viz_frame} + except Exception as e: logger.error(f"Segmentation failed: {e}") - return {'objects': [], 'viz_frame': rgb_image.copy()} + return {"objects": [], "viz_frame": rgb_image.copy()} def _run_grasp_generation(self, filtered_objects: List[Dict]) -> Optional[List[Dict]]: """Run grasp generation on filtered objects.""" @@ -367,7 +358,11 @@ def _run_grasp_generation(self, filtered_objects: List[Dict]) -> Optional[List[D if "colors_numpy" in obj and obj["colors_numpy"] is not None: colors = obj["colors_numpy"] if isinstance(colors, np.ndarray) and colors.size > 0: - if colors.shape[0] != points.shape[0] or len(colors.shape) != 2 or colors.shape[1] != 3: + if ( + colors.shape[0] != points.shape[0] + or len(colors.shape) != 2 + or colors.shape[1] != 3 + ): colors = None all_points.append(points) @@ -392,9 +387,7 @@ def _run_grasp_generation(self, filtered_objects: List[Dict]) -> Optional[List[D return None def _send_grasp_request_sync( - self, - points: np.ndarray, - colors: Optional[np.ndarray] + self, points: np.ndarray, colors: Optional[np.ndarray] ) -> Optional[List[Dict]]: """Send synchronous grasp request to AnyGrasp server.""" try: @@ -402,11 +395,11 @@ def _send_grasp_request_sync( if points is None or not isinstance(points, np.ndarray) or points.size == 0: logger.error("Invalid points array") return None - + if len(points.shape) != 2 or points.shape[1] != 3: logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)") return None - + if points.shape[0] < 100: logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100") return None @@ -441,9 +434,7 @@ def _send_grasp_request_sync( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - result = loop.run_until_complete( - self._async_grasp_request(points, colors) - ) + result = loop.run_until_complete(self._async_grasp_request(points, colors)) return result finally: loop.close() @@ -453,9 +444,7 @@ def _send_grasp_request_sync( return None async def _async_grasp_request( - self, - points: np.ndarray, - colors: np.ndarray + self, points: np.ndarray, colors: np.ndarray ) -> Optional[List[Dict]]: """Async grasp request helper.""" try: From 0f6632e80dc4c5112381d715fcf1e2cd136a6fff Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 19 Jun 2025 15:59:37 -0700 Subject: [PATCH 37/89] added misc points clustering --- dimos/perception/manip_aio_processer.py | 61 ++++++++++++++++++++----- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index eb51643255..0fa564be55 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -119,6 +119,8 @@ def process_frame( - detected_objects: Detection (Object Detection) objects with point clouds filtered - all_objects: All objects (including misc objects) (SAM segmentation) with point clouds filtered - full_pointcloud: Complete scene point cloud (if point cloud processing enabled) + - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) + - misc_pointcloud_viz: Visualization of misc/background cluster overlay - grasps: Grasp results (if enabled) - grasp_overlay: Grasp visualization (if enabled) - processing_time: Total processing time @@ -177,10 +179,23 @@ def process_frame( # Get full point cloud full_pcd = self.pointcloud_filter.get_full_point_cloud() - - results["detected_objects"] = detected_objects - results["all_objects"] = all_objects - results["full_pointcloud"] = full_pcd + + # Calculate misc_points clusters (full point cloud minus all object points) + misc_start = time.time() + from dimos.perception.pointcloud.utils import extract_and_cluster_misc_points + misc_clusters = extract_and_cluster_misc_points( + full_pcd, + all_objects, + eps=0.05, # 5cm cluster distance + min_points=50, # Minimum 50 points per cluster + enable_filtering=True + ) + misc_time = time.time() - misc_start + + results['detected_objects'] = detected_objects + results['all_objects'] = all_objects + results['full_pointcloud'] = full_pcd + results['misc_clusters'] = misc_clusters # Create point cloud visualizations base_image = colorize_depth(depth_image, max_depth=10.0) @@ -203,7 +218,28 @@ def process_frame( intrinsics=self.camera_intrinsics, ) else: - results["detected_pointcloud_viz"] = base_image + results['detected_pointcloud_viz'] = base_image + + # Misc clusters visualization overlay + if misc_clusters: + from dimos.perception.pointcloud.utils import overlay_point_clouds_on_image + # Generate random colors for each cluster + cluster_colors = [] + for i in range(len(misc_clusters)): + np.random.seed(i + 100) # Consistent colors + color = tuple((np.random.rand(3) * 255).astype(int)) + cluster_colors.append(color) + + results['misc_pointcloud_viz'] = overlay_point_clouds_on_image( + base_image=base_image, + point_clouds=misc_clusters, + camera_intrinsics=self.camera_intrinsics, + colors=cluster_colors, + point_size=2, + alpha=0.6, + ) + else: + results['misc_pointcloud_viz'] = base_image # Step 4: Grasp Generation (if enabled) should_generate_grasps = ( @@ -232,12 +268,13 @@ def process_frame( # Add timing information total_time = time.time() - start_time - results["processing_time"] = total_time - results["timing_breakdown"] = { - "detection": detection_time if "detection_time" in locals() else 0, - "segmentation": segmentation_time if "segmentation_time" in locals() else 0, - "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, - "total": total_time, + results['processing_time'] = total_time + results['timing_breakdown'] = { + 'detection': detection_time if 'detection_time' in locals() else 0, + 'segmentation': segmentation_time if 'segmentation_time' in locals() else 0, + 'pointcloud': pointcloud_time if 'pointcloud_time' in locals() else 0, + 'misc_extraction': misc_time if 'misc_time' in locals() else 0, + 'total': total_time } logger.debug(f"Frame processing completed in {total_time:.3f}s") logger.debug( @@ -531,6 +568,8 @@ def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, fl return {"roll": x, "pitch": y, "yaw": z} + + def cleanup(self): """Clean up resources.""" if hasattr(self.detector, "cleanup"): From bd5176f51a47a5c0249cb7d223c4a99270edacc2 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Thu, 19 Jun 2025 23:01:24 +0000 Subject: [PATCH 38/89] CI code cleanup --- dimos/perception/manip_aio_processer.py | 44 ++++++++++++------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index 0fa564be55..e71955a2c6 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -179,23 +179,24 @@ def process_frame( # Get full point cloud full_pcd = self.pointcloud_filter.get_full_point_cloud() - + # Calculate misc_points clusters (full point cloud minus all object points) misc_start = time.time() from dimos.perception.pointcloud.utils import extract_and_cluster_misc_points + misc_clusters = extract_and_cluster_misc_points( - full_pcd, + full_pcd, all_objects, eps=0.05, # 5cm cluster distance min_points=50, # Minimum 50 points per cluster - enable_filtering=True + enable_filtering=True, ) misc_time = time.time() - misc_start - - results['detected_objects'] = detected_objects - results['all_objects'] = all_objects - results['full_pointcloud'] = full_pcd - results['misc_clusters'] = misc_clusters + + results["detected_objects"] = detected_objects + results["all_objects"] = all_objects + results["full_pointcloud"] = full_pcd + results["misc_clusters"] = misc_clusters # Create point cloud visualizations base_image = colorize_depth(depth_image, max_depth=10.0) @@ -218,19 +219,20 @@ def process_frame( intrinsics=self.camera_intrinsics, ) else: - results['detected_pointcloud_viz'] = base_image - + results["detected_pointcloud_viz"] = base_image + # Misc clusters visualization overlay if misc_clusters: from dimos.perception.pointcloud.utils import overlay_point_clouds_on_image + # Generate random colors for each cluster cluster_colors = [] for i in range(len(misc_clusters)): np.random.seed(i + 100) # Consistent colors color = tuple((np.random.rand(3) * 255).astype(int)) cluster_colors.append(color) - - results['misc_pointcloud_viz'] = overlay_point_clouds_on_image( + + results["misc_pointcloud_viz"] = overlay_point_clouds_on_image( base_image=base_image, point_clouds=misc_clusters, camera_intrinsics=self.camera_intrinsics, @@ -239,7 +241,7 @@ def process_frame( alpha=0.6, ) else: - results['misc_pointcloud_viz'] = base_image + results["misc_pointcloud_viz"] = base_image # Step 4: Grasp Generation (if enabled) should_generate_grasps = ( @@ -268,13 +270,13 @@ def process_frame( # Add timing information total_time = time.time() - start_time - results['processing_time'] = total_time - results['timing_breakdown'] = { - 'detection': detection_time if 'detection_time' in locals() else 0, - 'segmentation': segmentation_time if 'segmentation_time' in locals() else 0, - 'pointcloud': pointcloud_time if 'pointcloud_time' in locals() else 0, - 'misc_extraction': misc_time if 'misc_time' in locals() else 0, - 'total': total_time + results["processing_time"] = total_time + results["timing_breakdown"] = { + "detection": detection_time if "detection_time" in locals() else 0, + "segmentation": segmentation_time if "segmentation_time" in locals() else 0, + "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, + "misc_extraction": misc_time if "misc_time" in locals() else 0, + "total": total_time, } logger.debug(f"Frame processing completed in {total_time:.3f}s") logger.debug( @@ -568,8 +570,6 @@ def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, fl return {"roll": x, "pitch": y, "yaw": z} - - def cleanup(self): """Clean up resources.""" if hasattr(self.detector, "cleanup"): From 4e34d654e99acbd8055e3e2444da59fa66f07de9 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 19 Jun 2025 18:11:29 -0700 Subject: [PATCH 39/89] added open3d's voxelgrid to manip_output --- dimos/perception/manip_aio_processer.py | 113 +++++++++++------------- 1 file changed, 51 insertions(+), 62 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index e71955a2c6..1360443dd3 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -30,7 +30,11 @@ from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter from dimos.perception.grasp_generation.utils import draw_grasps_on_image -from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization +from dimos.perception.pointcloud.utils import ( + create_point_cloud_overlay_visualization, + extract_and_cluster_misc_points, + overlay_point_clouds_on_image, +) from dimos.perception.common.utils import colorize_depth, detection_results_to_object_data logger = setup_logger("dimos.perception.manip_aio_processor") @@ -120,6 +124,7 @@ def process_frame( - all_objects: All objects (including misc objects) (SAM segmentation) with point clouds filtered - full_pointcloud: Complete scene point cloud (if point cloud processing enabled) - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) + - misc_voxel_grid: Open3D voxel grid approximating all misc/background points - misc_pointcloud_viz: Visualization of misc/background cluster overlay - grasps: Grasp results (if enabled) - grasp_overlay: Grasp visualization (if enabled) @@ -131,19 +136,15 @@ def process_frame( try: # Step 1: Object Detection step_start = time.time() - logger.debug("Running object detection...") detection_results = self._run_object_detection(rgb_image) - results["detection2d_objects"] = detection_results.get("objects", []) results["detection_viz"] = detection_results.get("viz_frame") detection_time = time.time() - step_start # Step 2: Semantic Segmentation (if enabled) segmentation_time = 0 - segmentation_results = None if self.enable_segmentation: step_start = time.time() - logger.debug("Running semantic segmentation...") segmentation_results = self._run_segmentation(rgb_image) results["segmentation2d_objects"] = segmentation_results.get("objects", []) results["segmentation_viz"] = segmentation_results.get("viz_frame") @@ -158,7 +159,6 @@ def process_frame( detected_objects = [] if detection2d_objects: step_start = time.time() - logger.debug(f"Processing {len(detection2d_objects)} detection2d_objects...") detected_objects = self._run_pointcloud_filtering( rgb_image, depth_image, detection2d_objects ) @@ -168,7 +168,6 @@ def process_frame( segmentation_filtered_objects = [] if segmentation2d_objects: step_start = time.time() - logger.debug(f"Processing {len(segmentation2d_objects)} segmentation objects...") segmentation_filtered_objects = self._run_pointcloud_filtering( rgb_image, depth_image, segmentation2d_objects ) @@ -179,60 +178,56 @@ def process_frame( # Get full point cloud full_pcd = self.pointcloud_filter.get_full_point_cloud() - - # Calculate misc_points clusters (full point cloud minus all object points) + + # Extract misc/background points and create voxel grid misc_start = time.time() - from dimos.perception.pointcloud.utils import extract_and_cluster_misc_points - - misc_clusters = extract_and_cluster_misc_points( - full_pcd, - all_objects, - eps=0.05, # 5cm cluster distance - min_points=50, # Minimum 50 points per cluster + all_filtered_objects = segmentation_filtered_objects + detected_objects + misc_clusters, misc_voxel_grid = extract_and_cluster_misc_points( + full_pcd, + all_filtered_objects, + eps=0.03, + min_points=100, enable_filtering=True, + voxel_size=0.02 ) misc_time = time.time() - misc_start - - results["detected_objects"] = detected_objects - results["all_objects"] = all_objects - results["full_pointcloud"] = full_pcd - results["misc_clusters"] = misc_clusters + + # Store results + results.update({ + 'detected_objects': detected_objects, + 'all_objects': all_objects, + 'full_pointcloud': full_pcd, + 'misc_clusters': misc_clusters, + 'misc_voxel_grid': misc_voxel_grid + }) # Create point cloud visualizations base_image = colorize_depth(depth_image, max_depth=10.0) - # Main pointcloud visualization (all objects) - if all_objects: - results["pointcloud_viz"] = create_point_cloud_overlay_visualization( + # Create visualizations + results["pointcloud_viz"] = ( + create_point_cloud_overlay_visualization( base_image=base_image, objects=all_objects, intrinsics=self.camera_intrinsics, - ) - else: - results["pointcloud_viz"] = base_image - - # Detection objects pointcloud visualization - if detected_objects: - results["detected_pointcloud_viz"] = create_point_cloud_overlay_visualization( + ) if all_objects else base_image + ) + + results["detected_pointcloud_viz"] = ( + create_point_cloud_overlay_visualization( base_image=base_image, objects=detected_objects, intrinsics=self.camera_intrinsics, - ) - else: - results["detected_pointcloud_viz"] = base_image - - # Misc clusters visualization overlay + ) if detected_objects else base_image + ) + if misc_clusters: - from dimos.perception.pointcloud.utils import overlay_point_clouds_on_image - - # Generate random colors for each cluster - cluster_colors = [] - for i in range(len(misc_clusters)): - np.random.seed(i + 100) # Consistent colors - color = tuple((np.random.rand(3) * 255).astype(int)) - cluster_colors.append(color) - - results["misc_pointcloud_viz"] = overlay_point_clouds_on_image( + # Generate consistent colors for clusters + cluster_colors = [ + tuple((np.random.RandomState(i + 100).rand(3) * 255).astype(int)) + for i in range(len(misc_clusters)) + ] + results['misc_pointcloud_viz'] = overlay_point_clouds_on_image( base_image=base_image, point_clouds=misc_clusters, camera_intrinsics=self.camera_intrinsics, @@ -249,17 +244,13 @@ def process_frame( ) if should_generate_grasps and all_objects: - logger.debug("Generating grasps...") grasps = self._run_grasp_generation(all_objects) results["grasps"] = grasps - - # Create grasp overlay if grasps: results["grasp_overlay"] = self._create_grasp_overlay(rgb_image, grasps) # Ensure segmentation runs even if no objects detected if self.enable_segmentation and "segmentation_viz" not in results: - logger.debug("Running semantic segmentation (no objects detected)...") segmentation_results = self._run_segmentation(rgb_image) results["segmentation2d_objects"] = segmentation_results.get("objects", []) results["segmentation_viz"] = segmentation_results.get("viz_frame") @@ -270,18 +261,16 @@ def process_frame( # Add timing information total_time = time.time() - start_time - results["processing_time"] = total_time - results["timing_breakdown"] = { - "detection": detection_time if "detection_time" in locals() else 0, - "segmentation": segmentation_time if "segmentation_time" in locals() else 0, - "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, - "misc_extraction": misc_time if "misc_time" in locals() else 0, - "total": total_time, - } - logger.debug(f"Frame processing completed in {total_time:.3f}s") - logger.debug( - f"Timing breakdown: detection={detection_time:.3f}s, segmentation={segmentation_time:.3f}s, pointcloud={pointcloud_time:.3f}s" - ) + results.update({ + 'processing_time': total_time, + 'timing_breakdown': { + 'detection': detection_time if 'detection_time' in locals() else 0, + 'segmentation': segmentation_time if 'segmentation_time' in locals() else 0, + 'pointcloud': pointcloud_time if 'pointcloud_time' in locals() else 0, + 'misc_extraction': misc_time if 'misc_time' in locals() else 0, + 'total': total_time + } + }) return results From 304dcc28e2e1ce59cbadcc708b7dc820e9e53531 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Fri, 20 Jun 2025 01:13:02 +0000 Subject: [PATCH 40/89] CI code cleanup --- dimos/perception/manip_aio_processer.py | 58 ++++++++++++++----------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index 1360443dd3..6465bf6576 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -178,28 +178,30 @@ def process_frame( # Get full point cloud full_pcd = self.pointcloud_filter.get_full_point_cloud() - + # Extract misc/background points and create voxel grid misc_start = time.time() all_filtered_objects = segmentation_filtered_objects + detected_objects misc_clusters, misc_voxel_grid = extract_and_cluster_misc_points( - full_pcd, + full_pcd, all_filtered_objects, eps=0.03, min_points=100, enable_filtering=True, - voxel_size=0.02 + voxel_size=0.02, ) misc_time = time.time() - misc_start - + # Store results - results.update({ - 'detected_objects': detected_objects, - 'all_objects': all_objects, - 'full_pointcloud': full_pcd, - 'misc_clusters': misc_clusters, - 'misc_voxel_grid': misc_voxel_grid - }) + results.update( + { + "detected_objects": detected_objects, + "all_objects": all_objects, + "full_pointcloud": full_pcd, + "misc_clusters": misc_clusters, + "misc_voxel_grid": misc_voxel_grid, + } + ) # Create point cloud visualizations base_image = colorize_depth(depth_image, max_depth=10.0) @@ -210,24 +212,28 @@ def process_frame( base_image=base_image, objects=all_objects, intrinsics=self.camera_intrinsics, - ) if all_objects else base_image + ) + if all_objects + else base_image ) - + results["detected_pointcloud_viz"] = ( create_point_cloud_overlay_visualization( base_image=base_image, objects=detected_objects, intrinsics=self.camera_intrinsics, - ) if detected_objects else base_image + ) + if detected_objects + else base_image ) - + if misc_clusters: # Generate consistent colors for clusters cluster_colors = [ tuple((np.random.RandomState(i + 100).rand(3) * 255).astype(int)) for i in range(len(misc_clusters)) ] - results['misc_pointcloud_viz'] = overlay_point_clouds_on_image( + results["misc_pointcloud_viz"] = overlay_point_clouds_on_image( base_image=base_image, point_clouds=misc_clusters, camera_intrinsics=self.camera_intrinsics, @@ -261,16 +267,18 @@ def process_frame( # Add timing information total_time = time.time() - start_time - results.update({ - 'processing_time': total_time, - 'timing_breakdown': { - 'detection': detection_time if 'detection_time' in locals() else 0, - 'segmentation': segmentation_time if 'segmentation_time' in locals() else 0, - 'pointcloud': pointcloud_time if 'pointcloud_time' in locals() else 0, - 'misc_extraction': misc_time if 'misc_time' in locals() else 0, - 'total': total_time + results.update( + { + "processing_time": total_time, + "timing_breakdown": { + "detection": detection_time if "detection_time" in locals() else 0, + "segmentation": segmentation_time if "segmentation_time" in locals() else 0, + "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, + "misc_extraction": misc_time if "misc_time" in locals() else 0, + "total": total_time, + }, } - }) + ) return results From a26d7a45d7a0c69d645429ce6b90a2a2f0ab7cd2 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 25 Jun 2025 14:59:16 -0700 Subject: [PATCH 41/89] supports contact graspnet --- dimos/perception/manip_aio_processer.py | 248 +----- .../pointcloud/pointcloud_filtering.py | 2 +- tests/manipulation_pipeline_demo.ipynb | 839 ++++++++++++++++++ 3 files changed, 879 insertions(+), 210 deletions(-) create mode 100644 tests/manipulation_pipeline_demo.ipynb diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index 6465bf6576..b8b0c0b72d 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -16,11 +16,8 @@ Sequential manipulation processor for single-frame processing without reactive streams. """ -import json import logging import time -import asyncio -import websockets from typing import Dict, List, Optional, Any, Tuple import numpy as np import cv2 @@ -29,7 +26,7 @@ from dimos.perception.detection2d.detic_2d_det import Detic2DDetector from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter -from dimos.perception.grasp_generation.utils import draw_grasps_on_image +from dimos.perception.grasp_generation.grasp_generation import ContactGraspNetGenerator from dimos.perception.pointcloud.utils import ( create_point_cloud_overlay_visualization, extract_and_cluster_misc_points, @@ -45,7 +42,7 @@ class ManipulationProcessor: Sequential manipulation processor for single-frame processing. Processes RGB-D frames through object detection, point cloud filtering, - and optional grasp generation in a single thread without reactive streams. + and ContactGraspNet grasp generation in a single thread without reactive streams. """ def __init__( @@ -54,7 +51,6 @@ def __init__( min_confidence: float = 0.6, max_objects: int = 20, vocabulary: Optional[str] = None, - grasp_server_url: Optional[str] = None, enable_grasp_generation: bool = False, enable_segmentation: bool = True, segmentation_model: str = "sam2_b.pt", @@ -67,15 +63,13 @@ def __init__( min_confidence: Minimum detection confidence threshold max_objects: Maximum number of objects to process vocabulary: Optional vocabulary for Detic detector - grasp_server_url: Optional WebSocket URL for AnyGrasp server - enable_grasp_generation: Whether to enable grasp generation + enable_grasp_generation: Whether to enable ContactGraspNet grasp generation enable_segmentation: Whether to enable semantic segmentation segmentation_model: Segmentation model to use (SAM 2 or FastSAM) """ self.camera_intrinsics = camera_intrinsics self.min_confidence = min_confidence self.max_objects = max_objects - self.grasp_server_url = grasp_server_url self.enable_grasp_generation = enable_grasp_generation self.enable_segmentation = enable_segmentation @@ -100,7 +94,20 @@ def __init__( model_type="auto", # Auto-detect model type ) - logger.info(f"Initialized ManipulationProcessor with confidence={min_confidence}") + # Initialize ContactGraspNet generator if enabled + self.grasp_generator = None + if self.enable_grasp_generation: + try: + self.grasp_generator = ContactGraspNetGenerator() + logger.info("ContactGraspNet generator initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize ContactGraspNet generator: {e}") + self.grasp_generator = None + self.enable_grasp_generation = False + + logger.info( + f"Initialized ManipulationProcessor with confidence={min_confidence}, grasp_generation={enable_grasp_generation}" + ) def process_frame( self, rgb_image: np.ndarray, depth_image: np.ndarray, generate_grasps: bool = None @@ -126,8 +133,7 @@ def process_frame( - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) - misc_voxel_grid: Open3D voxel grid approximating all misc/background points - misc_pointcloud_viz: Visualization of misc/background cluster overlay - - grasps: Grasp results (if enabled) - - grasp_overlay: Grasp visualization (if enabled) + - grasps: ContactGraspNet results (if enabled) - processing_time: Total processing time """ start_time = time.time() @@ -244,16 +250,14 @@ def process_frame( else: results["misc_pointcloud_viz"] = base_image - # Step 4: Grasp Generation (if enabled) + # Step 4: ContactGraspNet Grasp Generation (if enabled) should_generate_grasps = ( generate_grasps if generate_grasps is not None else self.enable_grasp_generation ) - if should_generate_grasps and all_objects: - grasps = self._run_grasp_generation(all_objects) + if should_generate_grasps and all_objects and full_pcd: + grasps = self._run_grasp_generation(all_objects, full_pcd) results["grasps"] = grasps - if grasps: - results["grasp_overlay"] = self._create_grasp_overlay(rgb_image, grasps) # Ensure segmentation runs even if no objects detected if self.enable_segmentation and "segmentation_viz" not in results: @@ -367,206 +371,30 @@ def _run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: logger.error(f"Segmentation failed: {e}") return {"objects": [], "viz_frame": rgb_image.copy()} - def _run_grasp_generation(self, filtered_objects: List[Dict]) -> Optional[List[Dict]]: - """Run grasp generation on filtered objects.""" - if not self.grasp_server_url: - logger.warning("Grasp generation requested but no server URL provided") - return None - - try: - # Combine all point clouds - all_points = [] - all_colors = [] - valid_objects = 0 - - for obj in filtered_objects: - if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: - continue - - points = obj["point_cloud_numpy"] - if not isinstance(points, np.ndarray) or points.size == 0: - continue - - if len(points.shape) != 2 or points.shape[1] != 3: - continue - - colors = None - if "colors_numpy" in obj and obj["colors_numpy"] is not None: - colors = obj["colors_numpy"] - if isinstance(colors, np.ndarray) and colors.size > 0: - if ( - colors.shape[0] != points.shape[0] - or len(colors.shape) != 2 - or colors.shape[1] != 3 - ): - colors = None - - all_points.append(points) - if colors is not None: - all_colors.append(colors) - valid_objects += 1 - - if not all_points: - return None - - # Combine point clouds - combined_points = np.vstack(all_points) - combined_colors = None - if len(all_colors) == valid_objects and len(all_colors) > 0: - combined_colors = np.vstack(all_colors) - - # Send grasp request synchronously - return self._send_grasp_request_sync(combined_points, combined_colors) - - except Exception as e: - logger.error(f"Grasp generation failed: {e}") + def _run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[Dict]: + """Run ContactGraspNet grasp generation.""" + if not self.grasp_generator: + logger.warning("Grasp generation requested but ContactGraspNet not available") return None - def _send_grasp_request_sync( - self, points: np.ndarray, colors: Optional[np.ndarray] - ) -> Optional[List[Dict]]: - """Send synchronous grasp request to AnyGrasp server.""" try: - # Validation (same as async version) - if points is None or not isinstance(points, np.ndarray) or points.size == 0: - logger.error("Invalid points array") - return None - - if len(points.shape) != 2 or points.shape[1] != 3: - logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)") - return None - - if points.shape[0] < 100: - logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100") - return None - - # Prepare colors - if colors is not None: - if not isinstance(colors, np.ndarray) or colors.size == 0: - colors = None - elif len(colors.shape) != 2 or colors.shape[1] != 3: - colors = None - elif colors.shape[0] != points.shape[0]: - colors = None - - if colors is None: - colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 - - # Ensure correct data types - points = points.astype(np.float32) - colors = colors.astype(np.float32) - - # Validate ranges - if np.any(np.isnan(points)) or np.any(np.isinf(points)): - logger.error("Points contain NaN or Inf values") - return None - if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): - logger.error("Colors contain NaN or Inf values") - return None - - colors = np.clip(colors, 0.0, 1.0) - - # Run async request in sync context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete(self._async_grasp_request(points, colors)) - return result - finally: - loop.close() - - except Exception as e: - logger.error(f"Error in synchronous grasp request: {e}") - return None - - async def _async_grasp_request( - self, points: np.ndarray, colors: np.ndarray - ) -> Optional[List[Dict]]: - """Async grasp request helper.""" - try: - async with websockets.connect(self.grasp_server_url) as websocket: - request = { - "points": points.tolist(), - "colors": colors.tolist(), - "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], - } - - await websocket.send(json.dumps(request)) - response = await websocket.recv() - grasps = json.loads(response) - - if isinstance(grasps, dict) and "error" in grasps: - logger.error(f"Server returned error: {grasps['error']}") - return None - elif isinstance(grasps, (int, float)) and grasps == 0: - return None - elif not isinstance(grasps, list): - logger.error(f"Server returned unexpected response type: {type(grasps)}") - return None - elif len(grasps) == 0: - return None + # Generate grasps using ContactGraspNet + pred_grasps_cam, scores, contact_pts, gripper_openings = ( + self.grasp_generator.generate_grasps_from_objects(filtered_objects, full_pcd) + ) - return self._convert_grasp_format(grasps) + # Return ContactGraspNet results directly + return { + "pred_grasps_cam": pred_grasps_cam, + "scores": scores, + "contact_pts": contact_pts, + "gripper_openings": gripper_openings, + } except Exception as e: - logger.error(f"Async grasp request failed: {e}") + logger.error(f"ContactGraspNet grasp generation failed: {e}") return None - def _create_grasp_overlay(self, rgb_image: np.ndarray, grasps: List[Dict]) -> np.ndarray: - """Create grasp visualization overlay on RGB image.""" - try: - bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) - result_bgr = draw_grasps_on_image( - bgr_image, - grasps, - self.camera_intrinsics, - max_grasps=-1, # Show all grasps - ) - return cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) - except Exception as e: - logger.error(f"Error creating grasp overlay: {e}") - return rgb_image.copy() - - def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: - """Convert AnyGrasp format to visualization format.""" - converted = [] - - for i, grasp in enumerate(anygrasp_grasps): - rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) - euler_angles = self._rotation_matrix_to_euler(rotation_matrix) - - converted_grasp = { - "id": f"grasp_{i}", - "score": grasp.get("score", 0.0), - "width": grasp.get("width", 0.0), - "height": grasp.get("height", 0.0), - "depth": grasp.get("depth", 0.0), - "translation": grasp.get("translation", [0, 0, 0]), - "rotation_matrix": rotation_matrix.tolist(), - "euler_angles": euler_angles, - } - converted.append(converted_grasp) - - converted.sort(key=lambda x: x["score"], reverse=True) - return converted - - def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: - """Convert rotation matrix to Euler angles (in radians).""" - sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) - - singular = sy < 1e-6 - - if not singular: - x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) - y = np.arctan2(-rotation_matrix[2, 0], sy) - z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) - else: - x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) - y = np.arctan2(-rotation_matrix[2, 0], sy) - z = 0 - - return {"roll": x, "pitch": y, "yaw": z} - def cleanup(self): """Clean up resources.""" if hasattr(self.detector, "cleanup"): @@ -575,4 +403,6 @@ def cleanup(self): self.pointcloud_filter.cleanup() if self.segmenter and hasattr(self.segmenter, "cleanup"): self.segmenter.cleanup() + if self.grasp_generator and hasattr(self.grasp_generator, "cleanup"): + self.grasp_generator.cleanup() logger.info("ManipulationProcessor cleaned up") diff --git a/dimos/perception/pointcloud/pointcloud_filtering.py b/dimos/perception/pointcloud/pointcloud_filtering.py index 3de2f3ae6a..47d351bd14 100644 --- a/dimos/perception/pointcloud/pointcloud_filtering.py +++ b/dimos/perception/pointcloud/pointcloud_filtering.py @@ -292,7 +292,7 @@ def process_images( pcd = self._apply_color_mask(pcd, rgb_color) # Apply subsampling to control point cloud size - pcd = self._apply_subsampling(pcd) + # pcd = self._apply_subsampling(pcd) # Apply filtering (optional based on flags) pcd_filtered = self._apply_filtering(pcd) diff --git a/tests/manipulation_pipeline_demo.ipynb b/tests/manipulation_pipeline_demo.ipynb new file mode 100644 index 0000000000..df43a7c6ac --- /dev/null +++ b/tests/manipulation_pipeline_demo.ipynb @@ -0,0 +1,839 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Manipulation Pipeline Demo with ContactGraspNet\n", + "\n", + "This notebook demonstrates the complete manipulation pipeline including:\n", + "- Object detection (Detic)\n", + "- Semantic segmentation (SAM/FastSAM)\n", + "- Point cloud processing\n", + "- 6-DoF grasp generation (ContactGraspNet)\n", + "- 3D visualization\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Jupyter environment detected. Enabling Open3D WebVisualizer.\n", + "[Open3D INFO] WebRTC GUI backend enabled.\n", + "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n", + "\u2705 All imports successful!\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "import cv2\n", + "import numpy as np\n", + "import time\n", + "import matplotlib\n", + "\n", + "# Configure matplotlib backend\n", + "try:\n", + " matplotlib.use(\"TkAgg\")\n", + "except:\n", + " try:\n", + " matplotlib.use(\"Qt5Agg\")\n", + " except:\n", + " matplotlib.use(\"Agg\")\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import open3d as o3d\n", + "from typing import Dict, List\n", + "\n", + "# Add project root to path\n", + "sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(\"__file__\"))))\n", + "\n", + "# Import DIMOS modules\n", + "from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid\n", + "from dimos.perception.manip_aio_processer import ManipulationProcessor\n", + "from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml, visualize_pcd\n", + "from dimos.utils.logging_config import setup_logger\n", + "\n", + "# Import ContactGraspNet visualization\n", + "from dimos.models.manipulation.contact_graspnet_pytorch.contact_graspnet_pytorch.visualization_utils_o3d import (\n", + " visualize_grasps,\n", + ")\n", + "\n", + "logger = setup_logger(\"manipulation_pipeline_demo\")\n", + "print(\"\u2705 All imports successful!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Configuration:\n", + " data_dir: /home/alex-lin/dimos/assets/rgbd_data\n", + " enable_grasp_generation: True\n", + " enable_segmentation: True\n", + " segmentation_model: FastSAM-x.pt\n", + " min_confidence: 0.6\n", + " max_objects: 20\n", + " show_3d_visualizations: True\n", + " save_results: True\n" + ] + } + ], + "source": [ + "# Configuration parameters\n", + "CONFIG = {\n", + " \"data_dir\": \"/home/alex-lin/dimos/assets/rgbd_data\",\n", + " \"enable_grasp_generation\": True,\n", + " \"enable_segmentation\": True,\n", + " \"segmentation_model\": \"FastSAM-x.pt\", # or \"sam2_b.pt\"\n", + " \"min_confidence\": 0.6,\n", + " \"max_objects\": 20,\n", + " \"show_3d_visualizations\": True,\n", + " \"save_results\": True,\n", + "}\n", + "\n", + "print(f\"Configuration:\")\n", + "for key, value in CONFIG.items():\n", + " print(f\" {key}: {value}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Data Loading Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u2705 Data loading functions defined!\n" + ] + } + ], + "source": [ + "def load_first_frame(data_dir: str):\n", + " \"\"\"Load first RGB-D frame and camera intrinsics.\"\"\"\n", + " # Load images\n", + " color_img = cv2.imread(os.path.join(data_dir, \"color\", \"00000.png\"))\n", + " color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)\n", + "\n", + " depth_img = cv2.imread(os.path.join(data_dir, \"depth\", \"00000.png\"), cv2.IMREAD_ANYDEPTH)\n", + " if depth_img.dtype == np.uint16:\n", + " depth_img = depth_img.astype(np.float32) / 1000.0\n", + "\n", + " # Load intrinsics\n", + " camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, \"color_camera_info.yaml\"))\n", + " intrinsics = [\n", + " camera_matrix[0, 0], # fx\n", + " camera_matrix[1, 1], # fy\n", + " camera_matrix[0, 2], # cx\n", + " camera_matrix[1, 2], # cy\n", + " ]\n", + "\n", + " return color_img, depth_img, intrinsics\n", + "\n", + "\n", + "def create_point_cloud(color_img, depth_img, intrinsics):\n", + " \"\"\"Create Open3D point cloud for reference.\"\"\"\n", + " fx, fy, cx, cy = intrinsics\n", + " height, width = depth_img.shape\n", + "\n", + " o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy)\n", + " color_o3d = o3d.geometry.Image(color_img)\n", + " depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16))\n", + "\n", + " rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(\n", + " color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False\n", + " )\n", + "\n", + " return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics)\n", + "\n", + "\n", + "print(\"\u2705 Data loading functions defined!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Load RGB-D Data" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-06-25 13:29:47,127 - manipulation_pipeline_demo - INFO - Loaded images: color (720, 1280, 3), depth (720, 1280)\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Camera intrinsics: fx=749.3, fy=748.6, cx=639.4, cy=357.2\n" + ] + } + ], + "source": [ + "# Load data\n", + "color_img, depth_img, intrinsics = load_first_frame(CONFIG[\"data_dir\"])\n", + "logger.info(f\"Loaded images: color {color_img.shape}, depth {depth_img.shape}\")\n", + "\n", + "# Display input images\n", + "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", + "\n", + "axes[0].imshow(color_img)\n", + "axes[0].set_title(\"RGB Image\")\n", + "axes[0].axis(\"off\")\n", + "\n", + "# Colorize depth for visualization\n", + "depth_colorized = cv2.applyColorMap(\n", + " cv2.convertScaleAbs(depth_img, alpha=255.0 / depth_img.max()), cv2.COLORMAP_JET\n", + ")\n", + "depth_colorized = cv2.cvtColor(depth_colorized, cv2.COLOR_BGR2RGB)\n", + "axes[1].imshow(depth_colorized)\n", + "axes[1].set_title(\"Depth Image\")\n", + "axes[1].axis(\"off\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "print(\n", + " f\"Camera intrinsics: fx={intrinsics[0]:.1f}, fy={intrinsics[1]:.1f}, cx={intrinsics[2]:.1f}, cy={intrinsics[3]:.1f}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Initialize Manipulation Processor" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n", + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/helpers.py:7: FutureWarning: Importing from timm.models.helpers is deprecated, please import via timm.models\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n", + " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/TensorShape.cpp:3526.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting zs_weight /home/alex-lin/dimos/dimos/perception/detection2d/../../models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy\n", + "Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/FastSAM-x.pt to 'FastSAM-x.pt'...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 138M/138M [00:03<00:00, 41.5MB/s] \n", + "\u001b[32m2025-06-25 13:30:01,134 - dimos.perception.grasp_generation - INFO - Initializing ContactGraspNet on device: cuda\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,141 - dimos.perception.grasp_generation - INFO - Loaded config from dimos/models/manipulation/contact_graspnet_pytorch/checkpoints\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,164 - dimos.perception.grasp_generation - INFO - ContactGraspNet model initialized\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator initialized successfully\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.manip_aio_processor - INFO - ContactGraspNet generator initialized successfully\u001b[0m\n", + "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.manip_aio_processor - INFO - Initialized ManipulationProcessor with confidence=0.6, grasp_generation=True\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model func: \n", + "\u2705 ManipulationProcessor initialized successfully!\n" + ] + } + ], + "source": [ + "# Create processor with ContactGraspNet enabled\n", + "processor = ManipulationProcessor(\n", + " camera_intrinsics=intrinsics,\n", + " min_confidence=CONFIG[\"min_confidence\"],\n", + " max_objects=CONFIG[\"max_objects\"],\n", + " enable_grasp_generation=CONFIG[\"enable_grasp_generation\"],\n", + " enable_segmentation=CONFIG[\"enable_segmentation\"],\n", + " segmentation_model=CONFIG[\"segmentation_model\"],\n", + ")\n", + "\n", + "print(\"\u2705 ManipulationProcessor initialized successfully!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Run Processing Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udd04 Processing frame through pipeline...\n", + "DBSCAN clustering found 11 clusters from 28067 points\n", + "Created voxel grid with 2220 voxels (voxel_size=0.02)\n", + "using local regions\n", + "Extracted Region Cube Size: 0.311576783657074\n", + "Extracted Region Cube Size: 0.445679247379303\n", + "Extracted Region Cube Size: 0.24130240082740784\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.46059030294418335\n", + "Extracted Region Cube Size: 0.2357255220413208\n", + "Extracted Region Cube Size: 0.3680998980998993\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.24357137084007263\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.2409430295228958\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.23709678649902344\n", + "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.5130001306533813\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", + " return _methods._mean(a, axis=axis, dtype=dtype,\n", + "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/numpy/core/_methods.py:121: RuntimeWarning: invalid value encountered in divide\n", + " ret = um.true_divide(\n", + "\u001b[32m2025-06-25 13:30:19,727 - dimos.perception.grasp_generation - INFO - Generated 3400 grasps across 17 objects in 12.91s\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u2705 Processing completed in 14.768s\n" + ] + } + ], + "source": [ + "# Process single frame\n", + "print(\"\ud83d\udd04 Processing frame through pipeline...\")\n", + "start_time = time.time()\n", + "\n", + "results = processor.process_frame(color_img, depth_img)\n", + "\n", + "processing_time = time.time() - start_time\n", + "print(f\"\u2705 Processing completed in {processing_time:.3f}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Results Summary" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\ud83d\udcca PROCESSING RESULTS SUMMARY\n", + "==================================================\n", + "Available results: ['detection2d_objects', 'detection_viz', 'segmentation2d_objects', 'segmentation_viz', 'detected_objects', 'all_objects', 'full_pointcloud', 'misc_clusters', 'misc_voxel_grid', 'pointcloud_viz', 'detected_pointcloud_viz', 'misc_pointcloud_viz', 'grasps', 'processing_time', 'timing_breakdown']\n", + "Total processing time: 14.768s\n", + "\n", + "\u23f1\ufe0f Timing breakdown:\n", + " Detection: 0.550s\n", + " Segmentation: 0.733s\n", + " Point cloud: 0.144s\n", + " Misc extraction: 0.371s\n", + "\n", + "\ud83c\udfaf Object Detection:\n", + " Detection objects: 13\n", + " All objects processed: 18\n", + "\n", + "\ud83e\udde9 Background Analysis:\n", + " Misc clusters: 11 clusters with 26,692 total points\n", + "\n", + "\ud83e\udd16 ContactGraspNet Results:\n", + " Total grasps: 3400\n", + " Best score: 0.911\n", + " Objects with grasps: 17\n", + "\n", + "==================================================\n" + ] + } + ], + "source": [ + "# Print comprehensive results summary\n", + "print(f\"\\n\ud83d\udcca PROCESSING RESULTS SUMMARY\")\n", + "print(f\"\" + \"=\" * 50)\n", + "print(f\"Available results: {list(results.keys())}\")\n", + "print(f\"Total processing time: {results.get('processing_time', 0):.3f}s\")\n", + "\n", + "# Show timing breakdown\n", + "if \"timing_breakdown\" in results:\n", + " breakdown = results[\"timing_breakdown\"]\n", + " print(f\"\\n\u23f1\ufe0f Timing breakdown:\")\n", + " print(f\" Detection: {breakdown.get('detection', 0):.3f}s\")\n", + " print(f\" Segmentation: {breakdown.get('segmentation', 0):.3f}s\")\n", + " print(f\" Point cloud: {breakdown.get('pointcloud', 0):.3f}s\")\n", + " print(f\" Misc extraction: {breakdown.get('misc_extraction', 0):.3f}s\")\n", + "\n", + "# Object counts\n", + "detected_count = len(results.get(\"detected_objects\", []))\n", + "all_count = len(results.get(\"all_objects\", []))\n", + "print(f\"\\n\ud83c\udfaf Object Detection:\")\n", + "print(f\" Detection objects: {detected_count}\")\n", + "print(f\" All objects processed: {all_count}\")\n", + "\n", + "# Misc clusters info\n", + "if \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", + " cluster_count = len(results[\"misc_clusters\"])\n", + " total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in results[\"misc_clusters\"])\n", + " print(f\"\\n\ud83e\udde9 Background Analysis:\")\n", + " print(f\" Misc clusters: {cluster_count} clusters with {total_misc_points:,} total points\")\n", + "else:\n", + " print(f\"\\n\ud83e\udde9 Background Analysis: No clusters found\")\n", + "\n", + "# ContactGraspNet grasp summary\n", + "if \"grasps\" in results and results[\"grasps\"]:\n", + " grasp_data = results[\"grasps\"]\n", + " if isinstance(grasp_data, dict):\n", + " pred_grasps = grasp_data.get(\"pred_grasps_cam\", {})\n", + " scores = grasp_data.get(\"scores\", {})\n", + "\n", + " total_grasps = 0\n", + " best_score = 0\n", + "\n", + " for obj_id, obj_grasps in pred_grasps.items():\n", + " num_grasps = len(obj_grasps) if hasattr(obj_grasps, \"__len__\") else 0\n", + " total_grasps += num_grasps\n", + "\n", + " if obj_id in scores and len(scores[obj_id]) > 0:\n", + " obj_best_score = max(scores[obj_id])\n", + " if obj_best_score > best_score:\n", + " best_score = obj_best_score\n", + "\n", + " print(f\"\\n\ud83e\udd16 ContactGraspNet Results:\")\n", + " print(f\" Total grasps: {total_grasps}\")\n", + " print(f\" Best score: {best_score:.3f}\")\n", + " print(f\" Objects with grasps: {len(pred_grasps)}\")\n", + " else:\n", + " print(f\"\\n\ud83e\udd16 ContactGraspNet Results: Invalid format\")\n", + "else:\n", + " print(f\"\\n\ud83e\udd16 ContactGraspNet Results: No grasps generated\")\n", + "\n", + "print(\"\\n\" + \"=\" * 50)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. 2D Visualization Results" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udcf8 Results saved to: manipulation_results.png\n" + ] + } + ], + "source": [ + "# Collect available visualizations\n", + "viz_configs = []\n", + "\n", + "if \"detection_viz\" in results and results[\"detection_viz\"] is not None:\n", + " viz_configs.append((\"detection_viz\", \"Object Detection\"))\n", + "\n", + "if \"segmentation_viz\" in results and results[\"segmentation_viz\"] is not None:\n", + " viz_configs.append((\"segmentation_viz\", \"Semantic Segmentation\"))\n", + "\n", + "if \"pointcloud_viz\" in results and results[\"pointcloud_viz\"] is not None:\n", + " viz_configs.append((\"pointcloud_viz\", \"All Objects Point Cloud\"))\n", + "\n", + "if \"detected_pointcloud_viz\" in results and results[\"detected_pointcloud_viz\"] is not None:\n", + " viz_configs.append((\"detected_pointcloud_viz\", \"Detection Objects Point Cloud\"))\n", + "\n", + "if \"misc_pointcloud_viz\" in results and results[\"misc_pointcloud_viz\"] is not None:\n", + " viz_configs.append((\"misc_pointcloud_viz\", \"Misc/Background Points\"))\n", + "\n", + "# Create visualization grid\n", + "if viz_configs:\n", + " num_plots = len(viz_configs)\n", + "\n", + " if num_plots <= 3:\n", + " fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5))\n", + " else:\n", + " rows = 2\n", + " cols = (num_plots + 1) // 2\n", + " fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows))\n", + "\n", + " # Ensure axes is always iterable\n", + " if num_plots == 1:\n", + " axes = [axes]\n", + " elif num_plots > 2:\n", + " axes = axes.flatten()\n", + "\n", + " # Plot each result\n", + " for i, (key, title) in enumerate(viz_configs):\n", + " axes[i].imshow(results[key])\n", + " axes[i].set_title(title, fontsize=12, fontweight=\"bold\")\n", + " axes[i].axis(\"off\")\n", + "\n", + " # Hide unused subplots\n", + " if num_plots > 3:\n", + " for i in range(num_plots, len(axes)):\n", + " axes[i].axis(\"off\")\n", + "\n", + " plt.tight_layout()\n", + "\n", + " if CONFIG[\"save_results\"]:\n", + " output_path = \"manipulation_results.png\"\n", + " plt.savefig(output_path, dpi=150, bbox_inches=\"tight\")\n", + " print(f\"\ud83d\udcf8 Results saved to: {output_path}\")\n", + "\n", + " plt.show()\n", + "else:\n", + " print(\"\u26a0\ufe0f No 2D visualization results to display\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. 3D ContactGraspNet Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83c\udfaf Launching 3D visualization with 3400 ContactGraspNet grasps...\n", + "\ud83d\udccc Note: Close the 3D window to continue with the notebook\n", + "Visualizing...\n", + "\u2705 3D grasp visualization completed!\n" + ] + } + ], + "source": [ + "# 3D ContactGraspNet visualization\n", + "if (\n", + " CONFIG[\"show_3d_visualizations\"]\n", + " and \"grasps\" in results\n", + " and results[\"grasps\"]\n", + " and \"full_pointcloud\" in results\n", + "):\n", + " grasp_data = results[\"grasps\"]\n", + " full_pcd = results[\"full_pointcloud\"]\n", + "\n", + " if isinstance(grasp_data, dict) and full_pcd is not None:\n", + " try:\n", + " # Extract ContactGraspNet data\n", + " pred_grasps_cam = grasp_data.get(\"pred_grasps_cam\", {})\n", + " scores = grasp_data.get(\"scores\", {})\n", + " contact_pts = grasp_data.get(\"contact_pts\", {})\n", + " gripper_openings = grasp_data.get(\"gripper_openings\", {})\n", + "\n", + " # Check if we have valid grasp data\n", + " total_grasps = (\n", + " sum(len(grasps) for grasps in pred_grasps_cam.values()) if pred_grasps_cam else 0\n", + " )\n", + "\n", + " if total_grasps > 0:\n", + " print(\n", + " f\"\ud83c\udfaf Launching 3D visualization with {total_grasps} ContactGraspNet grasps...\"\n", + " )\n", + " print(\"\ud83d\udccc Note: Close the 3D window to continue with the notebook\")\n", + "\n", + " # Use ContactGraspNet's native visualization - pass dictionaries directly\n", + " visualize_grasps(\n", + " full_pcd,\n", + " pred_grasps_cam, # Pass dictionary directly\n", + " scores, # Pass dictionary directly\n", + " gripper_openings=gripper_openings,\n", + " )\n", + "\n", + " print(\"\u2705 3D grasp visualization completed!\")\n", + " else:\n", + " print(\"\u26a0\ufe0f No valid grasps to visualize in 3D\")\n", + "\n", + " except Exception as e:\n", + " print(f\"\u274c Error in ContactGraspNet 3D visualization: {e}\")\n", + " print(\" Skipping 3D grasp visualization\")\n", + "else:\n", + " if not CONFIG[\"show_3d_visualizations\"]:\n", + " print(\"\u23ed\ufe0f 3D visualizations disabled in config\")\n", + " else:\n", + " print(\"\u26a0\ufe0f ContactGraspNet grasp generation disabled or no results\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 10. Additional 3D Visualizations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.1 Full Scene Point Cloud" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if (\n", + " CONFIG[\"show_3d_visualizations\"]\n", + " and \"full_pointcloud\" in results\n", + " and results[\"full_pointcloud\"] is not None\n", + "):\n", + " full_pcd = results[\"full_pointcloud\"]\n", + " num_points = len(np.asarray(full_pcd.points))\n", + " print(f\"\ud83c\udf0d Launching full scene point cloud visualization ({num_points:,} points)...\")\n", + " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", + "\n", + " try:\n", + " visualize_pcd(\n", + " full_pcd,\n", + " window_name=\"Full Scene Point Cloud\",\n", + " point_size=2.0,\n", + " show_coordinate_frame=True,\n", + " )\n", + " print(\"\u2705 Full point cloud visualization completed!\")\n", + " except (KeyboardInterrupt, EOFError):\n", + " print(\"\u23ed\ufe0f Full point cloud visualization skipped\")\n", + "else:\n", + " print(\"\u26a0\ufe0f No full point cloud available for visualization\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.2 Background/Misc Clusters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if CONFIG[\"show_3d_visualizations\"] and \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", + " misc_clusters = results[\"misc_clusters\"]\n", + " cluster_count = len(misc_clusters)\n", + " total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in misc_clusters)\n", + "\n", + " print(\n", + " f\"\ud83e\udde9 Launching misc/background clusters visualization ({cluster_count} clusters, {total_misc_points:,} points)...\"\n", + " )\n", + " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", + "\n", + " try:\n", + " visualize_clustered_point_clouds(\n", + " misc_clusters,\n", + " window_name=\"Misc/Background Clusters (DBSCAN)\",\n", + " point_size=3.0,\n", + " show_coordinate_frame=True,\n", + " )\n", + " print(\"\u2705 Misc clusters visualization completed!\")\n", + " except (KeyboardInterrupt, EOFError):\n", + " print(\"\u23ed\ufe0f Misc clusters visualization skipped\")\n", + "else:\n", + " print(\"\u26a0\ufe0f No misc clusters available for visualization\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.3 Voxel Grid Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if (\n", + " CONFIG[\"show_3d_visualizations\"]\n", + " and \"misc_voxel_grid\" in results\n", + " and results[\"misc_voxel_grid\"] is not None\n", + "):\n", + " misc_voxel_grid = results[\"misc_voxel_grid\"]\n", + " voxel_count = len(misc_voxel_grid.get_voxels())\n", + "\n", + " print(f\"\ud83d\udce6 Launching voxel grid visualization ({voxel_count} voxels)...\")\n", + " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", + "\n", + " try:\n", + " visualize_voxel_grid(\n", + " misc_voxel_grid,\n", + " window_name=\"Misc/Background Voxel Grid\",\n", + " show_coordinate_frame=True,\n", + " )\n", + " print(\"\u2705 Voxel grid visualization completed!\")\n", + " except (KeyboardInterrupt, EOFError):\n", + " print(\"\u23ed\ufe0f Voxel grid visualization skipped\")\n", + " except Exception as e:\n", + " print(f\"\u274c Error in voxel grid visualization: {e}\")\n", + "else:\n", + " print(\"\u26a0\ufe0f No voxel grid available for visualization\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 11. Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Clean up resources\n", + "processor.cleanup()\n", + "print(\"\u2705 Pipeline cleanup completed!\")\n", + "print(\"\\n\ud83c\udf89 Manipulation pipeline demo finished successfully!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Summary\n", + "\n", + "This notebook demonstrated the complete DIMOS manipulation pipeline:\n", + "\n", + "1. **Object Detection**: Using Detic for 2D object detection\n", + "2. **Semantic Segmentation**: Using SAM/FastSAM for detailed segmentation\n", + "3. **Point Cloud Processing**: Converting RGB-D to 3D point clouds with filtering\n", + "4. **Background Analysis**: DBSCAN clustering of miscellaneous points\n", + "5. **Grasp Generation**: ContactGraspNet for 6-DoF grasp pose estimation\n", + "6. **Visualization**: Comprehensive 2D and 3D visualizations\n", + "\n", + "The pipeline is designed for real-time robotic manipulation tasks and provides rich visual feedback for debugging and analysis.\n", + "\n", + "### Key Features:\n", + "- \u2705 Modular design with clean interfaces\n", + "- \u2705 Real-time performance optimizations\n", + "- \u2705 Comprehensive error handling\n", + "- \u2705 Rich visualization capabilities\n", + "- \u2705 ContactGraspNet integration for state-of-the-art grasp generation\n", + "\n", + "### Next Steps:\n", + "- Integrate with robotic control systems\n", + "- Add grasp execution and feedback\n", + "- Implement multi-frame tracking\n", + "- Add custom object vocabularies\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "contact-graspnet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 6575e0eb122a4cab21c521896399aed35f6677ab Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 25 Jun 2025 16:33:12 -0700 Subject: [PATCH 42/89] added parsing of contact graspnet results into dict --- tests/manipulation_pipeline_demo.ipynb | 186 ++++++++++++++++--------- 1 file changed, 119 insertions(+), 67 deletions(-) diff --git a/tests/manipulation_pipeline_demo.ipynb b/tests/manipulation_pipeline_demo.ipynb index df43a7c6ac..01470c6355 100644 --- a/tests/manipulation_pipeline_demo.ipynb +++ b/tests/manipulation_pipeline_demo.ipynb @@ -87,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -133,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -195,14 +195,14 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2025-06-25 13:29:47,127 - manipulation_pipeline_demo - INFO - Loaded images: color (720, 1280, 3), depth (720, 1280)\u001b[0m\n" + "\u001b[32m2025-06-25 15:04:05,846 - manipulation_pipeline_demo - INFO - Loaded images: color (720, 1280, 3), depth (720, 1280)\u001b[0m\n" ] }, { @@ -251,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -265,7 +265,13 @@ "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n", " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/TensorShape.cpp:3526.)\n", - " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", + "\u001b[32m2025-06-25 15:04:11,530 - dimos.perception.grasp_generation - INFO - Initializing ContactGraspNet on device: cuda\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,541 - dimos.perception.grasp_generation - INFO - Loaded config from dimos/models/manipulation/contact_graspnet_pytorch/checkpoints\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,565 - dimos.perception.grasp_generation - INFO - ContactGraspNet model initialized\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,566 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator initialized successfully\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,566 - dimos.perception.manip_aio_processor - INFO - ContactGraspNet generator initialized successfully\u001b[0m\n", + "\u001b[32m2025-06-25 15:04:11,567 - dimos.perception.manip_aio_processor - INFO - Initialized ManipulationProcessor with confidence=0.6, grasp_generation=True\u001b[0m\n" ] }, { @@ -273,26 +279,6 @@ "output_type": "stream", "text": [ "Resetting zs_weight /home/alex-lin/dimos/dimos/perception/detection2d/../../models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy\n", - "Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/FastSAM-x.pt to 'FastSAM-x.pt'...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 138M/138M [00:03<00:00, 41.5MB/s] \n", - "\u001b[32m2025-06-25 13:30:01,134 - dimos.perception.grasp_generation - INFO - Initializing ContactGraspNet on device: cuda\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,141 - dimos.perception.grasp_generation - INFO - Loaded config from dimos/models/manipulation/contact_graspnet_pytorch/checkpoints\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,164 - dimos.perception.grasp_generation - INFO - ContactGraspNet model initialized\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator initialized successfully\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.manip_aio_processor - INFO - ContactGraspNet generator initialized successfully\u001b[0m\n", - "\u001b[32m2025-06-25 13:30:01,165 - dimos.perception.manip_aio_processor - INFO - Initialized ManipulationProcessor with confidence=0.6, grasp_generation=True\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ "model func: \n", "\u2705 ManipulationProcessor initialized successfully!\n" ] @@ -321,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -329,44 +315,59 @@ "output_type": "stream", "text": [ "\ud83d\udd04 Processing frame through pipeline...\n", - "DBSCAN clustering found 11 clusters from 28067 points\n", - "Created voxel grid with 2220 voxels (voxel_size=0.02)\n", + "DBSCAN clustering found 13 clusters from 26536 points\n", + "Created voxel grid with 2074 voxels (voxel_size=0.02)\n", "using local regions\n", - "Extracted Region Cube Size: 0.311576783657074\n", - "Extracted Region Cube Size: 0.445679247379303\n", - "Extracted Region Cube Size: 0.24130240082740784\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.46059030294418335\n", - "Extracted Region Cube Size: 0.2357255220413208\n", - "Extracted Region Cube Size: 0.3680998980998993\n", + "Extracted Region Cube Size: 0.3148665130138397\n", + "Extracted Region Cube Size: 0.4740000367164612\n", + "Extracted Region Cube Size: 0.2676139771938324\n", "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.4960000514984131\n", + "Extracted Region Cube Size: 0.30400002002716064\n", + "Extracted Region Cube Size: 0.38946154713630676\n", + "Extracted Region Cube Size: 0.2087651789188385\n", "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.24357137084007263\n", "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.2409430295228958\n", + "Extracted Region Cube Size: 0.24777960777282715\n", "Extracted Region Cube Size: 0.2\n", + "Extracted Region Cube Size: 0.2502080202102661\n", "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.23709678649902344\n", + "Extracted Region Cube Size: 0.3400000333786011\n", + "Extracted Region Cube Size: 0.22946105897426605\n", "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.5130001306533813\n" + "Extracted Region Cube Size: 0.5360000133514404\n", + "Generated 18 grasps for object 2\n", + "Generated 44 grasps for object 3\n", + "Generated 14 grasps for object 4\n", + "Generated 6 grasps for object 7\n", + "Generated 9 grasps for object 8\n", + "Generated 15 grasps for object 9\n", + "Generated 25 grasps for object 10\n", + "Generated 25 grasps for object 11\n", + "Generated 16 grasps for object 14\n", + "Generated 3 grasps for object 15\n", + "Generated 13 grasps for object 16\n", + "Generated 15 grasps for object 19\n", + "Generated 12 grasps for object 27\n", + "Generated 17 grasps for object 29\n", + "Generated 19 grasps for object 31\n", + "Generated 19 grasps for object 32\n", + "Generated 3 grasps for object 33\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/numpy/core/_methods.py:121: RuntimeWarning: invalid value encountered in divide\n", - " ret = um.true_divide(\n", - "\u001b[32m2025-06-25 13:30:19,727 - dimos.perception.grasp_generation - INFO - Generated 3400 grasps across 17 objects in 12.91s\u001b[0m\n" + "\u001b[32m2025-06-25 15:04:30,107 - dimos.perception.grasp_generation - INFO - Generated 296 grasps across 18 objects in 14.69s\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "\u2705 Processing completed in 14.768s\n" + "Generated 23 grasps for object 37\n", + "\u2705 Processing completed in 18.517s\n" ] } ], @@ -390,7 +391,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -401,25 +402,25 @@ "\ud83d\udcca PROCESSING RESULTS SUMMARY\n", "==================================================\n", "Available results: ['detection2d_objects', 'detection_viz', 'segmentation2d_objects', 'segmentation_viz', 'detected_objects', 'all_objects', 'full_pointcloud', 'misc_clusters', 'misc_voxel_grid', 'pointcloud_viz', 'detected_pointcloud_viz', 'misc_pointcloud_viz', 'grasps', 'processing_time', 'timing_breakdown']\n", - "Total processing time: 14.768s\n", + "Total processing time: 18.517s\n", "\n", "\u23f1\ufe0f Timing breakdown:\n", - " Detection: 0.550s\n", - " Segmentation: 0.733s\n", - " Point cloud: 0.144s\n", - " Misc extraction: 0.371s\n", + " Detection: 0.529s\n", + " Segmentation: 0.720s\n", + " Point cloud: 1.837s\n", + " Misc extraction: 0.385s\n", "\n", "\ud83c\udfaf Object Detection:\n", " Detection objects: 13\n", " All objects processed: 18\n", "\n", "\ud83e\udde9 Background Analysis:\n", - " Misc clusters: 11 clusters with 26,692 total points\n", + " Misc clusters: 13 clusters with 25,628 total points\n", "\n", "\ud83e\udd16 ContactGraspNet Results:\n", - " Total grasps: 3400\n", - " Best score: 0.911\n", - " Objects with grasps: 17\n", + " Total grasps: 296\n", + " Best score: 0.798\n", + " Objects with grasps: 18\n", "\n", "==================================================\n" ] @@ -497,7 +498,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -576,14 +577,14 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\ud83c\udfaf Launching 3D visualization with 3400 ContactGraspNet grasps...\n", + "\ud83c\udfaf Launching 3D visualization with 296 ContactGraspNet grasps...\n", "\ud83d\udccc Note: Close the 3D window to continue with the notebook\n", "Visualizing...\n", "\u2705 3D grasp visualization completed!\n" @@ -658,9 +659,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83c\udf0d Launching full scene point cloud visualization (526,100 points)...\n", + "\ud83d\udccc Note: Close the 3D window to continue\n", + "Visualizing point cloud with 526100 points\n", + "\u2705 Full point cloud visualization completed!\n" + ] + } + ], "source": [ "if (\n", " CONFIG[\"show_3d_visualizations\"]\n", @@ -695,9 +707,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83e\udde9 Launching misc/background clusters visualization (13 clusters, 25,628 points)...\n", + "\ud83d\udccc Note: Close the 3D window to continue\n", + "Visualizing 13 clusters with 25628 total points\n", + "\u2705 Misc clusters visualization completed!\n" + ] + } + ], "source": [ "if CONFIG[\"show_3d_visualizations\"] and \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", " misc_clusters = results[\"misc_clusters\"]\n", @@ -732,9 +755,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\ud83d\udce6 Launching voxel grid visualization (2074 voxels)...\n", + "\ud83d\udccc Note: Close the 3D window to continue\n", + "Visualizing voxel grid with 2074 voxels\n", + "\u2705 Voxel grid visualization completed!\n" + ] + } + ], "source": [ "if (\n", " CONFIG[\"show_3d_visualizations\"]\n", @@ -771,9 +805,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-06-25 15:05:01,624 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator cleaned up\u001b[0m\n", + "\u001b[32m2025-06-25 15:05:01,626 - dimos.perception.manip_aio_processor - INFO - ManipulationProcessor cleaned up\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u2705 Pipeline cleanup completed!\n", + "\n", + "\ud83c\udf89 Manipulation pipeline demo finished successfully!\n" + ] + } + ], "source": [ "# Clean up resources\n", "processor.cleanup()\n", From db37462368c076b4ef4856880e9003f88c7be8a3 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 25 Jun 2025 18:19:38 -0700 Subject: [PATCH 43/89] refactored some code, fixed a few bugs --- dimos/perception/manip_aio_processer.py | 39 ++++++++++++------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index b8b0c0b72d..7ebd7e9726 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -32,7 +32,11 @@ extract_and_cluster_misc_points, overlay_point_clouds_on_image, ) -from dimos.perception.common.utils import colorize_depth, detection_results_to_object_data +from dimos.perception.common.utils import ( + colorize_depth, + detection_results_to_object_data, + combine_object_data, +) logger = setup_logger("dimos.perception.manip_aio_processor") @@ -142,7 +146,7 @@ def process_frame( try: # Step 1: Object Detection step_start = time.time() - detection_results = self._run_object_detection(rgb_image) + detection_results = self.run_object_detection(rgb_image) results["detection2d_objects"] = detection_results.get("objects", []) results["detection_viz"] = detection_results.get("viz_frame") detection_time = time.time() - step_start @@ -151,7 +155,7 @@ def process_frame( segmentation_time = 0 if self.enable_segmentation: step_start = time.time() - segmentation_results = self._run_segmentation(rgb_image) + segmentation_results = self.run_segmentation(rgb_image) results["segmentation2d_objects"] = segmentation_results.get("objects", []) results["segmentation_viz"] = segmentation_results.get("viz_frame") segmentation_time = time.time() - step_start @@ -165,7 +169,7 @@ def process_frame( detected_objects = [] if detection2d_objects: step_start = time.time() - detected_objects = self._run_pointcloud_filtering( + detected_objects = self.run_pointcloud_filtering( rgb_image, depth_image, detection2d_objects ) pointcloud_time += time.time() - step_start @@ -174,23 +178,24 @@ def process_frame( segmentation_filtered_objects = [] if segmentation2d_objects: step_start = time.time() - segmentation_filtered_objects = self._run_pointcloud_filtering( + segmentation_filtered_objects = self.run_pointcloud_filtering( rgb_image, depth_image, segmentation2d_objects ) pointcloud_time += time.time() - step_start - # Combine all objects - all_objects = segmentation_filtered_objects + # Combine all objects using intelligent duplicate removal + all_objects = combine_object_data( + detected_objects, segmentation_filtered_objects, overlap_threshold=0.8 + ) # Get full point cloud full_pcd = self.pointcloud_filter.get_full_point_cloud() # Extract misc/background points and create voxel grid misc_start = time.time() - all_filtered_objects = segmentation_filtered_objects + detected_objects misc_clusters, misc_voxel_grid = extract_and_cluster_misc_points( full_pcd, - all_filtered_objects, + all_objects, eps=0.03, min_points=100, enable_filtering=True, @@ -256,15 +261,9 @@ def process_frame( ) if should_generate_grasps and all_objects and full_pcd: - grasps = self._run_grasp_generation(all_objects, full_pcd) + grasps = self.run_grasp_generation(all_objects, full_pcd) results["grasps"] = grasps - # Ensure segmentation runs even if no objects detected - if self.enable_segmentation and "segmentation_viz" not in results: - segmentation_results = self._run_segmentation(rgb_image) - results["segmentation2d_objects"] = segmentation_results.get("objects", []) - results["segmentation_viz"] = segmentation_results.get("viz_frame") - except Exception as e: logger.error(f"Error processing frame: {e}") results["error"] = str(e) @@ -286,7 +285,7 @@ def process_frame( return results - def _run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: + def run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: """Run object detection on RGB image.""" try: # Convert RGB to BGR for Detic detector @@ -319,7 +318,7 @@ def _run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: logger.error(f"Object detection failed: {e}") return {"objects": [], "viz_frame": rgb_image.copy()} - def _run_pointcloud_filtering( + def run_pointcloud_filtering( self, rgb_image: np.ndarray, depth_image: np.ndarray, objects: List[Dict] ) -> List[Dict]: """Run point cloud filtering on detected objects.""" @@ -332,7 +331,7 @@ def _run_pointcloud_filtering( logger.error(f"Point cloud filtering failed: {e}") return [] - def _run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: + def run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: """Run semantic segmentation on RGB image.""" if not self.segmenter: return {"objects": [], "viz_frame": rgb_image.copy()} @@ -371,7 +370,7 @@ def _run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: logger.error(f"Segmentation failed: {e}") return {"objects": [], "viz_frame": rgb_image.copy()} - def _run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[Dict]: + def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[Dict]: """Run ContactGraspNet grasp generation.""" if not self.grasp_generator: logger.warning("Grasp generation requested but ContactGraspNet not available") From 4f5bbf53cc597bc500b1d27d46e8d59f5b49eb1c Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 25 Jun 2025 22:12:34 -0700 Subject: [PATCH 44/89] added anygrasp and contact graspnet support --- dimos/perception/manip_aio_processer.py | 62 ++++++++++++------- .../pointcloud/pointcloud_filtering.py | 2 +- 2 files changed, 40 insertions(+), 24 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index 7ebd7e9726..8247a6a85c 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -26,7 +26,8 @@ from dimos.perception.detection2d.detic_2d_det import Detic2DDetector from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter -from dimos.perception.grasp_generation.grasp_generation import ContactGraspNetGenerator +from dimos.perception.grasp_generation.grasp_generation import GraspGeneratorFactory +from dimos.perception.grasp_generation.utils import create_grasp_overlay from dimos.perception.pointcloud.utils import ( create_point_cloud_overlay_visualization, extract_and_cluster_misc_points, @@ -56,6 +57,8 @@ def __init__( max_objects: int = 20, vocabulary: Optional[str] = None, enable_grasp_generation: bool = False, + grasp_model: str = "contactgraspnet", # "contactgraspnet" or "anygrasp" + grasp_server_url: Optional[str] = None, # Required for AnyGrasp enable_segmentation: bool = True, segmentation_model: str = "sam2_b.pt", ): @@ -67,7 +70,9 @@ def __init__( min_confidence: Minimum detection confidence threshold max_objects: Maximum number of objects to process vocabulary: Optional vocabulary for Detic detector - enable_grasp_generation: Whether to enable ContactGraspNet grasp generation + enable_grasp_generation: Whether to enable grasp generation + grasp_model: Type of grasp generator ("contactgraspnet" or "anygrasp") + grasp_server_url: WebSocket URL for AnyGrasp server (required if grasp_model="anygrasp") enable_segmentation: Whether to enable semantic segmentation segmentation_model: Segmentation model to use (SAM 2 or FastSAM) """ @@ -75,6 +80,8 @@ def __init__( self.min_confidence = min_confidence self.max_objects = max_objects self.enable_grasp_generation = enable_grasp_generation + self.grasp_model = grasp_model + self.grasp_server_url = grasp_server_url self.enable_segmentation = enable_segmentation # Initialize object detector @@ -98,19 +105,28 @@ def __init__( model_type="auto", # Auto-detect model type ) - # Initialize ContactGraspNet generator if enabled + # Initialize grasp generator if enabled self.grasp_generator = None if self.enable_grasp_generation: try: - self.grasp_generator = ContactGraspNetGenerator() - logger.info("ContactGraspNet generator initialized successfully") + if grasp_model.lower() == "anygrasp": + if not grasp_server_url: + raise ValueError("AnyGrasp requires grasp_server_url parameter") + self.grasp_generator = GraspGeneratorFactory.create_generator( + "anygrasp", server_url=grasp_server_url + ) + else: + self.grasp_generator = GraspGeneratorFactory.create_generator("contactgraspnet") + + logger.info(f"{grasp_model} generator initialized successfully") except Exception as e: - logger.error(f"Failed to initialize ContactGraspNet generator: {e}") + logger.error(f"Failed to initialize {grasp_model} generator: {e}") self.grasp_generator = None self.enable_grasp_generation = False logger.info( - f"Initialized ManipulationProcessor with confidence={min_confidence}, grasp_generation={enable_grasp_generation}" + f"Initialized ManipulationProcessor with confidence={min_confidence}, " + f"grasp_generation={enable_grasp_generation} ({grasp_model})" ) def process_frame( @@ -132,12 +148,13 @@ def process_frame( - detection2d_objects: Raw detection results as ObjectData - segmentation2d_objects: Raw segmentation results as ObjectData (if enabled) - detected_objects: Detection (Object Detection) objects with point clouds filtered - - all_objects: All objects (including misc objects) (SAM segmentation) with point clouds filtered + - all_objects: Combined objects with intelligent duplicate removal - full_pointcloud: Complete scene point cloud (if point cloud processing enabled) - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) - misc_voxel_grid: Open3D voxel grid approximating all misc/background points - misc_pointcloud_viz: Visualization of misc/background cluster overlay - - grasps: ContactGraspNet results (if enabled) + - grasps: Grasp results (ContactGraspNet or AnyGrasp, if enabled) + - grasp_overlay: Grasp visualization overlay (if enabled) - processing_time: Total processing time """ start_time = time.time() @@ -255,7 +272,7 @@ def process_frame( else: results["misc_pointcloud_viz"] = base_image - # Step 4: ContactGraspNet Grasp Generation (if enabled) + # Step 4: Grasp Generation (if enabled) should_generate_grasps = ( generate_grasps if generate_grasps is not None else self.enable_grasp_generation ) @@ -263,6 +280,10 @@ def process_frame( if should_generate_grasps and all_objects and full_pcd: grasps = self.run_grasp_generation(all_objects, full_pcd) results["grasps"] = grasps + if grasps: + results["grasp_overlay"] = create_grasp_overlay( + rgb_image, grasps, self.camera_intrinsics + ) except Exception as e: logger.error(f"Error processing frame: {e}") @@ -371,27 +392,22 @@ def run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: return {"objects": [], "viz_frame": rgb_image.copy()} def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[Dict]: - """Run ContactGraspNet grasp generation.""" + """Run grasp generation using the configured generator (ContactGraspNet or AnyGrasp).""" if not self.grasp_generator: - logger.warning("Grasp generation requested but ContactGraspNet not available") + logger.warning("Grasp generation requested but no generator available") return None try: - # Generate grasps using ContactGraspNet - pred_grasps_cam, scores, contact_pts, gripper_openings = ( - self.grasp_generator.generate_grasps_from_objects(filtered_objects, full_pcd) + # Generate grasps using the configured generator + parsed_grasps = self.grasp_generator.generate_grasps_from_objects( + filtered_objects, full_pcd ) - # Return ContactGraspNet results directly - return { - "pred_grasps_cam": pred_grasps_cam, - "scores": scores, - "contact_pts": contact_pts, - "gripper_openings": gripper_openings, - } + # Return parsed results directly + return parsed_grasps except Exception as e: - logger.error(f"ContactGraspNet grasp generation failed: {e}") + logger.error(f"{self.grasp_model} grasp generation failed: {e}") return None def cleanup(self): diff --git a/dimos/perception/pointcloud/pointcloud_filtering.py b/dimos/perception/pointcloud/pointcloud_filtering.py index 47d351bd14..3de2f3ae6a 100644 --- a/dimos/perception/pointcloud/pointcloud_filtering.py +++ b/dimos/perception/pointcloud/pointcloud_filtering.py @@ -292,7 +292,7 @@ def process_images( pcd = self._apply_color_mask(pcd, rgb_color) # Apply subsampling to control point cloud size - # pcd = self._apply_subsampling(pcd) + pcd = self._apply_subsampling(pcd) # Apply filtering (optional based on flags) pcd_filtered = self._apply_filtering(pcd) From e54306edfa9725fe68ab1682e622df4b8acd90d2 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 25 Jun 2025 22:51:58 -0700 Subject: [PATCH 45/89] removed all contactgraspnet stuff --- dimos/perception/manip_aio_processer.py | 47 ++++++++++--------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index 8247a6a85c..6e7083b0f3 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -26,7 +26,7 @@ from dimos.perception.detection2d.detic_2d_det import Detic2DDetector from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter -from dimos.perception.grasp_generation.grasp_generation import GraspGeneratorFactory +from dimos.perception.grasp_generation.grasp_generation import AnyGraspGenerator from dimos.perception.grasp_generation.utils import create_grasp_overlay from dimos.perception.pointcloud.utils import ( create_point_cloud_overlay_visualization, @@ -47,7 +47,7 @@ class ManipulationProcessor: Sequential manipulation processor for single-frame processing. Processes RGB-D frames through object detection, point cloud filtering, - and ContactGraspNet grasp generation in a single thread without reactive streams. + and AnyGrasp grasp generation in a single thread without reactive streams. """ def __init__( @@ -57,8 +57,7 @@ def __init__( max_objects: int = 20, vocabulary: Optional[str] = None, enable_grasp_generation: bool = False, - grasp_model: str = "contactgraspnet", # "contactgraspnet" or "anygrasp" - grasp_server_url: Optional[str] = None, # Required for AnyGrasp + grasp_server_url: Optional[str] = None, # Required when enable_grasp_generation=True enable_segmentation: bool = True, segmentation_model: str = "sam2_b.pt", ): @@ -71,8 +70,7 @@ def __init__( max_objects: Maximum number of objects to process vocabulary: Optional vocabulary for Detic detector enable_grasp_generation: Whether to enable grasp generation - grasp_model: Type of grasp generator ("contactgraspnet" or "anygrasp") - grasp_server_url: WebSocket URL for AnyGrasp server (required if grasp_model="anygrasp") + grasp_server_url: WebSocket URL for AnyGrasp server (required when enable_grasp_generation=True) enable_segmentation: Whether to enable semantic segmentation segmentation_model: Segmentation model to use (SAM 2 or FastSAM) """ @@ -80,10 +78,13 @@ def __init__( self.min_confidence = min_confidence self.max_objects = max_objects self.enable_grasp_generation = enable_grasp_generation - self.grasp_model = grasp_model self.grasp_server_url = grasp_server_url self.enable_segmentation = enable_segmentation + # Validate grasp generation requirements + if enable_grasp_generation and not grasp_server_url: + raise ValueError("grasp_server_url is required when enable_grasp_generation=True") + # Initialize object detector self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) @@ -109,24 +110,16 @@ def __init__( self.grasp_generator = None if self.enable_grasp_generation: try: - if grasp_model.lower() == "anygrasp": - if not grasp_server_url: - raise ValueError("AnyGrasp requires grasp_server_url parameter") - self.grasp_generator = GraspGeneratorFactory.create_generator( - "anygrasp", server_url=grasp_server_url - ) - else: - self.grasp_generator = GraspGeneratorFactory.create_generator("contactgraspnet") - - logger.info(f"{grasp_model} generator initialized successfully") + self.grasp_generator = AnyGraspGenerator(server_url=grasp_server_url) + logger.info("AnyGrasp generator initialized successfully") except Exception as e: - logger.error(f"Failed to initialize {grasp_model} generator: {e}") + logger.error(f"Failed to initialize AnyGrasp generator: {e}") self.grasp_generator = None self.enable_grasp_generation = False logger.info( f"Initialized ManipulationProcessor with confidence={min_confidence}, " - f"grasp_generation={enable_grasp_generation} ({grasp_model})" + f"grasp_generation={enable_grasp_generation}" ) def process_frame( @@ -153,7 +146,7 @@ def process_frame( - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) - misc_voxel_grid: Open3D voxel grid approximating all misc/background points - misc_pointcloud_viz: Visualization of misc/background cluster overlay - - grasps: Grasp results (ContactGraspNet or AnyGrasp, if enabled) + - grasps: Grasp results (AnyGrasp list of dictionaries, if enabled) - grasp_overlay: Grasp visualization overlay (if enabled) - processing_time: Total processing time """ @@ -391,23 +384,21 @@ def run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: logger.error(f"Segmentation failed: {e}") return {"objects": [], "viz_frame": rgb_image.copy()} - def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[Dict]: - """Run grasp generation using the configured generator (ContactGraspNet or AnyGrasp).""" + def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[List[Dict]]: + """Run grasp generation using the configured generator (AnyGrasp).""" if not self.grasp_generator: logger.warning("Grasp generation requested but no generator available") return None try: # Generate grasps using the configured generator - parsed_grasps = self.grasp_generator.generate_grasps_from_objects( - filtered_objects, full_pcd - ) + grasps = self.grasp_generator.generate_grasps_from_objects(filtered_objects, full_pcd) - # Return parsed results directly - return parsed_grasps + # Return parsed results directly (list of grasp dictionaries) + return grasps except Exception as e: - logger.error(f"{self.grasp_model} grasp generation failed: {e}") + logger.error(f"AnyGrasp grasp generation failed: {e}") return None def cleanup(self): From 653bc355628928a94a0c7d2dee4f4ca6cb0c3029 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 26 Jun 2025 02:45:47 -0700 Subject: [PATCH 46/89] zed frames saving --- tests/manipulation_pipeline_demo.ipynb | 891 ------------------------- 1 file changed, 891 deletions(-) delete mode 100644 tests/manipulation_pipeline_demo.ipynb diff --git a/tests/manipulation_pipeline_demo.ipynb b/tests/manipulation_pipeline_demo.ipynb deleted file mode 100644 index 01470c6355..0000000000 --- a/tests/manipulation_pipeline_demo.ipynb +++ /dev/null @@ -1,891 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Manipulation Pipeline Demo with ContactGraspNet\n", - "\n", - "This notebook demonstrates the complete manipulation pipeline including:\n", - "- Object detection (Detic)\n", - "- Semantic segmentation (SAM/FastSAM)\n", - "- Point cloud processing\n", - "- 6-DoF grasp generation (ContactGraspNet)\n", - "- 3D visualization\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Setup and Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Jupyter environment detected. Enabling Open3D WebVisualizer.\n", - "[Open3D INFO] WebRTC GUI backend enabled.\n", - "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n", - "\u2705 All imports successful!\n" - ] - } - ], - "source": [ - "import os\n", - "import sys\n", - "import cv2\n", - "import numpy as np\n", - "import time\n", - "import matplotlib\n", - "\n", - "# Configure matplotlib backend\n", - "try:\n", - " matplotlib.use(\"TkAgg\")\n", - "except:\n", - " try:\n", - " matplotlib.use(\"Qt5Agg\")\n", - " except:\n", - " matplotlib.use(\"Agg\")\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import open3d as o3d\n", - "from typing import Dict, List\n", - "\n", - "# Add project root to path\n", - "sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(\"__file__\"))))\n", - "\n", - "# Import DIMOS modules\n", - "from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid\n", - "from dimos.perception.manip_aio_processer import ManipulationProcessor\n", - "from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml, visualize_pcd\n", - "from dimos.utils.logging_config import setup_logger\n", - "\n", - "# Import ContactGraspNet visualization\n", - "from dimos.models.manipulation.contact_graspnet_pytorch.contact_graspnet_pytorch.visualization_utils_o3d import (\n", - " visualize_grasps,\n", - ")\n", - "\n", - "logger = setup_logger(\"manipulation_pipeline_demo\")\n", - "print(\"\u2705 All imports successful!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Configuration" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Configuration:\n", - " data_dir: /home/alex-lin/dimos/assets/rgbd_data\n", - " enable_grasp_generation: True\n", - " enable_segmentation: True\n", - " segmentation_model: FastSAM-x.pt\n", - " min_confidence: 0.6\n", - " max_objects: 20\n", - " show_3d_visualizations: True\n", - " save_results: True\n" - ] - } - ], - "source": [ - "# Configuration parameters\n", - "CONFIG = {\n", - " \"data_dir\": \"/home/alex-lin/dimos/assets/rgbd_data\",\n", - " \"enable_grasp_generation\": True,\n", - " \"enable_segmentation\": True,\n", - " \"segmentation_model\": \"FastSAM-x.pt\", # or \"sam2_b.pt\"\n", - " \"min_confidence\": 0.6,\n", - " \"max_objects\": 20,\n", - " \"show_3d_visualizations\": True,\n", - " \"save_results\": True,\n", - "}\n", - "\n", - "print(f\"Configuration:\")\n", - "for key, value in CONFIG.items():\n", - " print(f\" {key}: {value}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Data Loading Functions" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u2705 Data loading functions defined!\n" - ] - } - ], - "source": [ - "def load_first_frame(data_dir: str):\n", - " \"\"\"Load first RGB-D frame and camera intrinsics.\"\"\"\n", - " # Load images\n", - " color_img = cv2.imread(os.path.join(data_dir, \"color\", \"00000.png\"))\n", - " color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)\n", - "\n", - " depth_img = cv2.imread(os.path.join(data_dir, \"depth\", \"00000.png\"), cv2.IMREAD_ANYDEPTH)\n", - " if depth_img.dtype == np.uint16:\n", - " depth_img = depth_img.astype(np.float32) / 1000.0\n", - "\n", - " # Load intrinsics\n", - " camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, \"color_camera_info.yaml\"))\n", - " intrinsics = [\n", - " camera_matrix[0, 0], # fx\n", - " camera_matrix[1, 1], # fy\n", - " camera_matrix[0, 2], # cx\n", - " camera_matrix[1, 2], # cy\n", - " ]\n", - "\n", - " return color_img, depth_img, intrinsics\n", - "\n", - "\n", - "def create_point_cloud(color_img, depth_img, intrinsics):\n", - " \"\"\"Create Open3D point cloud for reference.\"\"\"\n", - " fx, fy, cx, cy = intrinsics\n", - " height, width = depth_img.shape\n", - "\n", - " o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy)\n", - " color_o3d = o3d.geometry.Image(color_img)\n", - " depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16))\n", - "\n", - " rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(\n", - " color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False\n", - " )\n", - "\n", - " return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics)\n", - "\n", - "\n", - "print(\"\u2705 Data loading functions defined!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Load RGB-D Data" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2025-06-25 15:04:05,846 - manipulation_pipeline_demo - INFO - Loaded images: color (720, 1280, 3), depth (720, 1280)\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Camera intrinsics: fx=749.3, fy=748.6, cx=639.4, cy=357.2\n" - ] - } - ], - "source": [ - "# Load data\n", - "color_img, depth_img, intrinsics = load_first_frame(CONFIG[\"data_dir\"])\n", - "logger.info(f\"Loaded images: color {color_img.shape}, depth {depth_img.shape}\")\n", - "\n", - "# Display input images\n", - "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", - "\n", - "axes[0].imshow(color_img)\n", - "axes[0].set_title(\"RGB Image\")\n", - "axes[0].axis(\"off\")\n", - "\n", - "# Colorize depth for visualization\n", - "depth_colorized = cv2.applyColorMap(\n", - " cv2.convertScaleAbs(depth_img, alpha=255.0 / depth_img.max()), cv2.COLORMAP_JET\n", - ")\n", - "depth_colorized = cv2.cvtColor(depth_colorized, cv2.COLOR_BGR2RGB)\n", - "axes[1].imshow(depth_colorized)\n", - "axes[1].set_title(\"Depth Image\")\n", - "axes[1].axis(\"off\")\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(\n", - " f\"Camera intrinsics: fx={intrinsics[0]:.1f}, fy={intrinsics[1]:.1f}, cx={intrinsics[2]:.1f}, cy={intrinsics[3]:.1f}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Initialize Manipulation Processor" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n", - " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n", - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/helpers.py:7: FutureWarning: Importing from timm.models.helpers is deprecated, please import via timm.models\n", - " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n", - " warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n", - "/home/alex-lin/miniconda3/envs/contact-graspnet/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/TensorShape.cpp:3526.)\n", - " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", - "\u001b[32m2025-06-25 15:04:11,530 - dimos.perception.grasp_generation - INFO - Initializing ContactGraspNet on device: cuda\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,541 - dimos.perception.grasp_generation - INFO - Loaded config from dimos/models/manipulation/contact_graspnet_pytorch/checkpoints\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,565 - dimos.perception.grasp_generation - INFO - ContactGraspNet model initialized\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,566 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator initialized successfully\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,566 - dimos.perception.manip_aio_processor - INFO - ContactGraspNet generator initialized successfully\u001b[0m\n", - "\u001b[32m2025-06-25 15:04:11,567 - dimos.perception.manip_aio_processor - INFO - Initialized ManipulationProcessor with confidence=0.6, grasp_generation=True\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Resetting zs_weight /home/alex-lin/dimos/dimos/perception/detection2d/../../models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy\n", - "model func: \n", - "\u2705 ManipulationProcessor initialized successfully!\n" - ] - } - ], - "source": [ - "# Create processor with ContactGraspNet enabled\n", - "processor = ManipulationProcessor(\n", - " camera_intrinsics=intrinsics,\n", - " min_confidence=CONFIG[\"min_confidence\"],\n", - " max_objects=CONFIG[\"max_objects\"],\n", - " enable_grasp_generation=CONFIG[\"enable_grasp_generation\"],\n", - " enable_segmentation=CONFIG[\"enable_segmentation\"],\n", - " segmentation_model=CONFIG[\"segmentation_model\"],\n", - ")\n", - "\n", - "print(\"\u2705 ManipulationProcessor initialized successfully!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Run Processing Pipeline" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udd04 Processing frame through pipeline...\n", - "DBSCAN clustering found 13 clusters from 26536 points\n", - "Created voxel grid with 2074 voxels (voxel_size=0.02)\n", - "using local regions\n", - "Extracted Region Cube Size: 0.3148665130138397\n", - "Extracted Region Cube Size: 0.4740000367164612\n", - "Extracted Region Cube Size: 0.2676139771938324\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.4960000514984131\n", - "Extracted Region Cube Size: 0.30400002002716064\n", - "Extracted Region Cube Size: 0.38946154713630676\n", - "Extracted Region Cube Size: 0.2087651789188385\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.24777960777282715\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.2502080202102661\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.3400000333786011\n", - "Extracted Region Cube Size: 0.22946105897426605\n", - "Extracted Region Cube Size: 0.2\n", - "Extracted Region Cube Size: 0.5360000133514404\n", - "Generated 18 grasps for object 2\n", - "Generated 44 grasps for object 3\n", - "Generated 14 grasps for object 4\n", - "Generated 6 grasps for object 7\n", - "Generated 9 grasps for object 8\n", - "Generated 15 grasps for object 9\n", - "Generated 25 grasps for object 10\n", - "Generated 25 grasps for object 11\n", - "Generated 16 grasps for object 14\n", - "Generated 3 grasps for object 15\n", - "Generated 13 grasps for object 16\n", - "Generated 15 grasps for object 19\n", - "Generated 12 grasps for object 27\n", - "Generated 17 grasps for object 29\n", - "Generated 19 grasps for object 31\n", - "Generated 19 grasps for object 32\n", - "Generated 3 grasps for object 33\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2025-06-25 15:04:30,107 - dimos.perception.grasp_generation - INFO - Generated 296 grasps across 18 objects in 14.69s\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generated 23 grasps for object 37\n", - "\u2705 Processing completed in 18.517s\n" - ] - } - ], - "source": [ - "# Process single frame\n", - "print(\"\ud83d\udd04 Processing frame through pipeline...\")\n", - "start_time = time.time()\n", - "\n", - "results = processor.process_frame(color_img, depth_img)\n", - "\n", - "processing_time = time.time() - start_time\n", - "print(f\"\u2705 Processing completed in {processing_time:.3f}s\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 7. Results Summary" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\ud83d\udcca PROCESSING RESULTS SUMMARY\n", - "==================================================\n", - "Available results: ['detection2d_objects', 'detection_viz', 'segmentation2d_objects', 'segmentation_viz', 'detected_objects', 'all_objects', 'full_pointcloud', 'misc_clusters', 'misc_voxel_grid', 'pointcloud_viz', 'detected_pointcloud_viz', 'misc_pointcloud_viz', 'grasps', 'processing_time', 'timing_breakdown']\n", - "Total processing time: 18.517s\n", - "\n", - "\u23f1\ufe0f Timing breakdown:\n", - " Detection: 0.529s\n", - " Segmentation: 0.720s\n", - " Point cloud: 1.837s\n", - " Misc extraction: 0.385s\n", - "\n", - "\ud83c\udfaf Object Detection:\n", - " Detection objects: 13\n", - " All objects processed: 18\n", - "\n", - "\ud83e\udde9 Background Analysis:\n", - " Misc clusters: 13 clusters with 25,628 total points\n", - "\n", - "\ud83e\udd16 ContactGraspNet Results:\n", - " Total grasps: 296\n", - " Best score: 0.798\n", - " Objects with grasps: 18\n", - "\n", - "==================================================\n" - ] - } - ], - "source": [ - "# Print comprehensive results summary\n", - "print(f\"\\n\ud83d\udcca PROCESSING RESULTS SUMMARY\")\n", - "print(f\"\" + \"=\" * 50)\n", - "print(f\"Available results: {list(results.keys())}\")\n", - "print(f\"Total processing time: {results.get('processing_time', 0):.3f}s\")\n", - "\n", - "# Show timing breakdown\n", - "if \"timing_breakdown\" in results:\n", - " breakdown = results[\"timing_breakdown\"]\n", - " print(f\"\\n\u23f1\ufe0f Timing breakdown:\")\n", - " print(f\" Detection: {breakdown.get('detection', 0):.3f}s\")\n", - " print(f\" Segmentation: {breakdown.get('segmentation', 0):.3f}s\")\n", - " print(f\" Point cloud: {breakdown.get('pointcloud', 0):.3f}s\")\n", - " print(f\" Misc extraction: {breakdown.get('misc_extraction', 0):.3f}s\")\n", - "\n", - "# Object counts\n", - "detected_count = len(results.get(\"detected_objects\", []))\n", - "all_count = len(results.get(\"all_objects\", []))\n", - "print(f\"\\n\ud83c\udfaf Object Detection:\")\n", - "print(f\" Detection objects: {detected_count}\")\n", - "print(f\" All objects processed: {all_count}\")\n", - "\n", - "# Misc clusters info\n", - "if \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", - " cluster_count = len(results[\"misc_clusters\"])\n", - " total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in results[\"misc_clusters\"])\n", - " print(f\"\\n\ud83e\udde9 Background Analysis:\")\n", - " print(f\" Misc clusters: {cluster_count} clusters with {total_misc_points:,} total points\")\n", - "else:\n", - " print(f\"\\n\ud83e\udde9 Background Analysis: No clusters found\")\n", - "\n", - "# ContactGraspNet grasp summary\n", - "if \"grasps\" in results and results[\"grasps\"]:\n", - " grasp_data = results[\"grasps\"]\n", - " if isinstance(grasp_data, dict):\n", - " pred_grasps = grasp_data.get(\"pred_grasps_cam\", {})\n", - " scores = grasp_data.get(\"scores\", {})\n", - "\n", - " total_grasps = 0\n", - " best_score = 0\n", - "\n", - " for obj_id, obj_grasps in pred_grasps.items():\n", - " num_grasps = len(obj_grasps) if hasattr(obj_grasps, \"__len__\") else 0\n", - " total_grasps += num_grasps\n", - "\n", - " if obj_id in scores and len(scores[obj_id]) > 0:\n", - " obj_best_score = max(scores[obj_id])\n", - " if obj_best_score > best_score:\n", - " best_score = obj_best_score\n", - "\n", - " print(f\"\\n\ud83e\udd16 ContactGraspNet Results:\")\n", - " print(f\" Total grasps: {total_grasps}\")\n", - " print(f\" Best score: {best_score:.3f}\")\n", - " print(f\" Objects with grasps: {len(pred_grasps)}\")\n", - " else:\n", - " print(f\"\\n\ud83e\udd16 ContactGraspNet Results: Invalid format\")\n", - "else:\n", - " print(f\"\\n\ud83e\udd16 ContactGraspNet Results: No grasps generated\")\n", - "\n", - "print(\"\\n\" + \"=\" * 50)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 8. 2D Visualization Results" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udcf8 Results saved to: manipulation_results.png\n" - ] - } - ], - "source": [ - "# Collect available visualizations\n", - "viz_configs = []\n", - "\n", - "if \"detection_viz\" in results and results[\"detection_viz\"] is not None:\n", - " viz_configs.append((\"detection_viz\", \"Object Detection\"))\n", - "\n", - "if \"segmentation_viz\" in results and results[\"segmentation_viz\"] is not None:\n", - " viz_configs.append((\"segmentation_viz\", \"Semantic Segmentation\"))\n", - "\n", - "if \"pointcloud_viz\" in results and results[\"pointcloud_viz\"] is not None:\n", - " viz_configs.append((\"pointcloud_viz\", \"All Objects Point Cloud\"))\n", - "\n", - "if \"detected_pointcloud_viz\" in results and results[\"detected_pointcloud_viz\"] is not None:\n", - " viz_configs.append((\"detected_pointcloud_viz\", \"Detection Objects Point Cloud\"))\n", - "\n", - "if \"misc_pointcloud_viz\" in results and results[\"misc_pointcloud_viz\"] is not None:\n", - " viz_configs.append((\"misc_pointcloud_viz\", \"Misc/Background Points\"))\n", - "\n", - "# Create visualization grid\n", - "if viz_configs:\n", - " num_plots = len(viz_configs)\n", - "\n", - " if num_plots <= 3:\n", - " fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5))\n", - " else:\n", - " rows = 2\n", - " cols = (num_plots + 1) // 2\n", - " fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows))\n", - "\n", - " # Ensure axes is always iterable\n", - " if num_plots == 1:\n", - " axes = [axes]\n", - " elif num_plots > 2:\n", - " axes = axes.flatten()\n", - "\n", - " # Plot each result\n", - " for i, (key, title) in enumerate(viz_configs):\n", - " axes[i].imshow(results[key])\n", - " axes[i].set_title(title, fontsize=12, fontweight=\"bold\")\n", - " axes[i].axis(\"off\")\n", - "\n", - " # Hide unused subplots\n", - " if num_plots > 3:\n", - " for i in range(num_plots, len(axes)):\n", - " axes[i].axis(\"off\")\n", - "\n", - " plt.tight_layout()\n", - "\n", - " if CONFIG[\"save_results\"]:\n", - " output_path = \"manipulation_results.png\"\n", - " plt.savefig(output_path, dpi=150, bbox_inches=\"tight\")\n", - " print(f\"\ud83d\udcf8 Results saved to: {output_path}\")\n", - "\n", - " plt.show()\n", - "else:\n", - " print(\"\u26a0\ufe0f No 2D visualization results to display\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 9. 3D ContactGraspNet Visualization" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83c\udfaf Launching 3D visualization with 296 ContactGraspNet grasps...\n", - "\ud83d\udccc Note: Close the 3D window to continue with the notebook\n", - "Visualizing...\n", - "\u2705 3D grasp visualization completed!\n" - ] - } - ], - "source": [ - "# 3D ContactGraspNet visualization\n", - "if (\n", - " CONFIG[\"show_3d_visualizations\"]\n", - " and \"grasps\" in results\n", - " and results[\"grasps\"]\n", - " and \"full_pointcloud\" in results\n", - "):\n", - " grasp_data = results[\"grasps\"]\n", - " full_pcd = results[\"full_pointcloud\"]\n", - "\n", - " if isinstance(grasp_data, dict) and full_pcd is not None:\n", - " try:\n", - " # Extract ContactGraspNet data\n", - " pred_grasps_cam = grasp_data.get(\"pred_grasps_cam\", {})\n", - " scores = grasp_data.get(\"scores\", {})\n", - " contact_pts = grasp_data.get(\"contact_pts\", {})\n", - " gripper_openings = grasp_data.get(\"gripper_openings\", {})\n", - "\n", - " # Check if we have valid grasp data\n", - " total_grasps = (\n", - " sum(len(grasps) for grasps in pred_grasps_cam.values()) if pred_grasps_cam else 0\n", - " )\n", - "\n", - " if total_grasps > 0:\n", - " print(\n", - " f\"\ud83c\udfaf Launching 3D visualization with {total_grasps} ContactGraspNet grasps...\"\n", - " )\n", - " print(\"\ud83d\udccc Note: Close the 3D window to continue with the notebook\")\n", - "\n", - " # Use ContactGraspNet's native visualization - pass dictionaries directly\n", - " visualize_grasps(\n", - " full_pcd,\n", - " pred_grasps_cam, # Pass dictionary directly\n", - " scores, # Pass dictionary directly\n", - " gripper_openings=gripper_openings,\n", - " )\n", - "\n", - " print(\"\u2705 3D grasp visualization completed!\")\n", - " else:\n", - " print(\"\u26a0\ufe0f No valid grasps to visualize in 3D\")\n", - "\n", - " except Exception as e:\n", - " print(f\"\u274c Error in ContactGraspNet 3D visualization: {e}\")\n", - " print(\" Skipping 3D grasp visualization\")\n", - "else:\n", - " if not CONFIG[\"show_3d_visualizations\"]:\n", - " print(\"\u23ed\ufe0f 3D visualizations disabled in config\")\n", - " else:\n", - " print(\"\u26a0\ufe0f ContactGraspNet grasp generation disabled or no results\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 10. Additional 3D Visualizations" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 10.1 Full Scene Point Cloud" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83c\udf0d Launching full scene point cloud visualization (526,100 points)...\n", - "\ud83d\udccc Note: Close the 3D window to continue\n", - "Visualizing point cloud with 526100 points\n", - "\u2705 Full point cloud visualization completed!\n" - ] - } - ], - "source": [ - "if (\n", - " CONFIG[\"show_3d_visualizations\"]\n", - " and \"full_pointcloud\" in results\n", - " and results[\"full_pointcloud\"] is not None\n", - "):\n", - " full_pcd = results[\"full_pointcloud\"]\n", - " num_points = len(np.asarray(full_pcd.points))\n", - " print(f\"\ud83c\udf0d Launching full scene point cloud visualization ({num_points:,} points)...\")\n", - " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", - "\n", - " try:\n", - " visualize_pcd(\n", - " full_pcd,\n", - " window_name=\"Full Scene Point Cloud\",\n", - " point_size=2.0,\n", - " show_coordinate_frame=True,\n", - " )\n", - " print(\"\u2705 Full point cloud visualization completed!\")\n", - " except (KeyboardInterrupt, EOFError):\n", - " print(\"\u23ed\ufe0f Full point cloud visualization skipped\")\n", - "else:\n", - " print(\"\u26a0\ufe0f No full point cloud available for visualization\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 10.2 Background/Misc Clusters" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83e\udde9 Launching misc/background clusters visualization (13 clusters, 25,628 points)...\n", - "\ud83d\udccc Note: Close the 3D window to continue\n", - "Visualizing 13 clusters with 25628 total points\n", - "\u2705 Misc clusters visualization completed!\n" - ] - } - ], - "source": [ - "if CONFIG[\"show_3d_visualizations\"] and \"misc_clusters\" in results and results[\"misc_clusters\"]:\n", - " misc_clusters = results[\"misc_clusters\"]\n", - " cluster_count = len(misc_clusters)\n", - " total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in misc_clusters)\n", - "\n", - " print(\n", - " f\"\ud83e\udde9 Launching misc/background clusters visualization ({cluster_count} clusters, {total_misc_points:,} points)...\"\n", - " )\n", - " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", - "\n", - " try:\n", - " visualize_clustered_point_clouds(\n", - " misc_clusters,\n", - " window_name=\"Misc/Background Clusters (DBSCAN)\",\n", - " point_size=3.0,\n", - " show_coordinate_frame=True,\n", - " )\n", - " print(\"\u2705 Misc clusters visualization completed!\")\n", - " except (KeyboardInterrupt, EOFError):\n", - " print(\"\u23ed\ufe0f Misc clusters visualization skipped\")\n", - "else:\n", - " print(\"\u26a0\ufe0f No misc clusters available for visualization\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 10.3 Voxel Grid Visualization" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\ud83d\udce6 Launching voxel grid visualization (2074 voxels)...\n", - "\ud83d\udccc Note: Close the 3D window to continue\n", - "Visualizing voxel grid with 2074 voxels\n", - "\u2705 Voxel grid visualization completed!\n" - ] - } - ], - "source": [ - "if (\n", - " CONFIG[\"show_3d_visualizations\"]\n", - " and \"misc_voxel_grid\" in results\n", - " and results[\"misc_voxel_grid\"] is not None\n", - "):\n", - " misc_voxel_grid = results[\"misc_voxel_grid\"]\n", - " voxel_count = len(misc_voxel_grid.get_voxels())\n", - "\n", - " print(f\"\ud83d\udce6 Launching voxel grid visualization ({voxel_count} voxels)...\")\n", - " print(\"\ud83d\udccc Note: Close the 3D window to continue\")\n", - "\n", - " try:\n", - " visualize_voxel_grid(\n", - " misc_voxel_grid,\n", - " window_name=\"Misc/Background Voxel Grid\",\n", - " show_coordinate_frame=True,\n", - " )\n", - " print(\"\u2705 Voxel grid visualization completed!\")\n", - " except (KeyboardInterrupt, EOFError):\n", - " print(\"\u23ed\ufe0f Voxel grid visualization skipped\")\n", - " except Exception as e:\n", - " print(f\"\u274c Error in voxel grid visualization: {e}\")\n", - "else:\n", - " print(\"\u26a0\ufe0f No voxel grid available for visualization\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 11. Cleanup" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2025-06-25 15:05:01,624 - dimos.perception.grasp_generation - INFO - ContactGraspNet grasp generator cleaned up\u001b[0m\n", - "\u001b[32m2025-06-25 15:05:01,626 - dimos.perception.manip_aio_processor - INFO - ManipulationProcessor cleaned up\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u2705 Pipeline cleanup completed!\n", - "\n", - "\ud83c\udf89 Manipulation pipeline demo finished successfully!\n" - ] - } - ], - "source": [ - "# Clean up resources\n", - "processor.cleanup()\n", - "print(\"\u2705 Pipeline cleanup completed!\")\n", - "print(\"\\n\ud83c\udf89 Manipulation pipeline demo finished successfully!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "\n", - "## Summary\n", - "\n", - "This notebook demonstrated the complete DIMOS manipulation pipeline:\n", - "\n", - "1. **Object Detection**: Using Detic for 2D object detection\n", - "2. **Semantic Segmentation**: Using SAM/FastSAM for detailed segmentation\n", - "3. **Point Cloud Processing**: Converting RGB-D to 3D point clouds with filtering\n", - "4. **Background Analysis**: DBSCAN clustering of miscellaneous points\n", - "5. **Grasp Generation**: ContactGraspNet for 6-DoF grasp pose estimation\n", - "6. **Visualization**: Comprehensive 2D and 3D visualizations\n", - "\n", - "The pipeline is designed for real-time robotic manipulation tasks and provides rich visual feedback for debugging and analysis.\n", - "\n", - "### Key Features:\n", - "- \u2705 Modular design with clean interfaces\n", - "- \u2705 Real-time performance optimizations\n", - "- \u2705 Comprehensive error handling\n", - "- \u2705 Rich visualization capabilities\n", - "- \u2705 ContactGraspNet integration for state-of-the-art grasp generation\n", - "\n", - "### Next Steps:\n", - "- Integrate with robotic control systems\n", - "- Add grasp execution and feedback\n", - "- Implement multi-frame tracking\n", - "- Add custom object vocabularies\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "contact-graspnet", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.18" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 0599c19da3a25375ce5f26635592cf1a45ed953d Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 9 Jul 2025 14:07:58 -0700 Subject: [PATCH 47/89] fixes --- dimos/perception/manip_aio_processer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py index 6e7083b0f3..a8afc96a7c 100644 --- a/dimos/perception/manip_aio_processer.py +++ b/dimos/perception/manip_aio_processer.py @@ -59,7 +59,6 @@ def __init__( enable_grasp_generation: bool = False, grasp_server_url: Optional[str] = None, # Required when enable_grasp_generation=True enable_segmentation: bool = True, - segmentation_model: str = "sam2_b.pt", ): """ Initialize the manipulation processor. @@ -99,11 +98,9 @@ def __init__( self.segmenter = None if self.enable_segmentation: self.segmenter = Sam2DSegmenter( - model_path=segmentation_model, device="cuda", use_tracker=False, # Disable tracker for simple segmentation use_analyzer=False, # Disable analyzer for simple segmentation - model_type="auto", # Auto-detect model type ) # Initialize grasp generator if enabled From d1b96c1c40dbc504be1eff60abceb4a6b771ae1a Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 9 Jul 2025 17:08:10 -0700 Subject: [PATCH 48/89] first commit --- dimos/perception/manip_aio_pipeline.py | 590 ------------------------ dimos/perception/manip_aio_processer.py | 411 ----------------- 2 files changed, 1001 deletions(-) delete mode 100644 dimos/perception/manip_aio_pipeline.py delete mode 100644 dimos/perception/manip_aio_processer.py diff --git a/dimos/perception/manip_aio_pipeline.py b/dimos/perception/manip_aio_pipeline.py deleted file mode 100644 index 22e3f5d49e..0000000000 --- a/dimos/perception/manip_aio_pipeline.py +++ /dev/null @@ -1,590 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Asynchronous, reactive manipulation pipeline for realtime detection, filtering, and grasp generation. -""" - -import asyncio -import json -import logging -import threading -import time -import traceback -import websockets -from typing import Dict, List, Optional, Any -import numpy as np -import reactivex as rx -import reactivex.operators as ops -from dimos.utils.logging_config import setup_logger -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.perception.grasp_generation.utils import draw_grasps_on_image -from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization -from dimos.perception.common.utils import colorize_depth -from dimos.utils.logging_config import setup_logger -import cv2 - -logger = setup_logger("dimos.perception.manip_aio_pipeline") - - -class ManipulationPipeline: - """ - Clean separated stream pipeline with frame buffering. - - - Object detection runs independently on RGB stream - - Point cloud processing subscribes to both detection and ZED streams separately - - Simple frame buffering to match RGB+depth+objects - """ - - def __init__( - self, - camera_intrinsics: List[float], # [fx, fy, cx, cy] - min_confidence: float = 0.6, - max_objects: int = 10, - vocabulary: Optional[str] = None, - grasp_server_url: Optional[str] = None, - enable_grasp_generation: bool = False, - ): - """ - Initialize the manipulation pipeline. - - Args: - camera_intrinsics: [fx, fy, cx, cy] camera parameters - min_confidence: Minimum detection confidence threshold - max_objects: Maximum number of objects to process - vocabulary: Optional vocabulary for Detic detector - grasp_server_url: Optional WebSocket URL for AnyGrasp server - enable_grasp_generation: Whether to enable async grasp generation - """ - self.camera_intrinsics = camera_intrinsics - self.min_confidence = min_confidence - - # Grasp generation settings - self.grasp_server_url = grasp_server_url - self.enable_grasp_generation = enable_grasp_generation - - # Asyncio event loop for WebSocket communication - self.grasp_loop = None - self.grasp_loop_thread = None - - # Storage for grasp results and filtered objects - self.latest_grasps: List[dict] = [] # Simplified: just a list of grasps - self.grasps_consumed = False - self.latest_filtered_objects = [] - self.latest_rgb_for_grasps = None # Store RGB image for grasp overlay - self.grasp_lock = threading.Lock() - - # Track pending requests - simplified to single task - self.grasp_task: Optional[asyncio.Task] = None - - # Reactive subjects for streaming filtered objects and grasps - self.filtered_objects_subject = rx.subject.Subject() - self.grasps_subject = rx.subject.Subject() - self.grasp_overlay_subject = rx.subject.Subject() # Add grasp overlay subject - - # Initialize grasp client if enabled - if self.enable_grasp_generation and self.grasp_server_url: - self._start_grasp_loop() - - # Initialize object detector - self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) - - # Initialize point cloud processor - self.pointcloud_filter = PointcloudFiltering( - color_intrinsics=camera_intrinsics, - depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics - max_num_objects=max_objects, - ) - - logger.info(f"Initialized ManipulationPipeline with confidence={min_confidence}") - - def create_streams(self, zed_stream: rx.Observable) -> Dict[str, rx.Observable]: - """ - Create streams using exact old main logic. - """ - # Create ZED streams (from old main) - zed_frame_stream = zed_stream.pipe(ops.share()) - - # RGB stream for object detection (from old main) - video_stream = zed_frame_stream.pipe( - ops.map(lambda x: x.get("rgb") if x is not None else None), - ops.filter(lambda x: x is not None), - ops.share(), - ) - object_detector = ObjectDetectionStream( - camera_intrinsics=self.camera_intrinsics, - min_confidence=self.min_confidence, - class_filter=None, - detector=self.detector, - video_stream=video_stream, - disable_depth=True, - ) - - # Store latest frames for point cloud processing (from old main) - latest_rgb = None - latest_depth = None - latest_point_cloud_overlay = None - frame_lock = threading.Lock() - - # Subscribe to combined ZED frames (from old main) - def on_zed_frame(zed_data): - nonlocal latest_rgb, latest_depth - if zed_data is not None: - with frame_lock: - latest_rgb = zed_data.get("rgb") - latest_depth = zed_data.get("depth") - - # Depth stream for point cloud filtering (from old main) - def get_depth_or_overlay(zed_data): - if zed_data is None: - return None - - # Check if we have a point cloud overlay available - with frame_lock: - overlay = latest_point_cloud_overlay - - if overlay is not None: - return overlay - else: - # Return regular colorized depth - return colorize_depth(zed_data.get("depth"), max_depth=10.0) - - depth_stream = zed_frame_stream.pipe( - ops.map(get_depth_or_overlay), ops.filter(lambda x: x is not None), ops.share() - ) - - # Process object detection results with point cloud filtering (from old main) - def on_detection_next(result): - nonlocal latest_point_cloud_overlay - if "objects" in result and result["objects"]: - # Get latest RGB and depth frames - with frame_lock: - rgb = latest_rgb - depth = latest_depth - - if rgb is not None and depth is not None: - try: - filtered_objects = self.pointcloud_filter.process_images( - rgb, depth, result["objects"] - ) - - if filtered_objects: - # Store filtered objects - with self.grasp_lock: - self.latest_filtered_objects = filtered_objects - self.filtered_objects_subject.on_next(filtered_objects) - - # Create base image (colorized depth) - base_image = colorize_depth(depth, max_depth=10.0) - - # Create point cloud overlay visualization - overlay_viz = create_point_cloud_overlay_visualization( - base_image=base_image, - objects=filtered_objects, - intrinsics=self.camera_intrinsics, - ) - - # Store the overlay for the stream - with frame_lock: - latest_point_cloud_overlay = overlay_viz - - # Request grasps if enabled - if self.enable_grasp_generation and len(filtered_objects) > 0: - # Save RGB image for later grasp overlay - with frame_lock: - self.latest_rgb_for_grasps = rgb.copy() - - task = self.request_scene_grasps(filtered_objects) - if task: - # Check for results after a delay - def check_grasps_later(): - time.sleep(2.0) # Wait for grasp processing - # Wait for task to complete - if hasattr(self, "grasp_task") and self.grasp_task: - try: - result = self.grasp_task.result( - timeout=3.0 - ) # Get result with timeout - except Exception as e: - logger.warning(f"Grasp task failed or timeout: {e}") - - # Try to get latest grasps and create overlay - with self.grasp_lock: - grasps = self.latest_grasps - - if grasps and hasattr(self, "latest_rgb_for_grasps"): - # Create grasp overlay on the saved RGB image - try: - bgr_image = cv2.cvtColor( - self.latest_rgb_for_grasps, cv2.COLOR_RGB2BGR - ) - result_bgr = draw_grasps_on_image( - bgr_image, - grasps, - self.camera_intrinsics, - max_grasps=-1, # Show all grasps - ) - result_rgb = cv2.cvtColor( - result_bgr, cv2.COLOR_BGR2RGB - ) - - # Emit grasp overlay immediately - self.grasp_overlay_subject.on_next(result_rgb) - - except Exception as e: - logger.error(f"Error creating grasp overlay: {e}") - - # Emit grasps to stream - self.grasps_subject.on_next(grasps) - - threading.Thread(target=check_grasps_later, daemon=True).start() - else: - logger.warning("Failed to create grasp task") - except Exception as e: - logger.error(f"Error in point cloud filtering: {e}") - with frame_lock: - latest_point_cloud_overlay = None - - def on_error(error): - logger.error(f"Error in stream: {error}") - - def on_completed(): - logger.info("Stream completed") - - def start_subscriptions(): - """Start subscriptions in background thread (from old main)""" - # Subscribe to combined ZED frames - zed_frame_stream.subscribe(on_next=on_zed_frame) - - # Start subscriptions in background thread (from old main) - subscription_thread = threading.Thread(target=start_subscriptions, daemon=True) - subscription_thread.start() - time.sleep(2) # Give subscriptions time to start - - # Subscribe to object detection stream (from old main) - object_detector.get_stream().subscribe( - on_next=on_detection_next, on_error=on_error, on_completed=on_completed - ) - - # Create visualization stream for web interface (from old main) - viz_stream = object_detector.get_stream().pipe( - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), - ) - - # Create filtered objects stream - filtered_objects_stream = self.filtered_objects_subject - - # Create grasps stream - grasps_stream = self.grasps_subject - - # Create grasp overlay subject for immediate emission - grasp_overlay_stream = self.grasp_overlay_subject - - return { - "detection_viz": viz_stream, - "pointcloud_viz": depth_stream, - "objects": object_detector.get_stream().pipe(ops.map(lambda x: x.get("objects", []))), - "filtered_objects": filtered_objects_stream, - "grasps": grasps_stream, - "grasp_overlay": grasp_overlay_stream, - } - - def _start_grasp_loop(self): - """Start asyncio event loop in a background thread for WebSocket communication.""" - - def run_loop(): - self.grasp_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.grasp_loop) - self.grasp_loop.run_forever() - - self.grasp_loop_thread = threading.Thread(target=run_loop, daemon=True) - self.grasp_loop_thread.start() - - # Wait for loop to start - while self.grasp_loop is None: - time.sleep(0.01) - - async def _send_grasp_request( - self, points: np.ndarray, colors: Optional[np.ndarray] - ) -> Optional[List[dict]]: - """Send grasp request to AnyGrasp server.""" - try: - # Comprehensive client-side validation to prevent server errors - - # Validate points array - if points is None: - logger.error("Points array is None") - return None - if not isinstance(points, np.ndarray): - logger.error(f"Points is not numpy array: {type(points)}") - return None - if points.size == 0: - logger.error("Points array is empty") - return None - if len(points.shape) != 2 or points.shape[1] != 3: - logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)") - return None - if points.shape[0] < 100: # Minimum points for stable grasp detection - logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100") - return None - - # Validate and prepare colors - if colors is not None: - if not isinstance(colors, np.ndarray): - colors = None - elif colors.size == 0: - colors = None - elif len(colors.shape) != 2 or colors.shape[1] != 3: - colors = None - elif colors.shape[0] != points.shape[0]: - colors = None - - # If no valid colors, create default colors (required by server) - if colors is None: - # Create default white colors for all points - colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 - - # Ensure data types are correct (server expects float32) - points = points.astype(np.float32) - colors = colors.astype(np.float32) - - # Validate ranges (basic sanity checks) - if np.any(np.isnan(points)) or np.any(np.isinf(points)): - logger.error("Points contain NaN or Inf values") - return None - if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): - logger.error("Colors contain NaN or Inf values") - return None - - # Clamp color values to valid range [0, 1] - colors = np.clip(colors, 0.0, 1.0) - - async with websockets.connect(self.grasp_server_url) as websocket: - request = { - "points": points.tolist(), - "colors": colors.tolist(), # Always send colors array - "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], # Default workspace limits - } - - await websocket.send(json.dumps(request)) - - response = await websocket.recv() - grasps = json.loads(response) - - # Handle server response validation - if isinstance(grasps, dict) and "error" in grasps: - logger.error(f"Server returned error: {grasps['error']}") - return None - elif isinstance(grasps, (int, float)) and grasps == 0: - return None - elif not isinstance(grasps, list): - logger.error( - f"Server returned unexpected response type: {type(grasps)}, value: {grasps}" - ) - return None - elif len(grasps) == 0: - return None - - converted_grasps = self._convert_grasp_format(grasps) - with self.grasp_lock: - self.latest_grasps = converted_grasps - self.grasps_consumed = False # Reset consumed flag - - # Emit to reactive stream - self.grasps_subject.on_next(self.latest_grasps) - - return converted_grasps - except websockets.exceptions.ConnectionClosed as e: - logger.error(f"WebSocket connection closed: {e}") - except websockets.exceptions.WebSocketException as e: - logger.error(f"WebSocket error: {e}") - except json.JSONDecodeError as e: - logger.error(f"Failed to parse server response as JSON: {e}") - except Exception as e: - logger.error(f"Error requesting grasps: {e}") - - return None - - def request_scene_grasps(self, objects: List[dict]) -> Optional[asyncio.Task]: - """Request grasps for entire scene by combining all object point clouds.""" - if not self.grasp_loop or not objects: - return None - - all_points = [] - all_colors = [] - valid_objects = 0 - - for i, obj in enumerate(objects): - # Validate point cloud data - if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: - continue - - points = obj["point_cloud_numpy"] - if not isinstance(points, np.ndarray) or points.size == 0: - continue - - # Ensure points have correct shape (N, 3) - if len(points.shape) != 2 or points.shape[1] != 3: - continue - - # Validate colors if present - colors = None - if "colors_numpy" in obj and obj["colors_numpy"] is not None: - colors = obj["colors_numpy"] - if isinstance(colors, np.ndarray) and colors.size > 0: - # Ensure colors match points count and have correct shape - if colors.shape[0] != points.shape[0]: - colors = None # Ignore colors for this object - elif len(colors.shape) != 2 or colors.shape[1] != 3: - colors = None # Ignore colors for this object - - all_points.append(points) - if colors is not None: - all_colors.append(colors) - valid_objects += 1 - - if not all_points: - return None - - try: - combined_points = np.vstack(all_points) - - # Only combine colors if ALL objects have valid colors - combined_colors = None - if len(all_colors) == valid_objects and len(all_colors) > 0: - combined_colors = np.vstack(all_colors) - - # Validate final combined data - if combined_points.size == 0: - logger.warning("Combined point cloud is empty") - return None - - if combined_colors is not None and combined_colors.shape[0] != combined_points.shape[0]: - logger.warning( - f"Color/point count mismatch: {combined_colors.shape[0]} colors vs {combined_points.shape[0]} points, dropping colors" - ) - combined_colors = None - - except Exception as e: - logger.error(f"Failed to combine point clouds: {e}") - return None - - try: - # Check if there's already a grasp task running - if hasattr(self, "grasp_task") and self.grasp_task and not self.grasp_task.done(): - return self.grasp_task - - task = asyncio.run_coroutine_threadsafe( - self._send_grasp_request(combined_points, combined_colors), self.grasp_loop - ) - - self.grasp_task = task - return task - except Exception as e: - logger.warning("Failed to create grasp task") - return None - - def get_latest_grasps(self, timeout: float = 5.0) -> Optional[List[dict]]: - """Get latest grasp results, waiting for new ones if current ones have been consumed.""" - # Mark current grasps as consumed and get a reference - with self.grasp_lock: - current_grasps = self.latest_grasps - self.grasps_consumed = True - - # If we already have grasps and they haven't been consumed, return them - if current_grasps is not None and not getattr(self, "grasps_consumed", False): - return current_grasps - - # Wait for new grasps - start_time = time.time() - while time.time() - start_time < timeout: - with self.grasp_lock: - # Check if we have new grasps (different from what we marked as consumed) - if self.latest_grasps is not None and not getattr(self, "grasps_consumed", False): - return self.latest_grasps - time.sleep(0.1) # Check every 100ms - - return None # Timeout reached - - def clear_grasps(self) -> None: - """Clear all stored grasp results.""" - with self.grasp_lock: - self.latest_grasps = [] - - def _prepare_colors(self, colors: Optional[np.ndarray]) -> Optional[np.ndarray]: - """Prepare colors array, converting from various formats if needed.""" - if colors is None: - return None - - if colors.max() > 1.0: - colors = colors / 255.0 - - return colors - - def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: - """Convert AnyGrasp format to our visualization format.""" - converted = [] - - for i, grasp in enumerate(anygrasp_grasps): - rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) - euler_angles = self._rotation_matrix_to_euler(rotation_matrix) - - converted_grasp = { - "id": f"grasp_{i}", - "score": grasp.get("score", 0.0), - "width": grasp.get("width", 0.0), - "height": grasp.get("height", 0.0), - "depth": grasp.get("depth", 0.0), - "translation": grasp.get("translation", [0, 0, 0]), - "rotation_matrix": rotation_matrix.tolist(), - "euler_angles": euler_angles, - } - converted.append(converted_grasp) - - converted.sort(key=lambda x: x["score"], reverse=True) - - return converted - - def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: - """Convert rotation matrix to Euler angles (in radians).""" - sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) - - singular = sy < 1e-6 - - if not singular: - x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) - y = np.arctan2(-rotation_matrix[2, 0], sy) - z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) - else: - x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) - y = np.arctan2(-rotation_matrix[2, 0], sy) - z = 0 - - return {"roll": x, "pitch": y, "yaw": z} - - def cleanup(self): - """Clean up resources.""" - if hasattr(self.detector, "cleanup"): - self.detector.cleanup() - - if self.grasp_loop and self.grasp_loop_thread: - self.grasp_loop.call_soon_threadsafe(self.grasp_loop.stop) - self.grasp_loop_thread.join(timeout=1.0) - - if hasattr(self.pointcloud_filter, "cleanup"): - self.pointcloud_filter.cleanup() - logger.info("ManipulationPipeline cleaned up") diff --git a/dimos/perception/manip_aio_processer.py b/dimos/perception/manip_aio_processer.py deleted file mode 100644 index a8afc96a7c..0000000000 --- a/dimos/perception/manip_aio_processer.py +++ /dev/null @@ -1,411 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Sequential manipulation processor for single-frame processing without reactive streams. -""" - -import logging -import time -from typing import Dict, List, Optional, Any, Tuple -import numpy as np -import cv2 - -from dimos.utils.logging_config import setup_logger -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering -from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter -from dimos.perception.grasp_generation.grasp_generation import AnyGraspGenerator -from dimos.perception.grasp_generation.utils import create_grasp_overlay -from dimos.perception.pointcloud.utils import ( - create_point_cloud_overlay_visualization, - extract_and_cluster_misc_points, - overlay_point_clouds_on_image, -) -from dimos.perception.common.utils import ( - colorize_depth, - detection_results_to_object_data, - combine_object_data, -) - -logger = setup_logger("dimos.perception.manip_aio_processor") - - -class ManipulationProcessor: - """ - Sequential manipulation processor for single-frame processing. - - Processes RGB-D frames through object detection, point cloud filtering, - and AnyGrasp grasp generation in a single thread without reactive streams. - """ - - def __init__( - self, - camera_intrinsics: List[float], # [fx, fy, cx, cy] - min_confidence: float = 0.6, - max_objects: int = 20, - vocabulary: Optional[str] = None, - enable_grasp_generation: bool = False, - grasp_server_url: Optional[str] = None, # Required when enable_grasp_generation=True - enable_segmentation: bool = True, - ): - """ - Initialize the manipulation processor. - - Args: - camera_intrinsics: [fx, fy, cx, cy] camera parameters - min_confidence: Minimum detection confidence threshold - max_objects: Maximum number of objects to process - vocabulary: Optional vocabulary for Detic detector - enable_grasp_generation: Whether to enable grasp generation - grasp_server_url: WebSocket URL for AnyGrasp server (required when enable_grasp_generation=True) - enable_segmentation: Whether to enable semantic segmentation - segmentation_model: Segmentation model to use (SAM 2 or FastSAM) - """ - self.camera_intrinsics = camera_intrinsics - self.min_confidence = min_confidence - self.max_objects = max_objects - self.enable_grasp_generation = enable_grasp_generation - self.grasp_server_url = grasp_server_url - self.enable_segmentation = enable_segmentation - - # Validate grasp generation requirements - if enable_grasp_generation and not grasp_server_url: - raise ValueError("grasp_server_url is required when enable_grasp_generation=True") - - # Initialize object detector - self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) - - # Initialize point cloud processor - self.pointcloud_filter = PointcloudFiltering( - color_intrinsics=camera_intrinsics, - depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics - max_num_objects=max_objects, - ) - - # Initialize semantic segmentation - self.segmenter = None - if self.enable_segmentation: - self.segmenter = Sam2DSegmenter( - device="cuda", - use_tracker=False, # Disable tracker for simple segmentation - use_analyzer=False, # Disable analyzer for simple segmentation - ) - - # Initialize grasp generator if enabled - self.grasp_generator = None - if self.enable_grasp_generation: - try: - self.grasp_generator = AnyGraspGenerator(server_url=grasp_server_url) - logger.info("AnyGrasp generator initialized successfully") - except Exception as e: - logger.error(f"Failed to initialize AnyGrasp generator: {e}") - self.grasp_generator = None - self.enable_grasp_generation = False - - logger.info( - f"Initialized ManipulationProcessor with confidence={min_confidence}, " - f"grasp_generation={enable_grasp_generation}" - ) - - def process_frame( - self, rgb_image: np.ndarray, depth_image: np.ndarray, generate_grasps: bool = None - ) -> Dict[str, Any]: - """ - Process a single RGB-D frame through the complete pipeline. - - Args: - rgb_image: RGB image (H, W, 3) - depth_image: Depth image (H, W) in meters - generate_grasps: Override grasp generation setting for this frame - - Returns: - Dictionary containing: - - detection_viz: Visualization of object detection - - pointcloud_viz: Visualization of point cloud overlay - - segmentation_viz: Visualization of semantic segmentation (if enabled) - - detection2d_objects: Raw detection results as ObjectData - - segmentation2d_objects: Raw segmentation results as ObjectData (if enabled) - - detected_objects: Detection (Object Detection) objects with point clouds filtered - - all_objects: Combined objects with intelligent duplicate removal - - full_pointcloud: Complete scene point cloud (if point cloud processing enabled) - - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) - - misc_voxel_grid: Open3D voxel grid approximating all misc/background points - - misc_pointcloud_viz: Visualization of misc/background cluster overlay - - grasps: Grasp results (AnyGrasp list of dictionaries, if enabled) - - grasp_overlay: Grasp visualization overlay (if enabled) - - processing_time: Total processing time - """ - start_time = time.time() - results = {} - - try: - # Step 1: Object Detection - step_start = time.time() - detection_results = self.run_object_detection(rgb_image) - results["detection2d_objects"] = detection_results.get("objects", []) - results["detection_viz"] = detection_results.get("viz_frame") - detection_time = time.time() - step_start - - # Step 2: Semantic Segmentation (if enabled) - segmentation_time = 0 - if self.enable_segmentation: - step_start = time.time() - segmentation_results = self.run_segmentation(rgb_image) - results["segmentation2d_objects"] = segmentation_results.get("objects", []) - results["segmentation_viz"] = segmentation_results.get("viz_frame") - segmentation_time = time.time() - step_start - - # Step 3: Point Cloud Processing - pointcloud_time = 0 - detection2d_objects = results.get("detection2d_objects", []) - segmentation2d_objects = results.get("segmentation2d_objects", []) - - # Process detection objects if available - detected_objects = [] - if detection2d_objects: - step_start = time.time() - detected_objects = self.run_pointcloud_filtering( - rgb_image, depth_image, detection2d_objects - ) - pointcloud_time += time.time() - step_start - - # Process segmentation objects if available - segmentation_filtered_objects = [] - if segmentation2d_objects: - step_start = time.time() - segmentation_filtered_objects = self.run_pointcloud_filtering( - rgb_image, depth_image, segmentation2d_objects - ) - pointcloud_time += time.time() - step_start - - # Combine all objects using intelligent duplicate removal - all_objects = combine_object_data( - detected_objects, segmentation_filtered_objects, overlap_threshold=0.8 - ) - - # Get full point cloud - full_pcd = self.pointcloud_filter.get_full_point_cloud() - - # Extract misc/background points and create voxel grid - misc_start = time.time() - misc_clusters, misc_voxel_grid = extract_and_cluster_misc_points( - full_pcd, - all_objects, - eps=0.03, - min_points=100, - enable_filtering=True, - voxel_size=0.02, - ) - misc_time = time.time() - misc_start - - # Store results - results.update( - { - "detected_objects": detected_objects, - "all_objects": all_objects, - "full_pointcloud": full_pcd, - "misc_clusters": misc_clusters, - "misc_voxel_grid": misc_voxel_grid, - } - ) - - # Create point cloud visualizations - base_image = colorize_depth(depth_image, max_depth=10.0) - - # Create visualizations - results["pointcloud_viz"] = ( - create_point_cloud_overlay_visualization( - base_image=base_image, - objects=all_objects, - intrinsics=self.camera_intrinsics, - ) - if all_objects - else base_image - ) - - results["detected_pointcloud_viz"] = ( - create_point_cloud_overlay_visualization( - base_image=base_image, - objects=detected_objects, - intrinsics=self.camera_intrinsics, - ) - if detected_objects - else base_image - ) - - if misc_clusters: - # Generate consistent colors for clusters - cluster_colors = [ - tuple((np.random.RandomState(i + 100).rand(3) * 255).astype(int)) - for i in range(len(misc_clusters)) - ] - results["misc_pointcloud_viz"] = overlay_point_clouds_on_image( - base_image=base_image, - point_clouds=misc_clusters, - camera_intrinsics=self.camera_intrinsics, - colors=cluster_colors, - point_size=2, - alpha=0.6, - ) - else: - results["misc_pointcloud_viz"] = base_image - - # Step 4: Grasp Generation (if enabled) - should_generate_grasps = ( - generate_grasps if generate_grasps is not None else self.enable_grasp_generation - ) - - if should_generate_grasps and all_objects and full_pcd: - grasps = self.run_grasp_generation(all_objects, full_pcd) - results["grasps"] = grasps - if grasps: - results["grasp_overlay"] = create_grasp_overlay( - rgb_image, grasps, self.camera_intrinsics - ) - - except Exception as e: - logger.error(f"Error processing frame: {e}") - results["error"] = str(e) - - # Add timing information - total_time = time.time() - start_time - results.update( - { - "processing_time": total_time, - "timing_breakdown": { - "detection": detection_time if "detection_time" in locals() else 0, - "segmentation": segmentation_time if "segmentation_time" in locals() else 0, - "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, - "misc_extraction": misc_time if "misc_time" in locals() else 0, - "total": total_time, - }, - } - ) - - return results - - def run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: - """Run object detection on RGB image.""" - try: - # Convert RGB to BGR for Detic detector - bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) - - # Use process_image method from Detic detector - bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image( - bgr_image - ) - - # Convert to ObjectData format using utility function - objects = detection_results_to_object_data( - bboxes=bboxes, - track_ids=track_ids, - class_ids=class_ids, - confidences=confidences, - names=names, - masks=masks, - source="detection", - ) - - # Create visualization using detector's built-in method - viz_frame = self.detector.visualize_results( - rgb_image, bboxes, track_ids, class_ids, confidences, names - ) - - return {"objects": objects, "viz_frame": viz_frame} - - except Exception as e: - logger.error(f"Object detection failed: {e}") - return {"objects": [], "viz_frame": rgb_image.copy()} - - def run_pointcloud_filtering( - self, rgb_image: np.ndarray, depth_image: np.ndarray, objects: List[Dict] - ) -> List[Dict]: - """Run point cloud filtering on detected objects.""" - try: - filtered_objects = self.pointcloud_filter.process_images( - rgb_image, depth_image, objects - ) - return filtered_objects if filtered_objects else [] - except Exception as e: - logger.error(f"Point cloud filtering failed: {e}") - return [] - - def run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: - """Run semantic segmentation on RGB image.""" - if not self.segmenter: - return {"objects": [], "viz_frame": rgb_image.copy()} - - try: - # Convert RGB to BGR for segmenter - bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) - - # Get segmentation results - masks, bboxes, track_ids, probs, names = self.segmenter.process_image(bgr_image) - - # Convert to ObjectData format using utility function - objects = detection_results_to_object_data( - bboxes=bboxes, - track_ids=track_ids, - class_ids=list(range(len(bboxes))), # Use indices as class IDs for segmentation - confidences=probs, - names=names, - masks=masks, - source="segmentation", - ) - - # Create visualization - if masks: - viz_bgr = self.segmenter.visualize_results( - bgr_image, masks, bboxes, track_ids, probs, names - ) - # Convert back to RGB - viz_frame = cv2.cvtColor(viz_bgr, cv2.COLOR_BGR2RGB) - else: - viz_frame = rgb_image.copy() - - return {"objects": objects, "viz_frame": viz_frame} - - except Exception as e: - logger.error(f"Segmentation failed: {e}") - return {"objects": [], "viz_frame": rgb_image.copy()} - - def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[List[Dict]]: - """Run grasp generation using the configured generator (AnyGrasp).""" - if not self.grasp_generator: - logger.warning("Grasp generation requested but no generator available") - return None - - try: - # Generate grasps using the configured generator - grasps = self.grasp_generator.generate_grasps_from_objects(filtered_objects, full_pcd) - - # Return parsed results directly (list of grasp dictionaries) - return grasps - - except Exception as e: - logger.error(f"AnyGrasp grasp generation failed: {e}") - return None - - def cleanup(self): - """Clean up resources.""" - if hasattr(self.detector, "cleanup"): - self.detector.cleanup() - if hasattr(self.pointcloud_filter, "cleanup"): - self.pointcloud_filter.cleanup() - if self.segmenter and hasattr(self.segmenter, "cleanup"): - self.segmenter.cleanup() - if self.grasp_generator and hasattr(self.grasp_generator, "cleanup"): - self.grasp_generator.cleanup() - logger.info("ManipulationProcessor cleaned up") From 4d64e550fe9c7a6be97144decef404efc47eae5f Mon Sep 17 00:00:00 2001 From: mustafab0 <39084056+mustafab0@users.noreply.github.com> Date: Tue, 15 Jul 2025 01:27:45 +0000 Subject: [PATCH 49/89] CI code cleanup --- dimos/hardware/piper_arm.py | 61 ++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 1ac841ad2b..2b63bc13ee 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -27,6 +27,7 @@ import tty import select + class PiperArm: def __init__(self, arm_name: str = "arm"): self.init_can() @@ -110,7 +111,9 @@ def resetArm(self): print(f"[PiperArm] Resetting arm") def init_vel_controller(self): - self.chain = kp.build_serial_chain_from_urdf(open("dimos/dimos/hardware/piper_description.urdf"), "gripper_base") + self.chain = kp.build_serial_chain_from_urdf( + open("dimos/dimos/hardware/piper_description.urdf"), "gripper_base" + ) self.J = self.chain.jacobian(np.zeros(6)) self.J_pinv = np.linalg.pinv(self.J) self.dt = 0.01 @@ -125,20 +128,45 @@ def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): joint_state = self.arm.GetArmJointMsgs().joint_state # print(f"[PiperArm] Current Joints: {joint_state}", type(joint_state)) - joint_angles = np.array([joint_state.joint_1, joint_state.joint_2, joint_state.joint_3, joint_state.joint_4, joint_state.joint_5, joint_state.joint_6]) + joint_angles = np.array( + [ + joint_state.joint_1, + joint_state.joint_2, + joint_state.joint_3, + joint_state.joint_4, + joint_state.joint_5, + joint_state.joint_6, + ] + ) # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) - factor = 57295.7795 #1000*180/3.1415926 - joint_angles = joint_angles * factor # convert to radians + factor = 57295.7795 # 1000*180/3.1415926 + joint_angles = joint_angles * factor # convert to radians # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) - q = np.array([joint_angles[0], joint_angles[1], joint_angles[2], joint_angles[3], joint_angles[4], joint_angles[5]]) + q = np.array( + [ + joint_angles[0], + joint_angles[1], + joint_angles[2], + joint_angles[3], + joint_angles[4], + joint_angles[5], + ] + ) # print(f"[PiperArm] Current Joints: {q}") time.sleep(0.005) - dq = self.J_pinv@np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot])*self.dt + dq = self.J_pinv @ np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt newq = q + dq self.arm.MotionCtrl_2(0x01, 0x01, 100, 0x00) - self.arm.JointCtrl(int(round(newq[0])), int(round(newq[1])), int(round(newq[2])), int(round(newq[3])), int(round(newq[4])), int(round(newq[5]))) + self.arm.JointCtrl( + int(round(newq[0])), + int(round(newq[1])), + int(round(newq[2])), + int(round(newq[3])), + int(round(newq[4])), + int(round(newq[5])), + ) # print(f"[PiperArm] Moving to Joints to : {newq}") def cmd_vel_ee(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): @@ -173,7 +201,6 @@ def disable(self): print("get_EE_pose") arm.get_EE_pose() - def get_key(timeout=0.1): """Non-blocking key reader for arrow keys.""" fd = sys.stdin.fileno() @@ -183,9 +210,9 @@ def get_key(timeout=0.1): rlist, _, _ = select.select([fd], [], [], timeout) if rlist: ch1 = sys.stdin.read(1) - if ch1 == '\x1b': # Arrow keys start with ESC + if ch1 == "\x1b": # Arrow keys start with ESC ch2 = sys.stdin.read(1) - if ch2 == '[': + if ch2 == "[": ch3 = sys.stdin.read(1) return ch1 + ch2 + ch3 else: @@ -200,19 +227,19 @@ def teleop_linear_vel(arm): x_dot, y_dot, z_dot = 0.0, 0.0, 0.0 while True: key = get_key(timeout=0.1) - if key == '\x1b[A': # Up arrow + if key == "\x1b[A": # Up arrow x_dot += 0.01 - elif key == '\x1b[B': # Down arrow + elif key == "\x1b[B": # Down arrow x_dot -= 0.01 - elif key == '\x1b[C': # Right arrow + elif key == "\x1b[C": # Right arrow y_dot += 0.01 - elif key == '\x1b[D': # Left arrow + elif key == "\x1b[D": # Left arrow y_dot -= 0.01 - elif key == 'w': + elif key == "w": z_dot += 0.01 - elif key == 's': + elif key == "s": z_dot -= 0.01 - elif key == 'q': + elif key == "q": print("Exiting teleop.") arm.disable() break From c4493f2ea7cd1a33559b747ea7c607d15c12244d Mon Sep 17 00:00:00 2001 From: mustafab0 <39084056+mustafab0@users.noreply.github.com> Date: Tue, 15 Jul 2025 02:32:44 +0000 Subject: [PATCH 50/89] CI code cleanup --- dimos/hardware/piper_arm.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 2b63bc13ee..de4f69ef8e 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -179,9 +179,18 @@ def cmd_vel_ee(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): Y_dot = Y_dot * factor current_pose = self.get_EE_pose().end_pose - current_pose = np.array([current_pose.X_axis, current_pose.Y_axis, current_pose.Z_axis, current_pose.RX_axis, current_pose.RY_axis, current_pose.RZ_axis]) + current_pose = np.array( + [ + current_pose.X_axis, + current_pose.Y_axis, + current_pose.Z_axis, + current_pose.RX_axis, + current_pose.RY_axis, + current_pose.RZ_axis, + ] + ) current_pose = current_pose * factor - current_pose = current_pose + np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot])*self.dt + current_pose = current_pose + np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt current_pose = current_pose / factor self.cmd_EE_pose(current_pose[0], current_pose[1], current_pose[2], current_pose[3], current_pose[4], current_pose[5]) time.sleep(self.dt) @@ -251,7 +260,8 @@ def teleop_linear_vel(arm): # Only linear velocities, angular set to zero arm.cmd_vel_ee(x_dot, y_dot, z_dot, 0, 0, 0) - print(f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s") + print( + f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s" + ) teleop_linear_vel(arm) - From 6ac4da173b5ae08bb71d0d53d22c3a5021887e83 Mon Sep 17 00:00:00 2001 From: mustafab0 <39084056+mustafab0@users.noreply.github.com> Date: Tue, 15 Jul 2025 02:35:49 +0000 Subject: [PATCH 51/89] CI code cleanup --- dimos/hardware/piper_arm.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index de4f69ef8e..84eac1d63c 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -192,7 +192,14 @@ def cmd_vel_ee(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): current_pose = current_pose * factor current_pose = current_pose + np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt current_pose = current_pose / factor - self.cmd_EE_pose(current_pose[0], current_pose[1], current_pose[2], current_pose[3], current_pose[4], current_pose[5]) + self.cmd_EE_pose( + current_pose[0], + current_pose[1], + current_pose[2], + current_pose[3], + current_pose[4], + current_pose[5], + ) time.sleep(self.dt) def disable(self): From 6a760fa5e192ea7cc5630809625829e30f2e1e53 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 14 Jul 2025 19:09:42 -0700 Subject: [PATCH 52/89] fixed frame transform --- dimos/hardware/piper_arm.py | 38 ++++++++----- dimos/hardware/zed_camera.py | 1 + dimos/manipulation/ibvs/detection3d.py | 77 ++++++++++++++------------ dimos/manipulation/ibvs/pbvs.py | 63 +++------------------ dimos/manipulation/ibvs/utils.py | 6 +- 5 files changed, 77 insertions(+), 108 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 84eac1d63c..a8d99e0594 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -60,20 +60,19 @@ def enable(self): self.arm.MotionCtrl_2(0x01, 0x01, 80, 0x00) def gotoZero(self): - factor = 57295.7795 # 1000*180/3.1415926 - position = [0, 0, 0, 0, 0, 0, 0] - - joint_0 = round(position[0] * factor) - joint_1 = round(position[1] * factor) - joint_2 = round(position[2] * factor) - joint_3 = round(position[3] * factor) - joint_4 = round(position[4] * factor) - joint_5 = round(position[5] * factor) - joint_6 = round(position[6] * 1000 * 1000) - self.arm.ModeCtrl(0x01, 0x01, 30, 0x00) - self.arm.JointCtrl(joint_0, joint_1, joint_2, joint_3, joint_4, joint_5) + factor = 1000 + position = [57.0, 0.0, 215.0, 0, 90.0, 0, 0] + X = round(position[0] * factor) + Y = round(position[1] * factor) + Z = round(position[2] * factor) + RX = round(position[3] * factor) + RY = round(position[4] * factor) + RZ = round(position[5] * factor) + joint_6 = round(position[6] * factor) + print(X, Y, Z, RX, RY, RZ) + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) self.arm.GripperCtrl(abs(joint_6), 1000, 0x01, 0) - pass def softStop(self): self.gotoZero() @@ -95,8 +94,17 @@ def cmd_EE_pose(self, x, y, z, r, p, y_): def get_EE_pose(self): """Return the current end-effector pose as (x, y, z, r, p, y)""" pose = self.arm.GetArmEndPoseMsgs() - print(f"[PiperArm] Current pose: {pose}") - return pose + # Extract individual pose values and convert to base units + # Position values are divided by 1000 to convert from SDK units to mm + # Rotation values are divided by 1000 to convert from SDK units to degrees + x = pose.end_pose.X_axis / 1000.0 + y = pose.end_pose.Y_axis / 1000.0 + z = pose.end_pose.Z_axis / 1000.0 + r = pose.end_pose.RX_axis / 1000.0 + p = pose.end_pose.RY_axis / 1000.0 + y_rot = pose.end_pose.RZ_axis / 1000.0 + + return (x, y, z, r, p, y_rot) def cmd_gripper_ctrl(self, position): """Command end-effector gripper""" diff --git a/dimos/hardware/zed_camera.py b/dimos/hardware/zed_camera.py index ba936cec3a..a2ceeba54e 100644 --- a/dimos/hardware/zed_camera.py +++ b/dimos/hardware/zed_camera.py @@ -64,6 +64,7 @@ def __init__( self.init_params = sl.InitParameters() self.init_params.camera_resolution = resolution self.init_params.depth_mode = depth_mode + self.init_params.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Z_UP_X_FWD self.init_params.coordinate_units = sl.UNIT.METER self.init_params.camera_fps = fps diff --git a/dimos/manipulation/ibvs/detection3d.py b/dimos/manipulation/ibvs/detection3d.py index aca0169bf6..928a27a879 100644 --- a/dimos/manipulation/ibvs/detection3d.py +++ b/dimos/manipulation/ibvs/detection3d.py @@ -29,7 +29,12 @@ from dimos.types.pose import Pose from dimos.types.vector import Vector from dimos.types.manipulation import ObjectData -from dimos.manipulation.ibvs.utils import estimate_object_depth +from dimos.manipulation.ibvs.utils import ( + estimate_object_depth, + optical_to_robot_convention, + pose_to_transform_matrix, + transform_matrix_to_pose, +) logger = setup_logger("dimos.perception.detection3d") @@ -47,7 +52,7 @@ def __init__( camera_intrinsics: List[float], # [fx, fy, cx, cy] min_confidence: float = 0.6, min_points: int = 30, - max_depth: float = 5.0, + max_depth: float = 1.0, ): """ Initialize the real-time 3D detection processor. @@ -125,7 +130,7 @@ def process_frame( # Build detection results detections = [] - pose_dict = {p["mask_idx"]: p for p in poses} + pose_dict = {p["mask_idx"]: p for p in poses if p["centroid"][2] < self.max_depth} for i, (bbox, name, prob, track_id) in enumerate(zip(bboxes, names, probs, track_ids)): # Create ObjectData object @@ -175,56 +180,60 @@ def process_frame( # Transform to world frame if camera pose is available if camera_pose is not None: - world_pos = self._transform_to_world(obj_cam_pos, camera_pose) - obj_data["world_position"] = world_pos - obj_data["position"] = world_pos # Use world position + # Get orientation as euler angles, default to no rotation if not available + obj_cam_orientation = pose.get( + "rotation", np.array([0.0, 0.0, 0.0]) + ) # Default to no rotation + world_pose = self._transform_to_world( + obj_cam_pos, obj_cam_orientation, camera_pose + ) + obj_data["world_position"] = world_pose.pos + obj_data["position"] = world_pose.pos # Use world position + obj_data["rotation"] = world_pose.rot # Use world rotation else: # If no camera pose, use camera coordinates obj_data["position"] = Vector(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]) - detections.append(obj_data) + detections.append(obj_data) return {"detections": detections, "processing_time": time.time() - start_time} - def _transform_to_world(self, obj_pos: np.ndarray, camera_pose: Pose) -> Vector: + def _transform_to_world( + self, obj_pos: np.ndarray, obj_orientation: np.ndarray, camera_pose: Pose + ) -> Pose: """ - Transform object position from camera frame to world frame (ZED coordinates). + Transform object pose from optical frame to world frame. Args: - obj_pos: Object position in camera frame [x, y, z] - camera_pose: Camera pose in world frame + obj_pos: Object position in optical frame [x, y, z] + obj_orientation: Object orientation in optical frame [roll, pitch, yaw] in radians + camera_pose: Camera pose in world frame (x forward, y left, z up) Returns: - Object position in world frame as Vector + Object pose in world frame as Pose """ - # Simple transformation: rotate and translate - roll = camera_pose.rot.x - pitch = camera_pose.rot.y - yaw = camera_pose.rot.z - - # Create rotation matrices - cos_roll = np.cos(roll) - sin_roll = np.sin(roll) - R_x = np.array([[1, 0, 0], [0, cos_roll, -sin_roll], [0, sin_roll, cos_roll]]) + # Create object pose in optical frame + obj_pose_optical = Pose( + Vector(obj_pos[0], obj_pos[1], obj_pos[2]), + Vector([obj_orientation[0], obj_orientation[1], obj_orientation[2]]), + ) - cos_pitch = np.cos(pitch) - sin_pitch = np.sin(pitch) - R_y = np.array([[cos_pitch, 0, sin_pitch], [0, 1, 0], [-sin_pitch, 0, cos_pitch]]) + # Transform object pose from optical frame to world frame convention + obj_pose_world_frame = optical_to_robot_convention(obj_pose_optical) - cos_yaw = np.cos(yaw) - sin_yaw = np.sin(yaw) - R_z = np.array([[cos_yaw, -sin_yaw, 0], [sin_yaw, cos_yaw, 0], [0, 0, 1]]) + # Create transformation matrix from camera pose + T_world_camera = pose_to_transform_matrix(camera_pose) - # Combined rotation (ZYX convention) - rot_matrix = R_z @ R_y @ R_x + # Create transformation matrix from object pose (relative to camera) + T_camera_object = pose_to_transform_matrix(obj_pose_world_frame) - # Rotate object position - rotated_pos = rot_matrix @ obj_pos + # Combine transformations: T_world_object = T_world_camera * T_camera_object + T_world_object = T_world_camera @ T_camera_object - # Translate by camera position - world_pos = camera_pose.pos + Vector(rotated_pos[0], rotated_pos[1], rotated_pos[2]) + # Convert back to pose + world_pose = transform_matrix_to_pose(T_world_object) - return world_pos + return world_pose def visualize_detections( self, diff --git a/dimos/manipulation/ibvs/pbvs.py b/dimos/manipulation/ibvs/pbvs.py index 7fbf828535..cdcffbf15b 100644 --- a/dimos/manipulation/ibvs/pbvs.py +++ b/dimos/manipulation/ibvs/pbvs.py @@ -27,7 +27,7 @@ from dimos.manipulation.ibvs.utils import ( pose_to_transform_matrix, apply_transform, - zed_to_robot_convention, + optical_to_robot_convention, calculate_yaw_to_origin, ) @@ -53,7 +53,7 @@ def __init__( max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s target_tolerance: float = 0.05, # 5cm - tracking_distance_threshold: float = 0.1, # 10cm for target tracking + tracking_distance_threshold: float = 0.05, # 10cm for target tracking ): """ Initialize PBVS controller. @@ -97,7 +97,7 @@ def set_manipulator_origin(self, camera_pose: Pose): This establishes the robot arm coordinate frame. Args: - camera_pose: Current camera pose in ZED world frame + camera_pose: Current camera pose in world frame """ self.manipulator_origin_pose = camera_pose @@ -111,10 +111,6 @@ def set_manipulator_origin(self, camera_pose: Pose): f"{camera_pose.pos.y:.3f}, {camera_pose.pos.z:.3f})" ) - # Update current target if exists - if self.current_target and "position" in self.current_target: - self._update_target_robot_frame() - def _update_target_robot_frame(self): """Update current target with robot frame coordinates.""" if not self.current_target or "position" not in self.current_target: @@ -127,14 +123,11 @@ def _update_target_robot_frame(self): # Transform to manipulator frame target_pose_manip = apply_transform(target_pose_zed, self.manipulator_origin) - # Convert to robot convention - target_pose_robot = zed_to_robot_convention(target_pose_manip) - # Calculate orientation pointing at origin (in robot frame) - yaw_to_origin = calculate_yaw_to_origin(target_pose_robot.pos) + yaw_to_origin = calculate_yaw_to_origin(target_pose_manip.pos) # Update target with robot frame pose - self.current_target["robot_position"] = target_pose_robot.pos + self.current_target["robot_position"] = target_pose_manip.pos self.current_target["robot_rotation"] = Vector(0.0, 0.0, yaw_to_origin) # Level grasp def set_target(self, target_object: Dict[str, Any]) -> bool: @@ -254,8 +247,7 @@ def compute_control( self.update_target_tracking(new_detections) # Transform camera pose to robot frame - camera_pose_manip = apply_transform(camera_pose, self.manipulator_origin) - camera_pose_robot = zed_to_robot_convention(camera_pose_manip) + camera_pose_robot = apply_transform(camera_pose, self.manipulator_origin) # Get target in robot frame target_pos = self.current_target.get("robot_position") @@ -265,6 +257,7 @@ def compute_control( # Shouldn't happen but handle gracefully self._update_target_robot_frame() target_pos = self.current_target.get("robot_position", Vector(0, 0, 0)) + target_rot = self.current_target.get("robot_rotation", Vector(0, 0, 0)) # Calculate position error (target - camera) error = target_pos - camera_pose_robot.pos @@ -342,48 +335,6 @@ def _compute_angular_velocity(self, target_rot: Vector, current_pose: Pose) -> V return angular_velocity - def get_camera_pose_robot_frame(self, camera_pose_zed: Pose) -> Optional[Pose]: - """ - Get camera pose in robot frame coordinates. - - Args: - camera_pose_zed: Camera pose in ZED world frame - - Returns: - Camera pose in robot frame or None if no origin set - """ - if self.manipulator_origin is None: - return None - - camera_pose_manip = apply_transform(camera_pose_zed, self.manipulator_origin) - return zed_to_robot_convention(camera_pose_manip) - - def get_object_pose_robot_frame( - self, object_pos_zed: Vector - ) -> Optional[Tuple[Vector, Vector]]: - """ - Get object pose in robot frame coordinates with orientation. - - Args: - object_pos_zed: Object position in ZED world frame - - Returns: - Tuple of (position, rotation) in robot frame or None if no origin set - """ - if self.manipulator_origin is None: - return None - - # Transform position - obj_pose_zed = Pose(object_pos_zed, Vector(0, 0, 0)) - obj_pose_manip = apply_transform(obj_pose_zed, self.manipulator_origin) - obj_pose_robot = zed_to_robot_convention(obj_pose_manip) - - # Calculate orientation pointing at origin - yaw_to_origin = calculate_yaw_to_origin(obj_pose_robot.pos) - orientation = Vector(0.0, 0.0, yaw_to_origin) # Level grasp - - return obj_pose_robot.pos, orientation - def create_status_overlay( self, image: np.ndarray, camera_intrinsics: Optional[list] = None ) -> np.ndarray: diff --git a/dimos/manipulation/ibvs/utils.py b/dimos/manipulation/ibvs/utils.py index cca1acbab5..8befbdbad7 100644 --- a/dimos/manipulation/ibvs/utils.py +++ b/dimos/manipulation/ibvs/utils.py @@ -130,7 +130,7 @@ def apply_transform(pose: Pose, transform_matrix: np.ndarray) -> Pose: return transform_matrix_to_pose(T_result) -def zed_to_robot_convention(pose: Pose) -> Pose: +def optical_to_robot_convention(pose: Pose) -> Pose: """ Convert pose from ZED camera convention to robot arm convention. @@ -200,10 +200,10 @@ def zed_to_robot_convention(pose: Pose) -> Pose: return Pose(Vector(robot_x, robot_y, robot_z), Vector(robot_roll, robot_pitch, robot_yaw)) -def robot_to_zed_convention(pose: Pose) -> Pose: +def robot_to_optical_convention(pose: Pose) -> Pose: """ Convert pose from robot arm convention to ZED camera convention. - This is the inverse of zed_to_robot_convention. + This is the inverse of optical_to_robot_convention. Args: pose: Pose in robot arm convention From fa413616cd2d4ffb88508c60d96ea5c38389237d Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 14 Jul 2025 22:48:29 -0700 Subject: [PATCH 53/89] feature: PBVS fully working using ZED as origin --- dimos/hardware/piper_arm.py | 30 ++-- dimos/manipulation/ibvs/pbvs.py | 233 ++++++++++++++++++++++++++++---- tests/test_ibvs.py | 2 +- 3 files changed, 220 insertions(+), 45 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index a8d99e0594..19bb7f866e 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -45,7 +45,7 @@ def init_can(self): result = subprocess.run( [ "bash", - "dimos/dimos/hardware/can_activate.sh", + "dimos/hardware/can_activate.sh", ], # pass the script path directly if it has a shebang and execute perms stdout=subprocess.PIPE, # capture stdout stderr=subprocess.PIPE, # capture stderr @@ -61,7 +61,7 @@ def enable(self): def gotoZero(self): factor = 1000 - position = [57.0, 0.0, 215.0, 0, 90.0, 0, 0] + position = [57.0, 0.0, 250.0, 0, 90.0, 0, 0] X = round(position[0] * factor) Y = round(position[1] * factor) Z = round(position[2] * factor) @@ -89,7 +89,6 @@ def cmd_EE_pose(self, x, y, z, r, p, y_): self.arm.EndPoseCtrl( int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5]) ) - print(f"[PiperArm] Moving to pose: {pose}") def get_EE_pose(self): """Return the current end-effector pose as (x, y, z, r, p, y)""" @@ -120,7 +119,7 @@ def resetArm(self): def init_vel_controller(self): self.chain = kp.build_serial_chain_from_urdf( - open("dimos/dimos/hardware/piper_description.urdf"), "gripper_base" + open("dimos/hardware/piper_description.urdf"), "gripper_base" ) self.J = self.chain.jacobian(np.zeros(6)) self.J_pinv = np.linalg.pinv(self.J) @@ -186,20 +185,11 @@ def cmd_vel_ee(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): P_dot = P_dot * factor Y_dot = Y_dot * factor - current_pose = self.get_EE_pose().end_pose - current_pose = np.array( - [ - current_pose.X_axis, - current_pose.Y_axis, - current_pose.Z_axis, - current_pose.RX_axis, - current_pose.RY_axis, - current_pose.RZ_axis, - ] - ) - current_pose = current_pose * factor + current_pose = self.get_EE_pose() + current_pose = np.array(current_pose) + current_pose = current_pose current_pose = current_pose + np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt - current_pose = current_pose / factor + current_pose = current_pose self.cmd_EE_pose( current_pose[0], current_pose[1], @@ -269,9 +259,9 @@ def teleop_linear_vel(arm): break # Optionally, clamp velocities to reasonable limits - x_dot = max(min(x_dot, 0.2), -0.2) - y_dot = max(min(y_dot, 0.2), -0.2) - z_dot = max(min(z_dot, 0.2), -0.2) + x_dot = max(min(x_dot, 0.5), -0.5) + y_dot = max(min(y_dot, 0.5), -0.5) + z_dot = max(min(z_dot, 0.5), -0.5) # Only linear velocities, angular set to zero arm.cmd_vel_ee(x_dot, y_dot, z_dot, 0, 0, 0) diff --git a/dimos/manipulation/ibvs/pbvs.py b/dimos/manipulation/ibvs/pbvs.py index cdcffbf15b..7f27099caa 100644 --- a/dimos/manipulation/ibvs/pbvs.py +++ b/dimos/manipulation/ibvs/pbvs.py @@ -44,6 +44,8 @@ class PBVSController: - Velocity command generation with gain control - Automatic target tracking across frames - Frame transformations from ZED to robot conventions + - Pregrasp distance functionality + - 6DOF EE to camera transform handling """ def __init__( @@ -52,8 +54,12 @@ def __init__( rotation_gain: float = 0.3, max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s - target_tolerance: float = 0.05, # 5cm - tracking_distance_threshold: float = 0.05, # 10cm for target tracking + target_tolerance: float = 0.01, # 5cm + tracking_distance_threshold: float = 0.05, # 5cm for target tracking + pregrasp_distance: float = 0.15, # 15cm pregrasp distance + ee_to_camera_transform: Vector = Vector( + [0.0, 0.0, -0.06, 0.0, -1.57, 0.0] + ), # 6DOF: [x,y,z,rx,ry,rz] ): """ Initialize PBVS controller. @@ -65,6 +71,8 @@ def __init__( max_angular_velocity: Maximum angular velocity command magnitude (rad/s) target_tolerance: Distance threshold for considering target reached (m) tracking_distance_threshold: Max distance for target association (m) + pregrasp_distance: Distance to maintain before grasping (m) + ee_to_camera_transform: 6DOF transform from EE to camera [x,y,z,rx,ry,rz] """ self.position_gain = position_gain self.rotation_gain = rotation_gain @@ -72,6 +80,8 @@ def __init__( self.max_angular_velocity = max_angular_velocity self.target_tolerance = target_tolerance self.tracking_distance_threshold = tracking_distance_threshold + self.pregrasp_distance = pregrasp_distance + self.ee_to_camera_transform_vec = ee_to_camera_transform # State variables self.current_target = None @@ -85,12 +95,38 @@ def __init__( self.manipulator_origin = None # Transform matrix from world to manipulator frame self.manipulator_origin_pose = None # Original pose for reference + # Create 6DOF EE to camera transform matrix + self.ee_to_camera_transform = self._create_ee_to_camera_transform() + logger.info( f"Initialized PBVS controller: pos_gain={position_gain}, rot_gain={rotation_gain}, " f"max_vel={max_velocity}m/s, max_ang_vel={max_angular_velocity}rad/s, " - f"target_tolerance={target_tolerance}m" + f"target_tolerance={target_tolerance}m, pregrasp_distance={pregrasp_distance}m, " + f"ee_to_camera_transform={ee_to_camera_transform.to_list()}" ) + def _create_ee_to_camera_transform(self) -> np.ndarray: + """ + Create 6DOF transform matrix from EE to camera frame. + + Returns: + 4x4 transformation matrix from EE to camera + """ + # Extract position and rotation from 6DOF vector + pos = self.ee_to_camera_transform_vec.to_list()[:3] + rot = self.ee_to_camera_transform_vec.to_list()[3:6] + + # Create transformation matrix + T_ee_to_cam = np.eye(4) + T_ee_to_cam[0:3, 3] = pos + + # Apply rotation (using Rodrigues formula) + if np.linalg.norm(rot) > 1e-6: + rot_matrix = cv2.Rodrigues(np.array(rot))[0] + T_ee_to_cam[0:3, 0:3] = rot_matrix + + return T_ee_to_cam + def set_manipulator_origin(self, camera_pose: Pose): """ Set the manipulator frame origin based on current camera pose. @@ -111,6 +147,42 @@ def set_manipulator_origin(self, camera_pose: Pose): f"{camera_pose.pos.y:.3f}, {camera_pose.pos.z:.3f})" ) + def _apply_pregrasp_distance(self, target_pose: Pose) -> Pose: + """ + Apply pregrasp distance to target pose by moving back towards robot origin. + + Args: + target_pose: Target pose in robot frame + + Returns: + Modified target pose with pregrasp distance applied + """ + # Get approach vector (from target position towards robot origin) + target_pos = np.array([target_pose.pos.x, target_pose.pos.y, target_pose.pos.z]) + robot_origin = np.array([0.0, 0.0, 0.0]) # Robot origin in robot frame + approach_vector = robot_origin - target_pos # Vector pointing towards robot + + # Normalize approach vector + approach_magnitude = np.linalg.norm(approach_vector) + if approach_magnitude > 1e-6: # Avoid division by zero + norm_approach_vector = approach_vector / approach_magnitude + else: + norm_approach_vector = np.array([0.0, 0.0, 0.0]) + + # Move back by pregrasp distance towards robot + offset_vector = self.pregrasp_distance * norm_approach_vector + + # Apply offset to target position + new_position = Vector( + [ + target_pose.pos.x + offset_vector[0], + target_pose.pos.y + offset_vector[1], + target_pose.pos.z + offset_vector[2], + ] + ) + + return Pose(new_position, target_pose.rot) + def _update_target_robot_frame(self): """Update current target with robot frame coordinates.""" if not self.current_target or "position" not in self.current_target: @@ -118,7 +190,7 @@ def _update_target_robot_frame(self): # Get target position in ZED world frame target_pos = self.current_target["position"] - target_pose_zed = Pose(target_pos, Vector(0, 0, 0)) + target_pose_zed = Pose(target_pos, Vector([0.0, 0.0, 0.0])) # Transform to manipulator frame target_pose_manip = apply_transform(target_pose_zed, self.manipulator_origin) @@ -126,9 +198,15 @@ def _update_target_robot_frame(self): # Calculate orientation pointing at origin (in robot frame) yaw_to_origin = calculate_yaw_to_origin(target_pose_manip.pos) + # Create target pose with proper orientation + target_pose_robot = Pose(target_pose_manip.pos, Vector([0.0, 1.57, yaw_to_origin])) + + # Apply pregrasp distance + target_pose_pregrasp = self._apply_pregrasp_distance(target_pose_robot) + # Update target with robot frame pose - self.current_target["robot_position"] = target_pose_manip.pos - self.current_target["robot_rotation"] = Vector(0.0, 0.0, yaw_to_origin) # Level grasp + self.current_target["robot_position"] = target_pose_pregrasp.pos + self.current_target["robot_rotation"] = target_pose_pregrasp.rot def set_target(self, target_object: Dict[str, Any]) -> bool: """ @@ -217,6 +295,27 @@ def update_target_tracking(self, new_detections: List[Dict[str, Any]]) -> bool: return True return False + def _get_ee_pose_from_camera(self, camera_pose: Pose) -> Pose: + """ + Get end-effector pose from camera pose using 6DOF EE to camera transform. + + Args: + camera_pose: Current camera pose in robot frame + + Returns: + End-effector pose in robot frame + """ + # Transform camera pose to EE frame + camera_transform = pose_to_transform_matrix(camera_pose) + ee_transform = camera_transform @ np.linalg.inv(self.ee_to_camera_transform) + + # Extract position and rotation + ee_pos = Vector(ee_transform[0:3, 3]) + ee_rot_matrix = ee_transform[0:3, 0:3] + ee_rot = Vector(cv2.Rodrigues(ee_rot_matrix)[0].flatten()) + + return Pose(ee_pos, ee_rot) + def compute_control( self, camera_pose: Pose, new_detections: Optional[List[Dict[str, Any]]] = None ) -> Tuple[Optional[Vector], Optional[Vector], bool, bool]: @@ -246,40 +345,55 @@ def compute_control( if new_detections is not None: self.update_target_tracking(new_detections) + print(f"Camera pose: {camera_pose}") + # Transform camera pose to robot frame camera_pose_robot = apply_transform(camera_pose, self.manipulator_origin) + # Get EE pose from camera pose + ee_pose_robot = self._get_ee_pose_from_camera(camera_pose_robot) + # Get target in robot frame target_pos = self.current_target.get("robot_position") - target_rot = self.current_target.get("robot_rotation", Vector(0, 0, 0)) + target_rot = self.current_target.get("robot_rotation") - if target_pos is None: - # Shouldn't happen but handle gracefully - self._update_target_robot_frame() - target_pos = self.current_target.get("robot_position", Vector(0, 0, 0)) - target_rot = self.current_target.get("robot_rotation", Vector(0, 0, 0)) + if target_pos is None or target_rot is None: + logger.warning("Target position or rotation not available") + return None, None, False, False - # Calculate position error (target - camera) - error = target_pos - camera_pose_robot.pos + # Calculate position error (target - EE position) + error = target_pos - ee_pose_robot.pos self.last_position_error = error # Compute velocity command with proportional control - velocity_cmd = error * self.position_gain + velocity_cmd = Vector( + [ + error.x * self.position_gain, + error.y * self.position_gain, + error.z * self.position_gain, + ] + ) # Limit velocity magnitude vel_magnitude = np.linalg.norm([velocity_cmd.x, velocity_cmd.y, velocity_cmd.z]) if vel_magnitude > self.max_velocity: scale = self.max_velocity / vel_magnitude - velocity_cmd = velocity_cmd * scale + velocity_cmd = Vector( + [ + float(velocity_cmd.x * scale), + float(velocity_cmd.y * scale), + float(velocity_cmd.z * scale), + ] + ) self.last_velocity_cmd = velocity_cmd # Compute angular velocity for orientation control - angular_velocity_cmd = self._compute_angular_velocity(target_rot, camera_pose_robot) + angular_velocity_cmd = self._compute_angular_velocity(target_rot, ee_pose_robot) # Check if target reached error_magnitude = np.linalg.norm([error.x, error.y, error.z]) - target_reached = error_magnitude < self.target_tolerance + target_reached = bool(error_magnitude < self.target_tolerance) self.last_target_reached = target_reached # Clear target only if it's reached @@ -298,7 +412,7 @@ def _compute_angular_velocity(self, target_rot: Vector, current_pose: Pose) -> V Args: target_rot: Target orientation (roll, pitch, yaw) - current_pose: Current camera/EE pose + current_pose: Current EE pose Returns: Angular velocity command as Vector @@ -314,13 +428,15 @@ def _compute_angular_velocity(self, target_rot: Vector, current_pose: Pose) -> V while yaw_error < -np.pi: yaw_error += 2 * np.pi - self.last_rotation_error = Vector(roll_error, pitch_error, yaw_error) + self.last_rotation_error = Vector([roll_error, pitch_error, yaw_error]) # Apply proportional control angular_velocity = Vector( - roll_error * self.rotation_gain, - pitch_error * self.rotation_gain, - yaw_error * self.rotation_gain, + [ + roll_error * self.rotation_gain, + pitch_error * self.rotation_gain, + yaw_error * self.rotation_gain, + ] ) # Limit angular velocity magnitude @@ -335,6 +451,63 @@ def _compute_angular_velocity(self, target_rot: Vector, current_pose: Pose) -> V return angular_velocity + def get_camera_pose_robot_frame(self, camera_pose_zed: Pose) -> Optional[Pose]: + """ + Get camera pose in robot frame coordinates. + + Args: + camera_pose_zed: Camera pose in ZED world frame + + Returns: + Camera pose in robot frame or None if no origin set + """ + if self.manipulator_origin is None: + return None + + camera_pose_manip = apply_transform(camera_pose_zed, self.manipulator_origin) + return camera_pose_manip + + def get_ee_pose_robot_frame(self, camera_pose_zed: Pose) -> Optional[Pose]: + """ + Get end-effector pose in robot frame coordinates. + + Args: + camera_pose_zed: Camera pose in ZED world frame + + Returns: + End-effector pose in robot frame or None if no origin set + """ + if self.manipulator_origin is None: + return None + + camera_pose_robot = apply_transform(camera_pose_zed, self.manipulator_origin) + return self._get_ee_pose_from_camera(camera_pose_robot) + + def get_object_pose_robot_frame( + self, object_pos_zed: Vector + ) -> Optional[Tuple[Vector, Vector]]: + """ + Get object pose in robot frame coordinates with orientation. + + Args: + object_pos_zed: Object position in ZED world frame + + Returns: + Tuple of (position, rotation) in robot frame or None if no origin set + """ + if self.manipulator_origin is None: + return None + + # Transform position + obj_pose_zed = Pose(object_pos_zed, Vector([0.0, 0.0, 0.0])) + obj_pose_manip = apply_transform(obj_pose_zed, self.manipulator_origin) + + # Calculate orientation pointing at origin + yaw_to_origin = calculate_yaw_to_origin(obj_pose_manip.pos) + orientation = Vector([0.0, 0.0, yaw_to_origin]) # Level grasp + + return obj_pose_manip.pos, orientation + def create_status_overlay( self, image: np.ndarray, camera_intrinsics: Optional[list] = None ) -> np.ndarray: @@ -353,7 +526,7 @@ def create_status_overlay( # Status panel if self.current_target: - panel_height = 140 # Increased for rotation display + panel_height = 140 # Adjusted panel height panel_y = height - panel_height overlay = viz_img.copy() cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) @@ -436,6 +609,18 @@ def create_status_overlay( 1, ) + # Add config info + ee_transform = self.ee_to_camera_transform_vec.to_list() + cv2.putText( + viz_img, + f"Pregrasp: {self.pregrasp_distance:.3f}m | EE Transform: [{ee_transform[0]:.2f},{ee_transform[1]:.2f},{ee_transform[2]:.2f}]", + (10, y + 125), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) + if self.last_target_reached: cv2.putText( viz_img, diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index 03eb80f6ae..a33651a160 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -78,7 +78,7 @@ def main(): # Initialize processors detector = Detection3DProcessor(intrinsics) - pbvs = PBVSController(position_gain=0.3, rotation_gain=0.2, target_tolerance=0.1) + pbvs = PBVSController(position_gain=0.3, rotation_gain=0.2, target_tolerance=0.025) # Setup window cv2.namedWindow("PBVS") From 0f5a21423dff5fbd23d6bfb2728f6cfc6b80787f Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 15 Jul 2025 17:57:06 -0700 Subject: [PATCH 54/89] refactored transform function into transform utils --- dimos/manipulation/ibvs/detection3d.py | 18 +- dimos/manipulation/ibvs/pbvs.py | 18 +- dimos/manipulation/ibvs/utils.py | 232 +----------------------- dimos/utils/transform_utils.py | 233 +++++++++++++++++++++++++ 4 files changed, 252 insertions(+), 249 deletions(-) diff --git a/dimos/manipulation/ibvs/detection3d.py b/dimos/manipulation/ibvs/detection3d.py index 928a27a879..63df184d2e 100644 --- a/dimos/manipulation/ibvs/detection3d.py +++ b/dimos/manipulation/ibvs/detection3d.py @@ -29,11 +29,11 @@ from dimos.types.pose import Pose from dimos.types.vector import Vector from dimos.types.manipulation import ObjectData -from dimos.manipulation.ibvs.utils import ( - estimate_object_depth, - optical_to_robot_convention, - pose_to_transform_matrix, - transform_matrix_to_pose, +from dimos.manipulation.ibvs.utils import estimate_object_depth +from dimos.utils.transform_utils import ( + optical_to_robot_frame, + pose_to_matrix, + matrix_to_pose, ) logger = setup_logger("dimos.perception.detection3d") @@ -219,19 +219,19 @@ def _transform_to_world( ) # Transform object pose from optical frame to world frame convention - obj_pose_world_frame = optical_to_robot_convention(obj_pose_optical) + obj_pose_world_frame = optical_to_robot_frame(obj_pose_optical) # Create transformation matrix from camera pose - T_world_camera = pose_to_transform_matrix(camera_pose) + T_world_camera = pose_to_matrix(camera_pose) # Create transformation matrix from object pose (relative to camera) - T_camera_object = pose_to_transform_matrix(obj_pose_world_frame) + T_camera_object = pose_to_matrix(obj_pose_world_frame) # Combine transformations: T_world_object = T_world_camera * T_camera_object T_world_object = T_world_camera @ T_camera_object # Convert back to pose - world_pose = transform_matrix_to_pose(T_world_object) + world_pose = matrix_to_pose(T_world_object) return world_pose diff --git a/dimos/manipulation/ibvs/pbvs.py b/dimos/manipulation/ibvs/pbvs.py index 7f27099caa..e593248773 100644 --- a/dimos/manipulation/ibvs/pbvs.py +++ b/dimos/manipulation/ibvs/pbvs.py @@ -24,11 +24,11 @@ from dimos.types.pose import Pose from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger -from dimos.manipulation.ibvs.utils import ( - pose_to_transform_matrix, +from dimos.utils.transform_utils import ( + pose_to_matrix, apply_transform, - optical_to_robot_convention, - calculate_yaw_to_origin, + optical_to_robot_frame, + yaw_towards_point, ) logger = setup_logger("dimos.manipulation.pbvs") @@ -139,7 +139,7 @@ def set_manipulator_origin(self, camera_pose: Pose): # Create transform matrix from ZED world to manipulator origin # This is the inverse of the camera pose at origin - T_world_to_origin = pose_to_transform_matrix(camera_pose) + T_world_to_origin = pose_to_matrix(camera_pose) self.manipulator_origin = np.linalg.inv(T_world_to_origin) logger.info( @@ -196,7 +196,7 @@ def _update_target_robot_frame(self): target_pose_manip = apply_transform(target_pose_zed, self.manipulator_origin) # Calculate orientation pointing at origin (in robot frame) - yaw_to_origin = calculate_yaw_to_origin(target_pose_manip.pos) + yaw_to_origin = yaw_towards_point(target_pose_manip.pos) # Create target pose with proper orientation target_pose_robot = Pose(target_pose_manip.pos, Vector([0.0, 1.57, yaw_to_origin])) @@ -306,7 +306,7 @@ def _get_ee_pose_from_camera(self, camera_pose: Pose) -> Pose: End-effector pose in robot frame """ # Transform camera pose to EE frame - camera_transform = pose_to_transform_matrix(camera_pose) + camera_transform = pose_to_matrix(camera_pose) ee_transform = camera_transform @ np.linalg.inv(self.ee_to_camera_transform) # Extract position and rotation @@ -345,8 +345,6 @@ def compute_control( if new_detections is not None: self.update_target_tracking(new_detections) - print(f"Camera pose: {camera_pose}") - # Transform camera pose to robot frame camera_pose_robot = apply_transform(camera_pose, self.manipulator_origin) @@ -503,7 +501,7 @@ def get_object_pose_robot_frame( obj_pose_manip = apply_transform(obj_pose_zed, self.manipulator_origin) # Calculate orientation pointing at origin - yaw_to_origin = calculate_yaw_to_origin(obj_pose_manip.pos) + yaw_to_origin = yaw_towards_point(obj_pose_manip.pos) orientation = Vector([0.0, 0.0, yaw_to_origin]) # Level grasp return obj_pose_manip.pos, orientation diff --git a/dimos/manipulation/ibvs/utils.py b/dimos/manipulation/ibvs/utils.py index 8befbdbad7..d9094af4b7 100644 --- a/dimos/manipulation/ibvs/utils.py +++ b/dimos/manipulation/ibvs/utils.py @@ -13,10 +13,10 @@ # limitations under the License. import numpy as np -from typing import Dict, Any, Optional, Tuple, List +from typing import Dict, Any, Optional, List + from dimos.types.pose import Pose from dimos.types.vector import Vector -import cv2 def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: @@ -47,234 +47,6 @@ def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: return Pose(pos_vector, rot_vector) -def pose_to_transform_matrix(pose: Pose) -> np.ndarray: - """ - Convert pose to 4x4 homogeneous transform matrix. - - Args: - pose: Pose object with position and rotation (euler angles) - - Returns: - 4x4 transformation matrix - """ - # Extract position - tx, ty, tz = pose.pos.x, pose.pos.y, pose.pos.z - - # Extract euler angles - roll, pitch, yaw = pose.rot.x, pose.rot.y, pose.rot.z - - # Create rotation matrices - cos_roll, sin_roll = np.cos(roll), np.sin(roll) - cos_pitch, sin_pitch = np.cos(pitch), np.sin(pitch) - cos_yaw, sin_yaw = np.cos(yaw), np.sin(yaw) - - # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention - R_x = np.array([[1, 0, 0], [0, cos_roll, -sin_roll], [0, sin_roll, cos_roll]]) - - R_y = np.array([[cos_pitch, 0, sin_pitch], [0, 1, 0], [-sin_pitch, 0, cos_pitch]]) - - R_z = np.array([[cos_yaw, -sin_yaw, 0], [sin_yaw, cos_yaw, 0], [0, 0, 1]]) - - R = R_z @ R_y @ R_x - - # Create 4x4 transform - T = np.eye(4) - T[:3, :3] = R - T[:3, 3] = [tx, ty, tz] - - return T - - -def transform_matrix_to_pose(T: np.ndarray) -> Pose: - """ - Convert 4x4 transformation matrix to Pose object. - - Args: - T: 4x4 transformation matrix - - Returns: - Pose object with position and rotation (euler angles) - """ - # Extract position - pos = Vector(T[0, 3], T[1, 3], T[2, 3]) - - # Extract rotation (euler angles from rotation matrix) - R = T[:3, :3] - roll = np.arctan2(R[2, 1], R[2, 2]) - pitch = np.arctan2(-R[2, 0], np.sqrt(R[2, 1] ** 2 + R[2, 2] ** 2)) - yaw = np.arctan2(R[1, 0], R[0, 0]) - - rot = Vector(roll, pitch, yaw) - - return Pose(pos, rot) - - -def apply_transform(pose: Pose, transform_matrix: np.ndarray) -> Pose: - """ - Apply a transformation matrix to a pose. - - Args: - pose: Input pose - transform_matrix: 4x4 transformation matrix to apply - - Returns: - Transformed pose - """ - # Convert pose to matrix - T_pose = pose_to_transform_matrix(pose) - - # Apply transform - T_result = transform_matrix @ T_pose - - # Convert back to pose - return transform_matrix_to_pose(T_result) - - -def optical_to_robot_convention(pose: Pose) -> Pose: - """ - Convert pose from ZED camera convention to robot arm convention. - - ZED Camera Coordinates: - - X: Right - - Y: Down - - Z: Forward (away from camera) - - Robot/ROS Coordinates: - - X: Forward - - Y: Left - - Z: Up - - Args: - pose: Pose in ZED camera convention - - Returns: - Pose in robot arm convention - """ - # Position transformation - robot_x = pose.pos.z # Forward = ZED Z - robot_y = -pose.pos.x # Left = -ZED X - robot_z = -pose.pos.y # Up = -ZED Y - - # Rotation transformation using rotation matrices - # First, create rotation matrix from ZED Euler angles - roll_zed, pitch_zed, yaw_zed = pose.rot.x, pose.rot.y, pose.rot.z - - # Create rotation matrix for ZED frame (ZYX convention) - cr, sr = np.cos(roll_zed), np.sin(roll_zed) - cp, sp = np.cos(pitch_zed), np.sin(pitch_zed) - cy, sy = np.cos(yaw_zed), np.sin(yaw_zed) - - # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention - R_x = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]]) - - R_y = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]]) - - R_z = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]]) - - R_zed = R_z @ R_y @ R_x - - # Coordinate frame transformation matrix from ZED to Robot - # X_robot = Z_zed, Y_robot = -X_zed, Z_robot = -Y_zed - T_frame = np.array( - [ - [0, 0, 1], # X_robot = Z_zed - [-1, 0, 0], # Y_robot = -X_zed - [0, -1, 0], - ] - ) # Z_robot = -Y_zed - - # Transform the rotation matrix - R_robot = T_frame @ R_zed @ T_frame.T - - # Extract Euler angles from robot rotation matrix - # Using ZYX convention for robot frame as well - robot_roll = np.arctan2(R_robot[2, 1], R_robot[2, 2]) - robot_pitch = np.arctan2(-R_robot[2, 0], np.sqrt(R_robot[2, 1] ** 2 + R_robot[2, 2] ** 2)) - robot_yaw = np.arctan2(R_robot[1, 0], R_robot[0, 0]) - - # Normalize angles to [-π, π] - robot_roll = np.arctan2(np.sin(robot_roll), np.cos(robot_roll)) - robot_pitch = np.arctan2(np.sin(robot_pitch), np.cos(robot_pitch)) - robot_yaw = np.arctan2(np.sin(robot_yaw), np.cos(robot_yaw)) - - return Pose(Vector(robot_x, robot_y, robot_z), Vector(robot_roll, robot_pitch, robot_yaw)) - - -def robot_to_optical_convention(pose: Pose) -> Pose: - """ - Convert pose from robot arm convention to ZED camera convention. - This is the inverse of optical_to_robot_convention. - - Args: - pose: Pose in robot arm convention - - Returns: - Pose in ZED camera convention - """ - # Position transformation (inverse) - zed_x = -pose.pos.y # Right = -Left - zed_y = -pose.pos.z # Down = -Up - zed_z = pose.pos.x # Forward = Forward - - # Rotation transformation using rotation matrices - # First, create rotation matrix from Robot Euler angles - roll_robot, pitch_robot, yaw_robot = pose.rot.x, pose.rot.y, pose.rot.z - - # Create rotation matrix for Robot frame (ZYX convention) - cr, sr = np.cos(roll_robot), np.sin(roll_robot) - cp, sp = np.cos(pitch_robot), np.sin(pitch_robot) - cy, sy = np.cos(yaw_robot), np.sin(yaw_robot) - - # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention - R_x = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]]) - - R_y = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]]) - - R_z = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]]) - - R_robot = R_z @ R_y @ R_x - - # Coordinate frame transformation matrix from Robot to ZED (inverse of ZED to Robot) - # This is the transpose of the forward transformation - T_frame_inv = np.array( - [ - [0, -1, 0], # X_zed = -Y_robot - [0, 0, -1], # Y_zed = -Z_robot - [1, 0, 0], - ] - ) # Z_zed = X_robot - - # Transform the rotation matrix - R_zed = T_frame_inv @ R_robot @ T_frame_inv.T - - # Extract Euler angles from ZED rotation matrix - # Using ZYX convention for ZED frame as well - zed_roll = np.arctan2(R_zed[2, 1], R_zed[2, 2]) - zed_pitch = np.arctan2(-R_zed[2, 0], np.sqrt(R_zed[2, 1] ** 2 + R_zed[2, 2] ** 2)) - zed_yaw = np.arctan2(R_zed[1, 0], R_zed[0, 0]) - - # Normalize angles - zed_roll = np.arctan2(np.sin(zed_roll), np.cos(zed_roll)) - zed_pitch = np.arctan2(np.sin(zed_pitch), np.cos(zed_pitch)) - zed_yaw = np.arctan2(np.sin(zed_yaw), np.cos(zed_yaw)) - - return Pose(Vector(zed_x, zed_y, zed_z), Vector(zed_roll, zed_pitch, zed_yaw)) - - -def calculate_yaw_to_origin(position: Vector) -> float: - """ - Calculate yaw angle to point away from origin (0,0,0) - Assumes robot frame where X is forward and Y is left. - - Args: - position: Current position in robot frame - - Returns: - Yaw angle in radians to point away from origin - """ - return np.arctan2(position.y, position.x) - - def estimate_object_depth( depth_image: np.ndarray, segmentation_mask: Optional[np.ndarray], bbox: List[float] ) -> float: diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 31d3840884..3c53b44042 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -15,8 +15,10 @@ import numpy as np from typing import Tuple, Dict, Any import logging +import cv2 from dimos.types.vector import Vector +from dimos.types.pose import Pose logger = logging.getLogger(__name__) @@ -31,6 +33,237 @@ def distance_angle_to_goal_xy(distance: float, angle: float) -> Tuple[float, flo return distance * np.cos(angle), distance * np.sin(angle) +def pose_to_matrix(pose: Pose) -> np.ndarray: + """ + Convert pose to 4x4 homogeneous transform matrix. + + Args: + pose: Pose object with position and rotation (euler angles) + + Returns: + 4x4 transformation matrix + """ + # Extract position + tx, ty, tz = pose.pos.x, pose.pos.y, pose.pos.z + + # Extract euler angles + roll, pitch, yaw = pose.rot.x, pose.rot.y, pose.rot.z + + # Create rotation matrices + cos_roll, sin_roll = np.cos(roll), np.sin(roll) + cos_pitch, sin_pitch = np.cos(pitch), np.sin(pitch) + cos_yaw, sin_yaw = np.cos(yaw), np.sin(yaw) + + # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention + R_x = np.array([[1, 0, 0], [0, cos_roll, -sin_roll], [0, sin_roll, cos_roll]]) + + R_y = np.array([[cos_pitch, 0, sin_pitch], [0, 1, 0], [-sin_pitch, 0, cos_pitch]]) + + R_z = np.array([[cos_yaw, -sin_yaw, 0], [sin_yaw, cos_yaw, 0], [0, 0, 1]]) + + R = R_z @ R_y @ R_x + + # Create 4x4 transform + T = np.eye(4) + T[:3, :3] = R + T[:3, 3] = [tx, ty, tz] + + return T + + +def matrix_to_pose(T: np.ndarray) -> Pose: + """ + Convert 4x4 transformation matrix to Pose object. + + Args: + T: 4x4 transformation matrix + + Returns: + Pose object with position and rotation (euler angles) + """ + # Extract position + pos = Vector(T[0, 3], T[1, 3], T[2, 3]) + + # Extract rotation (euler angles from rotation matrix) + R = T[:3, :3] + roll = np.arctan2(R[2, 1], R[2, 2]) + pitch = np.arctan2(-R[2, 0], np.sqrt(R[2, 1] ** 2 + R[2, 2] ** 2)) + yaw = np.arctan2(R[1, 0], R[0, 0]) + + rot = Vector(roll, pitch, yaw) + + return Pose(pos, rot) + + +def apply_transform(pose: Pose, transform_matrix: np.ndarray) -> Pose: + """ + Apply a transformation matrix to a pose. + + Args: + pose: Input pose + transform_matrix: 4x4 transformation matrix to apply + + Returns: + Transformed pose + """ + # Convert pose to matrix + T_pose = pose_to_matrix(pose) + + # Apply transform + T_result = transform_matrix @ T_pose + + # Convert back to pose + return matrix_to_pose(T_result) + + +def optical_to_robot_frame(pose: Pose) -> Pose: + """ + Convert pose from optical camera frame to robot frame convention. + + Optical Camera Frame (e.g., ZED): + - X: Right + - Y: Down + - Z: Forward (away from camera) + + Robot Frame (ROS/REP-103): + - X: Forward + - Y: Left + - Z: Up + + Args: + pose: Pose in optical camera frame + + Returns: + Pose in robot frame + """ + # Position transformation + robot_x = pose.pos.z # Forward = Camera Z + robot_y = -pose.pos.x # Left = -Camera X + robot_z = -pose.pos.y # Up = -Camera Y + + # Rotation transformation using rotation matrices + # First, create rotation matrix from optical frame Euler angles + roll_optical, pitch_optical, yaw_optical = pose.rot.x, pose.rot.y, pose.rot.z + + # Create rotation matrix for optical frame (ZYX convention) + cr, sr = np.cos(roll_optical), np.sin(roll_optical) + cp, sp = np.cos(pitch_optical), np.sin(pitch_optical) + cy, sy = np.cos(yaw_optical), np.sin(yaw_optical) + + # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention + R_x = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]]) + + R_y = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]]) + + R_z = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]]) + + R_optical = R_z @ R_y @ R_x + + # Coordinate frame transformation matrix from optical to robot + # X_robot = Z_optical, Y_robot = -X_optical, Z_robot = -Y_optical + T_frame = np.array( + [ + [0, 0, 1], # X_robot = Z_optical + [-1, 0, 0], # Y_robot = -X_optical + [0, -1, 0], + ] + ) # Z_robot = -Y_optical + + # Transform the rotation matrix + R_robot = T_frame @ R_optical @ T_frame.T + + # Extract Euler angles from robot rotation matrix + # Using ZYX convention for robot frame as well + robot_roll = np.arctan2(R_robot[2, 1], R_robot[2, 2]) + robot_pitch = np.arctan2(-R_robot[2, 0], np.sqrt(R_robot[2, 1] ** 2 + R_robot[2, 2] ** 2)) + robot_yaw = np.arctan2(R_robot[1, 0], R_robot[0, 0]) + + # Normalize angles to [-π, π] + robot_roll = normalize_angle(robot_roll) + robot_pitch = normalize_angle(robot_pitch) + robot_yaw = normalize_angle(robot_yaw) + + return Pose(Vector(robot_x, robot_y, robot_z), Vector(robot_roll, robot_pitch, robot_yaw)) + + +def robot_to_optical_frame(pose: Pose) -> Pose: + """ + Convert pose from robot frame to optical camera frame convention. + This is the inverse of optical_to_robot_frame. + + Args: + pose: Pose in robot frame + + Returns: + Pose in optical camera frame + """ + # Position transformation (inverse) + optical_x = -pose.pos.y # Right = -Left + optical_y = -pose.pos.z # Down = -Up + optical_z = pose.pos.x # Forward = Forward + + # Rotation transformation using rotation matrices + # First, create rotation matrix from Robot Euler angles + roll_robot, pitch_robot, yaw_robot = pose.rot.x, pose.rot.y, pose.rot.z + + # Create rotation matrix for Robot frame (ZYX convention) + cr, sr = np.cos(roll_robot), np.sin(roll_robot) + cp, sp = np.cos(pitch_robot), np.sin(pitch_robot) + cy, sy = np.cos(yaw_robot), np.sin(yaw_robot) + + # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention + R_x = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]]) + + R_y = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]]) + + R_z = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]]) + + R_robot = R_z @ R_y @ R_x + + # Coordinate frame transformation matrix from Robot to optical (inverse of optical to Robot) + # This is the transpose of the forward transformation + T_frame_inv = np.array( + [ + [0, -1, 0], # X_optical = -Y_robot + [0, 0, -1], # Y_optical = -Z_robot + [1, 0, 0], + ] + ) # Z_optical = X_robot + + # Transform the rotation matrix + R_optical = T_frame_inv @ R_robot @ T_frame_inv.T + + # Extract Euler angles from optical rotation matrix + # Using ZYX convention for optical frame as well + optical_roll = np.arctan2(R_optical[2, 1], R_optical[2, 2]) + optical_pitch = np.arctan2(-R_optical[2, 0], np.sqrt(R_optical[2, 1] ** 2 + R_optical[2, 2] ** 2)) + optical_yaw = np.arctan2(R_optical[1, 0], R_optical[0, 0]) + + # Normalize angles + optical_roll = normalize_angle(optical_roll) + optical_pitch = normalize_angle(optical_pitch) + optical_yaw = normalize_angle(optical_yaw) + + return Pose(Vector(optical_x, optical_y, optical_z), Vector(optical_roll, optical_pitch, optical_yaw)) + + +def yaw_towards_point(position: Vector, target_point: Vector = Vector(0.0, 0.0, 0.0)) -> float: + """ + Calculate yaw angle from target point to position (away from target). + This is commonly used for object orientation in grasping applications. + Assumes robot frame where X is forward and Y is left. + + Args: + position: Current position in robot frame + target_point: Reference point (default: origin) + + Returns: + Yaw angle in radians pointing from target_point to position + """ + direction = position - target_point + return np.arctan2(direction.y, direction.x) + + def transform_robot_to_map( robot_position: Vector, robot_rotation: Vector, position: Vector, rotation: Vector ) -> Tuple[Vector, Vector]: From 9c7313600eea7ecc02e8fbe36690e6fdf6542e2c Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 15 Jul 2025 18:43:58 -0700 Subject: [PATCH 55/89] Use Scipy for Quaternion to Euler --- dimos/msgs/geometry_msgs/Quaternion.py | 28 +++++++------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/dimos/msgs/geometry_msgs/Quaternion.py b/dimos/msgs/geometry_msgs/Quaternion.py index 9879e1e263..9369ef99b3 100644 --- a/dimos/msgs/geometry_msgs/Quaternion.py +++ b/dimos/msgs/geometry_msgs/Quaternion.py @@ -22,6 +22,7 @@ import numpy as np from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion from plum import dispatch +from scipy.spatial.transform import Rotation as R from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -117,27 +118,12 @@ def to_euler(self) -> Vector3: Returns: Vector3: Euler angles as (roll, pitch, yaw) in radians """ - # Convert quaternion to Euler angles using ZYX convention (yaw, pitch, roll) - # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles - - # Roll (x-axis rotation) - sinr_cosp = 2 * (self.w * self.x + self.y * self.z) - cosr_cosp = 1 - 2 * (self.x * self.x + self.y * self.y) - roll = np.arctan2(sinr_cosp, cosr_cosp) - - # Pitch (y-axis rotation) - sinp = 2 * (self.w * self.y - self.z * self.x) - if abs(sinp) >= 1: - pitch = np.copysign(np.pi / 2, sinp) # Use 90 degrees if out of range - else: - pitch = np.arcsin(sinp) - - # Yaw (z-axis rotation) - siny_cosp = 2 * (self.w * self.z + self.x * self.y) - cosy_cosp = 1 - 2 * (self.y * self.y + self.z * self.z) - yaw = np.arctan2(siny_cosp, cosy_cosp) - - return Vector3(roll, pitch, yaw) + # Use scipy for accurate quaternion to euler conversion + quat = [self.x, self.y, self.z, self.w] + rotation = R.from_quat(quat) + euler_angles = rotation.as_euler('xyz') # roll, pitch, yaw + + return Vector3(euler_angles[0], euler_angles[1], euler_angles[2]) def __getitem__(self, idx: int) -> float: """Allow indexing into quaternion components: 0=x, 1=y, 2=z, 3=w.""" From b0e41fd4f3bf63ef2853a51b10f429a6766baa2a Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 15 Jul 2025 18:46:08 -0700 Subject: [PATCH 56/89] switched to using LCM Pose and scipy for pose transforms --- dimos/manipulation/ibvs/detection3d.py | 23 ++-- dimos/manipulation/ibvs/pbvs.py | 128 +++++++++++++------- dimos/manipulation/ibvs/utils.py | 15 +-- dimos/utils/transform_utils.py | 156 ++++++++----------------- tests/test_ibvs.py | 15 +-- 5 files changed, 158 insertions(+), 179 deletions(-) diff --git a/dimos/manipulation/ibvs/detection3d.py b/dimos/manipulation/ibvs/detection3d.py index 63df184d2e..508e8b4db4 100644 --- a/dimos/manipulation/ibvs/detection3d.py +++ b/dimos/manipulation/ibvs/detection3d.py @@ -20,13 +20,14 @@ from typing import Dict, List, Optional, Any import numpy as np import cv2 +from scipy.spatial.transform import Rotation as R from dimos.utils.logging_config import setup_logger from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter from dimos.perception.pointcloud.utils import extract_centroids_from_masks from dimos.perception.detection2d.utils import plot_results, calculate_object_size_from_bbox -from dimos.types.pose import Pose +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion from dimos.types.vector import Vector from dimos.types.manipulation import ObjectData from dimos.manipulation.ibvs.utils import estimate_object_depth @@ -187,12 +188,12 @@ def process_frame( world_pose = self._transform_to_world( obj_cam_pos, obj_cam_orientation, camera_pose ) - obj_data["world_position"] = world_pose.pos - obj_data["position"] = world_pose.pos # Use world position - obj_data["rotation"] = world_pose.rot # Use world rotation + obj_data["world_position"] = world_pose.position + obj_data["position"] = world_pose.position # Use world position + obj_data["rotation"] = world_pose.orientation # Use world rotation else: # If no camera pose, use camera coordinates - obj_data["position"] = Vector(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]) + obj_data["position"] = Vector3(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]) detections.append(obj_data) @@ -213,9 +214,13 @@ def _transform_to_world( Object pose in world frame as Pose """ # Create object pose in optical frame + # Convert euler angles to quaternion + quat = R.from_euler('xyz', obj_orientation).as_quat() # [x, y, z, w] + obj_orientation_quat = Quaternion(quat[0], quat[1], quat[2], quat[3]) + obj_pose_optical = Pose( - Vector(obj_pos[0], obj_pos[1], obj_pos[2]), - Vector([obj_orientation[0], obj_orientation[1], obj_orientation[2]]), + Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), + obj_orientation_quat ) # Transform object pose from optical frame to world frame convention @@ -284,7 +289,7 @@ def visualize_detections( bbox = det["bbox"] - if isinstance(display_position, Vector): + if isinstance(display_position, Vector3): display_xyz = np.array( [display_position.x, display_position.y, display_position.z] ) @@ -348,7 +353,7 @@ def get_closest_detection( # Sort by depth (Z coordinate) def get_z_coord(d): pos = d["position"] - if isinstance(pos, Vector): + if isinstance(pos, Vector3): return abs(pos.z) return abs(pos["z"]) diff --git a/dimos/manipulation/ibvs/pbvs.py b/dimos/manipulation/ibvs/pbvs.py index e593248773..58f7dc5839 100644 --- a/dimos/manipulation/ibvs/pbvs.py +++ b/dimos/manipulation/ibvs/pbvs.py @@ -21,7 +21,8 @@ from typing import Optional, Tuple, Dict, Any, List import cv2 -from dimos.types.pose import Pose +from scipy.spatial.transform import Rotation as R +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( @@ -114,16 +115,16 @@ def _create_ee_to_camera_transform(self) -> np.ndarray: """ # Extract position and rotation from 6DOF vector pos = self.ee_to_camera_transform_vec.to_list()[:3] - rot = self.ee_to_camera_transform_vec.to_list()[3:6] + rot = self.ee_to_camera_transform_vec.to_list()[3:6] # euler angles: [rx, ry, rz] # Create transformation matrix T_ee_to_cam = np.eye(4) T_ee_to_cam[0:3, 3] = pos - # Apply rotation (using Rodrigues formula) + # Apply rotation using scipy (treating as euler angles) if np.linalg.norm(rot) > 1e-6: - rot_matrix = cv2.Rodrigues(np.array(rot))[0] - T_ee_to_cam[0:3, 0:3] = rot_matrix + rotation = R.from_euler('xyz', rot) + T_ee_to_cam[0:3, 0:3] = rotation.as_matrix() return T_ee_to_cam @@ -143,8 +144,8 @@ def set_manipulator_origin(self, camera_pose: Pose): self.manipulator_origin = np.linalg.inv(T_world_to_origin) logger.info( - f"Set manipulator origin at pose: pos=({camera_pose.pos.x:.3f}, " - f"{camera_pose.pos.y:.3f}, {camera_pose.pos.z:.3f})" + f"Set manipulator origin at pose: pos=({camera_pose.position.x:.3f}, " + f"{camera_pose.position.y:.3f}, {camera_pose.position.z:.3f})" ) def _apply_pregrasp_distance(self, target_pose: Pose) -> Pose: @@ -158,7 +159,7 @@ def _apply_pregrasp_distance(self, target_pose: Pose) -> Pose: Modified target pose with pregrasp distance applied """ # Get approach vector (from target position towards robot origin) - target_pos = np.array([target_pose.pos.x, target_pose.pos.y, target_pose.pos.z]) + target_pos = np.array([target_pose.position.x, target_pose.position.y, target_pose.position.z]) robot_origin = np.array([0.0, 0.0, 0.0]) # Robot origin in robot frame approach_vector = robot_origin - target_pos # Vector pointing towards robot @@ -173,15 +174,13 @@ def _apply_pregrasp_distance(self, target_pose: Pose) -> Pose: offset_vector = self.pregrasp_distance * norm_approach_vector # Apply offset to target position - new_position = Vector( - [ - target_pose.pos.x + offset_vector[0], - target_pose.pos.y + offset_vector[1], - target_pose.pos.z + offset_vector[2], - ] + new_position = Vector3( + target_pose.position.x + offset_vector[0], + target_pose.position.y + offset_vector[1], + target_pose.position.z + offset_vector[2] ) - return Pose(new_position, target_pose.rot) + return Pose(new_position, target_pose.orientation) def _update_target_robot_frame(self): """Update current target with robot frame coordinates.""" @@ -190,23 +189,30 @@ def _update_target_robot_frame(self): # Get target position in ZED world frame target_pos = self.current_target["position"] - target_pose_zed = Pose(target_pos, Vector([0.0, 0.0, 0.0])) + target_pose_zed = Pose(target_pos, Quaternion()) # Identity quaternion # Transform to manipulator frame target_pose_manip = apply_transform(target_pose_zed, self.manipulator_origin) # Calculate orientation pointing at origin (in robot frame) - yaw_to_origin = yaw_towards_point(target_pose_manip.pos) + yaw_to_origin = yaw_towards_point(Vector(target_pose_manip.position.x, + target_pose_manip.position.y, + target_pose_manip.position.z)) # Create target pose with proper orientation - target_pose_robot = Pose(target_pose_manip.pos, Vector([0.0, 1.57, yaw_to_origin])) + # Convert euler angles to quaternion using scipy + euler = [0.0, 1.57, yaw_to_origin] # roll=0, pitch=90deg, yaw=calculated + quat = R.from_euler('xyz', euler).as_quat() # [x, y, z, w] + target_orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) + + target_pose_robot = Pose(target_pose_manip.position, target_orientation) # Apply pregrasp distance target_pose_pregrasp = self._apply_pregrasp_distance(target_pose_robot) # Update target with robot frame pose - self.current_target["robot_position"] = target_pose_pregrasp.pos - self.current_target["robot_rotation"] = target_pose_pregrasp.rot + self.current_target["robot_position"] = target_pose_pregrasp.position + self.current_target["robot_rotation"] = target_pose_pregrasp.orientation def set_target(self, target_object: Dict[str, Any]) -> bool: """ @@ -263,7 +269,7 @@ def update_target_tracking(self, new_detections: List[Dict[str, Any]]) -> bool: # Get current target position (in ZED world frame for matching) target_pos = self.current_target["position"] - if isinstance(target_pos, Vector): + if isinstance(target_pos, (Vector, Vector3)): target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) else: target_xyz = np.array([target_pos["x"], target_pos["y"], target_pos["z"]]) @@ -277,7 +283,7 @@ def update_target_tracking(self, new_detections: List[Dict[str, Any]]) -> bool: continue det_pos = detection["position"] - if isinstance(det_pos, Vector): + if isinstance(det_pos, (Vector, Vector3)): det_xyz = np.array([det_pos.x, det_pos.y, det_pos.z]) else: det_xyz = np.array([det_pos["x"], det_pos["y"], det_pos["z"]]) @@ -310,11 +316,22 @@ def _get_ee_pose_from_camera(self, camera_pose: Pose) -> Pose: ee_transform = camera_transform @ np.linalg.inv(self.ee_to_camera_transform) # Extract position and rotation - ee_pos = Vector(ee_transform[0:3, 3]) + ee_pos = Vector3(ee_transform[0:3, 3]) ee_rot_matrix = ee_transform[0:3, 0:3] - ee_rot = Vector(cv2.Rodrigues(ee_rot_matrix)[0].flatten()) - - return Pose(ee_pos, ee_rot) + + # Convert rotation matrix to quaternion + + # Ensure the rotation matrix is valid (orthogonal with det=1) + try: + rotation = R.from_matrix(ee_rot_matrix) + quat = rotation.as_quat() # [x, y, z, w] + ee_orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) + except ValueError as e: + logger.warning(f"Invalid rotation matrix in EE pose calculation: {e}") + # Fallback to identity quaternion + ee_orientation = Quaternion(0.0, 0.0, 0.0, 1.0) + + return Pose(ee_pos, ee_orientation) def compute_control( self, camera_pose: Pose, new_detections: Optional[List[Dict[str, Any]]] = None @@ -360,7 +377,11 @@ def compute_control( return None, None, False, False # Calculate position error (target - EE position) - error = target_pos - ee_pose_robot.pos + error = Vector3( + target_pos.x - ee_pose_robot.position.x, + target_pos.y - ee_pose_robot.position.y, + target_pos.z - ee_pose_robot.position.z + ) self.last_position_error = error # Compute velocity command with proportional control @@ -403,28 +424,39 @@ def compute_control( return velocity_cmd, angular_velocity_cmd, target_reached, True - def _compute_angular_velocity(self, target_rot: Vector, current_pose: Pose) -> Vector: + def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) -> Vector: """ Compute angular velocity commands for orientation control. - Aims for level grasping with appropriate yaw. + Uses quaternion error computation for better numerical stability. Args: - target_rot: Target orientation (roll, pitch, yaw) + target_rot: Target orientation (quaternion) current_pose: Current EE pose Returns: Angular velocity command as Vector """ - # Calculate rotation errors - roll_error = target_rot.x - current_pose.rot.x - pitch_error = target_rot.y - current_pose.rot.y - yaw_error = target_rot.z - current_pose.rot.z - - # Normalize yaw error to [-pi, pi] - while yaw_error > np.pi: - yaw_error -= 2 * np.pi - while yaw_error < -np.pi: - yaw_error += 2 * np.pi + # Use quaternion error for better numerical stability + + # Convert to scipy Rotation objects + target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w]) + current_rot_scipy = R.from_quat([ + current_pose.orientation.x, + current_pose.orientation.y, + current_pose.orientation.z, + current_pose.orientation.w + ]) + + # Compute rotation error: error = target * current^(-1) + error_rot = target_rot_scipy * current_rot_scipy.inv() + + # Convert to axis-angle representation for control + error_axis_angle = error_rot.as_rotvec() + + # Use axis-angle directly as angular velocity error (small angle approximation) + roll_error = error_axis_angle[0] + pitch_error = error_axis_angle[1] + yaw_error = error_axis_angle[2] self.last_rotation_error = Vector([roll_error, pitch_error, yaw_error]) @@ -497,14 +529,20 @@ def get_object_pose_robot_frame( return None # Transform position - obj_pose_zed = Pose(object_pos_zed, Vector([0.0, 0.0, 0.0])) + obj_pose_zed = Pose(object_pos_zed, Quaternion()) # Identity quaternion obj_pose_manip = apply_transform(obj_pose_zed, self.manipulator_origin) # Calculate orientation pointing at origin - yaw_to_origin = yaw_towards_point(obj_pose_manip.pos) - orientation = Vector([0.0, 0.0, yaw_to_origin]) # Level grasp - - return obj_pose_manip.pos, orientation + yaw_to_origin = yaw_towards_point(Vector(obj_pose_manip.position.x, + obj_pose_manip.position.y, + obj_pose_manip.position.z)) + + # Convert euler angles to quaternion + euler = [0.0, 0.0, yaw_to_origin] # Level grasp + quat = R.from_euler('xyz', euler).as_quat() # [x, y, z, w] + orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) + + return obj_pose_manip.position, orientation def create_status_overlay( self, image: np.ndarray, camera_intrinsics: Optional[list] = None diff --git a/dimos/manipulation/ibvs/utils.py b/dimos/manipulation/ibvs/utils.py index d9094af4b7..581d34dc8c 100644 --- a/dimos/manipulation/ibvs/utils.py +++ b/dimos/manipulation/ibvs/utils.py @@ -15,8 +15,7 @@ import numpy as np from typing import Dict, Any, Optional, List -from dimos.types.pose import Pose -from dimos.types.vector import Vector +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: @@ -31,20 +30,18 @@ def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: - valid: Whether pose is valid Returns: - Pose object with position and rotation, or None if invalid + Pose object with position and orientation, or None if invalid """ if not zed_pose_data or not zed_pose_data.get("valid", False): return None # Extract position position = zed_pose_data.get("position", [0, 0, 0]) - pos_vector = Vector(position[0], position[1], position[2]) + pos_vector = Vector3(position[0], position[1], position[2]) - # Extract euler angles (roll, pitch, yaw) - euler = zed_pose_data.get("euler_angles", [0, 0, 0]) - rot_vector = Vector(euler[0], euler[1], euler[2]) # roll, pitch, yaw - - return Pose(pos_vector, rot_vector) + quat = zed_pose_data["rotation"] + orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) + return Pose(pos_vector, orientation) def estimate_object_depth( diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 3c53b44042..8d39f79b7c 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -15,10 +15,10 @@ import numpy as np from typing import Tuple, Dict, Any import logging -import cv2 +from scipy.spatial.transform import Rotation from dimos.types.vector import Vector -from dimos.types.pose import Pose +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion logger = logging.getLogger(__name__) @@ -38,30 +38,18 @@ def pose_to_matrix(pose: Pose) -> np.ndarray: Convert pose to 4x4 homogeneous transform matrix. Args: - pose: Pose object with position and rotation (euler angles) + pose: Pose object with position and orientation (quaternion) Returns: 4x4 transformation matrix """ # Extract position - tx, ty, tz = pose.pos.x, pose.pos.y, pose.pos.z + tx, ty, tz = pose.position.x, pose.position.y, pose.position.z - # Extract euler angles - roll, pitch, yaw = pose.rot.x, pose.rot.y, pose.rot.z - - # Create rotation matrices - cos_roll, sin_roll = np.cos(roll), np.sin(roll) - cos_pitch, sin_pitch = np.cos(pitch), np.sin(pitch) - cos_yaw, sin_yaw = np.cos(yaw), np.sin(yaw) - - # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention - R_x = np.array([[1, 0, 0], [0, cos_roll, -sin_roll], [0, sin_roll, cos_roll]]) - - R_y = np.array([[cos_pitch, 0, sin_pitch], [0, 1, 0], [-sin_pitch, 0, cos_pitch]]) - - R_z = np.array([[cos_yaw, -sin_yaw, 0], [sin_yaw, cos_yaw, 0], [0, 0, 1]]) - - R = R_z @ R_y @ R_x + # Create rotation matrix from quaternion using scipy + quat = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + rotation = Rotation.from_quat(quat) + R = rotation.as_matrix() # Create 4x4 transform T = np.eye(4) @@ -79,20 +67,19 @@ def matrix_to_pose(T: np.ndarray) -> Pose: T: 4x4 transformation matrix Returns: - Pose object with position and rotation (euler angles) + Pose object with position and orientation (quaternion) """ # Extract position - pos = Vector(T[0, 3], T[1, 3], T[2, 3]) + pos = Vector3(T[0, 3], T[1, 3], T[2, 3]) - # Extract rotation (euler angles from rotation matrix) + # Extract rotation matrix and convert to quaternion R = T[:3, :3] - roll = np.arctan2(R[2, 1], R[2, 2]) - pitch = np.arctan2(-R[2, 0], np.sqrt(R[2, 1] ** 2 + R[2, 2] ** 2)) - yaw = np.arctan2(R[1, 0], R[0, 0]) + rotation = Rotation.from_matrix(R) + quat = rotation.as_quat() # Returns [x, y, z, w] + + orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) - rot = Vector(roll, pitch, yaw) - - return Pose(pos, rot) + return Pose(pos, orientation) def apply_transform(pose: Pose, transform_matrix: np.ndarray) -> Pose: @@ -137,53 +124,33 @@ def optical_to_robot_frame(pose: Pose) -> Pose: Pose in robot frame """ # Position transformation - robot_x = pose.pos.z # Forward = Camera Z - robot_y = -pose.pos.x # Left = -Camera X - robot_z = -pose.pos.y # Up = -Camera Y - - # Rotation transformation using rotation matrices - # First, create rotation matrix from optical frame Euler angles - roll_optical, pitch_optical, yaw_optical = pose.rot.x, pose.rot.y, pose.rot.z - - # Create rotation matrix for optical frame (ZYX convention) - cr, sr = np.cos(roll_optical), np.sin(roll_optical) - cp, sp = np.cos(pitch_optical), np.sin(pitch_optical) - cy, sy = np.cos(yaw_optical), np.sin(yaw_optical) - - # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention - R_x = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]]) + robot_x = pose.position.z # Forward = Camera Z + robot_y = -pose.position.x # Left = -Camera X + robot_z = -pose.position.y # Up = -Camera Y - R_y = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]]) - - R_z = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]]) - - R_optical = R_z @ R_y @ R_x + # Rotation transformation using quaternions + # First convert quaternion to rotation matrix + quat_optical = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + R_optical = Rotation.from_quat(quat_optical).as_matrix() # Coordinate frame transformation matrix from optical to robot # X_robot = Z_optical, Y_robot = -X_optical, Z_robot = -Y_optical - T_frame = np.array( - [ - [0, 0, 1], # X_robot = Z_optical - [-1, 0, 0], # Y_robot = -X_optical - [0, -1, 0], - ] - ) # Z_robot = -Y_optical + T_frame = np.array([ + [0, 0, 1], # X_robot = Z_optical + [-1, 0, 0], # Y_robot = -X_optical + [0, -1, 0] # Z_robot = -Y_optical + ]) # Transform the rotation matrix R_robot = T_frame @ R_optical @ T_frame.T - # Extract Euler angles from robot rotation matrix - # Using ZYX convention for robot frame as well - robot_roll = np.arctan2(R_robot[2, 1], R_robot[2, 2]) - robot_pitch = np.arctan2(-R_robot[2, 0], np.sqrt(R_robot[2, 1] ** 2 + R_robot[2, 2] ** 2)) - robot_yaw = np.arctan2(R_robot[1, 0], R_robot[0, 0]) - - # Normalize angles to [-π, π] - robot_roll = normalize_angle(robot_roll) - robot_pitch = normalize_angle(robot_pitch) - robot_yaw = normalize_angle(robot_yaw) + # Convert back to quaternion + quat_robot = Rotation.from_matrix(R_robot).as_quat() # [x, y, z, w] - return Pose(Vector(robot_x, robot_y, robot_z), Vector(robot_roll, robot_pitch, robot_yaw)) + return Pose( + Vector3(robot_x, robot_y, robot_z), + Quaternion(quat_robot[0], quat_robot[1], quat_robot[2], quat_robot[3]) + ) def robot_to_optical_frame(pose: Pose) -> Pose: @@ -198,53 +165,32 @@ def robot_to_optical_frame(pose: Pose) -> Pose: Pose in optical camera frame """ # Position transformation (inverse) - optical_x = -pose.pos.y # Right = -Left - optical_y = -pose.pos.z # Down = -Up - optical_z = pose.pos.x # Forward = Forward - - # Rotation transformation using rotation matrices - # First, create rotation matrix from Robot Euler angles - roll_robot, pitch_robot, yaw_robot = pose.rot.x, pose.rot.y, pose.rot.z - - # Create rotation matrix for Robot frame (ZYX convention) - cr, sr = np.cos(roll_robot), np.sin(roll_robot) - cp, sp = np.cos(pitch_robot), np.sin(pitch_robot) - cy, sy = np.cos(yaw_robot), np.sin(yaw_robot) + optical_x = -pose.position.y # Right = -Left + optical_y = -pose.position.z # Down = -Up + optical_z = pose.position.x # Forward = Forward - # Roll (X), Pitch (Y), Yaw (Z) - ZYX convention - R_x = np.array([[1, 0, 0], [0, cr, -sr], [0, sr, cr]]) - - R_y = np.array([[cp, 0, sp], [0, 1, 0], [-sp, 0, cp]]) - - R_z = np.array([[cy, -sy, 0], [sy, cy, 0], [0, 0, 1]]) - - R_robot = R_z @ R_y @ R_x + # Rotation transformation using quaternions + quat_robot = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + R_robot = Rotation.from_quat(quat_robot).as_matrix() # Coordinate frame transformation matrix from Robot to optical (inverse of optical to Robot) # This is the transpose of the forward transformation - T_frame_inv = np.array( - [ - [0, -1, 0], # X_optical = -Y_robot - [0, 0, -1], # Y_optical = -Z_robot - [1, 0, 0], - ] - ) # Z_optical = X_robot + T_frame_inv = np.array([ + [0, -1, 0], # X_optical = -Y_robot + [0, 0, -1], # Y_optical = -Z_robot + [1, 0, 0] # Z_optical = X_robot + ]) # Transform the rotation matrix R_optical = T_frame_inv @ R_robot @ T_frame_inv.T - # Extract Euler angles from optical rotation matrix - # Using ZYX convention for optical frame as well - optical_roll = np.arctan2(R_optical[2, 1], R_optical[2, 2]) - optical_pitch = np.arctan2(-R_optical[2, 0], np.sqrt(R_optical[2, 1] ** 2 + R_optical[2, 2] ** 2)) - optical_yaw = np.arctan2(R_optical[1, 0], R_optical[0, 0]) - - # Normalize angles - optical_roll = normalize_angle(optical_roll) - optical_pitch = normalize_angle(optical_pitch) - optical_yaw = normalize_angle(optical_yaw) + # Convert back to quaternion + quat_optical = Rotation.from_matrix(R_optical).as_quat() # [x, y, z, w] - return Pose(Vector(optical_x, optical_y, optical_z), Vector(optical_roll, optical_pitch, optical_yaw)) + return Pose( + Vector3(optical_x, optical_y, optical_z), + Quaternion(quat_optical[0], quat_optical[1], quat_optical[2], quat_optical[3]) + ) def yaw_towards_point(position: Vector, target_point: Vector = Vector(0.0, 0.0, 0.0)) -> float: diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index a33651a160..86b5b2b563 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -127,13 +127,6 @@ def main(): viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 ) - # Print velocity commands for debugging (only if origin set) - if vel_cmd and ang_vel_cmd: - print(f"Linear vel: ({vel_cmd.x:.3f}, {vel_cmd.y:.3f}, {vel_cmd.z:.3f}) m/s") - print( - f"Angular vel: ({ang_vel_cmd.x:.3f}, {ang_vel_cmd.y:.3f}, {ang_vel_cmd.z:.3f}) rad/s" - ) - # Convert back to BGR for OpenCV display viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) @@ -143,11 +136,11 @@ def main(): if pbvs.manipulator_origin is not None: cam_robot = pbvs.get_camera_pose_robot_frame(camera_pose) if cam_robot: - pose_text = f"Camera [Robot]: ({cam_robot.pos.x:.2f}, {cam_robot.pos.y:.2f}, {cam_robot.pos.z:.2f})m" + pose_text = f"Camera [Robot]: ({cam_robot.position.x:.2f}, {cam_robot.position.y:.2f}, {cam_robot.position.z:.2f})m" else: - pose_text = f"Camera [ZED]: ({camera_pose.pos.x:.2f}, {camera_pose.pos.y:.2f}, {camera_pose.pos.z:.2f})m" + pose_text = f"Camera [ZED]: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" else: - pose_text = f"Camera [ZED]: ({camera_pose.pos.x:.2f}, {camera_pose.pos.y:.2f}, {camera_pose.pos.z:.2f})m" + pose_text = f"Camera [ZED]: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" cv2.putText( viz_bgr, pose_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 @@ -221,7 +214,7 @@ def main(): elif key == ord("o") and camera_pose: pbvs.set_manipulator_origin(camera_pose) print( - f"Set manipulator origin at: ({camera_pose.pos.x:.3f}, {camera_pose.pos.y:.3f}, {camera_pose.pos.z:.3f})" + f"Set manipulator origin at: ({camera_pose.position.x:.3f}, {camera_pose.position.y:.3f}, {camera_pose.position.z:.3f})" ) except KeyboardInterrupt: From bb32bb4b03cbc8d2b67f51f66dbfbb68c36abe67 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Wed, 16 Jul 2025 01:47:02 +0000 Subject: [PATCH 57/89] CI code cleanup --- dimos/manipulation/ibvs/detection3d.py | 9 ++-- dimos/manipulation/ibvs/pbvs.py | 64 +++++++++++++++----------- dimos/msgs/geometry_msgs/Quaternion.py | 4 +- dimos/utils/transform_utils.py | 32 +++++++------ 4 files changed, 59 insertions(+), 50 deletions(-) diff --git a/dimos/manipulation/ibvs/detection3d.py b/dimos/manipulation/ibvs/detection3d.py index 508e8b4db4..caf693c78e 100644 --- a/dimos/manipulation/ibvs/detection3d.py +++ b/dimos/manipulation/ibvs/detection3d.py @@ -215,13 +215,10 @@ def _transform_to_world( """ # Create object pose in optical frame # Convert euler angles to quaternion - quat = R.from_euler('xyz', obj_orientation).as_quat() # [x, y, z, w] + quat = R.from_euler("xyz", obj_orientation).as_quat() # [x, y, z, w] obj_orientation_quat = Quaternion(quat[0], quat[1], quat[2], quat[3]) - - obj_pose_optical = Pose( - Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), - obj_orientation_quat - ) + + obj_pose_optical = Pose(Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), obj_orientation_quat) # Transform object pose from optical frame to world frame convention obj_pose_world_frame = optical_to_robot_frame(obj_pose_optical) diff --git a/dimos/manipulation/ibvs/pbvs.py b/dimos/manipulation/ibvs/pbvs.py index 58f7dc5839..c34d84d86b 100644 --- a/dimos/manipulation/ibvs/pbvs.py +++ b/dimos/manipulation/ibvs/pbvs.py @@ -123,7 +123,7 @@ def _create_ee_to_camera_transform(self) -> np.ndarray: # Apply rotation using scipy (treating as euler angles) if np.linalg.norm(rot) > 1e-6: - rotation = R.from_euler('xyz', rot) + rotation = R.from_euler("xyz", rot) T_ee_to_cam[0:3, 0:3] = rotation.as_matrix() return T_ee_to_cam @@ -159,7 +159,9 @@ def _apply_pregrasp_distance(self, target_pose: Pose) -> Pose: Modified target pose with pregrasp distance applied """ # Get approach vector (from target position towards robot origin) - target_pos = np.array([target_pose.position.x, target_pose.position.y, target_pose.position.z]) + target_pos = np.array( + [target_pose.position.x, target_pose.position.y, target_pose.position.z] + ) robot_origin = np.array([0.0, 0.0, 0.0]) # Robot origin in robot frame approach_vector = robot_origin - target_pos # Vector pointing towards robot @@ -177,7 +179,7 @@ def _apply_pregrasp_distance(self, target_pose: Pose) -> Pose: new_position = Vector3( target_pose.position.x + offset_vector[0], target_pose.position.y + offset_vector[1], - target_pose.position.z + offset_vector[2] + target_pose.position.z + offset_vector[2], ) return Pose(new_position, target_pose.orientation) @@ -195,16 +197,20 @@ def _update_target_robot_frame(self): target_pose_manip = apply_transform(target_pose_zed, self.manipulator_origin) # Calculate orientation pointing at origin (in robot frame) - yaw_to_origin = yaw_towards_point(Vector(target_pose_manip.position.x, - target_pose_manip.position.y, - target_pose_manip.position.z)) + yaw_to_origin = yaw_towards_point( + Vector( + target_pose_manip.position.x, + target_pose_manip.position.y, + target_pose_manip.position.z, + ) + ) # Create target pose with proper orientation # Convert euler angles to quaternion using scipy euler = [0.0, 1.57, yaw_to_origin] # roll=0, pitch=90deg, yaw=calculated - quat = R.from_euler('xyz', euler).as_quat() # [x, y, z, w] + quat = R.from_euler("xyz", euler).as_quat() # [x, y, z, w] target_orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) - + target_pose_robot = Pose(target_pose_manip.position, target_orientation) # Apply pregrasp distance @@ -318,9 +324,9 @@ def _get_ee_pose_from_camera(self, camera_pose: Pose) -> Pose: # Extract position and rotation ee_pos = Vector3(ee_transform[0:3, 3]) ee_rot_matrix = ee_transform[0:3, 0:3] - + # Convert rotation matrix to quaternion - + # Ensure the rotation matrix is valid (orthogonal with det=1) try: rotation = R.from_matrix(ee_rot_matrix) @@ -380,7 +386,7 @@ def compute_control( error = Vector3( target_pos.x - ee_pose_robot.position.x, target_pos.y - ee_pose_robot.position.y, - target_pos.z - ee_pose_robot.position.z + target_pos.z - ee_pose_robot.position.z, ) self.last_position_error = error @@ -437,25 +443,27 @@ def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) Angular velocity command as Vector """ # Use quaternion error for better numerical stability - + # Convert to scipy Rotation objects target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w]) - current_rot_scipy = R.from_quat([ - current_pose.orientation.x, - current_pose.orientation.y, - current_pose.orientation.z, - current_pose.orientation.w - ]) - + current_rot_scipy = R.from_quat( + [ + current_pose.orientation.x, + current_pose.orientation.y, + current_pose.orientation.z, + current_pose.orientation.w, + ] + ) + # Compute rotation error: error = target * current^(-1) error_rot = target_rot_scipy * current_rot_scipy.inv() - + # Convert to axis-angle representation for control error_axis_angle = error_rot.as_rotvec() - + # Use axis-angle directly as angular velocity error (small angle approximation) roll_error = error_axis_angle[0] - pitch_error = error_axis_angle[1] + pitch_error = error_axis_angle[1] yaw_error = error_axis_angle[2] self.last_rotation_error = Vector([roll_error, pitch_error, yaw_error]) @@ -529,17 +537,17 @@ def get_object_pose_robot_frame( return None # Transform position - obj_pose_zed = Pose(object_pos_zed, Quaternion()) # Identity quaternion + obj_pose_zed = Pose(object_pos_zed, Quaternion()) # Identity quaternion obj_pose_manip = apply_transform(obj_pose_zed, self.manipulator_origin) # Calculate orientation pointing at origin - yaw_to_origin = yaw_towards_point(Vector(obj_pose_manip.position.x, - obj_pose_manip.position.y, - obj_pose_manip.position.z)) - + yaw_to_origin = yaw_towards_point( + Vector(obj_pose_manip.position.x, obj_pose_manip.position.y, obj_pose_manip.position.z) + ) + # Convert euler angles to quaternion euler = [0.0, 0.0, yaw_to_origin] # Level grasp - quat = R.from_euler('xyz', euler).as_quat() # [x, y, z, w] + quat = R.from_euler("xyz", euler).as_quat() # [x, y, z, w] orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) return obj_pose_manip.position, orientation diff --git a/dimos/msgs/geometry_msgs/Quaternion.py b/dimos/msgs/geometry_msgs/Quaternion.py index 9369ef99b3..a7bb5543c1 100644 --- a/dimos/msgs/geometry_msgs/Quaternion.py +++ b/dimos/msgs/geometry_msgs/Quaternion.py @@ -121,8 +121,8 @@ def to_euler(self) -> Vector3: # Use scipy for accurate quaternion to euler conversion quat = [self.x, self.y, self.z, self.w] rotation = R.from_quat(quat) - euler_angles = rotation.as_euler('xyz') # roll, pitch, yaw - + euler_angles = rotation.as_euler("xyz") # roll, pitch, yaw + return Vector3(euler_angles[0], euler_angles[1], euler_angles[2]) def __getitem__(self, idx: int) -> float: diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 8d39f79b7c..46237ce0be 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -76,7 +76,7 @@ def matrix_to_pose(T: np.ndarray) -> Pose: R = T[:3, :3] rotation = Rotation.from_matrix(R) quat = rotation.as_quat() # Returns [x, y, z, w] - + orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) return Pose(pos, orientation) @@ -135,11 +135,13 @@ def optical_to_robot_frame(pose: Pose) -> Pose: # Coordinate frame transformation matrix from optical to robot # X_robot = Z_optical, Y_robot = -X_optical, Z_robot = -Y_optical - T_frame = np.array([ - [0, 0, 1], # X_robot = Z_optical - [-1, 0, 0], # Y_robot = -X_optical - [0, -1, 0] # Z_robot = -Y_optical - ]) + T_frame = np.array( + [ + [0, 0, 1], # X_robot = Z_optical + [-1, 0, 0], # Y_robot = -X_optical + [0, -1, 0], # Z_robot = -Y_optical + ] + ) # Transform the rotation matrix R_robot = T_frame @ R_optical @ T_frame.T @@ -149,7 +151,7 @@ def optical_to_robot_frame(pose: Pose) -> Pose: return Pose( Vector3(robot_x, robot_y, robot_z), - Quaternion(quat_robot[0], quat_robot[1], quat_robot[2], quat_robot[3]) + Quaternion(quat_robot[0], quat_robot[1], quat_robot[2], quat_robot[3]), ) @@ -167,7 +169,7 @@ def robot_to_optical_frame(pose: Pose) -> Pose: # Position transformation (inverse) optical_x = -pose.position.y # Right = -Left optical_y = -pose.position.z # Down = -Up - optical_z = pose.position.x # Forward = Forward + optical_z = pose.position.x # Forward = Forward # Rotation transformation using quaternions quat_robot = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] @@ -175,11 +177,13 @@ def robot_to_optical_frame(pose: Pose) -> Pose: # Coordinate frame transformation matrix from Robot to optical (inverse of optical to Robot) # This is the transpose of the forward transformation - T_frame_inv = np.array([ - [0, -1, 0], # X_optical = -Y_robot - [0, 0, -1], # Y_optical = -Z_robot - [1, 0, 0] # Z_optical = X_robot - ]) + T_frame_inv = np.array( + [ + [0, -1, 0], # X_optical = -Y_robot + [0, 0, -1], # Y_optical = -Z_robot + [1, 0, 0], # Z_optical = X_robot + ] + ) # Transform the rotation matrix R_optical = T_frame_inv @ R_robot @ T_frame_inv.T @@ -189,7 +193,7 @@ def robot_to_optical_frame(pose: Pose) -> Pose: return Pose( Vector3(optical_x, optical_y, optical_z), - Quaternion(quat_optical[0], quat_optical[1], quat_optical[2], quat_optical[3]) + Quaternion(quat_optical[0], quat_optical[1], quat_optical[2], quat_optical[3]), ) From 8a54752829a2fca43239a099afb50177b35054e4 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Wed, 16 Jul 2025 22:55:19 +0000 Subject: [PATCH 58/89] added timeout to controller --- It | 0 build/lib/dimos/__init__.py | 1 + build/lib/dimos/agents/__init__.py | 0 build/lib/dimos/agents/agent.py | 904 ++++++++++ build/lib/dimos/agents/agent_config.py | 55 + .../dimos/agents/agent_ctransformers_gguf.py | 210 +++ .../dimos/agents/agent_huggingface_local.py | 235 +++ .../dimos/agents/agent_huggingface_remote.py | 143 ++ build/lib/dimos/agents/cerebras_agent.py | 608 +++++++ build/lib/dimos/agents/claude_agent.py | 735 +++++++++ build/lib/dimos/agents/memory/__init__.py | 0 build/lib/dimos/agents/memory/base.py | 133 ++ build/lib/dimos/agents/memory/chroma_impl.py | 167 ++ .../dimos/agents/memory/image_embedding.py | 263 +++ .../dimos/agents/memory/spatial_vector_db.py | 268 +++ .../agents/memory/test_image_embedding.py | 212 +++ .../lib/dimos/agents/memory/visual_memory.py | 182 +++ build/lib/dimos/agents/planning_agent.py | 317 ++++ .../dimos/agents/prompt_builder/__init__.py | 0 build/lib/dimos/agents/prompt_builder/impl.py | 221 +++ build/lib/dimos/agents/tokenizer/__init__.py | 0 build/lib/dimos/agents/tokenizer/base.py | 37 + .../agents/tokenizer/huggingface_tokenizer.py | 88 + .../agents/tokenizer/openai_tokenizer.py | 88 + build/lib/dimos/core/__init__.py | 103 ++ build/lib/dimos/core/colors.py | 43 + build/lib/dimos/core/core.py | 260 +++ build/lib/dimos/core/module.py | 172 ++ build/lib/dimos/core/o3dpickle.py | 38 + build/lib/dimos/core/test_core.py | 199 +++ build/lib/dimos/core/transport.py | 102 ++ build/lib/dimos/environment/__init__.py | 0 .../dimos/environment/agent_environment.py | 139 ++ .../dimos/environment/colmap_environment.py | 89 + build/lib/dimos/environment/environment.py | 172 ++ build/lib/dimos/exceptions/__init__.py | 0 .../exceptions/agent_memory_exceptions.py | 89 + build/lib/dimos/hardware/__init__.py | 0 build/lib/dimos/hardware/camera.py | 52 + build/lib/dimos/hardware/end_effector.py | 21 + build/lib/dimos/hardware/interface.py | 51 + build/lib/dimos/hardware/piper_arm.py | 372 +++++ build/lib/dimos/hardware/sensor.py | 35 + build/lib/dimos/hardware/stereo_camera.py | 26 + .../dimos/hardware/test_simple_module(1).py | 90 + build/lib/dimos/hardware/ufactory.py | 32 + build/lib/dimos/hardware/zed_camera.py | 514 ++++++ build/lib/dimos/manipulation/__init__.py | 0 .../dimos/manipulation/manip_aio_pipeline.py | 590 +++++++ .../dimos/manipulation/manip_aio_processer.py | 411 +++++ .../manipulation/manipulation_history.py | 418 +++++ .../manipulation/manipulation_interface.py | 292 ++++ .../manipulation/test_manipulation_history.py | 461 ++++++ build/lib/dimos/models/__init__.py | 0 build/lib/dimos/models/depth/__init__.py | 0 build/lib/dimos/models/depth/metric3d.py | 173 ++ build/lib/dimos/models/labels/__init__.py | 0 build/lib/dimos/models/labels/llava-34b.py | 92 ++ .../lib/dimos/models/manipulation/__init__.py | 0 build/lib/dimos/models/pointcloud/__init__.py | 0 .../models/pointcloud/pointcloud_utils.py | 214 +++ .../lib/dimos/models/segmentation/__init__.py | 0 .../lib/dimos/models/segmentation/clipseg.py | 32 + build/lib/dimos/models/segmentation/sam.py | 35 + .../models/segmentation/segment_utils.py | 73 + build/lib/dimos/msgs/__init__.py | 0 build/lib/dimos/msgs/geometry_msgs/Pose.py | 181 ++ .../dimos/msgs/geometry_msgs/PoseStamped.py | 76 + .../dimos/msgs/geometry_msgs/Quaternion.py | 167 ++ build/lib/dimos/msgs/geometry_msgs/Twist.py | 73 + build/lib/dimos/msgs/geometry_msgs/Vector3.py | 467 ++++++ .../lib/dimos/msgs/geometry_msgs/__init__.py | 4 + .../lib/dimos/msgs/geometry_msgs/test_Pose.py | 555 +++++++ .../msgs/geometry_msgs/test_Quaternion.py | 210 +++ .../dimos/msgs/geometry_msgs/test_Vector3.py | 462 ++++++ .../dimos/msgs/geometry_msgs/test_publish.py | 54 + build/lib/dimos/msgs/sensor_msgs/Image.py | 372 +++++ .../lib/dimos/msgs/sensor_msgs/PointCloud2.py | 213 +++ build/lib/dimos/msgs/sensor_msgs/__init__.py | 2 + .../msgs/sensor_msgs/test_PointCloud2.py | 81 + .../lib/dimos/msgs/sensor_msgs/test_image.py | 63 + build/lib/dimos/perception/__init__.py | 0 build/lib/dimos/perception/common/__init__.py | 3 + .../lib/dimos/perception/common/cuboid_fit.py | 331 ++++ .../perception/common/detection2d_tracker.py | 385 +++++ .../perception/common/export_tensorrt.py | 57 + build/lib/dimos/perception/common/ibvs.py | 280 ++++ build/lib/dimos/perception/common/utils.py | 364 +++++ .../dimos/perception/detection2d/__init__.py | 2 + .../perception/detection2d/detic_2d_det.py | 414 +++++ .../detection2d/test_yolo_2d_det.py | 177 ++ .../lib/dimos/perception/detection2d/utils.py | 338 ++++ .../perception/detection2d/yolo_2d_det.py | 157 ++ .../perception/grasp_generation/__init__.py | 1 + .../grasp_generation/grasp_generation.py | 228 +++ .../perception/grasp_generation/utils.py | 621 +++++++ .../perception/object_detection_stream.py | 373 +++++ build/lib/dimos/perception/object_tracker.py | 357 ++++ build/lib/dimos/perception/person_tracker.py | 154 ++ .../dimos/perception/pointcloud/__init__.py | 3 + .../dimos/perception/pointcloud/cuboid_fit.py | 414 +++++ .../pointcloud/pointcloud_filtering.py | 674 ++++++++ .../lib/dimos/perception/pointcloud/utils.py | 1451 +++++++++++++++++ .../dimos/perception/segmentation/__init__.py | 2 + .../perception/segmentation/image_analyzer.py | 161 ++ .../perception/segmentation/sam_2d_seg.py | 335 ++++ .../segmentation/test_sam_2d_seg.py | 214 +++ .../dimos/perception/segmentation/utils.py | 315 ++++ build/lib/dimos/perception/semantic_seg.py | 245 +++ .../dimos/perception/spatial_perception.py | 438 +++++ .../dimos/perception/test_spatial_memory.py | 214 +++ build/lib/dimos/perception/visual_servoing.py | 500 ++++++ build/lib/dimos/robot/__init__.py | 0 build/lib/dimos/robot/connection_interface.py | 70 + build/lib/dimos/robot/foxglove_bridge.py | 49 + .../robot/frontier_exploration/__init__.py | 1 + .../qwen_frontier_predictor.py | 368 +++++ .../test_wavefront_frontier_goal_selector.py | 297 ++++ .../dimos/robot/frontier_exploration/utils.py | 188 +++ .../wavefront_frontier_goal_selector.py | 665 ++++++++ .../dimos/robot/global_planner/__init__.py | 1 + build/lib/dimos/robot/global_planner/algo.py | 273 ++++ .../lib/dimos/robot/global_planner/planner.py | 96 ++ .../lib/dimos/robot/local_planner/__init__.py | 7 + .../robot/local_planner/local_planner.py | 1442 ++++++++++++++++ build/lib/dimos/robot/local_planner/simple.py | 265 +++ .../robot/local_planner/vfh_local_planner.py | 435 +++++ build/lib/dimos/robot/position_stream.py | 162 ++ build/lib/dimos/robot/recorder.py | 159 ++ build/lib/dimos/robot/robot.py | 435 +++++ build/lib/dimos/robot/ros_command_queue.py | 471 ++++++ build/lib/dimos/robot/ros_control.py | 867 ++++++++++ build/lib/dimos/robot/ros_observable_topic.py | 240 +++ build/lib/dimos/robot/ros_transform.py | 243 +++ .../dimos/robot/test_ros_observable_topic.py | 255 +++ build/lib/dimos/robot/unitree/__init__.py | 0 build/lib/dimos/robot/unitree/unitree_go2.py | 208 +++ .../robot/unitree/unitree_ros_control.py | 157 ++ .../lib/dimos/robot/unitree/unitree_skills.py | 314 ++++ .../dimos/robot/unitree_webrtc/__init__.py | 0 .../dimos/robot/unitree_webrtc/connection.py | 309 ++++ .../robot/unitree_webrtc/testing/__init__.py | 0 .../robot/unitree_webrtc/testing/helpers.py | 168 ++ .../robot/unitree_webrtc/testing/mock.py | 91 ++ .../robot/unitree_webrtc/testing/multimock.py | 142 ++ .../robot/unitree_webrtc/testing/test_mock.py | 62 + .../unitree_webrtc/testing/test_multimock.py | 111 ++ .../robot/unitree_webrtc/type/__init__.py | 0 .../dimos/robot/unitree_webrtc/type/lidar.py | 138 ++ .../robot/unitree_webrtc/type/lowstate.py | 93 ++ .../dimos/robot/unitree_webrtc/type/map.py | 150 ++ .../robot/unitree_webrtc/type/odometry.py | 108 ++ .../robot/unitree_webrtc/type/test_lidar.py | 51 + .../robot/unitree_webrtc/type/test_map.py | 80 + .../unitree_webrtc/type/test_odometry.py | 109 ++ .../unitree_webrtc/type/test_timeseries.py | 44 + .../robot/unitree_webrtc/type/timeseries.py | 146 ++ .../dimos/robot/unitree_webrtc/type/vector.py | 448 +++++ .../dimos/robot/unitree_webrtc/unitree_go2.py | 224 +++ .../robot/unitree_webrtc/unitree_skills.py | 279 ++++ build/lib/dimos/simulation/__init__.py | 15 + build/lib/dimos/simulation/base/__init__.py | 0 .../dimos/simulation/base/simulator_base.py | 48 + .../lib/dimos/simulation/base/stream_base.py | 116 ++ .../lib/dimos/simulation/genesis/__init__.py | 4 + .../lib/dimos/simulation/genesis/simulator.py | 158 ++ build/lib/dimos/simulation/genesis/stream.py | 143 ++ build/lib/dimos/simulation/isaac/__init__.py | 4 + build/lib/dimos/simulation/isaac/simulator.py | 43 + build/lib/dimos/simulation/isaac/stream.py | 136 ++ build/lib/dimos/skills/__init__.py | 0 build/lib/dimos/skills/kill_skill.py | 62 + build/lib/dimos/skills/navigation.py | 699 ++++++++ build/lib/dimos/skills/observe.py | 192 +++ build/lib/dimos/skills/observe_stream.py | 245 +++ build/lib/dimos/skills/rest/__init__.py | 0 build/lib/dimos/skills/rest/rest.py | 99 ++ build/lib/dimos/skills/skills.py | 340 ++++ build/lib/dimos/skills/speak.py | 166 ++ build/lib/dimos/skills/unitree/__init__.py | 1 + .../lib/dimos/skills/unitree/unitree_speak.py | 280 ++++ .../dimos/skills/visual_navigation_skills.py | 148 ++ build/lib/dimos/stream/__init__.py | 0 build/lib/dimos/stream/audio/__init__.py | 0 build/lib/dimos/stream/audio/base.py | 114 ++ .../dimos/stream/audio/node_key_recorder.py | 336 ++++ .../lib/dimos/stream/audio/node_microphone.py | 131 ++ .../lib/dimos/stream/audio/node_normalizer.py | 220 +++ build/lib/dimos/stream/audio/node_output.py | 187 +++ .../lib/dimos/stream/audio/node_simulated.py | 221 +++ .../dimos/stream/audio/node_volume_monitor.py | 176 ++ build/lib/dimos/stream/audio/pipelines.py | 52 + build/lib/dimos/stream/audio/utils.py | 26 + build/lib/dimos/stream/audio/volume.py | 108 ++ build/lib/dimos/stream/data_provider.py | 183 +++ build/lib/dimos/stream/frame_processor.py | 300 ++++ build/lib/dimos/stream/ros_video_provider.py | 112 ++ build/lib/dimos/stream/rtsp_video_provider.py | 380 +++++ build/lib/dimos/stream/stream_merger.py | 45 + build/lib/dimos/stream/video_operators.py | 622 +++++++ build/lib/dimos/stream/video_provider.py | 235 +++ .../dimos/stream/video_providers/__init__.py | 0 .../dimos/stream/video_providers/unitree.py | 167 ++ build/lib/dimos/stream/videostream.py | 41 + build/lib/dimos/types/__init__.py | 0 build/lib/dimos/types/constants.py | 24 + build/lib/dimos/types/costmap.py | 534 ++++++ build/lib/dimos/types/label.py | 39 + build/lib/dimos/types/manipulation.py | 155 ++ build/lib/dimos/types/path.py | 414 +++++ build/lib/dimos/types/pose.py | 149 ++ build/lib/dimos/types/robot_capabilities.py | 27 + build/lib/dimos/types/robot_location.py | 130 ++ build/lib/dimos/types/ros_polyfill.py | 103 ++ build/lib/dimos/types/sample.py | 572 +++++++ build/lib/dimos/types/segmentation.py | 44 + build/lib/dimos/types/test_pose.py | 323 ++++ build/lib/dimos/types/test_timestamped.py | 26 + build/lib/dimos/types/test_vector.py | 384 +++++ build/lib/dimos/types/timestamped.py | 55 + build/lib/dimos/types/vector.py | 460 ++++++ build/lib/dimos/web/__init__.py | 0 .../lib/dimos/web/dimos_interface/__init__.py | 7 + .../dimos/web/dimos_interface/api/__init__.py | 0 .../dimos/web/dimos_interface/api/server.py | 362 ++++ build/lib/dimos/web/edge_io.py | 26 + build/lib/dimos/web/fastapi_server.py | 224 +++ build/lib/dimos/web/flask_server.py | 95 ++ build/lib/dimos/web/robot_web_interface.py | 35 + build/lib/tests/__init__.py | 1 + .../tests/agent_manip_flow_fastapi_test.py | 153 ++ .../lib/tests/agent_manip_flow_flask_test.py | 195 +++ build/lib/tests/agent_memory_test.py | 61 + build/lib/tests/colmap_test.py | 25 + build/lib/tests/run.py | 361 ++++ build/lib/tests/run_go2_ros.py | 178 ++ build/lib/tests/run_navigation_only.py | 191 +++ build/lib/tests/simple_agent_test.py | 39 + build/lib/tests/test_agent.py | 60 + build/lib/tests/test_agent_alibaba.py | 59 + .../tests/test_agent_ctransformers_gguf.py | 44 + .../lib/tests/test_agent_huggingface_local.py | 72 + .../test_agent_huggingface_local_jetson.py | 73 + .../tests/test_agent_huggingface_remote.py | 64 + build/lib/tests/test_audio_agent.py | 39 + build/lib/tests/test_audio_robot_agent.py | 51 + build/lib/tests/test_cerebras_unitree_ros.py | 118 ++ build/lib/tests/test_claude_agent_query.py | 29 + .../tests/test_claude_agent_skills_query.py | 135 ++ build/lib/tests/test_command_pose_unitree.py | 82 + build/lib/tests/test_header.py | 58 + build/lib/tests/test_huggingface_llm_agent.py | 116 ++ build/lib/tests/test_ibvs.py | 229 +++ build/lib/tests/test_manipulation_agent.py | 346 ++++ ...est_manipulation_perception_pipeline.py.py | 167 ++ ...test_manipulation_pipeline_single_frame.py | 248 +++ ..._manipulation_pipeline_single_frame_lcm.py | 431 +++++ build/lib/tests/test_move_vel_unitree.py | 32 + .../tests/test_navigate_to_object_robot.py | 137 ++ build/lib/tests/test_navigation_skills.py | 269 +++ ...bject_detection_agent_data_query_stream.py | 191 +++ .../lib/tests/test_object_detection_stream.py | 240 +++ .../lib/tests/test_object_tracking_webcam.py | 222 +++ .../tests/test_object_tracking_with_qwen.py | 216 +++ build/lib/tests/test_observe_stream_skill.py | 131 ++ .../lib/tests/test_person_following_robot.py | 113 ++ .../lib/tests/test_person_following_webcam.py | 230 +++ .../test_planning_agent_web_interface.py | 180 ++ build/lib/tests/test_planning_robot_agent.py | 177 ++ build/lib/tests/test_pointcloud_filtering.py | 105 ++ build/lib/tests/test_qwen_image_query.py | 49 + build/lib/tests/test_robot.py | 86 + build/lib/tests/test_rtsp_video_provider.py | 146 ++ build/lib/tests/test_semantic_seg_robot.py | 151 ++ .../tests/test_semantic_seg_robot_agent.py | 141 ++ build/lib/tests/test_semantic_seg_webcam.py | 140 ++ build/lib/tests/test_skills.py | 185 +++ build/lib/tests/test_skills_rest.py | 73 + build/lib/tests/test_spatial_memory.py | 297 ++++ build/lib/tests/test_spatial_memory_query.py | 297 ++++ build/lib/tests/test_standalone_chromadb.py | 87 + build/lib/tests/test_standalone_fastapi.py | 81 + .../lib/tests/test_standalone_hugging_face.py | 147 ++ .../lib/tests/test_standalone_openai_json.py | 108 ++ .../test_standalone_openai_json_struct.py | 92 ++ ...test_standalone_openai_json_struct_func.py | 177 ++ ...lone_openai_json_struct_func_playground.py | 222 +++ .../lib/tests/test_standalone_project_out.py | 141 ++ build/lib/tests/test_standalone_rxpy_01.py | 133 ++ build/lib/tests/test_unitree_agent.py | 318 ++++ .../test_unitree_agent_queries_fastapi.py | 105 ++ build/lib/tests/test_unitree_ros_v0.0.4.py | 198 +++ build/lib/tests/test_webrtc_queue.py | 156 ++ build/lib/tests/test_websocketvis.py | 152 ++ build/lib/tests/test_zed_setup.py | 182 +++ build/lib/tests/visualization_script.py | 1041 ++++++++++++ build/lib/tests/zed_neural_depth_demo.py | 450 +++++ 297 files changed, 54832 insertions(+) create mode 100644 It create mode 100644 build/lib/dimos/__init__.py create mode 100644 build/lib/dimos/agents/__init__.py create mode 100644 build/lib/dimos/agents/agent.py create mode 100644 build/lib/dimos/agents/agent_config.py create mode 100644 build/lib/dimos/agents/agent_ctransformers_gguf.py create mode 100644 build/lib/dimos/agents/agent_huggingface_local.py create mode 100644 build/lib/dimos/agents/agent_huggingface_remote.py create mode 100644 build/lib/dimos/agents/cerebras_agent.py create mode 100644 build/lib/dimos/agents/claude_agent.py create mode 100644 build/lib/dimos/agents/memory/__init__.py create mode 100644 build/lib/dimos/agents/memory/base.py create mode 100644 build/lib/dimos/agents/memory/chroma_impl.py create mode 100644 build/lib/dimos/agents/memory/image_embedding.py create mode 100644 build/lib/dimos/agents/memory/spatial_vector_db.py create mode 100644 build/lib/dimos/agents/memory/test_image_embedding.py create mode 100644 build/lib/dimos/agents/memory/visual_memory.py create mode 100644 build/lib/dimos/agents/planning_agent.py create mode 100644 build/lib/dimos/agents/prompt_builder/__init__.py create mode 100644 build/lib/dimos/agents/prompt_builder/impl.py create mode 100644 build/lib/dimos/agents/tokenizer/__init__.py create mode 100644 build/lib/dimos/agents/tokenizer/base.py create mode 100644 build/lib/dimos/agents/tokenizer/huggingface_tokenizer.py create mode 100644 build/lib/dimos/agents/tokenizer/openai_tokenizer.py create mode 100644 build/lib/dimos/core/__init__.py create mode 100644 build/lib/dimos/core/colors.py create mode 100644 build/lib/dimos/core/core.py create mode 100644 build/lib/dimos/core/module.py create mode 100644 build/lib/dimos/core/o3dpickle.py create mode 100644 build/lib/dimos/core/test_core.py create mode 100644 build/lib/dimos/core/transport.py create mode 100644 build/lib/dimos/environment/__init__.py create mode 100644 build/lib/dimos/environment/agent_environment.py create mode 100644 build/lib/dimos/environment/colmap_environment.py create mode 100644 build/lib/dimos/environment/environment.py create mode 100644 build/lib/dimos/exceptions/__init__.py create mode 100644 build/lib/dimos/exceptions/agent_memory_exceptions.py create mode 100644 build/lib/dimos/hardware/__init__.py create mode 100644 build/lib/dimos/hardware/camera.py create mode 100644 build/lib/dimos/hardware/end_effector.py create mode 100644 build/lib/dimos/hardware/interface.py create mode 100644 build/lib/dimos/hardware/piper_arm.py create mode 100644 build/lib/dimos/hardware/sensor.py create mode 100644 build/lib/dimos/hardware/stereo_camera.py create mode 100644 build/lib/dimos/hardware/test_simple_module(1).py create mode 100644 build/lib/dimos/hardware/ufactory.py create mode 100644 build/lib/dimos/hardware/zed_camera.py create mode 100644 build/lib/dimos/manipulation/__init__.py create mode 100644 build/lib/dimos/manipulation/manip_aio_pipeline.py create mode 100644 build/lib/dimos/manipulation/manip_aio_processer.py create mode 100644 build/lib/dimos/manipulation/manipulation_history.py create mode 100644 build/lib/dimos/manipulation/manipulation_interface.py create mode 100644 build/lib/dimos/manipulation/test_manipulation_history.py create mode 100644 build/lib/dimos/models/__init__.py create mode 100644 build/lib/dimos/models/depth/__init__.py create mode 100644 build/lib/dimos/models/depth/metric3d.py create mode 100644 build/lib/dimos/models/labels/__init__.py create mode 100644 build/lib/dimos/models/labels/llava-34b.py create mode 100644 build/lib/dimos/models/manipulation/__init__.py create mode 100644 build/lib/dimos/models/pointcloud/__init__.py create mode 100644 build/lib/dimos/models/pointcloud/pointcloud_utils.py create mode 100644 build/lib/dimos/models/segmentation/__init__.py create mode 100644 build/lib/dimos/models/segmentation/clipseg.py create mode 100644 build/lib/dimos/models/segmentation/sam.py create mode 100644 build/lib/dimos/models/segmentation/segment_utils.py create mode 100644 build/lib/dimos/msgs/__init__.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/Pose.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/PoseStamped.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/Quaternion.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/Twist.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/Vector3.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/__init__.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/test_Pose.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/test_Quaternion.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/test_Vector3.py create mode 100644 build/lib/dimos/msgs/geometry_msgs/test_publish.py create mode 100644 build/lib/dimos/msgs/sensor_msgs/Image.py create mode 100644 build/lib/dimos/msgs/sensor_msgs/PointCloud2.py create mode 100644 build/lib/dimos/msgs/sensor_msgs/__init__.py create mode 100644 build/lib/dimos/msgs/sensor_msgs/test_PointCloud2.py create mode 100644 build/lib/dimos/msgs/sensor_msgs/test_image.py create mode 100644 build/lib/dimos/perception/__init__.py create mode 100644 build/lib/dimos/perception/common/__init__.py create mode 100644 build/lib/dimos/perception/common/cuboid_fit.py create mode 100644 build/lib/dimos/perception/common/detection2d_tracker.py create mode 100644 build/lib/dimos/perception/common/export_tensorrt.py create mode 100644 build/lib/dimos/perception/common/ibvs.py create mode 100644 build/lib/dimos/perception/common/utils.py create mode 100644 build/lib/dimos/perception/detection2d/__init__.py create mode 100644 build/lib/dimos/perception/detection2d/detic_2d_det.py create mode 100644 build/lib/dimos/perception/detection2d/test_yolo_2d_det.py create mode 100644 build/lib/dimos/perception/detection2d/utils.py create mode 100644 build/lib/dimos/perception/detection2d/yolo_2d_det.py create mode 100644 build/lib/dimos/perception/grasp_generation/__init__.py create mode 100644 build/lib/dimos/perception/grasp_generation/grasp_generation.py create mode 100644 build/lib/dimos/perception/grasp_generation/utils.py create mode 100644 build/lib/dimos/perception/object_detection_stream.py create mode 100644 build/lib/dimos/perception/object_tracker.py create mode 100644 build/lib/dimos/perception/person_tracker.py create mode 100644 build/lib/dimos/perception/pointcloud/__init__.py create mode 100644 build/lib/dimos/perception/pointcloud/cuboid_fit.py create mode 100644 build/lib/dimos/perception/pointcloud/pointcloud_filtering.py create mode 100644 build/lib/dimos/perception/pointcloud/utils.py create mode 100644 build/lib/dimos/perception/segmentation/__init__.py create mode 100644 build/lib/dimos/perception/segmentation/image_analyzer.py create mode 100644 build/lib/dimos/perception/segmentation/sam_2d_seg.py create mode 100644 build/lib/dimos/perception/segmentation/test_sam_2d_seg.py create mode 100644 build/lib/dimos/perception/segmentation/utils.py create mode 100644 build/lib/dimos/perception/semantic_seg.py create mode 100644 build/lib/dimos/perception/spatial_perception.py create mode 100644 build/lib/dimos/perception/test_spatial_memory.py create mode 100644 build/lib/dimos/perception/visual_servoing.py create mode 100644 build/lib/dimos/robot/__init__.py create mode 100644 build/lib/dimos/robot/connection_interface.py create mode 100644 build/lib/dimos/robot/foxglove_bridge.py create mode 100644 build/lib/dimos/robot/frontier_exploration/__init__.py create mode 100644 build/lib/dimos/robot/frontier_exploration/qwen_frontier_predictor.py create mode 100644 build/lib/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py create mode 100644 build/lib/dimos/robot/frontier_exploration/utils.py create mode 100644 build/lib/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py create mode 100644 build/lib/dimos/robot/global_planner/__init__.py create mode 100644 build/lib/dimos/robot/global_planner/algo.py create mode 100644 build/lib/dimos/robot/global_planner/planner.py create mode 100644 build/lib/dimos/robot/local_planner/__init__.py create mode 100644 build/lib/dimos/robot/local_planner/local_planner.py create mode 100644 build/lib/dimos/robot/local_planner/simple.py create mode 100644 build/lib/dimos/robot/local_planner/vfh_local_planner.py create mode 100644 build/lib/dimos/robot/position_stream.py create mode 100644 build/lib/dimos/robot/recorder.py create mode 100644 build/lib/dimos/robot/robot.py create mode 100644 build/lib/dimos/robot/ros_command_queue.py create mode 100644 build/lib/dimos/robot/ros_control.py create mode 100644 build/lib/dimos/robot/ros_observable_topic.py create mode 100644 build/lib/dimos/robot/ros_transform.py create mode 100644 build/lib/dimos/robot/test_ros_observable_topic.py create mode 100644 build/lib/dimos/robot/unitree/__init__.py create mode 100644 build/lib/dimos/robot/unitree/unitree_go2.py create mode 100644 build/lib/dimos/robot/unitree/unitree_ros_control.py create mode 100644 build/lib/dimos/robot/unitree/unitree_skills.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/__init__.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/connection.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/__init__.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/helpers.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/mock.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/multimock.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/test_mock.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/test_multimock.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/__init__.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/lidar.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/lowstate.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/map.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/odometry.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/test_lidar.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/test_map.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/test_odometry.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/test_timeseries.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/timeseries.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/type/vector.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/unitree_go2.py create mode 100644 build/lib/dimos/robot/unitree_webrtc/unitree_skills.py create mode 100644 build/lib/dimos/simulation/__init__.py create mode 100644 build/lib/dimos/simulation/base/__init__.py create mode 100644 build/lib/dimos/simulation/base/simulator_base.py create mode 100644 build/lib/dimos/simulation/base/stream_base.py create mode 100644 build/lib/dimos/simulation/genesis/__init__.py create mode 100644 build/lib/dimos/simulation/genesis/simulator.py create mode 100644 build/lib/dimos/simulation/genesis/stream.py create mode 100644 build/lib/dimos/simulation/isaac/__init__.py create mode 100644 build/lib/dimos/simulation/isaac/simulator.py create mode 100644 build/lib/dimos/simulation/isaac/stream.py create mode 100644 build/lib/dimos/skills/__init__.py create mode 100644 build/lib/dimos/skills/kill_skill.py create mode 100644 build/lib/dimos/skills/navigation.py create mode 100644 build/lib/dimos/skills/observe.py create mode 100644 build/lib/dimos/skills/observe_stream.py create mode 100644 build/lib/dimos/skills/rest/__init__.py create mode 100644 build/lib/dimos/skills/rest/rest.py create mode 100644 build/lib/dimos/skills/skills.py create mode 100644 build/lib/dimos/skills/speak.py create mode 100644 build/lib/dimos/skills/unitree/__init__.py create mode 100644 build/lib/dimos/skills/unitree/unitree_speak.py create mode 100644 build/lib/dimos/skills/visual_navigation_skills.py create mode 100644 build/lib/dimos/stream/__init__.py create mode 100644 build/lib/dimos/stream/audio/__init__.py create mode 100644 build/lib/dimos/stream/audio/base.py create mode 100644 build/lib/dimos/stream/audio/node_key_recorder.py create mode 100644 build/lib/dimos/stream/audio/node_microphone.py create mode 100644 build/lib/dimos/stream/audio/node_normalizer.py create mode 100644 build/lib/dimos/stream/audio/node_output.py create mode 100644 build/lib/dimos/stream/audio/node_simulated.py create mode 100644 build/lib/dimos/stream/audio/node_volume_monitor.py create mode 100644 build/lib/dimos/stream/audio/pipelines.py create mode 100644 build/lib/dimos/stream/audio/utils.py create mode 100644 build/lib/dimos/stream/audio/volume.py create mode 100644 build/lib/dimos/stream/data_provider.py create mode 100644 build/lib/dimos/stream/frame_processor.py create mode 100644 build/lib/dimos/stream/ros_video_provider.py create mode 100644 build/lib/dimos/stream/rtsp_video_provider.py create mode 100644 build/lib/dimos/stream/stream_merger.py create mode 100644 build/lib/dimos/stream/video_operators.py create mode 100644 build/lib/dimos/stream/video_provider.py create mode 100644 build/lib/dimos/stream/video_providers/__init__.py create mode 100644 build/lib/dimos/stream/video_providers/unitree.py create mode 100644 build/lib/dimos/stream/videostream.py create mode 100644 build/lib/dimos/types/__init__.py create mode 100644 build/lib/dimos/types/constants.py create mode 100644 build/lib/dimos/types/costmap.py create mode 100644 build/lib/dimos/types/label.py create mode 100644 build/lib/dimos/types/manipulation.py create mode 100644 build/lib/dimos/types/path.py create mode 100644 build/lib/dimos/types/pose.py create mode 100644 build/lib/dimos/types/robot_capabilities.py create mode 100644 build/lib/dimos/types/robot_location.py create mode 100644 build/lib/dimos/types/ros_polyfill.py create mode 100644 build/lib/dimos/types/sample.py create mode 100644 build/lib/dimos/types/segmentation.py create mode 100644 build/lib/dimos/types/test_pose.py create mode 100644 build/lib/dimos/types/test_timestamped.py create mode 100644 build/lib/dimos/types/test_vector.py create mode 100644 build/lib/dimos/types/timestamped.py create mode 100644 build/lib/dimos/types/vector.py create mode 100644 build/lib/dimos/web/__init__.py create mode 100644 build/lib/dimos/web/dimos_interface/__init__.py create mode 100644 build/lib/dimos/web/dimos_interface/api/__init__.py create mode 100644 build/lib/dimos/web/dimos_interface/api/server.py create mode 100644 build/lib/dimos/web/edge_io.py create mode 100644 build/lib/dimos/web/fastapi_server.py create mode 100644 build/lib/dimos/web/flask_server.py create mode 100644 build/lib/dimos/web/robot_web_interface.py create mode 100644 build/lib/tests/__init__.py create mode 100644 build/lib/tests/agent_manip_flow_fastapi_test.py create mode 100644 build/lib/tests/agent_manip_flow_flask_test.py create mode 100644 build/lib/tests/agent_memory_test.py create mode 100644 build/lib/tests/colmap_test.py create mode 100644 build/lib/tests/run.py create mode 100644 build/lib/tests/run_go2_ros.py create mode 100644 build/lib/tests/run_navigation_only.py create mode 100644 build/lib/tests/simple_agent_test.py create mode 100644 build/lib/tests/test_agent.py create mode 100644 build/lib/tests/test_agent_alibaba.py create mode 100644 build/lib/tests/test_agent_ctransformers_gguf.py create mode 100644 build/lib/tests/test_agent_huggingface_local.py create mode 100644 build/lib/tests/test_agent_huggingface_local_jetson.py create mode 100644 build/lib/tests/test_agent_huggingface_remote.py create mode 100644 build/lib/tests/test_audio_agent.py create mode 100644 build/lib/tests/test_audio_robot_agent.py create mode 100644 build/lib/tests/test_cerebras_unitree_ros.py create mode 100644 build/lib/tests/test_claude_agent_query.py create mode 100644 build/lib/tests/test_claude_agent_skills_query.py create mode 100644 build/lib/tests/test_command_pose_unitree.py create mode 100644 build/lib/tests/test_header.py create mode 100644 build/lib/tests/test_huggingface_llm_agent.py create mode 100644 build/lib/tests/test_ibvs.py create mode 100644 build/lib/tests/test_manipulation_agent.py create mode 100644 build/lib/tests/test_manipulation_perception_pipeline.py.py create mode 100644 build/lib/tests/test_manipulation_pipeline_single_frame.py create mode 100644 build/lib/tests/test_manipulation_pipeline_single_frame_lcm.py create mode 100644 build/lib/tests/test_move_vel_unitree.py create mode 100644 build/lib/tests/test_navigate_to_object_robot.py create mode 100644 build/lib/tests/test_navigation_skills.py create mode 100644 build/lib/tests/test_object_detection_agent_data_query_stream.py create mode 100644 build/lib/tests/test_object_detection_stream.py create mode 100644 build/lib/tests/test_object_tracking_webcam.py create mode 100644 build/lib/tests/test_object_tracking_with_qwen.py create mode 100644 build/lib/tests/test_observe_stream_skill.py create mode 100644 build/lib/tests/test_person_following_robot.py create mode 100644 build/lib/tests/test_person_following_webcam.py create mode 100644 build/lib/tests/test_planning_agent_web_interface.py create mode 100644 build/lib/tests/test_planning_robot_agent.py create mode 100644 build/lib/tests/test_pointcloud_filtering.py create mode 100644 build/lib/tests/test_qwen_image_query.py create mode 100644 build/lib/tests/test_robot.py create mode 100644 build/lib/tests/test_rtsp_video_provider.py create mode 100644 build/lib/tests/test_semantic_seg_robot.py create mode 100644 build/lib/tests/test_semantic_seg_robot_agent.py create mode 100644 build/lib/tests/test_semantic_seg_webcam.py create mode 100644 build/lib/tests/test_skills.py create mode 100644 build/lib/tests/test_skills_rest.py create mode 100644 build/lib/tests/test_spatial_memory.py create mode 100644 build/lib/tests/test_spatial_memory_query.py create mode 100644 build/lib/tests/test_standalone_chromadb.py create mode 100644 build/lib/tests/test_standalone_fastapi.py create mode 100644 build/lib/tests/test_standalone_hugging_face.py create mode 100644 build/lib/tests/test_standalone_openai_json.py create mode 100644 build/lib/tests/test_standalone_openai_json_struct.py create mode 100644 build/lib/tests/test_standalone_openai_json_struct_func.py create mode 100644 build/lib/tests/test_standalone_openai_json_struct_func_playground.py create mode 100644 build/lib/tests/test_standalone_project_out.py create mode 100644 build/lib/tests/test_standalone_rxpy_01.py create mode 100644 build/lib/tests/test_unitree_agent.py create mode 100644 build/lib/tests/test_unitree_agent_queries_fastapi.py create mode 100644 build/lib/tests/test_unitree_ros_v0.0.4.py create mode 100644 build/lib/tests/test_webrtc_queue.py create mode 100644 build/lib/tests/test_websocketvis.py create mode 100644 build/lib/tests/test_zed_setup.py create mode 100644 build/lib/tests/visualization_script.py create mode 100644 build/lib/tests/zed_neural_depth_demo.py diff --git a/It b/It new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/__init__.py b/build/lib/dimos/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/build/lib/dimos/__init__.py @@ -0,0 +1 @@ + diff --git a/build/lib/dimos/agents/__init__.py b/build/lib/dimos/agents/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/agents/agent.py b/build/lib/dimos/agents/agent.py new file mode 100644 index 0000000000..1ce2216fe7 --- /dev/null +++ b/build/lib/dimos/agents/agent.py @@ -0,0 +1,904 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent framework for LLM-based autonomous systems. + +This module provides a flexible foundation for creating agents that can: +- Process image and text inputs through LLM APIs +- Store and retrieve contextual information using semantic memory +- Handle tool/function calling +- Process streaming inputs asynchronously + +The module offers base classes (Agent, LLMAgent) and concrete implementations +like OpenAIAgent that connect to specific LLM providers. +""" + +from __future__ import annotations + +# Standard library imports +import json +import os +import threading +from typing import Any, Tuple, Optional, Union + +# Third-party imports +from dotenv import load_dotenv +from openai import NOT_GIVEN, OpenAI +from pydantic import BaseModel +from reactivex import Observer, create, Observable, empty, operators as RxOps, just +from reactivex.disposable import CompositeDisposable, Disposable +from reactivex.scheduler import ThreadPoolScheduler +from reactivex.subject import Subject + +# Local imports +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.agents.memory.chroma_impl import OpenAISemanticMemory +from dimos.agents.prompt_builder.impl import PromptBuilder +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.stream_merger import create_stream_merger +from dimos.stream.video_operators import Operators as MyOps, VideoOperators as MyVidOps +from dimos.utils.threadpool import get_scheduler +from dimos.utils.logging_config import setup_logger + +# Initialize environment variables +load_dotenv() + +# Initialize logger for the agent module +logger = setup_logger("dimos.agents") + +# Constants +_TOKEN_BUDGET_PARTS = 4 # Number of parts to divide token budget +_MAX_SAVED_FRAMES = 100 # Maximum number of frames to save + + +# ----------------------------------------------------------------------------- +# region Agent Base Class +# ----------------------------------------------------------------------------- +class Agent: + """Base agent that manages memory and subscriptions.""" + + def __init__( + self, + dev_name: str = "NA", + agent_type: str = "Base", + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + pool_scheduler: Optional[ThreadPoolScheduler] = None, + ): + """ + Initializes a new instance of the Agent. + + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent (e.g., 'Base', 'Vision'). + agent_memory (AbstractAgentSemanticMemory): The memory system for the agent. + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + """ + self.dev_name = dev_name + self.agent_type = agent_type + self.agent_memory = agent_memory or OpenAISemanticMemory() + self.disposables = CompositeDisposable() + self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() + + def dispose_all(self): + """Disposes of all active subscriptions managed by this agent.""" + if self.disposables: + self.disposables.dispose() + else: + logger.info("No disposables to dispose.") + + +# endregion Agent Base Class + + +# ----------------------------------------------------------------------------- +# region LLMAgent Base Class (Generic LLM Agent) +# ----------------------------------------------------------------------------- +class LLMAgent(Agent): + """Generic LLM agent containing common logic for LLM-based agents. + + This class implements functionality for: + - Updating the query + - Querying the agent's memory (for RAG) + - Building prompts via a prompt builder + - Handling tooling callbacks in responses + - Subscribing to image and query streams + - Emitting responses as an observable stream + + Subclasses must implement the `_send_query` method, which is responsible + for sending the prompt to a specific LLM API. + + Attributes: + query (str): The current query text to process. + prompt_builder (PromptBuilder): Handles construction of prompts. + system_query (str): System prompt for RAG context situations. + image_detail (str): Detail level for image processing ('low','high','auto'). + max_input_tokens_per_request (int): Maximum input token count. + max_output_tokens_per_request (int): Maximum output token count. + max_tokens_per_request (int): Total maximum token count. + rag_query_n (int): Number of results to fetch from memory. + rag_similarity_threshold (float): Minimum similarity for RAG results. + frame_processor (FrameProcessor): Processes video frames. + output_dir (str): Directory for output files. + response_subject (Subject): Subject that emits agent responses. + process_all_inputs (bool): Whether to process every input emission (True) or + skip emissions when the agent is busy processing a previous input (False). + """ + + logging_file_memory_lock = threading.Lock() + + def __init__( + self, + dev_name: str = "NA", + agent_type: str = "LLM", + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + pool_scheduler: Optional[ThreadPoolScheduler] = None, + process_all_inputs: bool = False, + system_query: Optional[str] = None, + max_output_tokens_per_request: int = 16384, + max_input_tokens_per_request: int = 128000, + input_query_stream: Optional[Observable] = None, + input_data_stream: Optional[Observable] = None, + input_video_stream: Optional[Observable] = None, + ): + """ + Initializes a new instance of the LLMAgent. + + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent. + agent_memory (AbstractAgentSemanticMemory): The memory system for the agent. + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + process_all_inputs (bool): Whether to process every input emission (True) or + skip emissions when the agent is busy processing a previous input (False). + """ + super().__init__(dev_name, agent_type, agent_memory, pool_scheduler) + # These attributes can be configured by a subclass if needed. + self.query: Optional[str] = None + self.prompt_builder: Optional[PromptBuilder] = None + self.system_query: Optional[str] = system_query + self.image_detail: str = "low" + self.max_input_tokens_per_request: int = max_input_tokens_per_request + self.max_output_tokens_per_request: int = max_output_tokens_per_request + self.max_tokens_per_request: int = ( + self.max_input_tokens_per_request + self.max_output_tokens_per_request + ) + self.rag_query_n: int = 4 + self.rag_similarity_threshold: float = 0.45 + self.frame_processor: Optional[FrameProcessor] = None + self.output_dir: str = os.path.join(os.getcwd(), "assets", "agent") + self.process_all_inputs: bool = process_all_inputs + os.makedirs(self.output_dir, exist_ok=True) + + # Subject for emitting responses + self.response_subject = Subject() + + # Conversation history for maintaining context between calls + self.conversation_history = [] + + # Initialize input streams + self.input_video_stream = input_video_stream + self.input_query_stream = ( + input_query_stream + if (input_data_stream is None) + else ( + input_query_stream.pipe( + RxOps.with_latest_from(input_data_stream), + RxOps.map( + lambda combined: { + "query": combined[0], + "objects": combined[1] + if len(combined) > 1 + else "No object data available", + } + ), + RxOps.map( + lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}" + ), + RxOps.do_action( + lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m") + or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]] + ), + ) + ) + ) + + # Setup stream subscriptions based on inputs provided + if (self.input_video_stream is not None) and (self.input_query_stream is not None): + self.merged_stream = create_stream_merger( + data_input_stream=self.input_video_stream, text_query_stream=self.input_query_stream + ) + + logger.info("Subscribing to merged input stream...") + # Define a query extractor for the merged stream + query_extractor = lambda emission: (emission[0], emission[1][0]) + self.disposables.add( + self.subscribe_to_image_processing( + self.merged_stream, query_extractor=query_extractor + ) + ) + else: + # If no merged stream, fall back to individual streams + if self.input_video_stream is not None: + logger.info("Subscribing to input video stream...") + self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) + if self.input_query_stream is not None: + logger.info("Subscribing to input query stream...") + self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) + + def _update_query(self, incoming_query: Optional[str]) -> None: + """Updates the query if an incoming query is provided. + + Args: + incoming_query (str): The new query text. + """ + if incoming_query is not None: + self.query = incoming_query + + def _get_rag_context(self) -> Tuple[str, str]: + """Queries the agent memory to retrieve RAG context. + + Returns: + Tuple[str, str]: A tuple containing the formatted results (for logging) + and condensed results (for use in the prompt). + """ + results = self.agent_memory.query( + query_texts=self.query, + n_results=self.rag_query_n, + similarity_threshold=self.rag_similarity_threshold, + ) + formatted_results = "\n".join( + f"Document ID: {doc.id}\nMetadata: {doc.metadata}\nContent: {doc.page_content}\nScore: {score}\n" + for (doc, score) in results + ) + condensed_results = " | ".join(f"{doc.page_content}" for (doc, _) in results) + logger.info(f"Agent Memory Query Results:\n{formatted_results}") + logger.info("=== Results End ===") + return formatted_results, condensed_results + + def _build_prompt( + self, + base64_image: Optional[str], + dimensions: Optional[Tuple[int, int]], + override_token_limit: bool, + condensed_results: str, + ) -> list: + """Builds a prompt message using the prompt builder. + + Args: + base64_image (str): Optional Base64-encoded image. + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + condensed_results (str): The condensed RAG context. + + Returns: + list: A list of message dictionaries to be sent to the LLM. + """ + # Budget for each component of the prompt + budgets = { + "system_prompt": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, + "user_query": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, + "image": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, + "rag": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, + } + + # Define truncation policies for each component + policies = { + "system_prompt": "truncate_end", + "user_query": "truncate_middle", + "image": "do_not_truncate", + "rag": "truncate_end", + } + + return self.prompt_builder.build( + user_query=self.query, + override_token_limit=override_token_limit, + base64_image=base64_image, + image_width=dimensions[0] if dimensions is not None else None, + image_height=dimensions[1] if dimensions is not None else None, + image_detail=self.image_detail, + rag_context=condensed_results, + system_prompt=self.system_query, + budgets=budgets, + policies=policies, + ) + + def _handle_tooling(self, response_message, messages): + """Handles tooling callbacks in the response message. + + If tool calls are present, the corresponding functions are executed and + a follow-up query is sent. + + Args: + response_message: The response message containing tool calls. + messages (list): The original list of messages sent. + + Returns: + The final response message after processing tool calls, if any. + """ + + # TODO: Make this more generic or move implementation to OpenAIAgent. + # This is presently OpenAI-specific. + def _tooling_callback(message, messages, response_message, skill_library: SkillLibrary): + has_called_tools = False + new_messages = [] + for tool_call in message.tool_calls: + has_called_tools = True + name = tool_call.function.name + args = json.loads(tool_call.function.arguments) + result = skill_library.call(name, **args) + logger.info(f"Function Call Results: {result}") + new_messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": str(result), + "name": name, + } + ) + if has_called_tools: + logger.info("Sending Another Query.") + messages.append(response_message) + messages.extend(new_messages) + # Delegate to sending the query again. + return self._send_query(messages) + else: + logger.info("No Need for Another Query.") + return None + + if response_message.tool_calls is not None: + return _tooling_callback( + response_message, messages, response_message, self.skill_library + ) + return None + + def _observable_query( + self, + observer: Observer, + base64_image: Optional[str] = None, + dimensions: Optional[Tuple[int, int]] = None, + override_token_limit: bool = False, + incoming_query: Optional[str] = None, + ): + """Prepares and sends a query to the LLM, emitting the response to the observer. + + Args: + observer (Observer): The observer to emit responses to. + base64_image (str): Optional Base64-encoded image. + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + incoming_query (str): Optional query to update the agent's query. + + Raises: + Exception: Propagates any exceptions encountered during processing. + """ + try: + self._update_query(incoming_query) + _, condensed_results = self._get_rag_context() + messages = self._build_prompt( + base64_image, dimensions, override_token_limit, condensed_results + ) + # logger.debug(f"Sending Query: {messages}") + logger.info("Sending Query.") + response_message = self._send_query(messages) + logger.info(f"Received Response: {response_message}") + if response_message is None: + raise Exception("Response message does not exist.") + + # TODO: Make this more generic. The parsed tag and tooling handling may be OpenAI-specific. + # If no skill library is provided or there are no tool calls, emit the response directly. + if ( + self.skill_library is None + or self.skill_library.get_tools() in (None, NOT_GIVEN) + or response_message.tool_calls is None + ): + final_msg = ( + response_message.parsed + if hasattr(response_message, "parsed") and response_message.parsed + else ( + response_message.content + if hasattr(response_message, "content") + else response_message + ) + ) + observer.on_next(final_msg) + self.response_subject.on_next(final_msg) + else: + response_message_2 = self._handle_tooling(response_message, messages) + final_msg = ( + response_message_2 if response_message_2 is not None else response_message + ) + if isinstance(final_msg, BaseModel): # TODO: Test + final_msg = str(final_msg.content) + observer.on_next(final_msg) + self.response_subject.on_next(final_msg) + observer.on_completed() + except Exception as e: + logger.error(f"Query failed in {self.dev_name}: {e}") + observer.on_error(e) + self.response_subject.on_error(e) + + def _send_query(self, messages: list) -> Any: + """Sends the query to the LLM API. + + This method must be implemented by subclasses with specifics of the LLM API. + + Args: + messages (list): The prompt messages to be sent. + + Returns: + Any: The response message from the LLM. + + Raises: + NotImplementedError: Always, unless overridden. + """ + raise NotImplementedError("Subclasses must implement _send_query method.") + + def _log_response_to_file(self, response, output_dir: str = None): + """Logs the LLM response to a file. + + Args: + response: The response message to log. + output_dir (str): The directory where the log file is stored. + """ + if output_dir is None: + output_dir = self.output_dir + if response is not None: + with self.logging_file_memory_lock: + log_path = os.path.join(output_dir, "memory.txt") + with open(log_path, "a") as file: + file.write(f"{self.dev_name}: {response}\n") + logger.info(f"LLM Response [{self.dev_name}]: {response}") + + def subscribe_to_image_processing( + self, frame_observable: Observable, query_extractor=None + ) -> Disposable: + """Subscribes to a stream of video frames for processing. + + This method sets up a subscription to process incoming video frames. + Each frame is encoded and then sent to the LLM by directly calling the + _observable_query method. The response is then logged to a file. + + Args: + frame_observable (Observable): An observable emitting video frames or + (query, frame) tuples if query_extractor is provided. + query_extractor (callable, optional): Function to extract query and frame from + each emission. If None, assumes emissions are + raw frames and uses self.system_query. + + Returns: + Disposable: A disposable representing the subscription. + """ + # Initialize frame processor if not already set + if self.frame_processor is None: + self.frame_processor = FrameProcessor(delete_on_init=True) + + print_emission_args = {"enabled": True, "dev_name": self.dev_name, "counts": {}} + + def _process_frame(emission) -> Observable: + """ + Processes a frame or (query, frame) tuple. + """ + # Extract query and frame + if query_extractor: + query, frame = query_extractor(emission) + else: + query = self.system_query + frame = emission + return just(frame).pipe( + MyOps.print_emission(id="B", **print_emission_args), + RxOps.observe_on(self.pool_scheduler), + MyOps.print_emission(id="C", **print_emission_args), + RxOps.subscribe_on(self.pool_scheduler), + MyOps.print_emission(id="D", **print_emission_args), + MyVidOps.with_jpeg_export( + self.frame_processor, + suffix=f"{self.dev_name}_frame_", + save_limit=_MAX_SAVED_FRAMES, + ), + MyOps.print_emission(id="E", **print_emission_args), + MyVidOps.encode_image(), + MyOps.print_emission(id="F", **print_emission_args), + RxOps.filter( + lambda base64_and_dims: base64_and_dims is not None + and base64_and_dims[0] is not None + and base64_and_dims[1] is not None + ), + MyOps.print_emission(id="G", **print_emission_args), + RxOps.flat_map( + lambda base64_and_dims: create( + lambda observer, _: self._observable_query( + observer, + base64_image=base64_and_dims[0], + dimensions=base64_and_dims[1], + incoming_query=query, + ) + ) + ), # Use the extracted query + MyOps.print_emission(id="H", **print_emission_args), + ) + + # Use a mutable flag to ensure only one frame is processed at a time. + is_processing = [False] + + def process_if_free(emission): + if not self.process_all_inputs and is_processing[0]: + # Drop frame if a request is in progress and process_all_inputs is False + return empty() + else: + is_processing[0] = True + return _process_frame(emission).pipe( + MyOps.print_emission(id="I", **print_emission_args), + RxOps.observe_on(self.pool_scheduler), + MyOps.print_emission(id="J", **print_emission_args), + RxOps.subscribe_on(self.pool_scheduler), + MyOps.print_emission(id="K", **print_emission_args), + RxOps.do_action( + on_completed=lambda: is_processing.__setitem__(0, False), + on_error=lambda e: is_processing.__setitem__(0, False), + ), + MyOps.print_emission(id="L", **print_emission_args), + ) + + observable = frame_observable.pipe( + MyOps.print_emission(id="A", **print_emission_args), + RxOps.flat_map(process_if_free), + MyOps.print_emission(id="M", **print_emission_args), + ) + + disposable = observable.subscribe( + on_next=lambda response: self._log_response_to_file(response, self.output_dir), + on_error=lambda e: logger.error(f"Error encountered: {e}"), + on_completed=lambda: logger.info(f"Stream processing completed for {self.dev_name}"), + ) + self.disposables.add(disposable) + return disposable + + def subscribe_to_query_processing(self, query_observable: Observable) -> Disposable: + """Subscribes to a stream of queries for processing. + + This method sets up a subscription to process incoming queries by directly + calling the _observable_query method. The responses are logged to a file. + + Args: + query_observable (Observable): An observable emitting queries. + + Returns: + Disposable: A disposable representing the subscription. + """ + print_emission_args = {"enabled": False, "dev_name": self.dev_name, "counts": {}} + + def _process_query(query) -> Observable: + """ + Processes a single query by logging it and passing it to _observable_query. + Returns an observable that emits the LLM response. + """ + return just(query).pipe( + MyOps.print_emission(id="Pr A", **print_emission_args), + RxOps.flat_map( + lambda query: create( + lambda observer, _: self._observable_query(observer, incoming_query=query) + ) + ), + MyOps.print_emission(id="Pr B", **print_emission_args), + ) + + # A mutable flag indicating whether a query is currently being processed. + is_processing = [False] + + def process_if_free(query): + logger.info(f"Processing Query: {query}") + if not self.process_all_inputs and is_processing[0]: + # Drop query if a request is already in progress and process_all_inputs is False + return empty() + else: + is_processing[0] = True + logger.info("Processing Query.") + return _process_query(query).pipe( + MyOps.print_emission(id="B", **print_emission_args), + RxOps.observe_on(self.pool_scheduler), + MyOps.print_emission(id="C", **print_emission_args), + RxOps.subscribe_on(self.pool_scheduler), + MyOps.print_emission(id="D", **print_emission_args), + RxOps.do_action( + on_completed=lambda: is_processing.__setitem__(0, False), + on_error=lambda e: is_processing.__setitem__(0, False), + ), + MyOps.print_emission(id="E", **print_emission_args), + ) + + observable = query_observable.pipe( + MyOps.print_emission(id="A", **print_emission_args), + RxOps.flat_map(lambda query: process_if_free(query)), + MyOps.print_emission(id="F", **print_emission_args), + ) + + disposable = observable.subscribe( + on_next=lambda response: self._log_response_to_file(response, self.output_dir), + on_error=lambda e: logger.error(f"Error processing query for {self.dev_name}: {e}"), + on_completed=lambda: logger.info(f"Stream processing completed for {self.dev_name}"), + ) + self.disposables.add(disposable) + return disposable + + def get_response_observable(self) -> Observable: + """Gets an observable that emits responses from this agent. + + Returns: + Observable: An observable that emits string responses from the agent. + """ + return self.response_subject.pipe( + RxOps.observe_on(self.pool_scheduler), + RxOps.subscribe_on(self.pool_scheduler), + RxOps.share(), + ) + + def run_observable_query(self, query_text: str, **kwargs) -> Observable: + """Creates an observable that processes a one-off text query to Agent and emits the response. + + This method provides a simple way to send a text query and get an observable + stream of the response. It's designed for one-off queries rather than + continuous processing of input streams. Useful for testing and development. + + Args: + query_text (str): The query text to process. + **kwargs: Additional arguments to pass to _observable_query. Supported args vary by agent type. + For example, ClaudeAgent supports: base64_image, dimensions, override_token_limit, + reset_conversation, thinking_budget_tokens + + Returns: + Observable: An observable that emits the response as a string. + """ + return create( + lambda observer, _: self._observable_query( + observer, incoming_query=query_text, **kwargs + ) + ) + + def dispose_all(self): + """Disposes of all active subscriptions managed by this agent.""" + super().dispose_all() + self.response_subject.on_completed() + + +# endregion LLMAgent Base Class (Generic LLM Agent) + + +# ----------------------------------------------------------------------------- +# region OpenAIAgent Subclass (OpenAI-Specific Implementation) +# ----------------------------------------------------------------------------- +class OpenAIAgent(LLMAgent): + """OpenAI agent implementation that uses OpenAI's API for processing. + + This class implements the _send_query method to interact with OpenAI's API. + It also sets up OpenAI-specific parameters, such as the client, model name, + tokenizer, and response model. + """ + + def __init__( + self, + dev_name: str, + agent_type: str = "Vision", + query: str = "What do you see?", + input_query_stream: Optional[Observable] = None, + input_data_stream: Optional[Observable] = None, + input_video_stream: Optional[Observable] = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + system_query: Optional[str] = None, + max_input_tokens_per_request: int = 128000, + max_output_tokens_per_request: int = 16384, + model_name: str = "gpt-4o", + prompt_builder: Optional[PromptBuilder] = None, + tokenizer: Optional[AbstractTokenizer] = None, + rag_query_n: int = 4, + rag_similarity_threshold: float = 0.45, + skills: Optional[Union[AbstractSkill, list[AbstractSkill], SkillLibrary]] = None, + response_model: Optional[BaseModel] = None, + frame_processor: Optional[FrameProcessor] = None, + image_detail: str = "low", + pool_scheduler: Optional[ThreadPoolScheduler] = None, + process_all_inputs: Optional[bool] = None, + openai_client: Optional[OpenAI] = None, + ): + """ + Initializes a new instance of the OpenAIAgent. + + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent. + query (str): The default query text. + input_query_stream (Observable): An observable for query input. + input_data_stream (Observable): An observable for data input. + input_video_stream (Observable): An observable for video frames. + output_dir (str): Directory for output files. + agent_memory (AbstractAgentSemanticMemory): The memory system. + system_query (str): The system prompt to use with RAG context. + max_input_tokens_per_request (int): Maximum tokens for input. + max_output_tokens_per_request (int): Maximum tokens for output. + model_name (str): The OpenAI model name to use. + prompt_builder (PromptBuilder): Custom prompt builder. + tokenizer (AbstractTokenizer): Custom tokenizer for token counting. + rag_query_n (int): Number of results to fetch in RAG queries. + rag_similarity_threshold (float): Minimum similarity for RAG results. + skills (Union[AbstractSkill, List[AbstractSkill], SkillLibrary]): Skills available to the agent. + response_model (BaseModel): Optional Pydantic model for responses. + frame_processor (FrameProcessor): Custom frame processor. + image_detail (str): Detail level for images ("low", "high", "auto"). + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + process_all_inputs (bool): Whether to process all inputs or skip when busy. + If None, defaults to True for text queries and merged streams, False for video streams. + openai_client (OpenAI): The OpenAI client to use. This can be used to specify + a custom OpenAI client if targetting another provider. + """ + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + if input_query_stream is not None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory, + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + input_query_stream=input_query_stream, + input_data_stream=input_data_stream, + input_video_stream=input_video_stream, + ) + self.client = openai_client or OpenAI() + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + # Configure skill library. + self.skills = skills + self.skill_library = None + if isinstance(self.skills, SkillLibrary): + self.skill_library = self.skills + elif isinstance(self.skills, list): + self.skill_library = SkillLibrary() + for skill in self.skills: + self.skill_library.add(skill) + elif isinstance(self.skills, AbstractSkill): + self.skill_library = SkillLibrary() + self.skill_library.add(self.skills) + + self.response_model = response_model if response_model is not None else NOT_GIVEN + self.model_name = model_name + self.tokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name) + self.prompt_builder = prompt_builder or PromptBuilder( + self.model_name, tokenizer=self.tokenizer + ) + self.rag_query_n = rag_query_n + self.rag_similarity_threshold = rag_similarity_threshold + self.image_detail = image_detail + self.max_output_tokens_per_request = max_output_tokens_per_request + self.max_input_tokens_per_request = max_input_tokens_per_request + self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request + + # Add static context to memory. + self._add_context_to_memory() + + self.frame_processor = frame_processor or FrameProcessor(delete_on_init=True) + + logger.info("OpenAI Agent Initialized.") + + def _add_context_to_memory(self): + """Adds initial context to the agent's memory.""" + context_data = [ + ( + "id0", + "Optical Flow is a technique used to track the movement of objects in a video sequence.", + ), + ( + "id1", + "Edge Detection is a technique used to identify the boundaries of objects in an image.", + ), + ("id2", "Video is a sequence of frames captured at regular intervals."), + ( + "id3", + "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", + ), + ( + "id4", + "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", + ), + ] + for doc_id, text in context_data: + self.agent_memory.add_vector(doc_id, text) + + def _send_query(self, messages: list) -> Any: + """Sends the query to OpenAI's API. + + Depending on whether a response model is provided, the appropriate API + call is made. + + Args: + messages (list): The prompt messages to send. + + Returns: + The response message from OpenAI. + + Raises: + Exception: If no response message is returned. + ConnectionError: If there's an issue connecting to the API. + ValueError: If the messages or other parameters are invalid. + """ + try: + if self.response_model is not NOT_GIVEN: + response = self.client.beta.chat.completions.parse( + model=self.model_name, + messages=messages, + response_format=self.response_model, + tools=( + self.skill_library.get_tools() + if self.skill_library is not None + else NOT_GIVEN + ), + max_tokens=self.max_output_tokens_per_request, + ) + else: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=self.max_output_tokens_per_request, + tools=( + self.skill_library.get_tools() + if self.skill_library is not None + else NOT_GIVEN + ), + ) + response_message = response.choices[0].message + if response_message is None: + logger.error("Response message does not exist.") + raise Exception("Response message does not exist.") + return response_message + except ConnectionError as ce: + logger.error(f"Connection error with API: {ce}") + raise + except ValueError as ve: + logger.error(f"Invalid parameters: {ve}") + raise + except Exception as e: + logger.error(f"Unexpected error in API call: {e}") + raise + + def stream_query(self, query_text: str) -> Observable: + """Creates an observable that processes a text query and emits the response. + + This method provides a simple way to send a text query and get an observable + stream of the response. It's designed for one-off queries rather than + continuous processing of input streams. + + Args: + query_text (str): The query text to process. + + Returns: + Observable: An observable that emits the response as a string. + """ + return create( + lambda observer, _: self._observable_query(observer, incoming_query=query_text) + ) + + +# endregion OpenAIAgent Subclass (OpenAI-Specific Implementation) diff --git a/build/lib/dimos/agents/agent_config.py b/build/lib/dimos/agents/agent_config.py new file mode 100644 index 0000000000..0ffbcd2983 --- /dev/null +++ b/build/lib/dimos/agents/agent_config.py @@ -0,0 +1,55 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +from dimos.agents.agent import Agent + + +class AgentConfig: + def __init__(self, agents: List[Agent] = None): + """ + Initialize an AgentConfig with a list of agents. + + Args: + agents (List[Agent], optional): List of Agent instances. Defaults to empty list. + """ + self.agents = agents if agents is not None else [] + + def add_agent(self, agent: Agent): + """ + Add an agent to the configuration. + + Args: + agent (Agent): Agent instance to add + """ + self.agents.append(agent) + + def remove_agent(self, agent: Agent): + """ + Remove an agent from the configuration. + + Args: + agent (Agent): Agent instance to remove + """ + if agent in self.agents: + self.agents.remove(agent) + + def get_agents(self) -> List[Agent]: + """ + Get the list of configured agents. + + Returns: + List[Agent]: List of configured agents + """ + return self.agents diff --git a/build/lib/dimos/agents/agent_ctransformers_gguf.py b/build/lib/dimos/agents/agent_ctransformers_gguf.py new file mode 100644 index 0000000000..32d6fc59ca --- /dev/null +++ b/build/lib/dimos/agents/agent_ctransformers_gguf.py @@ -0,0 +1,210 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +# Standard library imports +import logging +import os +from typing import Any, Optional + +# Third-party imports +from dotenv import load_dotenv +from reactivex import Observable, create +from reactivex.scheduler import ThreadPoolScheduler +from reactivex.subject import Subject +import torch + +# Local imports +from dimos.agents.agent import LLMAgent +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.agents.prompt_builder.impl import PromptBuilder +from dimos.utils.logging_config import setup_logger + +# Initialize environment variables +load_dotenv() + +# Initialize logger for the agent module +logger = setup_logger("dimos.agents", level=logging.DEBUG) + +from ctransformers import AutoModelForCausalLM as CTransformersModel + + +class CTransformersTokenizerAdapter: + def __init__(self, model): + self.model = model + + def encode(self, text, **kwargs): + return self.model.tokenize(text) + + def decode(self, token_ids, **kwargs): + return self.model.detokenize(token_ids) + + def token_count(self, text): + return len(self.tokenize_text(text)) if text else 0 + + def tokenize_text(self, text): + return self.model.tokenize(text) + + def detokenize_text(self, tokenized_text): + try: + return self.model.detokenize(tokenized_text) + except Exception as e: + raise ValueError(f"Failed to detokenize text. Error: {str(e)}") + + def apply_chat_template(self, conversation, tokenize=False, add_generation_prompt=True): + prompt = "" + for message in conversation: + role = message["role"] + content = message["content"] + if role == "system": + prompt += f"<|system|>\n{content}\n" + elif role == "user": + prompt += f"<|user|>\n{content}\n" + elif role == "assistant": + prompt += f"<|assistant|>\n{content}\n" + if add_generation_prompt: + prompt += "<|assistant|>\n" + return prompt + + +# CTransformers Agent Class +class CTransformersGGUFAgent(LLMAgent): + def __init__( + self, + dev_name: str, + agent_type: str = "HF-LLM", + model_name: str = "TheBloke/Llama-2-7B-GGUF", + model_file: str = "llama-2-7b.Q4_K_M.gguf", + model_type: str = "llama", + gpu_layers: int = 50, + device: str = "auto", + query: str = "How many r's are in the word 'strawberry'?", + input_query_stream: Optional[Observable] = None, + input_video_stream: Optional[Observable] = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + system_query: Optional[str] = "You are a helpful assistant.", + max_output_tokens_per_request: int = 10, + max_input_tokens_per_request: int = 250, + prompt_builder: Optional[PromptBuilder] = None, + pool_scheduler: Optional[ThreadPoolScheduler] = None, + process_all_inputs: Optional[bool] = None, + ): + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + # Default to True for text queries, False for video streams + if input_query_stream is not None and input_video_stream is None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory, + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + max_output_tokens_per_request=max_output_tokens_per_request, + max_input_tokens_per_request=max_input_tokens_per_request, + ) + + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + self.model_name = model_name + self.device = device + if self.device == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" + if self.device == "cuda": + print(f"Using GPU: {torch.cuda.get_device_name(0)}") + else: + print("GPU not available, using CPU") + print(f"Device: {self.device}") + + self.model = CTransformersModel.from_pretrained( + model_name, model_file=model_file, model_type=model_type, gpu_layers=gpu_layers + ) + + self.tokenizer = CTransformersTokenizerAdapter(self.model) + + self.prompt_builder = prompt_builder or PromptBuilder( + self.model_name, tokenizer=self.tokenizer + ) + + self.max_output_tokens_per_request = max_output_tokens_per_request + + # self.stream_query(self.query).subscribe(lambda x: print(x)) + + self.input_video_stream = input_video_stream + self.input_query_stream = input_query_stream + + # Ensure only one input stream is provided. + if self.input_video_stream is not None and self.input_query_stream is not None: + raise ValueError( + "More than one input stream provided. Please provide only one input stream." + ) + + if self.input_video_stream is not None: + logger.info("Subscribing to input video stream...") + self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) + if self.input_query_stream is not None: + logger.info("Subscribing to input query stream...") + self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) + + def _send_query(self, messages: list) -> Any: + try: + _BLUE_PRINT_COLOR: str = "\033[34m" + _RESET_COLOR: str = "\033[0m" + + # === FIX: Flatten message content === + flat_messages = [] + for msg in messages: + role = msg["role"] + content = msg["content"] + if isinstance(content, list): + # Assume it's a list of {'type': 'text', 'text': ...} + text_parts = [c["text"] for c in content if isinstance(c, dict) and "text" in c] + content = " ".join(text_parts) + flat_messages.append({"role": role, "content": content}) + + print(f"{_BLUE_PRINT_COLOR}Messages: {flat_messages}{_RESET_COLOR}") + + print("Applying chat template...") + prompt_text = self.tokenizer.apply_chat_template( + conversation=flat_messages, tokenize=False, add_generation_prompt=True + ) + print("Chat template applied.") + print(f"Prompt text:\n{prompt_text}") + + response = self.model(prompt_text, max_new_tokens=self.max_output_tokens_per_request) + print("Model response received.") + return response + + except Exception as e: + logger.error(f"Error during HuggingFace query: {e}") + return "Error processing request." + + def stream_query(self, query_text: str) -> Subject: + """ + Creates an observable that processes a text query and emits the response. + """ + return create( + lambda observer, _: self._observable_query(observer, incoming_query=query_text) + ) + + +# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation) diff --git a/build/lib/dimos/agents/agent_huggingface_local.py b/build/lib/dimos/agents/agent_huggingface_local.py new file mode 100644 index 0000000000..14f970c3bc --- /dev/null +++ b/build/lib/dimos/agents/agent_huggingface_local.py @@ -0,0 +1,235 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +# Standard library imports +import logging +import os +from typing import Any, Optional + +# Third-party imports +from dotenv import load_dotenv +from reactivex import Observable, create +from reactivex.scheduler import ThreadPoolScheduler +from reactivex.subject import Subject +import torch +from transformers import AutoModelForCausalLM + +# Local imports +from dimos.agents.agent import LLMAgent +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.agents.memory.chroma_impl import LocalSemanticMemory +from dimos.agents.prompt_builder.impl import PromptBuilder +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.utils.logging_config import setup_logger + +# Initialize environment variables +load_dotenv() + +# Initialize logger for the agent module +logger = setup_logger("dimos.agents", level=logging.DEBUG) + + +# HuggingFaceLLMAgent Class +class HuggingFaceLocalAgent(LLMAgent): + def __init__( + self, + dev_name: str, + agent_type: str = "HF-LLM", + model_name: str = "Qwen/Qwen2.5-3B", + device: str = "auto", + query: str = "How many r's are in the word 'strawberry'?", + input_query_stream: Optional[Observable] = None, + input_video_stream: Optional[Observable] = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + system_query: Optional[str] = None, + max_output_tokens_per_request: int = None, + max_input_tokens_per_request: int = None, + prompt_builder: Optional[PromptBuilder] = None, + tokenizer: Optional[AbstractTokenizer] = None, + image_detail: str = "low", + pool_scheduler: Optional[ThreadPoolScheduler] = None, + process_all_inputs: Optional[bool] = None, + ): + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + # Default to True for text queries, False for video streams + if input_query_stream is not None and input_video_stream is None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory or LocalSemanticMemory(), + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + ) + + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + self.model_name = model_name + self.device = device + if self.device == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" + if self.device == "cuda": + print(f"Using GPU: {torch.cuda.get_device_name(0)}") + else: + print("GPU not available, using CPU") + print(f"Device: {self.device}") + + self.tokenizer = tokenizer or HuggingFaceTokenizer(self.model_name) + + self.prompt_builder = prompt_builder or PromptBuilder( + self.model_name, tokenizer=self.tokenizer + ) + + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, + device_map=self.device, + ) + + self.max_output_tokens_per_request = max_output_tokens_per_request + + # self.stream_query(self.query).subscribe(lambda x: print(x)) + + self.input_video_stream = input_video_stream + self.input_query_stream = input_query_stream + + # Ensure only one input stream is provided. + if self.input_video_stream is not None and self.input_query_stream is not None: + raise ValueError( + "More than one input stream provided. Please provide only one input stream." + ) + + if self.input_video_stream is not None: + logger.info("Subscribing to input video stream...") + self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) + if self.input_query_stream is not None: + logger.info("Subscribing to input query stream...") + self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) + + def _send_query(self, messages: list) -> Any: + _BLUE_PRINT_COLOR: str = "\033[34m" + _RESET_COLOR: str = "\033[0m" + + try: + # Log the incoming messages + print(f"{_BLUE_PRINT_COLOR}Messages: {str(messages)}{_RESET_COLOR}") + + # Process with chat template + try: + print("Applying chat template...") + prompt_text = self.tokenizer.tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": str(messages)}], + tokenize=False, + add_generation_prompt=True, + ) + print("Chat template applied.") + + # Tokenize the prompt + print("Preparing model inputs...") + model_inputs = self.tokenizer.tokenizer([prompt_text], return_tensors="pt").to( + self.model.device + ) + print("Model inputs prepared.") + + # Generate the response + print("Generating response...") + generated_ids = self.model.generate( + **model_inputs, max_new_tokens=self.max_output_tokens_per_request + ) + + # Extract the generated tokens (excluding the input prompt tokens) + print("Processing generated output...") + generated_ids = [ + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + # Convert tokens back to text + response = self.tokenizer.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True + )[0] + print("Response successfully generated.") + + return response + + except AttributeError as e: + # Handle case where tokenizer doesn't have the expected methods + logger.warning(f"Chat template not available: {e}. Using simple format.") + # Continue with execution and use simple format + + except Exception as e: + # Log any other errors but continue execution + logger.warning( + f"Error in chat template processing: {e}. Falling back to simple format." + ) + + # Fallback approach for models without chat template support + # This code runs if the try block above raises an exception + print("Using simple prompt format...") + + # Convert messages to a simple text format + if ( + isinstance(messages, list) + and messages + and isinstance(messages[0], dict) + and "content" in messages[0] + ): + prompt_text = messages[0]["content"] + else: + prompt_text = str(messages) + + # Tokenize the prompt + model_inputs = self.tokenizer.tokenize_text(prompt_text) + model_inputs = torch.tensor([model_inputs], device=self.model.device) + + # Generate the response + generated_ids = self.model.generate( + input_ids=model_inputs, max_new_tokens=self.max_output_tokens_per_request + ) + + # Extract the generated tokens + generated_ids = generated_ids[0][len(model_inputs[0]) :] + + # Convert tokens back to text + response = self.tokenizer.detokenize_text(generated_ids.tolist()) + print("Response generated using simple format.") + + return response + + except Exception as e: + # Catch all other errors + logger.error(f"Error during query processing: {e}", exc_info=True) + return "Error processing request. Please try again." + + def stream_query(self, query_text: str) -> Subject: + """ + Creates an observable that processes a text query and emits the response. + """ + return create( + lambda observer, _: self._observable_query(observer, incoming_query=query_text) + ) + + +# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation) diff --git a/build/lib/dimos/agents/agent_huggingface_remote.py b/build/lib/dimos/agents/agent_huggingface_remote.py new file mode 100644 index 0000000000..d98b277706 --- /dev/null +++ b/build/lib/dimos/agents/agent_huggingface_remote.py @@ -0,0 +1,143 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +# Standard library imports +import logging +import os +from typing import Any, Optional + +# Third-party imports +from dotenv import load_dotenv +from huggingface_hub import InferenceClient +from reactivex import create, Observable +from reactivex.scheduler import ThreadPoolScheduler +from reactivex.subject import Subject + +# Local imports +from dimos.agents.agent import LLMAgent +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.agents.prompt_builder.impl import PromptBuilder +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.utils.logging_config import setup_logger + +# Initialize environment variables +load_dotenv() + +# Initialize logger for the agent module +logger = setup_logger("dimos.agents", level=logging.DEBUG) + + +# HuggingFaceLLMAgent Class +class HuggingFaceRemoteAgent(LLMAgent): + def __init__( + self, + dev_name: str, + agent_type: str = "HF-LLM", + model_name: str = "Qwen/QwQ-32B", + query: str = "How many r's are in the word 'strawberry'?", + input_query_stream: Optional[Observable] = None, + input_video_stream: Optional[Observable] = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + system_query: Optional[str] = None, + max_output_tokens_per_request: int = 16384, + prompt_builder: Optional[PromptBuilder] = None, + tokenizer: Optional[AbstractTokenizer] = None, + image_detail: str = "low", + pool_scheduler: Optional[ThreadPoolScheduler] = None, + process_all_inputs: Optional[bool] = None, + api_key: Optional[str] = None, + hf_provider: Optional[str] = None, + hf_base_url: Optional[str] = None, + ): + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + # Default to True for text queries, False for video streams + if input_query_stream is not None and input_video_stream is None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory, + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + ) + + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + self.model_name = model_name + self.prompt_builder = prompt_builder or PromptBuilder( + self.model_name, tokenizer=tokenizer or HuggingFaceTokenizer(self.model_name) + ) + + self.model_name = model_name + + self.max_output_tokens_per_request = max_output_tokens_per_request + + self.api_key = api_key or os.getenv("HF_TOKEN") + self.provider = hf_provider or "hf-inference" + self.base_url = hf_base_url or os.getenv("HUGGINGFACE_PRV_ENDPOINT") + self.client = InferenceClient( + provider=self.provider, + base_url=self.base_url, + api_key=self.api_key, + ) + + # self.stream_query(self.query).subscribe(lambda x: print(x)) + + self.input_video_stream = input_video_stream + self.input_query_stream = input_query_stream + + # Ensure only one input stream is provided. + if self.input_video_stream is not None and self.input_query_stream is not None: + raise ValueError( + "More than one input stream provided. Please provide only one input stream." + ) + + if self.input_video_stream is not None: + logger.info("Subscribing to input video stream...") + self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) + if self.input_query_stream is not None: + logger.info("Subscribing to input query stream...") + self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) + + def _send_query(self, messages: list) -> Any: + try: + completion = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=self.max_output_tokens_per_request, + ) + + return completion.choices[0].message + except Exception as e: + logger.error(f"Error during HuggingFace query: {e}") + return "Error processing request." + + def stream_query(self, query_text: str) -> Subject: + """ + Creates an observable that processes a text query and emits the response. + """ + return create( + lambda observer, _: self._observable_query(observer, incoming_query=query_text) + ) diff --git a/build/lib/dimos/agents/cerebras_agent.py b/build/lib/dimos/agents/cerebras_agent.py new file mode 100644 index 0000000000..854beb848d --- /dev/null +++ b/build/lib/dimos/agents/cerebras_agent.py @@ -0,0 +1,608 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cerebras agent implementation for the DIMOS agent framework. + +This module provides a CerebrasAgent class that implements the LLMAgent interface +for Cerebras inference API using the official Cerebras Python SDK. +""" + +from __future__ import annotations + +import os +import threading +import copy +from typing import Any, Dict, List, Optional, Union, Tuple +import logging +import json +import re +import time + +from cerebras.cloud.sdk import Cerebras +from dotenv import load_dotenv +from pydantic import BaseModel +from reactivex import Observable +from reactivex.observer import Observer +from reactivex.scheduler import ThreadPoolScheduler + +# Local imports +from dimos.agents.agent import LLMAgent +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.agents.prompt_builder.impl import PromptBuilder +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.stream.frame_processor import FrameProcessor +from dimos.utils.logging_config import setup_logger +from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer + +# Initialize environment variables +load_dotenv() + +# Initialize logger for the Cerebras agent +logger = setup_logger("dimos.agents.cerebras") + + +# Response object compatible with LLMAgent +class CerebrasResponseMessage(dict): + def __init__( + self, + content="", + tool_calls=None, + ): + self.content = content + self.tool_calls = tool_calls or [] + self.parsed = None + + # Initialize as dict with the proper structure + super().__init__(self.to_dict()) + + def __str__(self): + # Return a string representation for logging + if self.content: + return self.content + elif self.tool_calls: + # Return JSON representation of the first tool call + if self.tool_calls: + tool_call = self.tool_calls[0] + tool_json = { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + return json.dumps(tool_json) + return "[No content]" + + def to_dict(self): + """Convert to dictionary format for JSON serialization.""" + result = {"role": "assistant", "content": self.content or ""} + + if self.tool_calls: + result["tool_calls"] = [] + for tool_call in self.tool_calls: + result["tool_calls"].append( + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + ) + + return result + + +class CerebrasAgent(LLMAgent): + """Cerebras agent implementation using the official Cerebras Python SDK. + + This class implements the _send_query method to interact with Cerebras API + using their official SDK, allowing most of the LLMAgent logic to be reused. + """ + + def __init__( + self, + dev_name: str, + agent_type: str = "Vision", + query: str = "What do you see?", + input_query_stream: Optional[Observable] = None, + input_video_stream: Optional[Observable] = None, + input_data_stream: Optional[Observable] = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + system_query: Optional[str] = None, + max_input_tokens_per_request: int = 128000, + max_output_tokens_per_request: int = 16384, + model_name: str = "llama-4-scout-17b-16e-instruct", + skills: Optional[Union[AbstractSkill, list[AbstractSkill], SkillLibrary]] = None, + response_model: Optional[BaseModel] = None, + frame_processor: Optional[FrameProcessor] = None, + image_detail: str = "low", + pool_scheduler: Optional[ThreadPoolScheduler] = None, + process_all_inputs: Optional[bool] = None, + tokenizer: Optional[AbstractTokenizer] = None, + prompt_builder: Optional[PromptBuilder] = None, + ): + """ + Initializes a new instance of the CerebrasAgent. + + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent. + query (str): The default query text. + input_query_stream (Observable): An observable for query input. + input_video_stream (Observable): An observable for video frames. + input_data_stream (Observable): An observable for data input. + output_dir (str): Directory for output files. + agent_memory (AbstractAgentSemanticMemory): The memory system. + system_query (str): The system prompt to use with RAG context. + max_input_tokens_per_request (int): Maximum tokens for input. + max_output_tokens_per_request (int): Maximum tokens for output. + model_name (str): The Cerebras model name to use. Available options: + - llama-4-scout-17b-16e-instruct (default, fastest) + - llama3.1-8b + - llama-3.3-70b + - qwen-3-32b + - deepseek-r1-distill-llama-70b (private preview) + skills (Union[AbstractSkill, List[AbstractSkill], SkillLibrary]): Skills available to the agent. + response_model (BaseModel): Optional Pydantic model for structured responses. + frame_processor (FrameProcessor): Custom frame processor. + image_detail (str): Detail level for images ("low", "high", "auto"). + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + process_all_inputs (bool): Whether to process all inputs or skip when busy. + tokenizer (AbstractTokenizer): The tokenizer for the agent. + prompt_builder (PromptBuilder): The prompt builder for the agent. + """ + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + # Default to True for text queries, False for video streams + if input_query_stream is not None and input_video_stream is None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory, + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + input_query_stream=input_query_stream, + input_video_stream=input_video_stream, + input_data_stream=input_data_stream, + ) + + # Initialize Cerebras client + self.client = Cerebras() + + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + # Initialize conversation history for multi-turn conversations + self.conversation_history = [] + self._history_lock = threading.Lock() + + # Configure skills + self.skills = skills + self.skill_library = None + if isinstance(self.skills, SkillLibrary): + self.skill_library = self.skills + elif isinstance(self.skills, list): + self.skill_library = SkillLibrary() + for skill in self.skills: + self.skill_library.add(skill) + elif isinstance(self.skills, AbstractSkill): + self.skill_library = SkillLibrary() + self.skill_library.add(self.skills) + + self.response_model = response_model + self.model_name = model_name + self.image_detail = image_detail + self.max_output_tokens_per_request = max_output_tokens_per_request + self.max_input_tokens_per_request = max_input_tokens_per_request + self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request + + # Add static context to memory. + self._add_context_to_memory() + + # Initialize tokenizer and prompt builder + self.tokenizer = tokenizer or OpenAITokenizer( + model_name="gpt-4o" + ) # Use GPT-4 tokenizer for better accuracy + self.prompt_builder = prompt_builder or PromptBuilder( + model_name=self.model_name, + max_tokens=self.max_input_tokens_per_request, + tokenizer=self.tokenizer, + ) + + logger.info("Cerebras Agent Initialized.") + + def _add_context_to_memory(self): + """Adds initial context to the agent's memory.""" + context_data = [ + ( + "id0", + "Optical Flow is a technique used to track the movement of objects in a video sequence.", + ), + ( + "id1", + "Edge Detection is a technique used to identify the boundaries of objects in an image.", + ), + ("id2", "Video is a sequence of frames captured at regular intervals."), + ( + "id3", + "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", + ), + ( + "id4", + "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", + ), + ] + for doc_id, text in context_data: + self.agent_memory.add_vector(doc_id, text) + + def _build_prompt( + self, + messages: list, + base64_image: Optional[Union[str, List[str]]] = None, + dimensions: Optional[Tuple[int, int]] = None, + override_token_limit: bool = False, + condensed_results: str = "", + ) -> list: + """Builds a prompt message specifically for Cerebras API. + + Args: + messages (list): Existing messages list to build upon. + base64_image (Union[str, List[str]]): Optional Base64-encoded image(s). + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + condensed_results (str): The condensed RAG context. + + Returns: + list: Messages formatted for Cerebras API. + """ + # Add system message if provided and not already in history + if self.system_query and (not messages or messages[0].get("role") != "system"): + messages.insert(0, {"role": "system", "content": self.system_query}) + logger.info("Added system message to conversation") + + # Append user query while handling RAG + if condensed_results: + user_message = {"role": "user", "content": f"{condensed_results}\n\n{self.query}"} + logger.info("Created user message with RAG context") + else: + user_message = {"role": "user", "content": self.query} + + messages.append(user_message) + + if base64_image is not None: + # Handle both single image (str) and multiple images (List[str]) + images = [base64_image] if isinstance(base64_image, str) else base64_image + + # For Cerebras, we'll add images inline with text (OpenAI-style format) + for img in images: + img_content = [ + {"type": "text", "text": "Here is an image to analyze:"}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{img}", + "detail": self.image_detail, + }, + }, + ] + messages.append({"role": "user", "content": img_content}) + + logger.info(f"Added {len(images)} image(s) to conversation") + + # Use new truncation function + messages = self._truncate_messages(messages, override_token_limit) + + return messages + + def _truncate_messages(self, messages: list, override_token_limit: bool = False) -> list: + """Truncate messages if total tokens exceed 16k using existing truncate_tokens method. + + Args: + messages (list): List of message dictionaries + override_token_limit (bool): Whether to skip truncation + + Returns: + list: Messages with content truncated if needed + """ + if override_token_limit: + return messages + + total_tokens = 0 + for message in messages: + if isinstance(message.get("content"), str): + total_tokens += self.prompt_builder.tokenizer.token_count(message["content"]) + elif isinstance(message.get("content"), list): + for item in message["content"]: + if item.get("type") == "text": + total_tokens += self.prompt_builder.tokenizer.token_count(item["text"]) + elif item.get("type") == "image_url": + total_tokens += 85 + + if total_tokens > 16000: + excess_tokens = total_tokens - 16000 + current_tokens = total_tokens + + # Start from oldest messages and truncate until under 16k + for i in range(len(messages)): + if current_tokens <= 16000: + break + + msg = messages[i] + if msg.get("role") == "system": + continue + + if isinstance(msg.get("content"), str): + original_tokens = self.prompt_builder.tokenizer.token_count(msg["content"]) + # Calculate how much to truncate from this message + tokens_to_remove = min(excess_tokens, original_tokens // 3) + new_max_tokens = max(50, original_tokens - tokens_to_remove) + + msg["content"] = self.prompt_builder.truncate_tokens( + msg["content"], new_max_tokens, "truncate_end" + ) + + new_tokens = self.prompt_builder.tokenizer.token_count(msg["content"]) + tokens_saved = original_tokens - new_tokens + current_tokens -= tokens_saved + excess_tokens -= tokens_saved + + logger.info( + f"Truncated older messages using truncate_tokens, final tokens: {current_tokens}" + ) + else: + logger.info(f"No truncation needed, total tokens: {total_tokens}") + + return messages + + def clean_cerebras_schema(self, schema: dict) -> dict: + """Simple schema cleaner that removes unsupported fields for Cerebras API.""" + if not isinstance(schema, dict): + return schema + + # Removing the problematic fields that pydantic generates + cleaned = {} + unsupported_fields = { + "minItems", + "maxItems", + "uniqueItems", + "exclusiveMinimum", + "exclusiveMaximum", + "minimum", + "maximum", + } + + for key, value in schema.items(): + if key in unsupported_fields: + continue # Skip unsupported fields + elif isinstance(value, dict): + cleaned[key] = self.clean_cerebras_schema(value) + elif isinstance(value, list): + cleaned[key] = [ + self.clean_cerebras_schema(item) if isinstance(item, dict) else item + for item in value + ] + else: + cleaned[key] = value + + return cleaned + + def create_tool_call( + self, name: str = None, arguments: dict = None, call_id: str = None, content: str = None + ): + """Create a tool call object from either direct parameters or JSON content.""" + # If content is provided, parse it as JSON + if content: + logger.info(f"Creating tool call from content: {content}") + try: + content_json = json.loads(content) + if ( + isinstance(content_json, dict) + and "name" in content_json + and "arguments" in content_json + ): + name = content_json["name"] + arguments = content_json["arguments"] + else: + return None + except json.JSONDecodeError: + logger.warning("Content appears to be JSON but failed to parse") + return None + + # Create the tool call object + if name and arguments is not None: + timestamp = int(time.time() * 1000000) # microsecond precision + tool_id = f"call_{timestamp}" + + logger.info(f"Creating tool call with timestamp ID: {tool_id}") + return type( + "ToolCall", + (), + { + "id": tool_id, + "function": type( + "Function", (), {"name": name, "arguments": json.dumps(arguments)} + ), + }, + ) + + return None + + def _send_query(self, messages: list) -> CerebrasResponseMessage: + """Sends the query to Cerebras API using the official Cerebras SDK. + + Args: + messages (list): The prompt messages to send. + + Returns: + The response message from Cerebras wrapped in our CerebrasResponseMessage class. + + Raises: + Exception: If no response message is returned from the API. + ConnectionError: If there's an issue connecting to the API. + ValueError: If the messages or other parameters are invalid. + """ + try: + # Prepare API call parameters + api_params = { + "model": self.model_name, + "messages": messages, + # "max_tokens": self.max_output_tokens_per_request, + } + + # Add tools if available + if self.skill_library and self.skill_library.get_tools(): + tools = self.skill_library.get_tools() + for tool in tools: + if "function" in tool and "parameters" in tool["function"]: + tool["function"]["parameters"] = self.clean_cerebras_schema( + tool["function"]["parameters"] + ) + api_params["tools"] = tools + api_params["tool_choice"] = "auto" + + if self.response_model is not None: + api_params["response_format"] = { + "type": "json_object", + "schema": self.response_model, + } + + # Make the API call + response = self.client.chat.completions.create(**api_params) + + raw_message = response.choices[0].message + if raw_message is None: + logger.error("Response message does not exist.") + raise Exception("Response message does not exist.") + + # Process response into final format + content = raw_message.content + tool_calls = getattr(raw_message, "tool_calls", None) + + # If no structured tool calls from API, try parsing content as JSON tool call + if not tool_calls and content and content.strip().startswith("{"): + parsed_tool_call = self.create_tool_call(content=content) + if parsed_tool_call: + tool_calls = [parsed_tool_call] + content = None + + return CerebrasResponseMessage(content=content, tool_calls=tool_calls) + + except ConnectionError as ce: + logger.error(f"Connection error with Cerebras API: {ce}") + raise + except ValueError as ve: + logger.error(f"Invalid parameters for Cerebras API: {ve}") + raise + except Exception as e: + # Print the raw API parameters when an error occurs + logger.error(f"Raw API parameters: {json.dumps(api_params, indent=2)}") + logger.error(f"Unexpected error in Cerebras API call: {e}") + raise + + def _observable_query( + self, + observer: Observer, + base64_image: Optional[str] = None, + dimensions: Optional[Tuple[int, int]] = None, + override_token_limit: bool = False, + incoming_query: Optional[str] = None, + reset_conversation: bool = False, + ): + """Main query handler that manages conversation history and Cerebras interactions. + + This method follows ClaudeAgent's pattern for efficient conversation history management. + + Args: + observer (Observer): The observer to emit responses to. + base64_image (str): Optional Base64-encoded image. + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + incoming_query (str): Optional query to update the agent's query. + reset_conversation (bool): Whether to reset the conversation history. + """ + try: + # Reset conversation history if requested + if reset_conversation: + self.conversation_history = [] + logger.info("Conversation history reset") + + # Create a local copy of conversation history and record its length + messages = copy.deepcopy(self.conversation_history) + + # Update query and get context + self._update_query(incoming_query) + _, condensed_results = self._get_rag_context() + + # Build prompt + messages = self._build_prompt( + messages, base64_image, dimensions, override_token_limit, condensed_results + ) + + while True: + logger.info("Sending Query.") + response_message = self._send_query(messages) + logger.info(f"Received Response: {response_message}") + + if response_message is None: + raise Exception("Response message does not exist.") + + # If no skill library or no tool calls, we're done + if ( + self.skill_library is None + or self.skill_library.get_tools() is None + or response_message.tool_calls is None + ): + final_msg = ( + response_message.parsed + if hasattr(response_message, "parsed") and response_message.parsed + else ( + response_message.content + if hasattr(response_message, "content") + else response_message + ) + ) + messages.append(response_message) + break + + logger.info(f"Assistant requested {len(response_message.tool_calls)} tool call(s)") + next_response = self._handle_tooling(response_message, messages) + + if next_response is None: + final_msg = response_message.content or "" + break + + response_message = next_response + + with self._history_lock: + self.conversation_history = messages + logger.info( + f"Updated conversation history (total: {len(self.conversation_history)} messages)" + ) + + # Emit the final message content to the observer + observer.on_next(final_msg) + self.response_subject.on_next(final_msg) + observer.on_completed() + + except Exception as e: + logger.error(f"Query failed in {self.dev_name}: {e}") + observer.on_error(e) + self.response_subject.on_error(e) diff --git a/build/lib/dimos/agents/claude_agent.py b/build/lib/dimos/agents/claude_agent.py new file mode 100644 index 0000000000..e87b1f47b4 --- /dev/null +++ b/build/lib/dimos/agents/claude_agent.py @@ -0,0 +1,735 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Claude agent implementation for the DIMOS agent framework. + +This module provides a ClaudeAgent class that implements the LLMAgent interface +for Anthropic's Claude models. It handles conversion between the DIMOS skill format +and Claude's tools format. +""" + +from __future__ import annotations + +import json +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import anthropic +from dotenv import load_dotenv +from pydantic import BaseModel +from reactivex import Observable +from reactivex.scheduler import ThreadPoolScheduler + +# Local imports +from dimos.agents.agent import LLMAgent +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.agents.prompt_builder.impl import PromptBuilder +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.stream.frame_processor import FrameProcessor +from dimos.utils.logging_config import setup_logger + +# Initialize environment variables +load_dotenv() + +# Initialize logger for the Claude agent +logger = setup_logger("dimos.agents.claude") + + +# Response object compatible with LLMAgent +class ResponseMessage: + def __init__(self, content="", tool_calls=None, thinking_blocks=None): + self.content = content + self.tool_calls = tool_calls or [] + self.thinking_blocks = thinking_blocks or [] + self.parsed = None + + def __str__(self): + # Return a string representation for logging + parts = [] + + # Include content if available + if self.content: + parts.append(self.content) + + # Include tool calls if available + if self.tool_calls: + tool_names = [tc.function.name for tc in self.tool_calls] + parts.append(f"[Tools called: {', '.join(tool_names)}]") + + return "\n".join(parts) if parts else "[No content]" + + +class ClaudeAgent(LLMAgent): + """Claude agent implementation that uses Anthropic's API for processing. + + This class implements the _send_query method to interact with Anthropic's API + and overrides _build_prompt to create Claude-formatted messages directly. + """ + + def __init__( + self, + dev_name: str, + agent_type: str = "Vision", + query: str = "What do you see?", + input_query_stream: Optional[Observable] = None, + input_video_stream: Optional[Observable] = None, + input_data_stream: Optional[Observable] = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: Optional[AbstractAgentSemanticMemory] = None, + system_query: Optional[str] = None, + max_input_tokens_per_request: int = 128000, + max_output_tokens_per_request: int = 16384, + model_name: str = "claude-3-7-sonnet-20250219", + prompt_builder: Optional[PromptBuilder] = None, + rag_query_n: int = 4, + rag_similarity_threshold: float = 0.45, + skills: Optional[AbstractSkill] = None, + response_model: Optional[BaseModel] = None, + frame_processor: Optional[FrameProcessor] = None, + image_detail: str = "low", + pool_scheduler: Optional[ThreadPoolScheduler] = None, + process_all_inputs: Optional[bool] = None, + thinking_budget_tokens: Optional[int] = 2000, + ): + """ + Initializes a new instance of the ClaudeAgent. + + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent. + query (str): The default query text. + input_query_stream (Observable): An observable for query input. + input_video_stream (Observable): An observable for video frames. + output_dir (str): Directory for output files. + agent_memory (AbstractAgentSemanticMemory): The memory system. + system_query (str): The system prompt to use with RAG context. + max_input_tokens_per_request (int): Maximum tokens for input. + max_output_tokens_per_request (int): Maximum tokens for output. + model_name (str): The Claude model name to use. + prompt_builder (PromptBuilder): Custom prompt builder (not used in Claude implementation). + rag_query_n (int): Number of results to fetch in RAG queries. + rag_similarity_threshold (float): Minimum similarity for RAG results. + skills (AbstractSkill): Skills available to the agent. + response_model (BaseModel): Optional Pydantic model for responses. + frame_processor (FrameProcessor): Custom frame processor. + image_detail (str): Detail level for images ("low", "high", "auto"). + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + process_all_inputs (bool): Whether to process all inputs or skip when busy. + thinking_budget_tokens (int): Number of tokens to allocate for Claude's thinking. 0 disables thinking. + """ + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + # Default to True for text queries, False for video streams + if input_query_stream is not None and input_video_stream is None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory, + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + input_query_stream=input_query_stream, + input_video_stream=input_video_stream, + input_data_stream=input_data_stream, + ) + + self.client = anthropic.Anthropic() + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + # Claude-specific parameters + self.thinking_budget_tokens = thinking_budget_tokens + self.claude_api_params = {} # Will store params for Claude API calls + + # Configure skills + self.skills = skills + self.skill_library = None # Required for error 'ClaudeAgent' object has no attribute 'skill_library' due to skills refactor + if isinstance(self.skills, SkillLibrary): + self.skill_library = self.skills + elif isinstance(self.skills, list): + self.skill_library = SkillLibrary() + for skill in self.skills: + self.skill_library.add(skill) + elif isinstance(self.skills, AbstractSkill): + self.skill_library = SkillLibrary() + self.skill_library.add(self.skills) + + self.response_model = response_model + self.model_name = model_name + self.rag_query_n = rag_query_n + self.rag_similarity_threshold = rag_similarity_threshold + self.image_detail = image_detail + self.max_output_tokens_per_request = max_output_tokens_per_request + self.max_input_tokens_per_request = max_input_tokens_per_request + self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request + + # Add static context to memory. + self._add_context_to_memory() + + self.frame_processor = frame_processor or FrameProcessor(delete_on_init=True) + + # Ensure only one input stream is provided. + if self.input_video_stream is not None and self.input_query_stream is not None: + raise ValueError( + "More than one input stream provided. Please provide only one input stream." + ) + + logger.info("Claude Agent Initialized.") + + def _add_context_to_memory(self): + """Adds initial context to the agent's memory.""" + context_data = [ + ( + "id0", + "Optical Flow is a technique used to track the movement of objects in a video sequence.", + ), + ( + "id1", + "Edge Detection is a technique used to identify the boundaries of objects in an image.", + ), + ("id2", "Video is a sequence of frames captured at regular intervals."), + ( + "id3", + "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", + ), + ( + "id4", + "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", + ), + ] + for doc_id, text in context_data: + self.agent_memory.add_vector(doc_id, text) + + def _convert_tools_to_claude_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Converts DIMOS tools to Claude format. + + Args: + tools: List of tools in DIMOS format. + + Returns: + List of tools in Claude format. + """ + if not tools: + return [] + + claude_tools = [] + + for tool in tools: + # Skip if not a function + if tool.get("type") != "function": + continue + + function = tool.get("function", {}) + name = function.get("name") + description = function.get("description", "") + parameters = function.get("parameters", {}) + + claude_tool = { + "name": name, + "description": description, + "input_schema": { + "type": "object", + "properties": parameters.get("properties", {}), + "required": parameters.get("required", []), + }, + } + + claude_tools.append(claude_tool) + + return claude_tools + + def _build_prompt( + self, + messages: list, + base64_image: Optional[Union[str, List[str]]] = None, + dimensions: Optional[Tuple[int, int]] = None, + override_token_limit: bool = False, + rag_results: str = "", + thinking_budget_tokens: int = None, + ) -> list: + """Builds a prompt message specifically for Claude API, using local messages copy.""" + """Builds a prompt message specifically for Claude API. + + This method creates messages in Claude's format directly, without using + any OpenAI-specific formatting or token counting. + + Args: + base64_image (Union[str, List[str]]): Optional Base64-encoded image(s). + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + rag_results (str): The condensed RAG context. + thinking_budget_tokens (int): Number of tokens to allocate for Claude's thinking. + + Returns: + dict: A dict containing Claude API parameters. + """ + + # Append user query to conversation history while handling RAG + if rag_results: + messages.append({"role": "user", "content": f"{rag_results}\n\n{self.query}"}) + logger.info( + f"Added new user message to conversation history with RAG context (now has {len(messages)} messages)" + ) + else: + messages.append({"role": "user", "content": self.query}) + logger.info( + f"Added new user message to conversation history (now has {len(messages)} messages)" + ) + + if base64_image is not None: + # Handle both single image (str) and multiple images (List[str]) + images = [base64_image] if isinstance(base64_image, str) else base64_image + + # Add each image as a separate entry in conversation history + for img in images: + img_content = [ + { + "type": "image", + "source": {"type": "base64", "media_type": "image/jpeg", "data": img}, + } + ] + messages.append({"role": "user", "content": img_content}) + + if images: + logger.info( + f"Added {len(images)} image(s) as separate entries to conversation history" + ) + + # Create Claude parameters with basic settings + claude_params = { + "model": self.model_name, + "max_tokens": self.max_output_tokens_per_request, + "temperature": 0, # Add temperature to make responses more deterministic + "messages": messages, + } + + # Add system prompt as a top-level parameter (not as a message) + if self.system_query: + claude_params["system"] = self.system_query + + # Store the parameters for use in _send_query + self.claude_api_params = claude_params.copy() + + # Add tools if skills are available + if self.skills and self.skills.get_tools(): + tools = self._convert_tools_to_claude_format(self.skills.get_tools()) + if tools: # Only add if we have valid tools + claude_params["tools"] = tools + # Enable tool calling with proper format + claude_params["tool_choice"] = {"type": "auto"} + + # Add thinking if enabled and hard code required temperature = 1 + if thinking_budget_tokens is not None and thinking_budget_tokens != 0: + claude_params["thinking"] = {"type": "enabled", "budget_tokens": thinking_budget_tokens} + claude_params["temperature"] = ( + 1 # Required to be 1 when thinking is enabled # Default to 0 for deterministic responses + ) + + # Store the parameters for use in _send_query and return them + self.claude_api_params = claude_params.copy() + return messages, claude_params + + def _send_query(self, messages: list, claude_params: dict) -> Any: + """Sends the query to Anthropic's API using streaming for better thinking visualization. + + Args: + messages: Dict with 'claude_prompt' key containing Claude API parameters. + + Returns: + The response message in a format compatible with LLMAgent's expectations. + """ + try: + # Get Claude parameters + claude_params = claude_params.get("claude_prompt", None) or self.claude_api_params + + # Log request parameters with truncated base64 data + logger.debug(self._debug_api_call(claude_params)) + + # Initialize response containers + text_content = "" + tool_calls = [] + thinking_blocks = [] + + # Log the start of streaming and the query + logger.info("Sending streaming request to Claude API") + + # Log the query to memory.txt + with open(os.path.join(self.output_dir, "memory.txt"), "a") as f: + f.write(f"\n\nQUERY: {self.query}\n\n") + f.flush() + + # Stream the response + with self.client.messages.stream(**claude_params) as stream: + print("\n==== CLAUDE API RESPONSE STREAM STARTED ====") + + # Open the memory file once for the entire stream processing + with open(os.path.join(self.output_dir, "memory.txt"), "a") as memory_file: + # Track the current block being processed + current_block = {"type": None, "id": None, "content": "", "signature": None} + + for event in stream: + # Log each event to console + # print(f"EVENT: {event.type}") + # print(json.dumps(event.model_dump(), indent=2, default=str)) + + if event.type == "content_block_start": + # Initialize a new content block + block_type = event.content_block.type + current_block = { + "type": block_type, + "id": event.index, + "content": "", + "signature": None, + } + logger.debug(f"Starting {block_type} block...") + + elif event.type == "content_block_delta": + if event.delta.type == "thinking_delta": + # Accumulate thinking content + current_block["content"] = event.delta.thinking + memory_file.write(f"{event.delta.thinking}") + memory_file.flush() # Ensure content is written immediately + + elif event.delta.type == "text_delta": + # Accumulate text content + text_content += event.delta.text + current_block["content"] += event.delta.text + memory_file.write(f"{event.delta.text}") + memory_file.flush() + + elif event.delta.type == "signature_delta": + # Store signature for thinking blocks + current_block["signature"] = event.delta.signature + memory_file.write( + f"\n[Signature received for block {current_block['id']}]\n" + ) + memory_file.flush() + + elif event.type == "content_block_stop": + # Store completed blocks + if current_block["type"] == "thinking": + # IMPORTANT: Store the complete event.content_block to ensure we preserve + # the exact format that Claude expects in subsequent requests + if hasattr(event, "content_block"): + # Use the exact thinking block as provided by Claude + thinking_blocks.append(event.content_block.model_dump()) + memory_file.write( + f"\nTHINKING COMPLETE: block {current_block['id']}\n" + ) + else: + # Fallback to constructed thinking block if content_block missing + thinking_block = { + "type": "thinking", + "thinking": current_block["content"], + "signature": current_block["signature"], + } + thinking_blocks.append(thinking_block) + memory_file.write( + f"\nTHINKING COMPLETE: block {current_block['id']}\n" + ) + + elif current_block["type"] == "redacted_thinking": + # Handle redacted thinking blocks + if hasattr(event, "content_block") and hasattr( + event.content_block, "data" + ): + redacted_block = { + "type": "redacted_thinking", + "data": event.content_block.data, + } + thinking_blocks.append(redacted_block) + + elif current_block["type"] == "tool_use": + # Process tool use blocks when they're complete + if hasattr(event, "content_block"): + tool_block = event.content_block + tool_id = tool_block.id + tool_name = tool_block.name + tool_input = tool_block.input + + # Create a tool call object for LLMAgent compatibility + tool_call_obj = type( + "ToolCall", + (), + { + "id": tool_id, + "function": type( + "Function", + (), + { + "name": tool_name, + "arguments": json.dumps(tool_input), + }, + ), + }, + ) + tool_calls.append(tool_call_obj) + + # Write tool call information to memory.txt + memory_file.write(f"\n\nTOOL CALL: {tool_name}\n") + memory_file.write( + f"ARGUMENTS: {json.dumps(tool_input, indent=2)}\n" + ) + + # Reset current block + current_block = { + "type": None, + "id": None, + "content": "", + "signature": None, + } + memory_file.flush() + + elif ( + event.type == "message_delta" and event.delta.stop_reason == "tool_use" + ): + # When a tool use is detected + logger.info("Tool use stop reason detected in stream") + + # Mark the end of the response in memory.txt + memory_file.write("\n\nRESPONSE COMPLETE\n\n") + memory_file.flush() + + print("\n==== CLAUDE API RESPONSE STREAM COMPLETED ====") + + # Final response + logger.info( + f"Claude streaming complete. Text: {len(text_content)} chars, Tool calls: {len(tool_calls)}, Thinking blocks: {len(thinking_blocks)}" + ) + + # Return the complete response with all components + return ResponseMessage( + content=text_content, + tool_calls=tool_calls if tool_calls else None, + thinking_blocks=thinking_blocks if thinking_blocks else None, + ) + + except ConnectionError as ce: + logger.error(f"Connection error with Anthropic API: {ce}") + raise + except ValueError as ve: + logger.error(f"Invalid parameters for Anthropic API: {ve}") + raise + except Exception as e: + logger.error(f"Unexpected error in Anthropic API call: {e}") + logger.exception(e) # This will print the full traceback + raise + + def _observable_query( + self, + observer: Observer, + base64_image: Optional[str] = None, + dimensions: Optional[Tuple[int, int]] = None, + override_token_limit: bool = False, + incoming_query: Optional[str] = None, + reset_conversation: bool = False, + thinking_budget_tokens: int = None, + ): + """Main query handler that manages conversation history and Claude interactions. + + This is the primary method for handling all queries, whether they come through + direct_query or through the observable pattern. It manages the conversation + history, builds prompts, and handles tool calls. + + Args: + observer (Observer): The observer to emit responses to + base64_image (Optional[str]): Optional Base64-encoded image + dimensions (Optional[Tuple[int, int]]): Optional image dimensions + override_token_limit (bool): Whether to override token limits + incoming_query (Optional[str]): Optional query to update the agent's query + reset_conversation (bool): Whether to reset the conversation history + """ + + try: + logger.info("_observable_query called in claude") + import copy + + # Reset conversation history if requested + if reset_conversation: + self.conversation_history = [] + + # Create a local copy of conversation history and record its length + messages = copy.deepcopy(self.conversation_history) + base_len = len(messages) + + # Update query and get context + self._update_query(incoming_query) + _, rag_results = self._get_rag_context() + + # Build prompt and get Claude parameters + budget = ( + thinking_budget_tokens + if thinking_budget_tokens is not None + else self.thinking_budget_tokens + ) + messages, claude_params = self._build_prompt( + messages, base64_image, dimensions, override_token_limit, rag_results, budget + ) + + # Send query and get response + response_message = self._send_query(messages, claude_params) + + if response_message is None: + logger.error("Received None response from Claude API") + observer.on_next("") + observer.on_completed() + return + # Add thinking blocks and text content to conversation history + content_blocks = [] + if response_message.thinking_blocks: + content_blocks.extend(response_message.thinking_blocks) + if response_message.content: + content_blocks.append({"type": "text", "text": response_message.content}) + if content_blocks: + messages.append({"role": "assistant", "content": content_blocks}) + + # Handle tool calls if present + if response_message.tool_calls: + self._handle_tooling(response_message, messages) + + # At the end, append only new messages (including tool-use/results) to the global conversation history under a lock + import threading + + if not hasattr(self, "_history_lock"): + self._history_lock = threading.Lock() + with self._history_lock: + for msg in messages[base_len:]: + self.conversation_history.append(msg) + + # After merging, run tooling callback (outside lock) + if response_message.tool_calls: + self._tooling_callback(response_message) + + # Send response to observers + result = response_message.content or "" + observer.on_next(result) + self.response_subject.on_next(result) + observer.on_completed() + except Exception as e: + logger.error(f"Query failed in {self.dev_name}: {e}") + # Send a user-friendly error message instead of propagating the error + error_message = "I apologize, but I'm having trouble processing your request right now. Please try again." + observer.on_next(error_message) + self.response_subject.on_next(error_message) + observer.on_completed() + + def _handle_tooling(self, response_message, messages): + """Executes tools and appends tool-use/result blocks to messages.""" + if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: + logger.info("No tool calls found in response message") + return None + + if len(response_message.tool_calls) > 1: + logger.warning( + "Multiple tool calls detected in response message. Not a tested feature." + ) + + # Execute all tools first and collect their results + for tool_call in response_message.tool_calls: + logger.info(f"Processing tool call: {tool_call.function.name}") + tool_use_block = { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments), + } + messages.append({"role": "assistant", "content": [tool_use_block]}) + + try: + # Execute the tool + args = json.loads(tool_call.function.arguments) + tool_result = self.skills.call(tool_call.function.name, **args) + + # Check if the result is an error message + if isinstance(tool_result, str) and ( + "Error executing skill" in tool_result or "is not available" in tool_result + ): + # Log the error but provide a user-friendly message + logger.error(f"Tool execution failed: {tool_result}") + tool_result = "I apologize, but I'm having trouble executing that action right now. Please try again or ask for something else." + + # Add tool result to conversation history + if tool_result: + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_call.id, + "content": f"{tool_result}", + } + ], + } + ) + except Exception as e: + logger.error(f"Unexpected error executing tool {tool_call.function.name}: {e}") + # Add error result to conversation history + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_call.id, + "content": "I apologize, but I encountered an error while trying to execute that action. Please try again.", + } + ], + } + ) + + def _tooling_callback(self, response_message): + """Runs the observable query for each tool call in the current response_message""" + if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: + return + + try: + for tool_call in response_message.tool_calls: + tool_name = tool_call.function.name + tool_id = tool_call.id + self.run_observable_query( + query_text=f"Tool {tool_name}, ID: {tool_id} execution complete. Please summarize the results and continue.", + thinking_budget_tokens=0, + ).run() + except Exception as e: + logger.error(f"Error in tooling callback: {e}") + # Continue processing even if the callback fails + pass + + def _debug_api_call(self, claude_params: dict): + """Debugging function to log API calls with truncated base64 data.""" + # Remove tools to reduce verbosity + import copy + + log_params = copy.deepcopy(claude_params) + if "tools" in log_params: + del log_params["tools"] + + # Truncate base64 data in images - much cleaner approach + if "messages" in log_params: + for msg in log_params["messages"]: + if "content" in msg: + for content in msg["content"]: + if isinstance(content, dict) and content.get("type") == "image": + source = content.get("source", {}) + if source.get("type") == "base64" and "data" in source: + data = source["data"] + source["data"] = f"{data[:50]}..." + return json.dumps(log_params, indent=2, default=str) diff --git a/build/lib/dimos/agents/memory/__init__.py b/build/lib/dimos/agents/memory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/agents/memory/base.py b/build/lib/dimos/agents/memory/base.py new file mode 100644 index 0000000000..af8cbf689f --- /dev/null +++ b/build/lib/dimos/agents/memory/base.py @@ -0,0 +1,133 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from dimos.exceptions.agent_memory_exceptions import ( + UnknownConnectionTypeError, + AgentMemoryConnectionError, +) +from dimos.utils.logging_config import setup_logger + +# TODO +# class AbstractAgentMemory(ABC): + +# TODO +# class AbstractAgentSymbolicMemory(AbstractAgentMemory): + + +class AbstractAgentSemanticMemory: # AbstractAgentMemory): + def __init__(self, connection_type="local", **kwargs): + """ + Initialize with dynamic connection parameters. + Args: + connection_type (str): 'local' for a local database, 'remote' for a remote connection. + Raises: + UnknownConnectionTypeError: If an unrecognized connection type is specified. + AgentMemoryConnectionError: If initializing the database connection fails. + """ + self.logger = setup_logger(self.__class__.__name__) + self.logger.info("Initializing AgentMemory with connection type: %s", connection_type) + self.connection_params = kwargs + self.db_connection = ( + None # Holds the conection, whether local or remote, to the database used. + ) + + if connection_type not in ["local", "remote"]: + error = UnknownConnectionTypeError( + f"Invalid connection_type {connection_type}. Expected 'local' or 'remote'." + ) + self.logger.error(str(error)) + raise error + + try: + if connection_type == "remote": + self.connect() + elif connection_type == "local": + self.create() + except Exception as e: + self.logger.error("Failed to initialize database connection: %s", str(e), exc_info=True) + raise AgentMemoryConnectionError( + "Initialization failed due to an unexpected error.", cause=e + ) from e + + @abstractmethod + def connect(self): + """Establish a connection to the data store using dynamic parameters specified during initialization.""" + + @abstractmethod + def create(self): + """Create a local instance of the data store tailored to specific requirements.""" + + ## Create ## + @abstractmethod + def add_vector(self, vector_id, vector_data): + """Add a vector to the database. + Args: + vector_id (any): Unique identifier for the vector. + vector_data (any): The actual data of the vector to be stored. + """ + + ## Read ## + @abstractmethod + def get_vector(self, vector_id): + """Retrieve a vector from the database by its identifier. + Args: + vector_id (any): The identifier of the vector to retrieve. + """ + + @abstractmethod + def query(self, query_texts, n_results=4, similarity_threshold=None): + """Performs a semantic search in the vector database. + + Args: + query_texts (Union[str, List[str]]): The query text or list of query texts to search for. + n_results (int, optional): Number of results to return. Defaults to 4. + similarity_threshold (float, optional): Minimum similarity score for results to be included [0.0, 1.0]. Defaults to None. + + Returns: + List[Tuple[Document, Optional[float]]]: A list of tuples containing the search results. Each tuple + contains: + Document: The retrieved document object. + Optional[float]: The similarity score of the match, or None if not applicable. + + Raises: + ValueError: If query_texts is empty or invalid. + ConnectionError: If database connection fails during query. + """ + + ## Update ## + @abstractmethod + def update_vector(self, vector_id, new_vector_data): + """Update an existing vector in the database. + Args: + vector_id (any): The identifier of the vector to update. + new_vector_data (any): The new data to replace the existing vector data. + """ + + ## Delete ## + @abstractmethod + def delete_vector(self, vector_id): + """Delete a vector from the database using its identifier. + Args: + vector_id (any): The identifier of the vector to delete. + """ + + +# query(string, metadata/tag, n_rets, kwargs) + +# query by string, timestamp, id, n_rets + +# (some sort of tag/metadata) + +# temporal diff --git a/build/lib/dimos/agents/memory/chroma_impl.py b/build/lib/dimos/agents/memory/chroma_impl.py new file mode 100644 index 0000000000..06f6989355 --- /dev/null +++ b/build/lib/dimos/agents/memory/chroma_impl.py @@ -0,0 +1,167 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.agents.memory.base import AbstractAgentSemanticMemory + +from langchain_openai import OpenAIEmbeddings +from langchain_chroma import Chroma +import os +import torch + + +class ChromaAgentSemanticMemory(AbstractAgentSemanticMemory): + """Base class for Chroma-based semantic memory implementations.""" + + def __init__(self, collection_name="my_collection"): + """Initialize the connection to the local Chroma DB.""" + self.collection_name = collection_name + self.db_connection = None + self.embeddings = None + super().__init__(connection_type="local") + + def connect(self): + # Stub + return super().connect() + + def create(self): + """Create the embedding function and initialize the Chroma database. + This method must be implemented by child classes.""" + raise NotImplementedError("Child classes must implement this method") + + def add_vector(self, vector_id, vector_data): + """Add a vector to the ChromaDB collection.""" + if not self.db_connection: + raise Exception("Collection not initialized. Call connect() first.") + self.db_connection.add_texts( + ids=[vector_id], + texts=[vector_data], + metadatas=[{"name": vector_id}], + ) + + def get_vector(self, vector_id): + """Retrieve a vector from the ChromaDB by its identifier.""" + result = self.db_connection.get(include=["embeddings"], ids=[vector_id]) + return result + + def query(self, query_texts, n_results=4, similarity_threshold=None): + """Query the collection with a specific text and return up to n results.""" + if not self.db_connection: + raise Exception("Collection not initialized. Call connect() first.") + + if similarity_threshold is not None: + if not (0 <= similarity_threshold <= 1): + raise ValueError("similarity_threshold must be between 0 and 1.") + return self.db_connection.similarity_search_with_relevance_scores( + query=query_texts, k=n_results, score_threshold=similarity_threshold + ) + else: + documents = self.db_connection.similarity_search(query=query_texts, k=n_results) + return [(doc, None) for doc in documents] + + def update_vector(self, vector_id, new_vector_data): + # TODO + return super().connect() + + def delete_vector(self, vector_id): + """Delete a vector from the ChromaDB using its identifier.""" + if not self.db_connection: + raise Exception("Collection not initialized. Call connect() first.") + self.db_connection.delete(ids=[vector_id]) + + +class OpenAISemanticMemory(ChromaAgentSemanticMemory): + """Semantic memory implementation using OpenAI's embedding API.""" + + def __init__( + self, collection_name="my_collection", model="text-embedding-3-large", dimensions=1024 + ): + """Initialize OpenAI-based semantic memory. + + Args: + collection_name (str): Name of the Chroma collection + model (str): OpenAI embedding model to use + dimensions (int): Dimension of the embedding vectors + """ + self.model = model + self.dimensions = dimensions + super().__init__(collection_name=collection_name) + + def create(self): + """Connect to OpenAI API and create the ChromaDB client.""" + # Get OpenAI key + self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + if not self.OPENAI_API_KEY: + raise Exception("OpenAI key was not specified.") + + # Set embeddings + self.embeddings = OpenAIEmbeddings( + model=self.model, + dimensions=self.dimensions, + api_key=self.OPENAI_API_KEY, + ) + + # Create the database + self.db_connection = Chroma( + collection_name=self.collection_name, + embedding_function=self.embeddings, + collection_metadata={"hnsw:space": "cosine"}, + ) + + +class LocalSemanticMemory(ChromaAgentSemanticMemory): + """Semantic memory implementation using local models.""" + + def __init__( + self, collection_name="my_collection", model_name="sentence-transformers/all-MiniLM-L6-v2" + ): + """Initialize the local semantic memory using SentenceTransformer. + + Args: + collection_name (str): Name of the Chroma collection + model_name (str): Embeddings model + """ + + self.model_name = model_name + super().__init__(collection_name=collection_name) + + def create(self): + """Create local embedding model and initialize the ChromaDB client.""" + # Load the sentence transformer model + # Use CUDA if available, otherwise fall back to CPU + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + self.model = SentenceTransformer(self.model_name, device=device) + + # Create a custom embedding class that implements the embed_query method + class SentenceTransformerEmbeddings: + def __init__(self, model): + self.model = model + + def embed_query(self, text): + """Embed a single query text.""" + return self.model.encode(text, normalize_embeddings=True).tolist() + + def embed_documents(self, texts): + """Embed multiple documents/texts.""" + return self.model.encode(texts, normalize_embeddings=True).tolist() + + # Create an instance of our custom embeddings class + self.embeddings = SentenceTransformerEmbeddings(self.model) + + # Create the database + self.db_connection = Chroma( + collection_name=self.collection_name, + embedding_function=self.embeddings, + collection_metadata={"hnsw:space": "cosine"}, + ) diff --git a/build/lib/dimos/agents/memory/image_embedding.py b/build/lib/dimos/agents/memory/image_embedding.py new file mode 100644 index 0000000000..1ad0e9132d --- /dev/null +++ b/build/lib/dimos/agents/memory/image_embedding.py @@ -0,0 +1,263 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Image embedding module for converting images to vector embeddings. + +This module provides a class for generating vector embeddings from images +using pre-trained models like CLIP, ResNet, etc. +""" + +import base64 +import io +import os +from typing import Union + +import cv2 +import numpy as np +from PIL import Image + +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.agents.memory.image_embedding") + + +class ImageEmbeddingProvider: + """ + A provider for generating vector embeddings from images. + + This class uses pre-trained models to convert images into vector embeddings + that can be stored in a vector database and used for similarity search. + """ + + def __init__(self, model_name: str = "clip", dimensions: int = 512): + """ + Initialize the image embedding provider. + + Args: + model_name: Name of the embedding model to use ("clip", "resnet", etc.) + dimensions: Dimensions of the embedding vectors + """ + self.model_name = model_name + self.dimensions = dimensions + self.model = None + self.processor = None + + self._initialize_model() + + logger.info(f"ImageEmbeddingProvider initialized with model {model_name}") + + def _initialize_model(self): + """Initialize the specified embedding model.""" + try: + import onnxruntime as ort + import torch + from transformers import AutoFeatureExtractor, AutoModel, CLIPProcessor + + if self.model_name == "clip": + model_id = get_data("models_clip") / "model.onnx" + processor_id = "openai/clip-vit-base-patch32" + self.model = ort.InferenceSession(model_id) + self.processor = CLIPProcessor.from_pretrained(processor_id) + logger.info(f"Loaded CLIP model: {model_id}") + elif self.model_name == "resnet": + model_id = "microsoft/resnet-50" + self.model = AutoModel.from_pretrained(model_id) + self.processor = AutoFeatureExtractor.from_pretrained(model_id) + logger.info(f"Loaded ResNet model: {model_id}") + else: + raise ValueError(f"Unsupported model: {self.model_name}") + except ImportError as e: + logger.error(f"Failed to import required modules: {e}") + logger.error("Please install with: pip install transformers torch") + # Initialize with dummy model for type checking + self.model = None + self.processor = None + raise + + def get_embedding(self, image: Union[np.ndarray, str, bytes]) -> np.ndarray: + """ + Generate an embedding vector for the provided image. + + Args: + image: The image to embed, can be a numpy array (OpenCV format), + a file path, or a base64-encoded string + + Returns: + A numpy array containing the embedding vector + """ + if self.model is None or self.processor is None: + logger.error("Model not initialized. Using fallback random embedding.") + return np.random.randn(self.dimensions).astype(np.float32) + + pil_image = self._prepare_image(image) + + try: + import torch + + if self.model_name == "clip": + inputs = self.processor(images=pil_image, return_tensors="np") + + with torch.no_grad(): + ort_inputs = { + inp.name: inputs[inp.name] + for inp in self.model.get_inputs() + if inp.name in inputs + } + + # If required, add dummy text inputs + input_names = [i.name for i in self.model.get_inputs()] + batch_size = inputs["pixel_values"].shape[0] + if "input_ids" in input_names: + ort_inputs["input_ids"] = np.zeros((batch_size, 1), dtype=np.int64) + if "attention_mask" in input_names: + ort_inputs["attention_mask"] = np.ones((batch_size, 1), dtype=np.int64) + + # Run inference + ort_outputs = self.model.run(None, ort_inputs) + + # Look up correct output name + output_names = [o.name for o in self.model.get_outputs()] + if "image_embeds" in output_names: + image_embedding = ort_outputs[output_names.index("image_embeds")] + else: + raise RuntimeError(f"No 'image_embeds' found in outputs: {output_names}") + + embedding = image_embedding / np.linalg.norm(image_embedding, axis=1, keepdims=True) + embedding = embedding[0] + + elif self.model_name == "resnet": + inputs = self.processor(images=pil_image, return_tensors="pt") + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Get the [CLS] token embedding + embedding = outputs.last_hidden_state[:, 0, :].numpy()[0] + else: + logger.warning(f"Unsupported model: {self.model_name}. Using random embedding.") + embedding = np.random.randn(self.dimensions).astype(np.float32) + + # Normalize and ensure correct dimensions + embedding = embedding / np.linalg.norm(embedding) + + logger.debug(f"Generated embedding with shape {embedding.shape}") + return embedding + + except Exception as e: + logger.error(f"Error generating embedding: {e}") + return np.random.randn(self.dimensions).astype(np.float32) + + def get_text_embedding(self, text: str) -> np.ndarray: + """ + Generate an embedding vector for the provided text. + + Args: + text: The text to embed + + Returns: + A numpy array containing the embedding vector + """ + if self.model is None or self.processor is None: + logger.error("Model not initialized. Using fallback random embedding.") + return np.random.randn(self.dimensions).astype(np.float32) + + if self.model_name != "clip": + logger.warning( + f"Text embeddings are only supported with CLIP model, not {self.model_name}. Using random embedding." + ) + return np.random.randn(self.dimensions).astype(np.float32) + + try: + import torch + + inputs = self.processor(text=[text], return_tensors="np", padding=True) + + with torch.no_grad(): + # Prepare ONNX input dict (handle only what's needed) + ort_inputs = { + inp.name: inputs[inp.name] + for inp in self.model.get_inputs() + if inp.name in inputs + } + # Determine which inputs are expected by the ONNX model + input_names = [i.name for i in self.model.get_inputs()] + batch_size = inputs["input_ids"].shape[0] # pulled from text input + + # If the model expects pixel_values (i.e., fused model), add dummy vision input + if "pixel_values" in input_names: + ort_inputs["pixel_values"] = np.zeros( + (batch_size, 3, 224, 224), dtype=np.float32 + ) + + # Run inference + ort_outputs = self.model.run(None, ort_inputs) + + # Determine correct output (usually 'last_hidden_state' or 'text_embeds') + output_names = [o.name for o in self.model.get_outputs()] + if "text_embeds" in output_names: + text_embedding = ort_outputs[output_names.index("text_embeds")] + else: + text_embedding = ort_outputs[0] # fallback to first output + + # Normalize + text_embedding = text_embedding / np.linalg.norm( + text_embedding, axis=1, keepdims=True + ) + text_embedding = text_embedding[0] # shape: (512,) + + logger.debug( + f"Generated text embedding with shape {text_embedding.shape} for text: '{text}'" + ) + return text_embedding + + except Exception as e: + logger.error(f"Error generating text embedding: {e}") + return np.random.randn(self.dimensions).astype(np.float32) + + def _prepare_image(self, image: Union[np.ndarray, str, bytes]) -> Image.Image: + """ + Convert the input image to PIL format required by the models. + + Args: + image: Input image in various formats + + Returns: + PIL Image object + """ + if isinstance(image, np.ndarray): + if len(image.shape) == 3 and image.shape[2] == 3: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + else: + image_rgb = image + + return Image.fromarray(image_rgb) + + elif isinstance(image, str): + if os.path.isfile(image): + return Image.open(image) + else: + try: + image_data = base64.b64decode(image) + return Image.open(io.BytesIO(image_data)) + except Exception as e: + logger.error(f"Failed to decode image string: {e}") + raise ValueError("Invalid image string format") + + elif isinstance(image, bytes): + return Image.open(io.BytesIO(image)) + + else: + raise ValueError(f"Unsupported image format: {type(image)}") diff --git a/build/lib/dimos/agents/memory/spatial_vector_db.py b/build/lib/dimos/agents/memory/spatial_vector_db.py new file mode 100644 index 0000000000..cf44d0c589 --- /dev/null +++ b/build/lib/dimos/agents/memory/spatial_vector_db.py @@ -0,0 +1,268 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Spatial vector database for storing and querying images with XY locations. + +This module extends the ChromaDB implementation to support storing images with +their XY locations and querying by location or image similarity. +""" + +import numpy as np +from typing import List, Dict, Tuple, Any +import chromadb + +from dimos.agents.memory.visual_memory import VisualMemory +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.agents.memory.spatial_vector_db") + + +class SpatialVectorDB: + """ + A vector database for storing and querying images mapped to X,Y,theta absolute locations for SpatialMemory. + + This class extends the ChromaDB implementation to support storing images with + their absolute locations and querying by location, text, or image cosine semantic similarity. + """ + + def __init__( + self, collection_name: str = "spatial_memory", chroma_client=None, visual_memory=None + ): + """ + Initialize the spatial vector database. + + Args: + collection_name: Name of the vector database collection + chroma_client: Optional ChromaDB client for persistence. If None, an in-memory client is used. + visual_memory: Optional VisualMemory instance for storing images. If None, a new one is created. + """ + self.collection_name = collection_name + + # Use provided client or create in-memory client + self.client = chroma_client if chroma_client is not None else chromadb.Client() + + # Check if collection already exists - in newer ChromaDB versions list_collections returns names directly + existing_collections = self.client.list_collections() + + # Handle different versions of ChromaDB API + try: + collection_exists = collection_name in existing_collections + except: + try: + collection_exists = collection_name in [c.name for c in existing_collections] + except: + try: + self.client.get_collection(name=collection_name) + collection_exists = True + except Exception: + collection_exists = False + + # Get or create the collection + self.image_collection = self.client.get_or_create_collection( + name=collection_name, metadata={"hnsw:space": "cosine"} + ) + + # Use provided visual memory or create a new one + self.visual_memory = visual_memory if visual_memory is not None else VisualMemory() + + # Log initialization info with details about whether using existing collection + client_type = "persistent" if chroma_client is not None else "in-memory" + try: + count = len(self.image_collection.get(include=[])["ids"]) + if collection_exists: + logger.info( + f"Using EXISTING {client_type} collection '{collection_name}' with {count} entries" + ) + else: + logger.info(f"Created NEW {client_type} collection '{collection_name}'") + except Exception as e: + logger.info( + f"Initialized {client_type} collection '{collection_name}' (count error: {str(e)})" + ) + + def add_image_vector( + self, vector_id: str, image: np.ndarray, embedding: np.ndarray, metadata: Dict[str, Any] + ) -> None: + """ + Add an image with its embedding and metadata to the vector database. + + Args: + vector_id: Unique identifier for the vector + image: The image to store + embedding: The pre-computed embedding vector for the image + metadata: Metadata for the image, including x, y coordinates + """ + # Store the image in visual memory + self.visual_memory.add(vector_id, image) + + # Add the vector to ChromaDB + self.image_collection.add( + ids=[vector_id], embeddings=[embedding.tolist()], metadatas=[metadata] + ) + + logger.debug(f"Added image vector {vector_id} with metadata: {metadata}") + + def query_by_embedding(self, embedding: np.ndarray, limit: int = 5) -> List[Dict]: + """ + Query the vector database for images similar to the provided embedding. + + Args: + embedding: Query embedding vector + limit: Maximum number of results to return + + Returns: + List of results, each containing the image and its metadata + """ + results = self.image_collection.query( + query_embeddings=[embedding.tolist()], n_results=limit + ) + + return self._process_query_results(results) + + # TODO: implement efficient nearest neighbor search + def query_by_location( + self, x: float, y: float, radius: float = 2.0, limit: int = 5 + ) -> List[Dict]: + """ + Query the vector database for images near the specified location. + + Args: + x: X coordinate + y: Y coordinate + radius: Search radius in meters + limit: Maximum number of results to return + + Returns: + List of results, each containing the image and its metadata + """ + results = self.image_collection.get() + + if not results or not results["ids"]: + return [] + + filtered_results = {"ids": [], "metadatas": [], "distances": []} + + for i, metadata in enumerate(results["metadatas"]): + item_x = metadata.get("x") + item_y = metadata.get("y") + + if item_x is not None and item_y is not None: + distance = np.sqrt((x - item_x) ** 2 + (y - item_y) ** 2) + + if distance <= radius: + filtered_results["ids"].append(results["ids"][i]) + filtered_results["metadatas"].append(metadata) + filtered_results["distances"].append(distance) + + sorted_indices = np.argsort(filtered_results["distances"]) + filtered_results["ids"] = [filtered_results["ids"][i] for i in sorted_indices[:limit]] + filtered_results["metadatas"] = [ + filtered_results["metadatas"][i] for i in sorted_indices[:limit] + ] + filtered_results["distances"] = [ + filtered_results["distances"][i] for i in sorted_indices[:limit] + ] + + return self._process_query_results(filtered_results) + + def _process_query_results(self, results) -> List[Dict]: + """Process query results to include decoded images.""" + if not results or not results["ids"]: + return [] + + processed_results = [] + + for i, vector_id in enumerate(results["ids"]): + lookup_id = vector_id[0] if isinstance(vector_id, list) else vector_id + + # Create the result dictionary with metadata regardless of image availability + result = { + "metadata": results["metadatas"][i] if "metadatas" in results else {}, + "id": lookup_id, + } + + # Add distance if available + if "distances" in results: + result["distance"] = ( + results["distances"][i][0] + if isinstance(results["distances"][i], list) + else results["distances"][i] + ) + + # Get the image from visual memory + image = self.visual_memory.get(lookup_id) + result["image"] = image + + processed_results.append(result) + + return processed_results + + def query_by_text(self, text: str, limit: int = 5) -> List[Dict]: + """ + Query the vector database for images matching the provided text description. + + This method uses CLIP's text-to-image matching capability to find images + that semantically match the text query (e.g., "where is the kitchen"). + + Args: + text: Text query to search for + limit: Maximum number of results to return + + Returns: + List of results, each containing the image, its metadata, and similarity score + """ + from dimos.agents.memory.image_embedding import ImageEmbeddingProvider + + embedding_provider = ImageEmbeddingProvider(model_name="clip") + + text_embedding = embedding_provider.get_text_embedding(text) + + results = self.image_collection.query( + query_embeddings=[text_embedding.tolist()], + n_results=limit, + include=["documents", "metadatas", "distances"], + ) + + logger.info( + f"Text query: '{text}' returned {len(results['ids'] if 'ids' in results else [])} results" + ) + return self._process_query_results(results) + + def get_all_locations(self) -> List[Tuple[float, float, float]]: + """Get all locations stored in the database.""" + # Get all items from the collection without embeddings + results = self.image_collection.get(include=["metadatas"]) + + if not results or "metadatas" not in results or not results["metadatas"]: + return [] + + # Extract x, y coordinates from metadata + locations = [] + for metadata in results["metadatas"]: + if isinstance(metadata, list) and metadata and isinstance(metadata[0], dict): + metadata = metadata[0] # Handle nested metadata + + if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: + x = metadata.get("x", 0) + y = metadata.get("y", 0) + z = metadata.get("z", 0) if "z" in metadata else 0 + locations.append((x, y, z)) + + return locations + + @property + def image_storage(self): + """Legacy accessor for compatibility with existing code.""" + return self.visual_memory.images diff --git a/build/lib/dimos/agents/memory/test_image_embedding.py b/build/lib/dimos/agents/memory/test_image_embedding.py new file mode 100644 index 0000000000..c424b950bb --- /dev/null +++ b/build/lib/dimos/agents/memory/test_image_embedding.py @@ -0,0 +1,212 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test module for the CLIP image embedding functionality in dimos. +""" + +import os +import time + +import numpy as np +import pytest +import reactivex as rx +from reactivex import operators as ops + +from dimos.agents.memory.image_embedding import ImageEmbeddingProvider +from dimos.stream.video_provider import VideoProvider + + +@pytest.mark.heavy +class TestImageEmbedding: + """Test class for CLIP image embedding functionality.""" + + def test_clip_embedding_initialization(self): + """Test CLIP embedding provider initializes correctly.""" + try: + # Initialize the embedding provider with CLIP model + embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) + assert embedding_provider.model is not None, "CLIP model failed to initialize" + assert embedding_provider.processor is not None, "CLIP processor failed to initialize" + assert embedding_provider.model_name == "clip", "Model name should be 'clip'" + assert embedding_provider.dimensions == 512, "Embedding dimensions should be 512" + except Exception as e: + pytest.skip(f"Skipping test due to model initialization error: {e}") + + def test_clip_embedding_process_video(self): + """Test CLIP embedding provider can process video frames and return embeddings.""" + try: + from dimos.utils.data import get_data + + video_path = get_data("assets") / "trimmed_video_office.mov" + + embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) + + assert os.path.exists(video_path), f"Test video not found: {video_path}" + video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) + + # Use ReactiveX operators to process the stream + def process_frame(frame): + try: + # Process frame with CLIP + embedding = embedding_provider.get_embedding(frame) + print( + f"Generated CLIP embedding with shape: {embedding.shape}, norm: {np.linalg.norm(embedding):.4f}" + ) + + return {"frame": frame, "embedding": embedding} + except Exception as e: + print(f"Error in process_frame: {e}") + return None + + embedding_stream = video_stream.pipe(ops.map(process_frame)) + + results = [] + frames_processed = 0 + target_frames = 10 + + def on_next(result): + nonlocal frames_processed, results + if not result: # Skip None results + return + + results.append(result) + frames_processed += 1 + + # Stop processing after target frames + if frames_processed >= target_frames: + subscription.dispose() + + def on_error(error): + pytest.fail(f"Error in embedding stream: {error}") + + def on_completed(): + pass + + # Subscribe and wait for results + subscription = embedding_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + timeout = 60.0 + start_time = time.time() + while frames_processed < target_frames and time.time() - start_time < timeout: + time.sleep(0.5) + print(f"Processed {frames_processed}/{target_frames} frames") + + # Clean up subscription + subscription.dispose() + video_provider.dispose_all() + + # Check if we have results + if len(results) == 0: + pytest.skip("No embeddings generated, but test connection established correctly") + return + + print(f"Processed {len(results)} frames with CLIP embeddings") + + # Analyze the results + assert len(results) > 0, "No embeddings generated" + + # Check properties of first embedding + first_result = results[0] + assert "embedding" in first_result, "Result doesn't contain embedding" + assert "frame" in first_result, "Result doesn't contain frame" + + # Check embedding shape and normalization + embedding = first_result["embedding"] + assert isinstance(embedding, np.ndarray), "Embedding is not a numpy array" + assert embedding.shape == (512,), ( + f"Embedding has wrong shape: {embedding.shape}, expected (512,)" + ) + assert abs(np.linalg.norm(embedding) - 1.0) < 1e-5, "Embedding is not normalized" + + # Save the first embedding for similarity tests + if len(results) > 1 and "embedding" in results[0]: + # Create a class variable to store embeddings for the similarity test + TestImageEmbedding.test_embeddings = { + "embedding1": results[0]["embedding"], + "embedding2": results[1]["embedding"] if len(results) > 1 else None, + } + print(f"Saved embeddings for similarity testing") + + print("CLIP embedding test passed successfully!") + + except Exception as e: + pytest.fail(f"Test failed with error: {e}") + + def test_clip_embedding_similarity(self): + """Test CLIP embedding similarity search and text-to-image queries.""" + try: + # Skip if previous test didn't generate embeddings + if not hasattr(TestImageEmbedding, "test_embeddings"): + pytest.skip("No embeddings available from previous test") + return + + # Get embeddings from previous test + embedding1 = TestImageEmbedding.test_embeddings["embedding1"] + embedding2 = TestImageEmbedding.test_embeddings["embedding2"] + + # Initialize embedding provider for text embeddings + embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) + + # Test frame-to-frame similarity + if embedding1 is not None and embedding2 is not None: + # Compute cosine similarity + similarity = np.dot(embedding1, embedding2) + print(f"Similarity between first two frames: {similarity:.4f}") + + # Should be in range [-1, 1] + assert -1.0 <= similarity <= 1.0, f"Similarity out of valid range: {similarity}" + + # Test text-to-image similarity + if embedding1 is not None: + # Generate a list of text queries to test + text_queries = ["a video frame", "a person", "an outdoor scene", "a kitchen"] + + # Test each text query + for text_query in text_queries: + # Get text embedding + text_embedding = embedding_provider.get_text_embedding(text_query) + + # Check text embedding properties + assert isinstance(text_embedding, np.ndarray), ( + "Text embedding is not a numpy array" + ) + assert text_embedding.shape == (512,), ( + f"Text embedding has wrong shape: {text_embedding.shape}" + ) + assert abs(np.linalg.norm(text_embedding) - 1.0) < 1e-5, ( + "Text embedding is not normalized" + ) + + # Compute similarity between frame and text + text_similarity = np.dot(embedding1, text_embedding) + print(f"Similarity between frame and '{text_query}': {text_similarity:.4f}") + + # Should be in range [-1, 1] + assert -1.0 <= text_similarity <= 1.0, ( + f"Text-image similarity out of range: {text_similarity}" + ) + + print("CLIP embedding similarity tests passed successfully!") + + except Exception as e: + pytest.fail(f"Similarity test failed with error: {e}") + + +if __name__ == "__main__": + pytest.main(["-v", "--disable-warnings", __file__]) diff --git a/build/lib/dimos/agents/memory/visual_memory.py b/build/lib/dimos/agents/memory/visual_memory.py new file mode 100644 index 0000000000..0087a4fe9b --- /dev/null +++ b/build/lib/dimos/agents/memory/visual_memory.py @@ -0,0 +1,182 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Visual memory storage for managing image data persistence and retrieval +""" + +import os +import pickle +import base64 +import numpy as np +import cv2 + +from typing import Optional +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.agents.memory.visual_memory") + + +class VisualMemory: + """ + A class for storing and retrieving visual memories (images) with persistence. + + This class handles the storage, encoding, and retrieval of images associated + with vector database entries. It provides persistence mechanisms to save and + load the image data from disk. + """ + + def __init__(self, output_dir: str = None): + """ + Initialize the visual memory system. + + Args: + output_dir: Directory to store the serialized image data + """ + self.images = {} # Maps IDs to encoded images + self.output_dir = output_dir + + if output_dir: + os.makedirs(output_dir, exist_ok=True) + logger.info(f"VisualMemory initialized with output directory: {output_dir}") + else: + logger.info("VisualMemory initialized with no persistence directory") + + def add(self, image_id: str, image: np.ndarray) -> None: + """ + Add an image to visual memory. + + Args: + image_id: Unique identifier for the image + image: The image data as a numpy array + """ + # Encode the image to base64 for storage + success, encoded_image = cv2.imencode(".jpg", image) + if not success: + logger.error(f"Failed to encode image {image_id}") + return + + image_bytes = encoded_image.tobytes() + b64_encoded = base64.b64encode(image_bytes).decode("utf-8") + + # Store the encoded image + self.images[image_id] = b64_encoded + logger.debug(f"Added image {image_id} to visual memory") + + def get(self, image_id: str) -> Optional[np.ndarray]: + """ + Retrieve an image from visual memory. + + Args: + image_id: Unique identifier for the image + + Returns: + The decoded image as a numpy array, or None if not found + """ + if image_id not in self.images: + logger.warning( + f"Image not found in storage for ID {image_id}. Incomplete or corrupted image storage." + ) + return None + + try: + encoded_image = self.images[image_id] + image_bytes = base64.b64decode(encoded_image) + image_array = np.frombuffer(image_bytes, dtype=np.uint8) + image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + return image + except Exception as e: + logger.warning(f"Failed to decode image for ID {image_id}: {str(e)}") + return None + + def contains(self, image_id: str) -> bool: + """ + Check if an image ID exists in visual memory. + + Args: + image_id: Unique identifier for the image + + Returns: + True if the image exists, False otherwise + """ + return image_id in self.images + + def count(self) -> int: + """ + Get the number of images in visual memory. + + Returns: + The number of images stored + """ + return len(self.images) + + def save(self, filename: Optional[str] = None) -> str: + """ + Save the visual memory to disk. + + Args: + filename: Optional filename to save to. If None, uses a default name in the output directory. + + Returns: + The path where the data was saved + """ + if not self.output_dir: + logger.warning("No output directory specified for VisualMemory. Cannot save.") + return "" + + if not filename: + filename = "visual_memory.pkl" + + output_path = os.path.join(self.output_dir, filename) + + try: + with open(output_path, "wb") as f: + pickle.dump(self.images, f) + logger.info(f"Saved {len(self.images)} images to {output_path}") + return output_path + except Exception as e: + logger.error(f"Failed to save visual memory: {str(e)}") + return "" + + @classmethod + def load(cls, path: str, output_dir: Optional[str] = None) -> "VisualMemory": + """ + Load visual memory from disk. + + Args: + path: Path to the saved visual memory file + output_dir: Optional output directory for the new instance + + Returns: + A new VisualMemory instance with the loaded data + """ + instance = cls(output_dir=output_dir) + + if not os.path.exists(path): + logger.warning(f"Visual memory file {path} not found") + return instance + + try: + with open(path, "rb") as f: + instance.images = pickle.load(f) + logger.info(f"Loaded {len(instance.images)} images from {path}") + return instance + except Exception as e: + logger.error(f"Failed to load visual memory: {str(e)}") + return instance + + def clear(self) -> None: + """Clear all images from memory.""" + self.images = {} + logger.info("Visual memory cleared") diff --git a/build/lib/dimos/agents/planning_agent.py b/build/lib/dimos/agents/planning_agent.py new file mode 100644 index 0000000000..52971e770a --- /dev/null +++ b/build/lib/dimos/agents/planning_agent.py @@ -0,0 +1,317 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from typing import List, Optional, Literal +from reactivex import Observable +from reactivex import operators as ops +import time +from dimos.skills.skills import AbstractSkill +from dimos.agents.agent import OpenAIAgent +from dimos.utils.logging_config import setup_logger +from textwrap import dedent +from pydantic import BaseModel + +logger = setup_logger("dimos.agents.planning_agent") + + +# For response validation +class PlanningAgentResponse(BaseModel): + type: Literal["dialogue", "plan"] + content: List[str] + needs_confirmation: bool + + +class PlanningAgent(OpenAIAgent): + """Agent that plans and breaks down tasks through dialogue. + + This agent specializes in: + 1. Understanding complex tasks through dialogue + 2. Breaking tasks into concrete, executable steps + 3. Refining plans based on user feedback + 4. Streaming individual steps to ExecutionAgents + + The agent maintains conversation state and can refine plans until + the user confirms they are ready to execute. + """ + + def __init__( + self, + dev_name: str = "PlanningAgent", + model_name: str = "gpt-4", + input_query_stream: Optional[Observable] = None, + use_terminal: bool = False, + skills: Optional[AbstractSkill] = None, + ): + """Initialize the planning agent. + + Args: + dev_name: Name identifier for the agent + model_name: OpenAI model to use + input_query_stream: Observable stream of user queries + use_terminal: Whether to enable terminal input + skills: Available skills/functions for the agent + """ + # Planning state + self.conversation_history = [] + self.current_plan = [] + self.plan_confirmed = False + self.latest_response = None + + # Build system prompt + skills_list = [] + if skills is not None: + skills_list = skills.get_tools() + + system_query = dedent(f""" + You are a Robot planning assistant that helps break down tasks into concrete, executable steps. + Your goal is to: + 1. Break down the task into clear, sequential steps + 2. Refine the plan based on user feedback as needed + 3. Only finalize the plan when the user explicitly confirms + + You have the following skills at your disposal: + {skills_list} + + IMPORTANT: You MUST ALWAYS respond with ONLY valid JSON in the following format, with no additional text or explanation: + {{ + "type": "dialogue" | "plan", + "content": string | list[string], + "needs_confirmation": boolean + }} + + Your goal is to: + 1. Understand the user's task through dialogue + 2. Break it down into clear, sequential steps + 3. Refine the plan based on user feedback + 4. Only finalize the plan when the user explicitly confirms + + For dialogue responses, use: + {{ + "type": "dialogue", + "content": "Your message to the user", + "needs_confirmation": false + }} + + For plan proposals, use: + {{ + "type": "plan", + "content": ["Execute", "Execute", ...], + "needs_confirmation": true + }} + + Remember: ONLY output valid JSON, no other text.""") + + # Initialize OpenAIAgent with our configuration + super().__init__( + dev_name=dev_name, + agent_type="Planning", + query="", # Will be set by process_user_input + model_name=model_name, + input_query_stream=input_query_stream, + system_query=system_query, + max_output_tokens_per_request=1000, + response_model=PlanningAgentResponse, + ) + logger.info("Planning agent initialized") + + # Set up terminal mode if requested + self.use_terminal = use_terminal + use_terminal = False + if use_terminal: + # Start terminal interface in a separate thread + logger.info("Starting terminal interface in a separate thread") + terminal_thread = threading.Thread(target=self.start_terminal_interface, daemon=True) + terminal_thread.start() + + def _handle_response(self, response) -> None: + """Handle the agent's response and update state. + + Args: + response: ParsedChatCompletionMessage containing PlanningAgentResponse + """ + print("handle response", response) + print("handle response type", type(response)) + + # Extract the PlanningAgentResponse from parsed field if available + planning_response = response.parsed if hasattr(response, "parsed") else response + print("planning response", planning_response) + print("planning response type", type(planning_response)) + # Convert to dict for storage in conversation history + response_dict = planning_response.model_dump() + self.conversation_history.append(response_dict) + + # If it's a plan, update current plan + if planning_response.type == "plan": + logger.info(f"Updating current plan: {planning_response.content}") + self.current_plan = planning_response.content + + # Store latest response + self.latest_response = response_dict + + def _stream_plan(self) -> None: + """Stream each step of the confirmed plan.""" + logger.info("Starting to stream plan steps") + logger.debug(f"Current plan: {self.current_plan}") + + for i, step in enumerate(self.current_plan, 1): + logger.info(f"Streaming step {i}: {step}") + # Add a small delay between steps to ensure they're processed + time.sleep(0.5) + try: + self.response_subject.on_next(str(step)) + logger.debug(f"Successfully emitted step {i} to response_subject") + except Exception as e: + logger.error(f"Error emitting step {i}: {e}") + + logger.info("Plan streaming completed") + self.response_subject.on_completed() + + def _send_query(self, messages: list) -> PlanningAgentResponse: + """Send query to OpenAI and parse the response. + + Extends OpenAIAgent's _send_query to handle planning-specific response formats. + + Args: + messages: List of message dictionaries + + Returns: + PlanningAgentResponse: Validated response with type, content, and needs_confirmation + """ + try: + return super()._send_query(messages) + except Exception as e: + logger.error(f"Caught exception in _send_query: {str(e)}") + return PlanningAgentResponse( + type="dialogue", content=f"Error: {str(e)}", needs_confirmation=False + ) + + def process_user_input(self, user_input: str) -> None: + """Process user input and generate appropriate response. + + Args: + user_input: The user's message + """ + if not user_input: + return + + # Check for plan confirmation + if self.current_plan and user_input.lower() in ["yes", "y", "confirm"]: + logger.info("Plan confirmation received") + self.plan_confirmed = True + # Create a proper PlanningAgentResponse with content as a list + confirmation_msg = PlanningAgentResponse( + type="dialogue", + content="Plan confirmed! Streaming steps to execution...", + needs_confirmation=False, + ) + self._handle_response(confirmation_msg) + self._stream_plan() + return + + # Build messages for OpenAI with conversation history + messages = [ + {"role": "system", "content": self.system_query} # Using system_query from OpenAIAgent + ] + + # Add the new user input to conversation history + self.conversation_history.append({"type": "user_message", "content": user_input}) + + # Add complete conversation history including both user and assistant messages + for msg in self.conversation_history: + if msg["type"] == "user_message": + messages.append({"role": "user", "content": msg["content"]}) + elif msg["type"] == "dialogue": + messages.append({"role": "assistant", "content": msg["content"]}) + elif msg["type"] == "plan": + plan_text = "Here's my proposed plan:\n" + "\n".join( + f"{i + 1}. {step}" for i, step in enumerate(msg["content"]) + ) + messages.append({"role": "assistant", "content": plan_text}) + + # Get and handle response + response = self._send_query(messages) + self._handle_response(response) + + def start_terminal_interface(self): + """Start the terminal interface for input/output.""" + + time.sleep(5) # buffer time for clean terminal interface printing + print("=" * 50) + print("\nDimOS Action PlanningAgent\n") + print("I have access to your Robot() and Robot Skills()") + print( + "Describe your task and I'll break it down into steps using your skills as a reference." + ) + print("Once you're happy with the plan, type 'yes' to execute it.") + print("Type 'quit' to exit.\n") + + while True: + try: + print("=" * 50) + user_input = input("USER > ") + if user_input.lower() in ["quit", "exit"]: + break + + self.process_user_input(user_input) + + # Display response + if self.latest_response["type"] == "dialogue": + print(f"\nPlanner: {self.latest_response['content']}") + elif self.latest_response["type"] == "plan": + print("\nProposed Plan:") + for i, step in enumerate(self.latest_response["content"], 1): + print(f"{i}. {step}") + if self.latest_response["needs_confirmation"]: + print("\nDoes this plan look good? (yes/no)") + + if self.plan_confirmed: + print("\nPlan confirmed! Streaming steps to execution...") + break + + except KeyboardInterrupt: + print("\nStopping...") + break + except Exception as e: + print(f"\nError: {e}") + break + + def get_response_observable(self) -> Observable: + """Gets an observable that emits responses from this agent. + + This method processes the response stream from the parent class, + extracting content from `PlanningAgentResponse` objects and flattening + any lists of plan steps for emission. + + Returns: + Observable: An observable that emits plan steps from the agent. + """ + + def extract_content(response) -> List[str]: + if isinstance(response, PlanningAgentResponse): + if response.type == "plan": + return response.content # List of steps to be emitted individually + else: # dialogue type + return [response.content] # Wrap single dialogue message in a list + else: + return [str(response)] # Wrap non-PlanningAgentResponse in a list + + # Get base observable from parent class + base_observable = super().get_response_observable() + + # Process the stream: extract content and flatten plan lists + return base_observable.pipe( + ops.map(extract_content), + ops.flat_map(lambda items: items), # Flatten the list of items + ) diff --git a/build/lib/dimos/agents/prompt_builder/__init__.py b/build/lib/dimos/agents/prompt_builder/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/agents/prompt_builder/impl.py b/build/lib/dimos/agents/prompt_builder/impl.py new file mode 100644 index 0000000000..0e66191837 --- /dev/null +++ b/build/lib/dimos/agents/prompt_builder/impl.py @@ -0,0 +1,221 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from textwrap import dedent +from typing import Optional +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer + +# TODO: Make class more generic when implementing other tokenizers. Presently its OpenAI specific. +# TODO: Build out testing and logging + + +class PromptBuilder: + DEFAULT_SYSTEM_PROMPT = dedent(""" + You are an AI assistant capable of understanding and analyzing both visual and textual information. + Your task is to provide accurate and insightful responses based on the data provided to you. + Use the following information to assist the user with their query. Do not rely on any internal + knowledge or make assumptions beyond the provided data. + + Visual Context: You may have been given an image to analyze. Use the visual details to enhance your response. + Textual Context: There may be some text retrieved from a relevant database to assist you + + Instructions: + - Combine insights from both the image and the text to answer the user's question. + - If the information is insufficient to provide a complete answer, acknowledge the limitation. + - Maintain a professional and informative tone in your response. + """) + + def __init__( + self, model_name="gpt-4o", max_tokens=128000, tokenizer: Optional[AbstractTokenizer] = None + ): + """ + Initialize the prompt builder. + Args: + model_name (str): Model used (e.g., 'gpt-4o', 'gpt-4', 'gpt-3.5-turbo'). + max_tokens (int): Maximum tokens allowed in the input prompt. + tokenizer (AbstractTokenizer): The tokenizer to use for token counting and truncation. + """ + self.model_name = model_name + self.max_tokens = max_tokens + self.tokenizer: AbstractTokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name) + + def truncate_tokens(self, text, max_tokens, strategy): + """ + Truncate text to fit within max_tokens using a specified strategy. + Args: + text (str): Input text to truncate. + max_tokens (int): Maximum tokens allowed. + strategy (str): Truncation strategy ('truncate_head', 'truncate_middle', 'truncate_end', 'do_not_truncate'). + Returns: + str: Truncated text. + """ + if strategy == "do_not_truncate" or not text: + return text + + tokens = self.tokenizer.tokenize_text(text) + if len(tokens) <= max_tokens: + return text + + if strategy == "truncate_head": + truncated = tokens[-max_tokens:] + elif strategy == "truncate_end": + truncated = tokens[:max_tokens] + elif strategy == "truncate_middle": + half = max_tokens // 2 + truncated = tokens[:half] + tokens[-half:] + else: + raise ValueError(f"Unknown truncation strategy: {strategy}") + + return self.tokenizer.detokenize_text(truncated) + + def build( + self, + system_prompt=None, + user_query=None, + base64_image=None, + image_width=None, + image_height=None, + image_detail="low", + rag_context=None, + budgets=None, + policies=None, + override_token_limit=False, + ): + """ + Builds a dynamic prompt tailored to token limits, respecting budgets and policies. + + Args: + system_prompt (str): System-level instructions. + user_query (str, optional): User's query. + base64_image (str, optional): Base64-encoded image string. + image_width (int, optional): Width of the image. + image_height (int, optional): Height of the image. + image_detail (str, optional): Detail level for the image ("low" or "high"). + rag_context (str, optional): Retrieved context. + budgets (dict, optional): Token budgets for each input type. Defaults to equal allocation. + policies (dict, optional): Truncation policies for each input type. + override_token_limit (bool, optional): Whether to override the token limit. Defaults to False. + + Returns: + dict: Messages array ready to send to the OpenAI API. + """ + if user_query is None: + raise ValueError("User query is required.") + + # Debug: + # base64_image = None + + budgets = budgets or { + "system_prompt": self.max_tokens // 4, + "user_query": self.max_tokens // 4, + "image": self.max_tokens // 4, + "rag": self.max_tokens // 4, + } + policies = policies or { + "system_prompt": "truncate_end", + "user_query": "truncate_middle", + "image": "do_not_truncate", + "rag": "truncate_end", + } + + # Validate and sanitize image_detail + if image_detail not in {"low", "high"}: + image_detail = "low" # Default to "low" if invalid or None + + # Determine which system prompt to use + if system_prompt is None: + system_prompt = self.DEFAULT_SYSTEM_PROMPT + + rag_context = rag_context or "" + + # Debug: + # print("system_prompt: ", system_prompt) + # print("rag_context: ", rag_context) + + # region Token Counts + if not override_token_limit: + rag_token_cnt = self.tokenizer.token_count(rag_context) + system_prompt_token_cnt = self.tokenizer.token_count(system_prompt) + user_query_token_cnt = self.tokenizer.token_count(user_query) + image_token_cnt = ( + self.tokenizer.image_token_count(image_width, image_height, image_detail) + if base64_image + else 0 + ) + else: + rag_token_cnt = 0 + system_prompt_token_cnt = 0 + user_query_token_cnt = 0 + image_token_cnt = 0 + # endregion Token Counts + + # Create a component dictionary for dynamic allocation + components = { + "system_prompt": {"text": system_prompt, "tokens": system_prompt_token_cnt}, + "user_query": {"text": user_query, "tokens": user_query_token_cnt}, + "image": {"text": None, "tokens": image_token_cnt}, + "rag": {"text": rag_context, "tokens": rag_token_cnt}, + } + + if not override_token_limit: + # Adjust budgets and apply truncation + total_tokens = sum(comp["tokens"] for comp in components.values()) + excess_tokens = total_tokens - self.max_tokens + if excess_tokens > 0: + for key, component in components.items(): + if excess_tokens <= 0: + break + if policies[key] != "do_not_truncate": + max_allowed = max(0, budgets[key] - excess_tokens) + components[key]["text"] = self.truncate_tokens( + component["text"], max_allowed, policies[key] + ) + tokens_after = self.tokenizer.token_count(components[key]["text"]) + excess_tokens -= component["tokens"] - tokens_after + component["tokens"] = tokens_after + + # Build the `messages` structure (OpenAI specific) + messages = [{"role": "system", "content": components["system_prompt"]["text"]}] + + if components["rag"]["text"]: + user_content = [ + { + "type": "text", + "text": f"{components['rag']['text']}\n\n{components['user_query']['text']}", + } + ] + else: + user_content = [{"type": "text", "text": components["user_query"]["text"]}] + + if base64_image: + user_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + } + ) + messages.append({"role": "user", "content": user_content}) + + # Debug: + # print("system_prompt: ", system_prompt) + # print("user_query: ", user_query) + # print("user_content: ", user_content) + # print(f"Messages: {messages}") + + return messages diff --git a/build/lib/dimos/agents/tokenizer/__init__.py b/build/lib/dimos/agents/tokenizer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/agents/tokenizer/base.py b/build/lib/dimos/agents/tokenizer/base.py new file mode 100644 index 0000000000..b7e96de71f --- /dev/null +++ b/build/lib/dimos/agents/tokenizer/base.py @@ -0,0 +1,37 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +# TODO: Add a class for specific tokenizer exceptions +# TODO: Build out testing and logging +# TODO: Create proper doc strings after multiple tokenizers are implemented + + +class AbstractTokenizer(ABC): + @abstractmethod + def tokenize_text(self, text): + pass + + @abstractmethod + def detokenize_text(self, tokenized_text): + pass + + @abstractmethod + def token_count(self, text): + pass + + @abstractmethod + def image_token_count(self, image_width, image_height, image_detail="low"): + pass diff --git a/build/lib/dimos/agents/tokenizer/huggingface_tokenizer.py b/build/lib/dimos/agents/tokenizer/huggingface_tokenizer.py new file mode 100644 index 0000000000..2a7b0d2283 --- /dev/null +++ b/build/lib/dimos/agents/tokenizer/huggingface_tokenizer.py @@ -0,0 +1,88 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoTokenizer +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.utils.logging_config import setup_logger + + +class HuggingFaceTokenizer(AbstractTokenizer): + def __init__(self, model_name: str = "Qwen/Qwen2.5-0.5B", **kwargs): + super().__init__(**kwargs) + + # Initilize the tokenizer for the huggingface models + self.model_name = model_name + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + except Exception as e: + raise ValueError( + f"Failed to initialize tokenizer for model {self.model_name}. Error: {str(e)}" + ) + + def tokenize_text(self, text): + """ + Tokenize a text string using the openai tokenizer. + """ + return self.tokenizer.encode(text) + + def detokenize_text(self, tokenized_text): + """ + Detokenize a text string using the openai tokenizer. + """ + try: + return self.tokenizer.decode(tokenized_text, errors="ignore") + except Exception as e: + raise ValueError(f"Failed to detokenize text. Error: {str(e)}") + + def token_count(self, text): + """ + Gets the token count of a text string using the openai tokenizer. + """ + return len(self.tokenize_text(text)) if text else 0 + + @staticmethod + def image_token_count(image_width, image_height, image_detail="high"): + """ + Calculate the number of tokens in an image. Low detail is 85 tokens, high detail is 170 tokens per 512x512 square. + """ + logger = setup_logger("dimos.agents.tokenizer.HuggingFaceTokenizer.image_token_count") + + if image_detail == "low": + return 85 + elif image_detail == "high": + # Image dimensions + logger.debug(f"Image Width: {image_width}, Image Height: {image_height}") + if image_width is None or image_height is None: + raise ValueError( + "Image width and height must be provided for high detail image token count calculation." + ) + + # Scale image to fit within 2048 x 2048 + max_dimension = max(image_width, image_height) + if max_dimension > 2048: + scale_factor = 2048 / max_dimension + image_width = int(image_width * scale_factor) + image_height = int(image_height * scale_factor) + + # Scale shortest side to 768px + min_dimension = min(image_width, image_height) + scale_factor = 768 / min_dimension + image_width = int(image_width * scale_factor) + image_height = int(image_height * scale_factor) + + # Calculate number of 512px squares + num_squares = (image_width // 512) * (image_height // 512) + return 170 * num_squares + 85 + else: + raise ValueError("Detail specification of image is not 'low' or 'high'") diff --git a/build/lib/dimos/agents/tokenizer/openai_tokenizer.py b/build/lib/dimos/agents/tokenizer/openai_tokenizer.py new file mode 100644 index 0000000000..7517ae5e72 --- /dev/null +++ b/build/lib/dimos/agents/tokenizer/openai_tokenizer.py @@ -0,0 +1,88 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tiktoken +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.utils.logging_config import setup_logger + + +class OpenAITokenizer(AbstractTokenizer): + def __init__(self, model_name: str = "gpt-4o", **kwargs): + super().__init__(**kwargs) + + # Initilize the tokenizer for the openai set of models + self.model_name = model_name + try: + self.tokenizer = tiktoken.encoding_for_model(self.model_name) + except Exception as e: + raise ValueError( + f"Failed to initialize tokenizer for model {self.model_name}. Error: {str(e)}" + ) + + def tokenize_text(self, text): + """ + Tokenize a text string using the openai tokenizer. + """ + return self.tokenizer.encode(text) + + def detokenize_text(self, tokenized_text): + """ + Detokenize a text string using the openai tokenizer. + """ + try: + return self.tokenizer.decode(tokenized_text, errors="ignore") + except Exception as e: + raise ValueError(f"Failed to detokenize text. Error: {str(e)}") + + def token_count(self, text): + """ + Gets the token count of a text string using the openai tokenizer. + """ + return len(self.tokenize_text(text)) if text else 0 + + @staticmethod + def image_token_count(image_width, image_height, image_detail="high"): + """ + Calculate the number of tokens in an image. Low detail is 85 tokens, high detail is 170 tokens per 512x512 square. + """ + logger = setup_logger("dimos.agents.tokenizer.openai.image_token_count") + + if image_detail == "low": + return 85 + elif image_detail == "high": + # Image dimensions + logger.debug(f"Image Width: {image_width}, Image Height: {image_height}") + if image_width is None or image_height is None: + raise ValueError( + "Image width and height must be provided for high detail image token count calculation." + ) + + # Scale image to fit within 2048 x 2048 + max_dimension = max(image_width, image_height) + if max_dimension > 2048: + scale_factor = 2048 / max_dimension + image_width = int(image_width * scale_factor) + image_height = int(image_height * scale_factor) + + # Scale shortest side to 768px + min_dimension = min(image_width, image_height) + scale_factor = 768 / min_dimension + image_width = int(image_width * scale_factor) + image_height = int(image_height * scale_factor) + + # Calculate number of 512px squares + num_squares = (image_width // 512) * (image_height // 512) + return 170 * num_squares + 85 + else: + raise ValueError("Detail specification of image is not 'low' or 'high'") diff --git a/build/lib/dimos/core/__init__.py b/build/lib/dimos/core/__init__.py new file mode 100644 index 0000000000..5df6d4e803 --- /dev/null +++ b/build/lib/dimos/core/__init__.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import multiprocessing as mp +import time +from typing import Optional + +from dask.distributed import Client, LocalCluster +from rich.console import Console + +import dimos.core.colors as colors +from dimos.core.core import In, Out, RemoteOut, rpc +from dimos.core.module import Module, ModuleBase +from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport +from dimos.protocol.rpc.lcmrpc import LCMRPC +from dimos.protocol.rpc.spec import RPC + + +def patch_actor(actor, cls): ... + + +class RPCClient: + def __init__(self, actor_instance, actor_class): + self.rpc = LCMRPC() + self.actor_class = actor_class + self.remote_name = actor_class.__name__ + self.actor_instance = actor_instance + self.rpcs = actor_class.rpcs.keys() + self.rpc.start() + + def __reduce__(self): + # Return the class and the arguments needed to reconstruct the object + return ( + self.__class__, + (self.actor_instance, self.actor_class), + ) + + # passthrough + def __getattr__(self, name: str): + # Check if accessing a known safe attribute to avoid recursion + if name in { + "__class__", + "__init__", + "__dict__", + "__getattr__", + "rpcs", + "remote_name", + "remote_instance", + "actor_instance", + }: + raise AttributeError(f"{name} is not found.") + + if name in self.rpcs: + return lambda *args: self.rpc.call_sync(f"{self.remote_name}/{name}", args) + + # return super().__getattr__(name) + # Try to avoid recursion by directly accessing attributes that are known + return self.actor_instance.__getattr__(name) + + +def patchdask(dask_client: Client): + def deploy( + actor_class, + *args, + **kwargs, + ): + console = Console() + with console.status(f"deploying [green]{actor_class.__name__}", spinner="arc"): + actor = dask_client.submit( + actor_class, + *args, + **kwargs, + actor=True, + ).result() + + worker = actor.set_ref(actor).result() + print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) + + return RPCClient(actor, actor_class) + + dask_client.deploy = deploy + return dask_client + + +def start(n: Optional[int] = None) -> Client: + console = Console() + if not n: + n = mp.cpu_count() + with console.status( + f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc" + ) as status: + cluster = LocalCluster( + n_workers=n, + threads_per_worker=4, + ) + client = Client(cluster) + + console.print(f"[green]Initialized dimos local cluster with [bright_blue]{n} workers") + return patchdask(client) + + +def stop(client: Client): + client.close() + client.cluster.close() diff --git a/build/lib/dimos/core/colors.py b/build/lib/dimos/core/colors.py new file mode 100644 index 0000000000..f137523e67 --- /dev/null +++ b/build/lib/dimos/core/colors.py @@ -0,0 +1,43 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def green(text: str) -> str: + """Return the given text in green color.""" + return f"\033[92m{text}\033[0m" + + +def blue(text: str) -> str: + """Return the given text in blue color.""" + return f"\033[94m{text}\033[0m" + + +def red(text: str) -> str: + """Return the given text in red color.""" + return f"\033[91m{text}\033[0m" + + +def yellow(text: str) -> str: + """Return the given text in yellow color.""" + return f"\033[93m{text}\033[0m" + + +def cyan(text: str) -> str: + """Return the given text in cyan color.""" + return f"\033[96m{text}\033[0m" + + +def orange(text: str) -> str: + """Return the given text in orange color.""" + return f"\033[38;5;208m{text}\033[0m" diff --git a/build/lib/dimos/core/core.py b/build/lib/dimos/core/core.py new file mode 100644 index 0000000000..9c57d93559 --- /dev/null +++ b/build/lib/dimos/core/core.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import enum +import inspect +import traceback +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Protocol, + TypeVar, + get_args, + get_origin, + get_type_hints, +) + +from dask.distributed import Actor + +import dimos.core.colors as colors +from dimos.core.o3dpickle import register_picklers + +register_picklers() +T = TypeVar("T") + + +class Transport(Protocol[T]): + # used by local Output + def broadcast(self, selfstream: Out[T], value: T): ... + + # used by local Input + def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... + + +class DaskTransport(Transport[T]): + subscribers: List[Callable[[T], None]] + _started: bool = False + + def __init__(self): + self.subscribers = [] + + def __str__(self) -> str: + return colors.yellow("DaskTransport") + + def __reduce__(self): + return (DaskTransport, ()) + + def broadcast(self, selfstream: RemoteIn[T], msg: T) -> None: + for subscriber in self.subscribers: + # there is some sort of a bug here with losing worker loop + # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) + # subscriber.owner._try_bind_worker_client() + # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) + + subscriber.owner.dask_receive_msg(subscriber.name, msg).result() + + def dask_receive_msg(self, msg) -> None: + for subscriber in self.subscribers: + try: + subscriber(msg) + except Exception as e: + print( + colors.red("Error in DaskTransport subscriber callback:"), + e, + traceback.format_exc(), + ) + + # for outputs + def dask_register_subscriber(self, remoteInput: RemoteIn[T]) -> None: + self.subscribers.append(remoteInput) + + # for inputs + def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: + if not self._started: + selfstream.connection.owner.dask_register_subscriber( + selfstream.connection.name, selfstream + ).result() + self._started = True + self.subscribers.append(callback) + + +class State(enum.Enum): + UNBOUND = "unbound" # descriptor defined but not bound + READY = "ready" # bound to owner but not yet connected + CONNECTED = "connected" # input bound to an output + FLOWING = "flowing" # runtime: data observed + + +class Stream(Generic[T]): + _transport: Optional[Transport] + + def __init__( + self, + type: type[T], + name: str, + owner: Optional[Any] = None, + transport: Optional[Transport] = None, + ): + self.name = name + self.owner = owner + self.type = type + if transport: + self._transport = transport + if not hasattr(self, "_transport"): + self._transport = None + + @property + def type_name(self) -> str: + return getattr(self.type, "__name__", repr(self.type)) + + def _color_fn(self) -> Callable[[str], str]: + if self.state == State.UNBOUND: + return colors.orange + if self.state == State.READY: + return colors.blue + if self.state == State.CONNECTED: + return colors.green + return lambda s: s + + def __str__(self) -> str: # noqa: D401 + return ( + self.__class__.__name__ + + " " + + self._color_fn()(f"{self.name}[{self.type_name}]") + + " @ " + + ( + colors.orange(self.owner) + if isinstance(self.owner, Actor) + else colors.green(self.owner) + ) + + ("" if not self._transport else " via " + str(self._transport)) + ) + + +class Out(Stream[T]): + _transport: Transport + + def __init__(self, *argv, **kwargs): + super().__init__(*argv, **kwargs) + if not hasattr(self, "_transport") or self._transport is None: + self._transport = DaskTransport() + + @property + def transport(self) -> Transport[T]: + return self._transport + + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + def __reduce__(self): # noqa: D401 + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise Out without an owner ref") + return ( + RemoteOut, + ( + self.type, + self.name, + self.owner.ref, + self._transport, + ), + ) + + def publish(self, msg): + self._transport.broadcast(self, msg) + + +class RemoteStream(Stream[T]): + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + # this won't work but nvm + @property + def transport(self) -> Transport[T]: + return self._transport + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() + self._transport = value + + +class RemoteOut(RemoteStream[T]): + def connect(self, other: RemoteIn[T]): + return other.connect(self) + + +class In(Stream[T]): + connection: Optional[RemoteOut[T]] = None + _transport: Transport + + def __str__(self): + mystr = super().__str__() + + if not self.connection: + return mystr + + return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}" + + def __reduce__(self): # noqa: D401 + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise Out without an owner ref") + return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport)) + + @property + def transport(self) -> Transport[T]: + if not self._transport: + self._transport = self.connection.transport + return self._transport + + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + def subscribe(self, cb): + self.transport.subscribe(self, cb) + + +class RemoteIn(RemoteStream[T]): + def connect(self, other: RemoteOut[T]) -> None: + return self.owner.connect_stream(self.name, other).result() + + # this won't work but that's ok + @property + def transport(self) -> Transport[T]: + return self._transport + + def publish(self, msg): + self.transport.broadcast(self, msg) + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() + self._transport = value + + +def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: + fn.__rpc__ = True # type: ignore[attr-defined] + return fn + + +daskTransport = DaskTransport() # singleton instance for use in Out/RemoteOut diff --git a/build/lib/dimos/core/module.py b/build/lib/dimos/core/module.py new file mode 100644 index 0000000000..c232e613c2 --- /dev/null +++ b/build/lib/dimos/core/module.py @@ -0,0 +1,172 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import ( + Any, + Callable, + get_args, + get_origin, + get_type_hints, +) + +from dask.distributed import Actor, get_worker + +from dimos.core import colors +from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport +from dimos.protocol.rpc.lcmrpc import LCMRPC + + +class ModuleBase: + def __init__(self, *args, **kwargs): + try: + get_worker() + self.rpc = LCMRPC() + self.rpc.serve_module_rpc(self) + self.rpc.start() + except ValueError: + return + + @property + def outputs(self) -> dict[str, Out]: + return { + name: s + for name, s in self.__dict__.items() + if isinstance(s, Out) and not name.startswith("_") + } + + @property + def inputs(self) -> dict[str, In]: + return { + name: s + for name, s in self.__dict__.items() + if isinstance(s, In) and not name.startswith("_") + } + + @classmethod + @property + def rpcs(cls) -> dict[str, Callable]: + return { + name: getattr(cls, name) + for name in dir(cls) + if not name.startswith("_") + and name != "rpcs" # Exclude the rpcs property itself to prevent recursion + and callable(getattr(cls, name, None)) + and hasattr(getattr(cls, name), "__rpc__") + } + + def io(self) -> str: + def _box(name: str) -> str: + return [ + f"┌┴" + "─" * (len(name) + 1) + "┐", + f"│ {name} │", + f"└┬" + "─" * (len(name) + 1) + "┘", + ] + + # can't modify __str__ on a function like we are doing for I/O + # so we have a separate repr function here + def repr_rpc(fn: Callable) -> str: + sig = inspect.signature(fn) + # Remove 'self' parameter + params = [p for name, p in sig.parameters.items() if name != "self"] + + # Format parameters with colored types + param_strs = [] + for param in params: + param_str = param.name + if param.annotation != inspect.Parameter.empty: + type_name = getattr(param.annotation, "__name__", str(param.annotation)) + param_str += ": " + colors.green(type_name) + if param.default != inspect.Parameter.empty: + param_str += f" = {param.default}" + param_strs.append(param_str) + + # Format return type + return_annotation = "" + if sig.return_annotation != inspect.Signature.empty: + return_type = getattr(sig.return_annotation, "__name__", str(sig.return_annotation)) + return_annotation = " -> " + colors.green(return_type) + + return ( + "RPC " + colors.blue(fn.__name__) + f"({', '.join(param_strs)})" + return_annotation + ) + + ret = [ + *(f" ├─ {stream}" for stream in self.inputs.values()), + *_box(self.__class__.__name__), + *(f" ├─ {stream}" for stream in self.outputs.values()), + " │", + *(f" ├─ {repr_rpc(rpc)}" for rpc in self.rpcs.values()), + ] + + return "\n".join(ret) + + +class DaskModule(ModuleBase): + ref: Actor + worker: int + + def __init__(self, *args, **kwargs): + self.ref = None + + for name, ann in get_type_hints(self, include_extras=True).items(): + origin = get_origin(ann) + if origin is Out: + inner, *_ = get_args(ann) or (Any,) + stream = Out(inner, name, self) + setattr(self, name, stream) + elif origin is In: + inner, *_ = get_args(ann) or (Any,) + stream = In(inner, name, self) + setattr(self, name, stream) + super().__init__(*args, **kwargs) + + def set_ref(self, ref) -> int: + worker = get_worker() + self.ref = ref + self.worker = worker.name + return worker.name + + def __str__(self): + return f"{self.__class__.__name__}" + + # called from remote + def set_transport(self, stream_name: str, transport: Transport): + stream = getattr(self, stream_name, None) + if not stream: + raise ValueError(f"{stream_name} not found in {self.__class__.__name__}") + + if not isinstance(stream, Out) and not isinstance(stream, In): + raise TypeError(f"Output {stream_name} is not a valid stream") + + stream._transport = transport + return True + + # called from remote + def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): + input_stream = getattr(self, input_name, None) + if not input_stream: + raise ValueError(f"{input_name} not found in {self.__class__.__name__}") + if not isinstance(input_stream, In): + raise TypeError(f"Input {input_name} is not a valid stream") + input_stream.connection = remote_stream + + def dask_receive_msg(self, input_name: str, msg: Any): + getattr(self, input_name).transport.dask_receive_msg(msg) + + def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]): + getattr(self, output_name).transport.dask_register_subscriber(subscriber) + + +# global setting +Module = DaskModule diff --git a/build/lib/dimos/core/o3dpickle.py b/build/lib/dimos/core/o3dpickle.py new file mode 100644 index 0000000000..a18916a06c --- /dev/null +++ b/build/lib/dimos/core/o3dpickle.py @@ -0,0 +1,38 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copyreg + +import numpy as np +import open3d as o3d + + +def reduce_external(obj): + # Convert Vector3dVector to numpy array for pickling + points_array = np.asarray(obj.points) + return (reconstruct_pointcloud, (points_array,)) + + +def reconstruct_pointcloud(points_array): + # Create new PointCloud and assign the points + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(points_array) + return pc + + +def register_picklers(): + # Register for the actual PointCloud class that gets instantiated + # We need to create a dummy PointCloud to get its actual class + _dummy_pc = o3d.geometry.PointCloud() + copyreg.pickle(_dummy_pc.__class__, reduce_external) diff --git a/build/lib/dimos/core/test_core.py b/build/lib/dimos/core/test_core.py new file mode 100644 index 0000000000..ace435b54b --- /dev/null +++ b/build/lib/dimos/core/test_core.py @@ -0,0 +1,199 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Event, Thread + +import pytest + +from dimos.core import ( + In, + LCMTransport, + Module, + Out, + RemoteOut, + ZenohTransport, + pLCMTransport, + rpc, + start, + stop, +) +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.vector import Vector +from dimos.utils.testing import SensorReplay + +# never delete this line + + +@pytest.fixture +def dimos(): + """Fixture to create a Dimos client for testing.""" + client = start(2) + yield client + stop(client) + + +class RobotClient(Module): + odometry: Out[Odometry] = None + lidar: Out[LidarMessage] = None + mov: In[Vector] = None + + mov_msg_count = 0 + + def mov_callback(self, msg): + self.mov_msg_count += 1 + + def __init__(self): + super().__init__() + self._stop_event = Event() + self._thread = None + + def start(self): + self._thread = Thread(target=self.odomloop) + self._thread.start() + self.mov.subscribe(self.mov_callback) + + def odomloop(self): + odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) + lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + lidariter = lidardata.iterate() + self._stop_event.clear() + while not self._stop_event.is_set(): + for odom in odomdata.iterate(): + if self._stop_event.is_set(): + return + print(odom) + odom.pubtime = time.perf_counter() + self.odometry.publish(odom) + + lidarmsg = next(lidariter) + lidarmsg.pubtime = time.perf_counter() + self.lidar.publish(lidarmsg) + time.sleep(0.1) + + def stop(self): + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) # Wait up to 1 second for clean shutdown + + +class Navigation(Module): + mov: Out[Vector] = None + lidar: In[LidarMessage] = None + target_position: In[Vector] = None + odometry: In[Odometry] = None + + odom_msg_count = 0 + lidar_msg_count = 0 + + @rpc + def navigate_to(self, target: Vector) -> bool: ... + + def __init__(self): + super().__init__() + + @rpc + def start(self): + def _odom(msg): + self.odom_msg_count += 1 + print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) + self.mov.publish(msg.position) + + self.odometry.subscribe(_odom) + + def _lidar(msg): + self.lidar_msg_count += 1 + if hasattr(msg, "pubtime"): + print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) + else: + print("RCV: unknown time", msg) + + self.lidar.subscribe(_lidar) + + +def test_classmethods(): + # Test class property access + class_rpcs = Navigation.rpcs + print("Class rpcs:", class_rpcs) + + # Test instance property access + nav = Navigation() + instance_rpcs = nav.rpcs + print("Instance rpcs:", instance_rpcs) + + # Assertions + assert isinstance(class_rpcs, dict), "Class rpcs should be a dictionary" + assert isinstance(instance_rpcs, dict), "Instance rpcs should be a dictionary" + assert class_rpcs == instance_rpcs, "Class and instance rpcs should be identical" + + # Check that we have the expected RPC methods + assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" + assert "start" in class_rpcs, "start should be in rpcs" + assert len(class_rpcs) == 2, "Should have exactly 2 RPC methods" + + # Check that the values are callable + assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" + assert callable(class_rpcs["start"]), "start should be callable" + + # Check that they have the __rpc__ attribute + assert hasattr(class_rpcs["navigate_to"], "__rpc__"), ( + "navigate_to should have __rpc__ attribute" + ) + assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" + + +@pytest.mark.tool +def test_deployment(dimos): + robot = dimos.deploy(RobotClient) + target_stream = RemoteOut[Vector](Vector, "target") + + print("\n") + print("lidar stream", robot.lidar) + print("target stream", target_stream) + print("odom stream", robot.odometry) + + nav = dimos.deploy(Navigation) + + # this one encodes proper LCM messages + robot.lidar.transport = LCMTransport("/lidar", LidarMessage) + # odometry & mov using just a pickle over LCM + robot.odometry.transport = pLCMTransport("/odom") + nav.mov.transport = pLCMTransport("/mov") + + nav.lidar.connect(robot.lidar) + nav.odometry.connect(robot.odometry) + robot.mov.connect(nav.mov) + + print("\n" + robot.io().result() + "\n") + print("\n" + nav.io().result() + "\n") + robot.start().result() + nav.start().result() + + time.sleep(1) + robot.stop().result() + + print("robot.mov_msg_count", robot.mov_msg_count) + print("nav.odom_msg_count", nav.odom_msg_count) + print("nav.lidar_msg_count", nav.lidar_msg_count) + + assert robot.mov_msg_count >= 8 + assert nav.odom_msg_count >= 8 + assert nav.lidar_msg_count >= 8 + + +if __name__ == "__main__": + client = start(1) # single process for CI memory + test_deployment(client) diff --git a/build/lib/dimos/core/transport.py b/build/lib/dimos/core/transport.py new file mode 100644 index 0000000000..5457517b28 --- /dev/null +++ b/build/lib/dimos/core/transport.py @@ -0,0 +1,102 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import traceback +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Protocol, + TypeVar, + get_args, + get_origin, + get_type_hints, +) + +import dimos.core.colors as colors +from dimos.core.core import In, Transport +from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM +from dimos.protocol.pubsub.lcmpubsub import Topic as LCMTopic + +T = TypeVar("T") + + +class PubSubTransport(Transport[T]): + topic: any + + def __init__(self, topic: any): + self.topic = topic + + def __str__(self) -> str: + return ( + colors.green(f"{self.__class__.__name__}(") + + colors.blue(self.topic) + + colors.green(")") + ) + + +class pLCMTransport(PubSubTransport[T]): + _started: bool = False + + def __init__(self, topic: str, **kwargs): + super().__init__(topic) + self.lcm = PickleLCM(**kwargs) + + def __reduce__(self): + return (pLCMTransport, (self.topic,)) + + def broadcast(self, _, msg): + if not self._started: + self.lcm.start() + self._started = True + + self.lcm.publish(self.topic, msg) + + def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: + if not self._started: + self.lcm.start() + self._started = True + self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) + + +class LCMTransport(PubSubTransport[T]): + _started: bool = False + + def __init__(self, topic: str, type: type, **kwargs): + super().__init__(LCMTopic(topic, type)) + self.lcm = LCM(**kwargs) + + def __reduce__(self): + return (LCMTransport, (self.topic.topic, self.topic.lcm_type)) + + def broadcast(self, _, msg): + if not self._started: + self.lcm.start() + self._started = True + + self.lcm.publish(self.topic, msg) + + def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: + if not self._started: + self.lcm.start() + self._started = True + self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) + + +class ZenohTransport(PubSubTransport[T]): ... diff --git a/build/lib/dimos/environment/__init__.py b/build/lib/dimos/environment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/environment/agent_environment.py b/build/lib/dimos/environment/agent_environment.py new file mode 100644 index 0000000000..861a1f429b --- /dev/null +++ b/build/lib/dimos/environment/agent_environment.py @@ -0,0 +1,139 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +from pathlib import Path +from typing import List, Union +from .environment import Environment + + +class AgentEnvironment(Environment): + def __init__(self): + super().__init__() + self.environment_type = "agent" + self.frames = [] + self.current_frame_idx = 0 + self._depth_maps = [] + self._segmentations = [] + self._point_clouds = [] + + def initialize_from_images(self, images: Union[List[str], List[np.ndarray]]) -> bool: + """Initialize environment from a list of image paths or numpy arrays. + + Args: + images: List of image paths or numpy arrays representing frames + + Returns: + bool: True if initialization successful, False otherwise + """ + try: + self.frames = [] + for img in images: + if isinstance(img, str): + frame = cv2.imread(img) + if frame is None: + raise ValueError(f"Failed to load image: {img}") + self.frames.append(frame) + elif isinstance(img, np.ndarray): + self.frames.append(img.copy()) + else: + raise ValueError(f"Unsupported image type: {type(img)}") + return True + except Exception as e: + print(f"Failed to initialize from images: {e}") + return False + + def initialize_from_file(self, file_path: str) -> bool: + """Initialize environment from a video file. + + Args: + file_path: Path to the video file + + Returns: + bool: True if initialization successful, False otherwise + """ + try: + if not Path(file_path).exists(): + raise FileNotFoundError(f"Video file not found: {file_path}") + + cap = cv2.VideoCapture(file_path) + self.frames = [] + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + self.frames.append(frame) + + cap.release() + return len(self.frames) > 0 + except Exception as e: + print(f"Failed to initialize from video: {e}") + return False + + def initialize_from_directory(self, directory_path: str) -> bool: + """Initialize environment from a directory of images.""" + # TODO: Implement directory initialization + raise NotImplementedError("Directory initialization not yet implemented") + + def label_objects(self) -> List[str]: + """Implementation of abstract method to label objects.""" + # TODO: Implement object labeling using a detection model + raise NotImplementedError("Object labeling not yet implemented") + + def generate_segmentations( + self, model: str = None, objects: List[str] = None, *args, **kwargs + ) -> List[np.ndarray]: + """Generate segmentations for the current frame.""" + # TODO: Implement segmentation generation using specified model + raise NotImplementedError("Segmentation generation not yet implemented") + + def get_segmentations(self) -> List[np.ndarray]: + """Return pre-computed segmentations for the current frame.""" + if self._segmentations: + return self._segmentations[self.current_frame_idx] + return [] + + def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarray: + """Generate point cloud from the current frame.""" + # TODO: Implement point cloud generation + raise NotImplementedError("Point cloud generation not yet implemented") + + def get_point_cloud(self, object: str = None) -> np.ndarray: + """Return pre-computed point cloud.""" + if self._point_clouds: + return self._point_clouds[self.current_frame_idx] + return np.array([]) + + def generate_depth_map( + self, stereo: bool = None, monocular: bool = None, model: str = None, *args, **kwargs + ) -> np.ndarray: + """Generate depth map for the current frame.""" + # TODO: Implement depth map generation using specified method + raise NotImplementedError("Depth map generation not yet implemented") + + def get_depth_map(self) -> np.ndarray: + """Return pre-computed depth map for the current frame.""" + if self._depth_maps: + return self._depth_maps[self.current_frame_idx] + return np.array([]) + + def get_frame_count(self) -> int: + """Return the total number of frames.""" + return len(self.frames) + + def get_current_frame_index(self) -> int: + """Return the current frame index.""" + return self.current_frame_idx diff --git a/build/lib/dimos/environment/colmap_environment.py b/build/lib/dimos/environment/colmap_environment.py new file mode 100644 index 0000000000..9981e50098 --- /dev/null +++ b/build/lib/dimos/environment/colmap_environment.py @@ -0,0 +1,89 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# UNDER DEVELOPMENT 🚧🚧🚧 + +import cv2 +import pycolmap +from pathlib import Path +from dimos.environment.environment import Environment + + +class COLMAPEnvironment(Environment): + def initialize_from_images(self, image_dir): + """Initialize the environment from a set of image frames or video.""" + image_dir = Path(image_dir) + output_path = Path("colmap_output") + output_path.mkdir(exist_ok=True) + mvs_path = output_path / "mvs" + database_path = output_path / "database.db" + + # Step 1: Feature extraction + pycolmap.extract_features(database_path, image_dir) + + # Step 2: Feature matching + pycolmap.match_exhaustive(database_path) + + # Step 3: Sparse reconstruction + maps = pycolmap.incremental_mapping(database_path, image_dir, output_path) + maps[0].write(output_path) + + # Step 4: Dense reconstruction (optional) + pycolmap.undistort_images(mvs_path, output_path, image_dir) + pycolmap.patch_match_stereo(mvs_path) # Requires compilation with CUDA + pycolmap.stereo_fusion(mvs_path / "dense.ply", mvs_path) + + return maps + + def initialize_from_video(self, video_path, frame_output_dir): + """Extract frames from a video and initialize the environment.""" + video_path = Path(video_path) + frame_output_dir = Path(frame_output_dir) + frame_output_dir.mkdir(exist_ok=True) + + # Extract frames from the video + self._extract_frames_from_video(video_path, frame_output_dir) + + # Initialize from the extracted frames + return self.initialize_from_images(frame_output_dir) + + def _extract_frames_from_video(self, video_path, frame_output_dir): + """Extract frames from a video and save them to a directory.""" + cap = cv2.VideoCapture(str(video_path)) + frame_count = 0 + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame_filename = frame_output_dir / f"frame_{frame_count:04d}.jpg" + cv2.imwrite(str(frame_filename), frame) + frame_count += 1 + + cap.release() + + def label_objects(self): + pass + + def get_visualization(self, format_type): + pass + + def get_segmentations(self): + pass + + def get_point_cloud(self, object_id=None): + pass + + def get_depth_map(self): + pass diff --git a/build/lib/dimos/environment/environment.py b/build/lib/dimos/environment/environment.py new file mode 100644 index 0000000000..0770b0f2ce --- /dev/null +++ b/build/lib/dimos/environment/environment.py @@ -0,0 +1,172 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +import numpy as np + + +class Environment(ABC): + def __init__(self): + self.environment_type = None + self.graph = None + + @abstractmethod + def label_objects(self) -> list[str]: + """ + Label all objects in the environment. + + Returns: + A list of string labels representing the objects in the environment. + """ + pass + + @abstractmethod + def get_visualization(self, format_type): + """Return different visualization formats like images, NERFs, or other 3D file types.""" + pass + + @abstractmethod + def generate_segmentations( + self, model: str = None, objects: list[str] = None, *args, **kwargs + ) -> list[np.ndarray]: + """ + Generate object segmentations of objects[] using neural methods. + + Args: + model (str, optional): The string of the desired segmentation model (SegmentAnything, etc.) + objects (list[str], optional): The list of strings of the specific objects to segment. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + list of numpy.ndarray: A list where each element is a numpy array + representing a binary mask for a segmented area of an object in the environment. + + Note: + The specific arguments and their usage will depend on the subclass implementation. + """ + pass + + @abstractmethod + def get_segmentations(self) -> list[np.ndarray]: + """ + Get segmentations using a method like 'segment anything'. + + Returns: + list of numpy.ndarray: A list where each element is a numpy array + representing a binary mask for a segmented area of an object in the environment. + """ + pass + + @abstractmethod + def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarray: + """ + Generate a point cloud for the entire environment or a specific object. + + Args: + object (str, optional): The string of the specific object to get the point cloud for. + If None, returns the point cloud for the entire environment. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + np.ndarray: A numpy array representing the generated point cloud. + Shape: (n, 3) where n is the number of points and each point is [x, y, z]. + + Note: + The specific arguments and their usage will depend on the subclass implementation. + """ + pass + + @abstractmethod + def get_point_cloud(self, object: str = None) -> np.ndarray: + """ + Return point clouds of the entire environment or a specific object. + + Args: + object (str, optional): The string of the specific object to get the point cloud for. If None, returns the point cloud for the entire environment. + + Returns: + np.ndarray: A numpy array representing the point cloud. + Shape: (n, 3) where n is the number of points and each point is [x, y, z]. + """ + pass + + @abstractmethod + def generate_depth_map( + self, stereo: bool = None, monocular: bool = None, model: str = None, *args, **kwargs + ) -> np.ndarray: + """ + Generate a depth map using monocular or stereo camera methods. + + Args: + stereo (bool, optional): Whether to stereo camera is avaliable for ground truth depth map generation. + monocular (bool, optional): Whether to use monocular camera for neural depth map generation. + model (str, optional): The string of the desired monocular depth model (Metric3D, ZoeDepth, etc.) + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + np.ndarray: A 2D numpy array representing the generated depth map. + Shape: (height, width) where each value represents the depth + at that pixel location. + + Note: + The specific arguments and their usage will depend on the subclass implementation. + """ + pass + + @abstractmethod + def get_depth_map(self) -> np.ndarray: + """ + Return a depth map of the environment. + + Returns: + np.ndarray: A 2D numpy array representing the depth map. + Shape: (height, width) where each value represents the depth + at that pixel location. Typically, closer objects have smaller + values and farther objects have larger values. + + Note: + The exact range and units of the depth values may vary depending on the + specific implementation and the sensor or method used to generate the depth map. + """ + pass + + def initialize_from_images(self, images): + """Initialize the environment from a set of image frames or video.""" + raise NotImplementedError("This method is not implemented for this environment type.") + + def initialize_from_file(self, file_path): + """Initialize the environment from a spatial file type. + + Supported file types include: + - GLTF/GLB (GL Transmission Format) + - FBX (Filmbox) + - OBJ (Wavefront Object) + - USD/USDA/USDC (Universal Scene Description) + - STL (Stereolithography) + - COLLADA (DAE) + - Alembic (ABC) + - PLY (Polygon File Format) + - 3DS (3D Studio) + - VRML/X3D (Virtual Reality Modeling Language) + + Args: + file_path (str): Path to the spatial file. + + Raises: + NotImplementedError: If the method is not implemented for this environment type. + """ + raise NotImplementedError("This method is not implemented for this environment type.") diff --git a/build/lib/dimos/exceptions/__init__.py b/build/lib/dimos/exceptions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/exceptions/agent_memory_exceptions.py b/build/lib/dimos/exceptions/agent_memory_exceptions.py new file mode 100644 index 0000000000..cbf3460754 --- /dev/null +++ b/build/lib/dimos/exceptions/agent_memory_exceptions.py @@ -0,0 +1,89 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback + + +class AgentMemoryError(Exception): + """ + Base class for all exceptions raised by AgentMemory operations. + All custom exceptions related to AgentMemory should inherit from this class. + + Args: + message (str): Human-readable message describing the error. + """ + + def __init__(self, message="Error in AgentMemory operation"): + super().__init__(message) + + +class AgentMemoryConnectionError(AgentMemoryError): + """ + Exception raised for errors attempting to connect to the database. + This includes failures due to network issues, authentication errors, or incorrect connection parameters. + + Args: + message (str): Human-readable message describing the error. + cause (Exception, optional): Original exception, if any, that led to this error. + """ + + def __init__(self, message="Failed to connect to the database", cause=None): + super().__init__(message) + if cause: + self.cause = cause + self.traceback = traceback.format_exc() if cause else None + + def __str__(self): + return f"{self.message}\nCaused by: {repr(self.cause)}" if self.cause else self.message + + +class UnknownConnectionTypeError(AgentMemoryConnectionError): + """ + Exception raised when an unknown or unsupported connection type is specified during AgentMemory setup. + + Args: + message (str): Human-readable message explaining that an unknown connection type was used. + """ + + def __init__(self, message="Unknown connection type used in AgentMemory connection"): + super().__init__(message) + + +class DataRetrievalError(AgentMemoryError): + """ + Exception raised for errors retrieving data from the database. + This could occur due to query failures, timeouts, or corrupt data issues. + + Args: + message (str): Human-readable message describing the data retrieval error. + """ + + def __init__(self, message="Error in retrieving data during AgentMemory operation"): + super().__init__(message) + + +class DataNotFoundError(DataRetrievalError): + """ + Exception raised when the requested data is not found in the database. + This is used when a query completes successfully but returns no result for the specified identifier. + + Args: + vector_id (int or str): The identifier for the vector that was not found. + message (str, optional): Human-readable message providing more detail. If not provided, a default message is generated. + """ + + def __init__(self, vector_id, message=None): + message = message or f"Requested data for vector ID {vector_id} was not found." + super().__init__(message) + self.vector_id = vector_id diff --git a/build/lib/dimos/hardware/__init__.py b/build/lib/dimos/hardware/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/hardware/camera.py b/build/lib/dimos/hardware/camera.py new file mode 100644 index 0000000000..07c74ce508 --- /dev/null +++ b/build/lib/dimos/hardware/camera.py @@ -0,0 +1,52 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.hardware.sensor import AbstractSensor + + +class Camera(AbstractSensor): + def __init__(self, resolution=None, focal_length=None, sensor_size=None, sensor_type="Camera"): + super().__init__(sensor_type) + self.resolution = resolution # (width, height) in pixels + self.focal_length = focal_length # in millimeters + self.sensor_size = sensor_size # (width, height) in millimeters + + def get_sensor_type(self): + return self.sensor_type + + def calculate_intrinsics(self): + if not self.resolution or not self.focal_length or not self.sensor_size: + raise ValueError("Resolution, focal length, and sensor size must be provided") + + # Calculate pixel size + pixel_size_x = self.sensor_size[0] / self.resolution[0] + pixel_size_y = self.sensor_size[1] / self.resolution[1] + + # Calculate the principal point (assuming it's at the center of the image) + principal_point_x = self.resolution[0] / 2 + principal_point_y = self.resolution[1] / 2 + + # Calculate the focal length in pixels + focal_length_x = self.focal_length / pixel_size_x + focal_length_y = self.focal_length / pixel_size_y + + return { + "focal_length_x": focal_length_x, + "focal_length_y": focal_length_y, + "principal_point_x": principal_point_x, + "principal_point_y": principal_point_y, + } + + def get_intrinsics(self): + return self.calculate_intrinsics() diff --git a/build/lib/dimos/hardware/end_effector.py b/build/lib/dimos/hardware/end_effector.py new file mode 100644 index 0000000000..373408003d --- /dev/null +++ b/build/lib/dimos/hardware/end_effector.py @@ -0,0 +1,21 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class EndEffector: + def __init__(self, effector_type=None): + self.effector_type = effector_type + + def get_effector_type(self): + return self.effector_type diff --git a/build/lib/dimos/hardware/interface.py b/build/lib/dimos/hardware/interface.py new file mode 100644 index 0000000000..9d7797a569 --- /dev/null +++ b/build/lib/dimos/hardware/interface.py @@ -0,0 +1,51 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.hardware.end_effector import EndEffector +from dimos.hardware.camera import Camera +from dimos.hardware.stereo_camera import StereoCamera +from dimos.hardware.ufactory import UFactory7DOFArm + + +class HardwareInterface: + def __init__( + self, + end_effector: EndEffector = None, + sensors: list = None, + arm_architecture: UFactory7DOFArm = None, + ): + self.end_effector = end_effector + self.sensors = sensors if sensors is not None else [] + self.arm_architecture = arm_architecture + + def get_configuration(self): + """Return the current hardware configuration.""" + return { + "end_effector": self.end_effector, + "sensors": [sensor.get_sensor_type() for sensor in self.sensors], + "arm_architecture": self.arm_architecture, + } + + def set_configuration(self, configuration): + """Set the hardware configuration.""" + self.end_effector = configuration.get("end_effector", self.end_effector) + self.sensors = configuration.get("sensors", self.sensors) + self.arm_architecture = configuration.get("arm_architecture", self.arm_architecture) + + def add_sensor(self, sensor): + """Add a sensor to the hardware interface.""" + if isinstance(sensor, (Camera, StereoCamera)): + self.sensors.append(sensor) + else: + raise ValueError("Sensor must be a Camera or StereoCamera instance.") diff --git a/build/lib/dimos/hardware/piper_arm.py b/build/lib/dimos/hardware/piper_arm.py new file mode 100644 index 0000000000..5ff6357237 --- /dev/null +++ b/build/lib/dimos/hardware/piper_arm.py @@ -0,0 +1,372 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# dimos/hardware/piper_arm.py + +from typing import ( + Optional, +) +from piper_sdk import * # from the official Piper SDK +import numpy as np +import time +import subprocess +import kinpy as kp +import sys +import termios +import tty +import select + +import random +import threading + +import pytest + +import dimos.core as core +import dimos.protocol.service.lcmservice as lcmservice +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import Pose, Vector3, Twist + + +class PiperArm: + def __init__(self, arm_name: str = "arm"): + self.init_can() + self.arm = C_PiperInterface_V2() + self.arm.ConnectPort() + time.sleep(0.1) + self.resetArm() + time.sleep(0.1) + self.enable() + self.gotoZero() + time.sleep(1) + self.init_vel_controller() + + def init_can(self): + result = subprocess.run( + [ + "bash", + "dimos/hardware/can_activate.sh", + ], # pass the script path directly if it has a shebang and execute perms + stdout=subprocess.PIPE, # capture stdout + stderr=subprocess.PIPE, # capture stderr + text=True, # return strings instead of bytes + ) + + def enable(self): + while not self.arm.EnablePiper(): + pass + time.sleep(0.01) + print(f"[PiperArm] Enabled") + # self.arm.ModeCtrl( + # ctrl_mode=0x01, # CAN command mode + # move_mode=0x01, # “Move-J”, but ignored in MIT + # move_spd_rate_ctrl=100, # doesn’t matter in MIT + # is_mit_mode=0xAD # <-- the magic flag + # ) + self.arm.MotionCtrl_2(0x01, 0x01, 80, 0xAD) + + def gotoZero(self): + factor = 1000 + position = [57.0, 0.0, 250.0, 0, 85.0, .0, 0] + X = round(position[0] * factor) + Y = round(position[1] * factor) + Z = round(position[2] * factor) + RX = round(position[3] * factor) + RY = round(position[4] * factor) + RZ = round(position[5] * factor) + joint_6 = round(position[6] * factor) + print(X, Y, Z, RX, RY, RZ) + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) + self.arm.GripperCtrl(abs(joint_6), 1000, 0x01, 0) + + def softStop(self): + self.gotoZero() + time.sleep(1) + self.arm.MotionCtrl_2(0x01, 0x00, 100, ) + self.arm.MotionCtrl_1(0x01, 0, 0) + time.sleep(5) + + def cmd_EE_pose(self, x, y, z, r, p, y_): + """Command end-effector to target pose in space (position + Euler angles)""" + factor = 1000 + pose = [x * factor, y * factor, z * factor, r * factor, p * factor, y_ * factor] + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0xAD) + self.arm.EndPoseCtrl( + int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5]) + ) + + def get_EE_pose(self): + """Return the current end-effector pose as (x, y, z, r, p, y)""" + pose = self.arm.GetArmEndPoseMsgs() + # Extract individual pose values and convert to base units + # Position values are divided by 1000 to convert from SDK units to mm + # Rotation values are divided by 1000 to convert from SDK units to degrees + x = pose.end_pose.X_axis / 1000.0 + y = pose.end_pose.Y_axis / 1000.0 + z = pose.end_pose.Z_axis / 1000.0 + r = pose.end_pose.RX_axis / 1000.0 + p = pose.end_pose.RY_axis / 1000.0 + y_rot = pose.end_pose.RZ_axis / 1000.0 + + return (x, y, z, r, p, y_rot) + + def cmd_gripper_ctrl(self, position): + """Command end-effector gripper""" + position = position * 1000 + + self.arm.GripperCtrl(abs(round(position)), 1000, 0x01, 0) + print(f"[PiperArm] Commanding gripper position: {position}") + + def resetArm(self): + self.arm.MotionCtrl_1(0x02, 0, 0) + self.arm.MotionCtrl_2(0, 0, 0, 0xAD) + print(f"[PiperArm] Resetting arm") + + def init_vel_controller(self): + self.chain = kp.build_serial_chain_from_urdf( + open("dimos/hardware/piper_description.urdf"), "gripper_base" + ) + self.J = self.chain.jacobian(np.zeros(6)) + self.J_pinv = np.linalg.pinv(self.J) + self.dt = 0.01 + + def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): + + + joint_state = self.arm.GetArmJointMsgs().joint_state + # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) + joint_angles = np.array( + [ + joint_state.joint_1, + joint_state.joint_2, + joint_state.joint_3, + joint_state.joint_4, + joint_state.joint_5, + joint_state.joint_6, + ] + ) + # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) + factor = 57295.7795 # 1000*180/3.1415926 + joint_angles = joint_angles / factor # convert to radians + + q = np.array( + [ + joint_angles[0], + joint_angles[1], + joint_angles[2], + joint_angles[3], + joint_angles[4], + joint_angles[5], + ] + ) + J = self.chain.jacobian(q) + self.J_pinv = np.linalg.pinv(J) + dq = self.J_pinv @ np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt + newq = q + dq + + + + newq = newq * factor + + self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) + self.arm.JointCtrl( + int(round(newq[0])), + int(round(newq[1])), + int(round(newq[2])), + int(round(newq[3])), + int(round(newq[4])), + int(round(newq[5])), + ) + time.sleep(self.dt) + # print(f"[PiperArm] Moving to Joints to : {newq}") + + def cmd_vel_ee(self, x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot): + factor = 1000 + x_dot = x_dot * factor + y_dot = y_dot * factor + z_dot = z_dot * factor + RX_dot = RX_dot * factor + PY_dot = PY_dot * factor + YZ_dot = YZ_dot * factor + + current_pose = self.get_EE_pose() + current_pose = np.array(current_pose) + current_pose = current_pose + current_pose = current_pose + np.array([x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot]) * self.dt + current_pose = current_pose + self.cmd_EE_pose( + current_pose[0], + current_pose[1], + current_pose[2], + current_pose[3], + current_pose[4], + current_pose[5], + ) + time.sleep(self.dt) + + def disable(self): + self.softStop() + + while self.arm.DisablePiper(): + pass + time.sleep(0.01) + self.arm.DisconnectPort() + +class VelocityController(Module): + + cmd_vel: In[Twist] = None + + def __init__(self, arm, period=0.01, *args, **kwargs): + super().__init__(*args, **kwargs) + self.arm = arm + self.period = period + self.latest_cmd = None + + + @rpc + def start(self): + self.cmd_vel.subscribe(self.handle_cmd_vel) + + def control_loop(): + + while True: + + cmd_vel = self.latest_cmd + + joint_state = self.arm.GetArmJointMsgs().joint_state + # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) + joint_angles = np.array( + [ + joint_state.joint_1, + joint_state.joint_2, + joint_state.joint_3, + joint_state.joint_4, + joint_state.joint_5, + joint_state.joint_6, + ] + ) + factor = 57295.7795 # 1000*180/3.1415926 + joint_angles = joint_angles / factor # convert to radians + q = np.array( + [ + joint_angles[0], + joint_angles[1], + joint_angles[2], + joint_angles[3], + joint_angles[4], + joint_angles[5], + ] + ) + + J = self.chain.jacobian(q) + self.J_pinv = np.linalg.pinv(J) + dq = self.J_pinv @ np.array([cmd_vel.linear.X, cmd_vel.linear.y, cmd_vel.linear.z, cmd_vel.angular.x, cmd_vel.angular.y, cmd_vel.angular.z]) * self.dt + newq = q + dq + + newq = newq * factor #convert radians to scaled degree units for joint control + + self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) + self.arm.JointCtrl( + int(round(newq[0])), + int(round(newq[1])), + int(round(newq[2])), + int(round(newq[3])), + int(round(newq[4])), + int(round(newq[5])), + ) + time.sleep(self.period) + + thread = threading.Thread(target=control_loop, daemon=True) + thread.start() + + def handle_cmd_vel(self, cmd_vel: Twist): + self.latest_cmd = cmd_vel + +@pytest.mark.tool +def run_velocity_controller(): + lcmservice.autoconf() + dimos = core.start(2) + + velocity_controller = dimos.deploy(VelocityController, arm=arm, period=0.01) + velocity_controller.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) + + velocity_controller.start() + + print("Velocity controller started") + while True: + time.sleep(1) + + + +if __name__ == "__main__": + arm = PiperArm() + + print("get_EE_pose") + arm.get_EE_pose() + + def get_key(timeout=0.1): + """Non-blocking key reader for arrow keys.""" + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + rlist, _, _ = select.select([fd], [], [], timeout) + if rlist: + ch1 = sys.stdin.read(1) + if ch1 == "\x1b": # Arrow keys start with ESC + ch2 = sys.stdin.read(1) + if ch2 == "[": + ch3 = sys.stdin.read(1) + return ch1 + ch2 + ch3 + else: + return ch1 + return None + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + def teleop_linear_vel(arm): + print("Use arrow keys to control linear velocity (x/y/z). Press 'q' to quit.") + print("Up/Down: +x/-x, Left/Right: +y/-y, 'w'/'s': +z/-z") + x_dot, y_dot, z_dot = 0.0, 0.0, 0.0 + while True: + key = get_key(timeout=0.1) + if key == "\x1b[A": # Up arrow + x_dot += 0.01 + elif key == "\x1b[B": # Down arrow + x_dot -= 0.01 + elif key == "\x1b[C": # Right arrow + y_dot += 0.01 + elif key == "\x1b[D": # Left arrow + y_dot -= 0.01 + elif key == "w": + z_dot += 0.01 + elif key == "s": + z_dot -= 0.01 + elif key == "q": + print("Exiting teleop.") + arm.disable() + break + + # Optionally, clamp velocities to reasonable limits + x_dot = max(min(x_dot, 0.5), -0.5) + y_dot = max(min(y_dot, 0.5), -0.5) + z_dot = max(min(z_dot, 0.5), -0.5) + + # Only linear velocities, angular set to zero + arm.cmd_vel_ee(x_dot, y_dot, z_dot, 0, 0, 0) + print( + f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s" + ) + + run_velocity_controller() diff --git a/build/lib/dimos/hardware/sensor.py b/build/lib/dimos/hardware/sensor.py new file mode 100644 index 0000000000..3dc7b3850e --- /dev/null +++ b/build/lib/dimos/hardware/sensor.py @@ -0,0 +1,35 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class AbstractSensor(ABC): + def __init__(self, sensor_type=None): + self.sensor_type = sensor_type + + @abstractmethod + def get_sensor_type(self): + """Return the type of sensor.""" + pass + + @abstractmethod + def calculate_intrinsics(self): + """Calculate the sensor's intrinsics.""" + pass + + @abstractmethod + def get_intrinsics(self): + """Return the sensor's intrinsics.""" + pass diff --git a/build/lib/dimos/hardware/stereo_camera.py b/build/lib/dimos/hardware/stereo_camera.py new file mode 100644 index 0000000000..4ffdc51811 --- /dev/null +++ b/build/lib/dimos/hardware/stereo_camera.py @@ -0,0 +1,26 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.hardware.camera import Camera + + +class StereoCamera(Camera): + def __init__(self, baseline=None, **kwargs): + super().__init__(**kwargs) + self.baseline = baseline + + def get_intrinsics(self): + intrinsics = super().get_intrinsics() + intrinsics["baseline"] = self.baseline + return intrinsics diff --git a/build/lib/dimos/hardware/test_simple_module(1).py b/build/lib/dimos/hardware/test_simple_module(1).py new file mode 100644 index 0000000000..759b627ac6 --- /dev/null +++ b/build/lib/dimos/hardware/test_simple_module(1).py @@ -0,0 +1,90 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import threading +import time + +import pytest + +import dimos.core as core +import dimos.protocol.service.lcmservice as lcmservice +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import Pose, Vector3 + + +class MyComponent(Module): + ctrl: In[Vector3] = None + current_pose: Out[Vector3] = None + + @rpc + def start(self): + # at start you have self.ctrl and self.current_pose available + self.ctrl.subscribe(self.handle_ctrl) + + def handle_ctrl(self, target: Vector3): + print("handling control command:", target) + self.current_pose.publish(target) + + @rpc + def some_service_call(self, x: int) -> int: + return 3 + x + + +class Controller(Module): + cmd: Out[Vector3] = None + + # we can accept some parameters in the constructor + # but make sure to call super().__init__(*args, **kwargs) + def __init__(self, period=1, *args, **kwargs): + super().__init__(*args, **kwargs) + self.period = period + + @rpc + def start(self): + def send_loop(): + while True: + time.sleep(self.period) + vector = Vector3(0, 0, random.uniform(-1, 1)) + print("sending", vector) + self.cmd.publish(vector) + + thread = threading.Thread(target=send_loop, daemon=True) + thread.start() + + +@pytest.mark.tool +def test_my_component(): + # configures underlying system + lcmservice.autoconf() + dimos = core.start(2) + + controller = dimos.deploy(Controller, period=2) + component = dimos.deploy(MyComponent) + + controller.cmd.transport = core.LCMTransport("/cmd", Vector3) + component.current_pose.transport = core.LCMTransport("/pos", Vector3) + + controller.cmd.connect(component.ctrl) + controller.start() + component.start() + + print("service call result is", component.some_service_call(3)) + + while True: + time.sleep(1) + + +if __name__ == "__main__": + test_my_component() diff --git a/build/lib/dimos/hardware/ufactory.py b/build/lib/dimos/hardware/ufactory.py new file mode 100644 index 0000000000..cf4e139ccb --- /dev/null +++ b/build/lib/dimos/hardware/ufactory.py @@ -0,0 +1,32 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.hardware.end_effector import EndEffector + + +class UFactoryEndEffector(EndEffector): + def __init__(self, model=None, **kwargs): + super().__init__(**kwargs) + self.model = model + + def get_model(self): + return self.model + + +class UFactory7DOFArm: + def __init__(self, arm_length=None): + self.arm_length = arm_length + + def get_arm_length(self): + return self.arm_length diff --git a/build/lib/dimos/hardware/zed_camera.py b/build/lib/dimos/hardware/zed_camera.py new file mode 100644 index 0000000000..a2ceeba54e --- /dev/null +++ b/build/lib/dimos/hardware/zed_camera.py @@ -0,0 +1,514 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 +import open3d as o3d +from typing import Optional, Tuple, Dict, Any +import logging + +try: + import pyzed.sl as sl +except ImportError: + sl = None + logging.warning("ZED SDK not found. Please install pyzed to use ZED camera functionality.") + +from dimos.hardware.stereo_camera import StereoCamera + +logger = logging.getLogger(__name__) + + +class ZEDCamera(StereoCamera): + """ZED Camera capture node with neural depth processing.""" + + def __init__( + self, + camera_id: int = 0, + resolution: sl.RESOLUTION = sl.RESOLUTION.HD720, + depth_mode: sl.DEPTH_MODE = sl.DEPTH_MODE.NEURAL, + fps: int = 30, + **kwargs, + ): + """ + Initialize ZED Camera. + + Args: + camera_id: Camera ID (0 for first ZED) + resolution: ZED camera resolution + depth_mode: Depth computation mode + fps: Camera frame rate (default: 30) + """ + if sl is None: + raise ImportError("ZED SDK not installed. Please install pyzed package.") + + super().__init__(**kwargs) + + self.camera_id = camera_id + self.resolution = resolution + self.depth_mode = depth_mode + self.fps = fps + + # Initialize ZED camera + self.zed = sl.Camera() + self.init_params = sl.InitParameters() + self.init_params.camera_resolution = resolution + self.init_params.depth_mode = depth_mode + self.init_params.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Z_UP_X_FWD + self.init_params.coordinate_units = sl.UNIT.METER + self.init_params.camera_fps = fps + + # Set camera ID using the correct parameter name + if hasattr(self.init_params, "set_from_camera_id"): + self.init_params.set_from_camera_id(camera_id) + elif hasattr(self.init_params, "input"): + self.init_params.input.set_from_camera_id(camera_id) + + # Use enable_fill_mode instead of SENSING_MODE.STANDARD + self.runtime_params = sl.RuntimeParameters() + self.runtime_params.enable_fill_mode = True # False = STANDARD mode, True = FILL mode + + # Image containers + self.image_left = sl.Mat() + self.image_right = sl.Mat() + self.depth_map = sl.Mat() + self.point_cloud = sl.Mat() + self.confidence_map = sl.Mat() + + # Positional tracking + self.tracking_enabled = False + self.tracking_params = sl.PositionalTrackingParameters() + self.camera_pose = sl.Pose() + self.sensors_data = sl.SensorsData() + + self.is_opened = False + + def open(self) -> bool: + """Open the ZED camera.""" + try: + err = self.zed.open(self.init_params) + if err != sl.ERROR_CODE.SUCCESS: + logger.error(f"Failed to open ZED camera: {err}") + return False + + self.is_opened = True + logger.info("ZED camera opened successfully") + + # Get camera information + info = self.zed.get_camera_information() + logger.info(f"ZED Camera Model: {info.camera_model}") + logger.info(f"Serial Number: {info.serial_number}") + logger.info(f"Firmware: {info.camera_configuration.firmware_version}") + + return True + + except Exception as e: + logger.error(f"Error opening ZED camera: {e}") + return False + + def enable_positional_tracking( + self, + enable_area_memory: bool = False, + enable_pose_smoothing: bool = True, + enable_imu_fusion: bool = True, + set_floor_as_origin: bool = False, + initial_world_transform: Optional[sl.Transform] = None, + ) -> bool: + """ + Enable positional tracking on the ZED camera. + + Args: + enable_area_memory: Enable area learning to correct tracking drift + enable_pose_smoothing: Enable pose smoothing + enable_imu_fusion: Enable IMU fusion if available + set_floor_as_origin: Set the floor as origin (useful for robotics) + initial_world_transform: Initial world transform + + Returns: + True if tracking enabled successfully + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return False + + try: + # Configure tracking parameters + self.tracking_params.enable_area_memory = enable_area_memory + self.tracking_params.enable_pose_smoothing = enable_pose_smoothing + self.tracking_params.enable_imu_fusion = enable_imu_fusion + self.tracking_params.set_floor_as_origin = set_floor_as_origin + + if initial_world_transform is not None: + self.tracking_params.initial_world_transform = initial_world_transform + + # Enable tracking + err = self.zed.enable_positional_tracking(self.tracking_params) + if err != sl.ERROR_CODE.SUCCESS: + logger.error(f"Failed to enable positional tracking: {err}") + return False + + self.tracking_enabled = True + logger.info("Positional tracking enabled successfully") + return True + + except Exception as e: + logger.error(f"Error enabling positional tracking: {e}") + return False + + def disable_positional_tracking(self): + """Disable positional tracking.""" + if self.tracking_enabled: + self.zed.disable_positional_tracking() + self.tracking_enabled = False + logger.info("Positional tracking disabled") + + def get_pose( + self, reference_frame: sl.REFERENCE_FRAME = sl.REFERENCE_FRAME.WORLD + ) -> Optional[Dict[str, Any]]: + """ + Get the current camera pose. + + Args: + reference_frame: Reference frame (WORLD or CAMERA) + + Returns: + Dictionary containing: + - position: [x, y, z] in meters + - rotation: [x, y, z, w] quaternion + - euler_angles: [roll, pitch, yaw] in radians + - timestamp: Pose timestamp in nanoseconds + - confidence: Tracking confidence (0-100) + - valid: Whether pose is valid + """ + if not self.tracking_enabled: + logger.error("Positional tracking not enabled") + return None + + try: + # Get current pose + tracking_state = self.zed.get_position(self.camera_pose, reference_frame) + + if tracking_state == sl.POSITIONAL_TRACKING_STATE.OK: + # Extract translation + translation = self.camera_pose.get_translation().get() + + # Extract rotation (quaternion) + rotation = self.camera_pose.get_orientation().get() + + # Get Euler angles + euler = self.camera_pose.get_euler_angles() + + return { + "position": translation.tolist(), + "rotation": rotation.tolist(), # [x, y, z, w] + "euler_angles": euler.tolist(), # [roll, pitch, yaw] + "timestamp": self.camera_pose.timestamp.get_nanoseconds(), + "confidence": self.camera_pose.pose_confidence, + "valid": True, + "tracking_state": str(tracking_state), + } + else: + logger.warning(f"Tracking state: {tracking_state}") + return {"valid": False, "tracking_state": str(tracking_state)} + + except Exception as e: + logger.error(f"Error getting pose: {e}") + return None + + def get_imu_data(self) -> Optional[Dict[str, Any]]: + """ + Get IMU sensor data if available. + + Returns: + Dictionary containing: + - orientation: IMU orientation quaternion [x, y, z, w] + - angular_velocity: [x, y, z] in rad/s + - linear_acceleration: [x, y, z] in m/s² + - timestamp: IMU data timestamp + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None + + try: + # Get sensors data synchronized with images + if ( + self.zed.get_sensors_data(self.sensors_data, sl.TIME_REFERENCE.IMAGE) + == sl.ERROR_CODE.SUCCESS + ): + imu = self.sensors_data.get_imu_data() + + # Get IMU orientation + imu_orientation = imu.get_pose().get_orientation().get() + + # Get angular velocity + angular_vel = imu.get_angular_velocity() + + # Get linear acceleration + linear_accel = imu.get_linear_acceleration() + + return { + "orientation": imu_orientation.tolist(), + "angular_velocity": angular_vel.tolist(), + "linear_acceleration": linear_accel.tolist(), + "timestamp": self.sensors_data.timestamp.get_nanoseconds(), + "temperature": self.sensors_data.temperature.get(sl.SENSOR_LOCATION.IMU), + } + else: + return None + + except Exception as e: + logger.error(f"Error getting IMU data: {e}") + return None + + def capture_frame( + self, + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: + """ + Capture a frame from ZED camera. + + Returns: + Tuple of (left_image, right_image, depth_map) as numpy arrays + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None, None, None + + try: + # Grab frame + if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: + # Retrieve left image + self.zed.retrieve_image(self.image_left, sl.VIEW.LEFT) + left_img = self.image_left.get_data()[:, :, :3] # Remove alpha channel + + # Retrieve right image + self.zed.retrieve_image(self.image_right, sl.VIEW.RIGHT) + right_img = self.image_right.get_data()[:, :, :3] # Remove alpha channel + + # Retrieve depth map + self.zed.retrieve_measure(self.depth_map, sl.MEASURE.DEPTH) + depth = self.depth_map.get_data() + + return left_img, right_img, depth + else: + logger.warning("Failed to grab frame from ZED camera") + return None, None, None + + except Exception as e: + logger.error(f"Error capturing frame: {e}") + return None, None, None + + def capture_pointcloud(self) -> Optional[o3d.geometry.PointCloud]: + """ + Capture point cloud from ZED camera. + + Returns: + Open3D point cloud with XYZ coordinates and RGB colors + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None + + try: + if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: + # Retrieve point cloud with RGBA data + self.zed.retrieve_measure(self.point_cloud, sl.MEASURE.XYZRGBA) + point_cloud_data = self.point_cloud.get_data() + + # Convert to numpy array format + height, width = point_cloud_data.shape[:2] + points = point_cloud_data.reshape(-1, 4) + + # Extract XYZ coordinates + xyz = points[:, :3] + + # Extract and unpack RGBA color data from 4th channel + rgba_packed = points[:, 3].view(np.uint32) + + # Unpack RGBA: each 32-bit value contains 4 bytes (R, G, B, A) + colors_rgba = np.zeros((len(rgba_packed), 4), dtype=np.uint8) + colors_rgba[:, 0] = rgba_packed & 0xFF # R + colors_rgba[:, 1] = (rgba_packed >> 8) & 0xFF # G + colors_rgba[:, 2] = (rgba_packed >> 16) & 0xFF # B + colors_rgba[:, 3] = (rgba_packed >> 24) & 0xFF # A + + # Extract RGB (ignore alpha) and normalize to [0, 1] + colors_rgb = colors_rgba[:, :3].astype(np.float64) / 255.0 + + # Filter out invalid points (NaN or inf) + valid = np.isfinite(xyz).all(axis=1) + valid_xyz = xyz[valid] + valid_colors = colors_rgb[valid] + + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + + if len(valid_xyz) > 0: + pcd.points = o3d.utility.Vector3dVector(valid_xyz) + pcd.colors = o3d.utility.Vector3dVector(valid_colors) + + return pcd + else: + logger.warning("Failed to grab frame for point cloud") + return None + + except Exception as e: + logger.error(f"Error capturing point cloud: {e}") + return None + + def capture_frame_with_pose( + self, + ) -> Tuple[ + Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[Dict[str, Any]] + ]: + """ + Capture a frame with synchronized pose data. + + Returns: + Tuple of (left_image, right_image, depth_map, pose_data) + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None, None, None, None + + try: + # Grab frame + if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: + # Get images and depth + left_img, right_img, depth = self.capture_frame() + + # Get synchronized pose if tracking is enabled + pose_data = None + if self.tracking_enabled: + pose_data = self.get_pose() + + return left_img, right_img, depth, pose_data + else: + logger.warning("Failed to grab frame from ZED camera") + return None, None, None, None + + except Exception as e: + logger.error(f"Error capturing frame with pose: {e}") + return None, None, None, None + + def close(self): + """Close the ZED camera.""" + if self.is_opened: + # Disable tracking if enabled + if self.tracking_enabled: + self.disable_positional_tracking() + + self.zed.close() + self.is_opened = False + logger.info("ZED camera closed") + + def get_camera_info(self) -> Dict[str, Any]: + """Get ZED camera information and calibration parameters.""" + if not self.is_opened: + return {} + + try: + info = self.zed.get_camera_information() + calibration = info.camera_configuration.calibration_parameters + + # In ZED SDK 4.0+, the baseline calculation has changed + # Try to get baseline from the stereo parameters + try: + # Method 1: Try to get from stereo parameters if available + if hasattr(calibration, "getCameraBaseline"): + baseline = calibration.getCameraBaseline() + else: + # Method 2: Calculate from left and right camera positions + # The baseline is the distance between left and right cameras + left_cam = calibration.left_cam + right_cam = calibration.right_cam + + # Try different ways to get baseline in SDK 4.0+ + if hasattr(info.camera_configuration, "calibration_parameters_raw"): + # Use raw calibration if available + raw_calib = info.camera_configuration.calibration_parameters_raw + if hasattr(raw_calib, "T"): + baseline = abs(raw_calib.T[0]) + else: + baseline = 0.12 # Default ZED-M baseline approximation + else: + # Use default baseline for ZED-M + baseline = 0.12 # ZED-M baseline is approximately 120mm + except: + baseline = 0.12 # Fallback to approximate ZED-M baseline + + return { + "model": str(info.camera_model), + "serial_number": info.serial_number, + "firmware": info.camera_configuration.firmware_version, + "resolution": { + "width": info.camera_configuration.resolution.width, + "height": info.camera_configuration.resolution.height, + }, + "fps": info.camera_configuration.fps, + "left_cam": { + "fx": calibration.left_cam.fx, + "fy": calibration.left_cam.fy, + "cx": calibration.left_cam.cx, + "cy": calibration.left_cam.cy, + "k1": calibration.left_cam.disto[0], + "k2": calibration.left_cam.disto[1], + "p1": calibration.left_cam.disto[2], + "p2": calibration.left_cam.disto[3], + "k3": calibration.left_cam.disto[4], + }, + "right_cam": { + "fx": calibration.right_cam.fx, + "fy": calibration.right_cam.fy, + "cx": calibration.right_cam.cx, + "cy": calibration.right_cam.cy, + "k1": calibration.right_cam.disto[0], + "k2": calibration.right_cam.disto[1], + "p1": calibration.right_cam.disto[2], + "p2": calibration.right_cam.disto[3], + "k3": calibration.right_cam.disto[4], + }, + "baseline": baseline, + } + except Exception as e: + logger.error(f"Error getting camera info: {e}") + return {} + + def calculate_intrinsics(self): + """Calculate camera intrinsics from ZED calibration.""" + info = self.get_camera_info() + if not info: + return super().calculate_intrinsics() + + left_cam = info.get("left_cam", {}) + resolution = info.get("resolution", {}) + + return { + "focal_length_x": left_cam.get("fx", 0), + "focal_length_y": left_cam.get("fy", 0), + "principal_point_x": left_cam.get("cx", 0), + "principal_point_y": left_cam.get("cy", 0), + "baseline": info.get("baseline", 0), + "resolution_width": resolution.get("width", 0), + "resolution_height": resolution.get("height", 0), + } + + def __enter__(self): + """Context manager entry.""" + if not self.open(): + raise RuntimeError("Failed to open ZED camera") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() diff --git a/build/lib/dimos/manipulation/__init__.py b/build/lib/dimos/manipulation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/manipulation/manip_aio_pipeline.py b/build/lib/dimos/manipulation/manip_aio_pipeline.py new file mode 100644 index 0000000000..22e3f5d49e --- /dev/null +++ b/build/lib/dimos/manipulation/manip_aio_pipeline.py @@ -0,0 +1,590 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Asynchronous, reactive manipulation pipeline for realtime detection, filtering, and grasp generation. +""" + +import asyncio +import json +import logging +import threading +import time +import traceback +import websockets +from typing import Dict, List, Optional, Any +import numpy as np +import reactivex as rx +import reactivex.operators as ops +from dimos.utils.logging_config import setup_logger +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.grasp_generation.utils import draw_grasps_on_image +from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization +from dimos.perception.common.utils import colorize_depth +from dimos.utils.logging_config import setup_logger +import cv2 + +logger = setup_logger("dimos.perception.manip_aio_pipeline") + + +class ManipulationPipeline: + """ + Clean separated stream pipeline with frame buffering. + + - Object detection runs independently on RGB stream + - Point cloud processing subscribes to both detection and ZED streams separately + - Simple frame buffering to match RGB+depth+objects + """ + + def __init__( + self, + camera_intrinsics: List[float], # [fx, fy, cx, cy] + min_confidence: float = 0.6, + max_objects: int = 10, + vocabulary: Optional[str] = None, + grasp_server_url: Optional[str] = None, + enable_grasp_generation: bool = False, + ): + """ + Initialize the manipulation pipeline. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + max_objects: Maximum number of objects to process + vocabulary: Optional vocabulary for Detic detector + grasp_server_url: Optional WebSocket URL for AnyGrasp server + enable_grasp_generation: Whether to enable async grasp generation + """ + self.camera_intrinsics = camera_intrinsics + self.min_confidence = min_confidence + + # Grasp generation settings + self.grasp_server_url = grasp_server_url + self.enable_grasp_generation = enable_grasp_generation + + # Asyncio event loop for WebSocket communication + self.grasp_loop = None + self.grasp_loop_thread = None + + # Storage for grasp results and filtered objects + self.latest_grasps: List[dict] = [] # Simplified: just a list of grasps + self.grasps_consumed = False + self.latest_filtered_objects = [] + self.latest_rgb_for_grasps = None # Store RGB image for grasp overlay + self.grasp_lock = threading.Lock() + + # Track pending requests - simplified to single task + self.grasp_task: Optional[asyncio.Task] = None + + # Reactive subjects for streaming filtered objects and grasps + self.filtered_objects_subject = rx.subject.Subject() + self.grasps_subject = rx.subject.Subject() + self.grasp_overlay_subject = rx.subject.Subject() # Add grasp overlay subject + + # Initialize grasp client if enabled + if self.enable_grasp_generation and self.grasp_server_url: + self._start_grasp_loop() + + # Initialize object detector + self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) + + # Initialize point cloud processor + self.pointcloud_filter = PointcloudFiltering( + color_intrinsics=camera_intrinsics, + depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics + max_num_objects=max_objects, + ) + + logger.info(f"Initialized ManipulationPipeline with confidence={min_confidence}") + + def create_streams(self, zed_stream: rx.Observable) -> Dict[str, rx.Observable]: + """ + Create streams using exact old main logic. + """ + # Create ZED streams (from old main) + zed_frame_stream = zed_stream.pipe(ops.share()) + + # RGB stream for object detection (from old main) + video_stream = zed_frame_stream.pipe( + ops.map(lambda x: x.get("rgb") if x is not None else None), + ops.filter(lambda x: x is not None), + ops.share(), + ) + object_detector = ObjectDetectionStream( + camera_intrinsics=self.camera_intrinsics, + min_confidence=self.min_confidence, + class_filter=None, + detector=self.detector, + video_stream=video_stream, + disable_depth=True, + ) + + # Store latest frames for point cloud processing (from old main) + latest_rgb = None + latest_depth = None + latest_point_cloud_overlay = None + frame_lock = threading.Lock() + + # Subscribe to combined ZED frames (from old main) + def on_zed_frame(zed_data): + nonlocal latest_rgb, latest_depth + if zed_data is not None: + with frame_lock: + latest_rgb = zed_data.get("rgb") + latest_depth = zed_data.get("depth") + + # Depth stream for point cloud filtering (from old main) + def get_depth_or_overlay(zed_data): + if zed_data is None: + return None + + # Check if we have a point cloud overlay available + with frame_lock: + overlay = latest_point_cloud_overlay + + if overlay is not None: + return overlay + else: + # Return regular colorized depth + return colorize_depth(zed_data.get("depth"), max_depth=10.0) + + depth_stream = zed_frame_stream.pipe( + ops.map(get_depth_or_overlay), ops.filter(lambda x: x is not None), ops.share() + ) + + # Process object detection results with point cloud filtering (from old main) + def on_detection_next(result): + nonlocal latest_point_cloud_overlay + if "objects" in result and result["objects"]: + # Get latest RGB and depth frames + with frame_lock: + rgb = latest_rgb + depth = latest_depth + + if rgb is not None and depth is not None: + try: + filtered_objects = self.pointcloud_filter.process_images( + rgb, depth, result["objects"] + ) + + if filtered_objects: + # Store filtered objects + with self.grasp_lock: + self.latest_filtered_objects = filtered_objects + self.filtered_objects_subject.on_next(filtered_objects) + + # Create base image (colorized depth) + base_image = colorize_depth(depth, max_depth=10.0) + + # Create point cloud overlay visualization + overlay_viz = create_point_cloud_overlay_visualization( + base_image=base_image, + objects=filtered_objects, + intrinsics=self.camera_intrinsics, + ) + + # Store the overlay for the stream + with frame_lock: + latest_point_cloud_overlay = overlay_viz + + # Request grasps if enabled + if self.enable_grasp_generation and len(filtered_objects) > 0: + # Save RGB image for later grasp overlay + with frame_lock: + self.latest_rgb_for_grasps = rgb.copy() + + task = self.request_scene_grasps(filtered_objects) + if task: + # Check for results after a delay + def check_grasps_later(): + time.sleep(2.0) # Wait for grasp processing + # Wait for task to complete + if hasattr(self, "grasp_task") and self.grasp_task: + try: + result = self.grasp_task.result( + timeout=3.0 + ) # Get result with timeout + except Exception as e: + logger.warning(f"Grasp task failed or timeout: {e}") + + # Try to get latest grasps and create overlay + with self.grasp_lock: + grasps = self.latest_grasps + + if grasps and hasattr(self, "latest_rgb_for_grasps"): + # Create grasp overlay on the saved RGB image + try: + bgr_image = cv2.cvtColor( + self.latest_rgb_for_grasps, cv2.COLOR_RGB2BGR + ) + result_bgr = draw_grasps_on_image( + bgr_image, + grasps, + self.camera_intrinsics, + max_grasps=-1, # Show all grasps + ) + result_rgb = cv2.cvtColor( + result_bgr, cv2.COLOR_BGR2RGB + ) + + # Emit grasp overlay immediately + self.grasp_overlay_subject.on_next(result_rgb) + + except Exception as e: + logger.error(f"Error creating grasp overlay: {e}") + + # Emit grasps to stream + self.grasps_subject.on_next(grasps) + + threading.Thread(target=check_grasps_later, daemon=True).start() + else: + logger.warning("Failed to create grasp task") + except Exception as e: + logger.error(f"Error in point cloud filtering: {e}") + with frame_lock: + latest_point_cloud_overlay = None + + def on_error(error): + logger.error(f"Error in stream: {error}") + + def on_completed(): + logger.info("Stream completed") + + def start_subscriptions(): + """Start subscriptions in background thread (from old main)""" + # Subscribe to combined ZED frames + zed_frame_stream.subscribe(on_next=on_zed_frame) + + # Start subscriptions in background thread (from old main) + subscription_thread = threading.Thread(target=start_subscriptions, daemon=True) + subscription_thread.start() + time.sleep(2) # Give subscriptions time to start + + # Subscribe to object detection stream (from old main) + object_detector.get_stream().subscribe( + on_next=on_detection_next, on_error=on_error, on_completed=on_completed + ) + + # Create visualization stream for web interface (from old main) + viz_stream = object_detector.get_stream().pipe( + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), + ) + + # Create filtered objects stream + filtered_objects_stream = self.filtered_objects_subject + + # Create grasps stream + grasps_stream = self.grasps_subject + + # Create grasp overlay subject for immediate emission + grasp_overlay_stream = self.grasp_overlay_subject + + return { + "detection_viz": viz_stream, + "pointcloud_viz": depth_stream, + "objects": object_detector.get_stream().pipe(ops.map(lambda x: x.get("objects", []))), + "filtered_objects": filtered_objects_stream, + "grasps": grasps_stream, + "grasp_overlay": grasp_overlay_stream, + } + + def _start_grasp_loop(self): + """Start asyncio event loop in a background thread for WebSocket communication.""" + + def run_loop(): + self.grasp_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.grasp_loop) + self.grasp_loop.run_forever() + + self.grasp_loop_thread = threading.Thread(target=run_loop, daemon=True) + self.grasp_loop_thread.start() + + # Wait for loop to start + while self.grasp_loop is None: + time.sleep(0.01) + + async def _send_grasp_request( + self, points: np.ndarray, colors: Optional[np.ndarray] + ) -> Optional[List[dict]]: + """Send grasp request to AnyGrasp server.""" + try: + # Comprehensive client-side validation to prevent server errors + + # Validate points array + if points is None: + logger.error("Points array is None") + return None + if not isinstance(points, np.ndarray): + logger.error(f"Points is not numpy array: {type(points)}") + return None + if points.size == 0: + logger.error("Points array is empty") + return None + if len(points.shape) != 2 or points.shape[1] != 3: + logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)") + return None + if points.shape[0] < 100: # Minimum points for stable grasp detection + logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100") + return None + + # Validate and prepare colors + if colors is not None: + if not isinstance(colors, np.ndarray): + colors = None + elif colors.size == 0: + colors = None + elif len(colors.shape) != 2 or colors.shape[1] != 3: + colors = None + elif colors.shape[0] != points.shape[0]: + colors = None + + # If no valid colors, create default colors (required by server) + if colors is None: + # Create default white colors for all points + colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 + + # Ensure data types are correct (server expects float32) + points = points.astype(np.float32) + colors = colors.astype(np.float32) + + # Validate ranges (basic sanity checks) + if np.any(np.isnan(points)) or np.any(np.isinf(points)): + logger.error("Points contain NaN or Inf values") + return None + if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): + logger.error("Colors contain NaN or Inf values") + return None + + # Clamp color values to valid range [0, 1] + colors = np.clip(colors, 0.0, 1.0) + + async with websockets.connect(self.grasp_server_url) as websocket: + request = { + "points": points.tolist(), + "colors": colors.tolist(), # Always send colors array + "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], # Default workspace limits + } + + await websocket.send(json.dumps(request)) + + response = await websocket.recv() + grasps = json.loads(response) + + # Handle server response validation + if isinstance(grasps, dict) and "error" in grasps: + logger.error(f"Server returned error: {grasps['error']}") + return None + elif isinstance(grasps, (int, float)) and grasps == 0: + return None + elif not isinstance(grasps, list): + logger.error( + f"Server returned unexpected response type: {type(grasps)}, value: {grasps}" + ) + return None + elif len(grasps) == 0: + return None + + converted_grasps = self._convert_grasp_format(grasps) + with self.grasp_lock: + self.latest_grasps = converted_grasps + self.grasps_consumed = False # Reset consumed flag + + # Emit to reactive stream + self.grasps_subject.on_next(self.latest_grasps) + + return converted_grasps + except websockets.exceptions.ConnectionClosed as e: + logger.error(f"WebSocket connection closed: {e}") + except websockets.exceptions.WebSocketException as e: + logger.error(f"WebSocket error: {e}") + except json.JSONDecodeError as e: + logger.error(f"Failed to parse server response as JSON: {e}") + except Exception as e: + logger.error(f"Error requesting grasps: {e}") + + return None + + def request_scene_grasps(self, objects: List[dict]) -> Optional[asyncio.Task]: + """Request grasps for entire scene by combining all object point clouds.""" + if not self.grasp_loop or not objects: + return None + + all_points = [] + all_colors = [] + valid_objects = 0 + + for i, obj in enumerate(objects): + # Validate point cloud data + if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: + continue + + points = obj["point_cloud_numpy"] + if not isinstance(points, np.ndarray) or points.size == 0: + continue + + # Ensure points have correct shape (N, 3) + if len(points.shape) != 2 or points.shape[1] != 3: + continue + + # Validate colors if present + colors = None + if "colors_numpy" in obj and obj["colors_numpy"] is not None: + colors = obj["colors_numpy"] + if isinstance(colors, np.ndarray) and colors.size > 0: + # Ensure colors match points count and have correct shape + if colors.shape[0] != points.shape[0]: + colors = None # Ignore colors for this object + elif len(colors.shape) != 2 or colors.shape[1] != 3: + colors = None # Ignore colors for this object + + all_points.append(points) + if colors is not None: + all_colors.append(colors) + valid_objects += 1 + + if not all_points: + return None + + try: + combined_points = np.vstack(all_points) + + # Only combine colors if ALL objects have valid colors + combined_colors = None + if len(all_colors) == valid_objects and len(all_colors) > 0: + combined_colors = np.vstack(all_colors) + + # Validate final combined data + if combined_points.size == 0: + logger.warning("Combined point cloud is empty") + return None + + if combined_colors is not None and combined_colors.shape[0] != combined_points.shape[0]: + logger.warning( + f"Color/point count mismatch: {combined_colors.shape[0]} colors vs {combined_points.shape[0]} points, dropping colors" + ) + combined_colors = None + + except Exception as e: + logger.error(f"Failed to combine point clouds: {e}") + return None + + try: + # Check if there's already a grasp task running + if hasattr(self, "grasp_task") and self.grasp_task and not self.grasp_task.done(): + return self.grasp_task + + task = asyncio.run_coroutine_threadsafe( + self._send_grasp_request(combined_points, combined_colors), self.grasp_loop + ) + + self.grasp_task = task + return task + except Exception as e: + logger.warning("Failed to create grasp task") + return None + + def get_latest_grasps(self, timeout: float = 5.0) -> Optional[List[dict]]: + """Get latest grasp results, waiting for new ones if current ones have been consumed.""" + # Mark current grasps as consumed and get a reference + with self.grasp_lock: + current_grasps = self.latest_grasps + self.grasps_consumed = True + + # If we already have grasps and they haven't been consumed, return them + if current_grasps is not None and not getattr(self, "grasps_consumed", False): + return current_grasps + + # Wait for new grasps + start_time = time.time() + while time.time() - start_time < timeout: + with self.grasp_lock: + # Check if we have new grasps (different from what we marked as consumed) + if self.latest_grasps is not None and not getattr(self, "grasps_consumed", False): + return self.latest_grasps + time.sleep(0.1) # Check every 100ms + + return None # Timeout reached + + def clear_grasps(self) -> None: + """Clear all stored grasp results.""" + with self.grasp_lock: + self.latest_grasps = [] + + def _prepare_colors(self, colors: Optional[np.ndarray]) -> Optional[np.ndarray]: + """Prepare colors array, converting from various formats if needed.""" + if colors is None: + return None + + if colors.max() > 1.0: + colors = colors / 255.0 + + return colors + + def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: + """Convert AnyGrasp format to our visualization format.""" + converted = [] + + for i, grasp in enumerate(anygrasp_grasps): + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) + euler_angles = self._rotation_matrix_to_euler(rotation_matrix) + + converted_grasp = { + "id": f"grasp_{i}", + "score": grasp.get("score", 0.0), + "width": grasp.get("width", 0.0), + "height": grasp.get("height", 0.0), + "depth": grasp.get("depth", 0.0), + "translation": grasp.get("translation", [0, 0, 0]), + "rotation_matrix": rotation_matrix.tolist(), + "euler_angles": euler_angles, + } + converted.append(converted_grasp) + + converted.sort(key=lambda x: x["score"], reverse=True) + + return converted + + def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: + """Convert rotation matrix to Euler angles (in radians).""" + sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) + + singular = sy < 1e-6 + + if not singular: + x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + else: + x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = 0 + + return {"roll": x, "pitch": y, "yaw": z} + + def cleanup(self): + """Clean up resources.""" + if hasattr(self.detector, "cleanup"): + self.detector.cleanup() + + if self.grasp_loop and self.grasp_loop_thread: + self.grasp_loop.call_soon_threadsafe(self.grasp_loop.stop) + self.grasp_loop_thread.join(timeout=1.0) + + if hasattr(self.pointcloud_filter, "cleanup"): + self.pointcloud_filter.cleanup() + logger.info("ManipulationPipeline cleaned up") diff --git a/build/lib/dimos/manipulation/manip_aio_processer.py b/build/lib/dimos/manipulation/manip_aio_processer.py new file mode 100644 index 0000000000..a8afc96a7c --- /dev/null +++ b/build/lib/dimos/manipulation/manip_aio_processer.py @@ -0,0 +1,411 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sequential manipulation processor for single-frame processing without reactive streams. +""" + +import logging +import time +from typing import Dict, List, Optional, Any, Tuple +import numpy as np +import cv2 + +from dimos.utils.logging_config import setup_logger +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.perception.grasp_generation.grasp_generation import AnyGraspGenerator +from dimos.perception.grasp_generation.utils import create_grasp_overlay +from dimos.perception.pointcloud.utils import ( + create_point_cloud_overlay_visualization, + extract_and_cluster_misc_points, + overlay_point_clouds_on_image, +) +from dimos.perception.common.utils import ( + colorize_depth, + detection_results_to_object_data, + combine_object_data, +) + +logger = setup_logger("dimos.perception.manip_aio_processor") + + +class ManipulationProcessor: + """ + Sequential manipulation processor for single-frame processing. + + Processes RGB-D frames through object detection, point cloud filtering, + and AnyGrasp grasp generation in a single thread without reactive streams. + """ + + def __init__( + self, + camera_intrinsics: List[float], # [fx, fy, cx, cy] + min_confidence: float = 0.6, + max_objects: int = 20, + vocabulary: Optional[str] = None, + enable_grasp_generation: bool = False, + grasp_server_url: Optional[str] = None, # Required when enable_grasp_generation=True + enable_segmentation: bool = True, + ): + """ + Initialize the manipulation processor. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + max_objects: Maximum number of objects to process + vocabulary: Optional vocabulary for Detic detector + enable_grasp_generation: Whether to enable grasp generation + grasp_server_url: WebSocket URL for AnyGrasp server (required when enable_grasp_generation=True) + enable_segmentation: Whether to enable semantic segmentation + segmentation_model: Segmentation model to use (SAM 2 or FastSAM) + """ + self.camera_intrinsics = camera_intrinsics + self.min_confidence = min_confidence + self.max_objects = max_objects + self.enable_grasp_generation = enable_grasp_generation + self.grasp_server_url = grasp_server_url + self.enable_segmentation = enable_segmentation + + # Validate grasp generation requirements + if enable_grasp_generation and not grasp_server_url: + raise ValueError("grasp_server_url is required when enable_grasp_generation=True") + + # Initialize object detector + self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) + + # Initialize point cloud processor + self.pointcloud_filter = PointcloudFiltering( + color_intrinsics=camera_intrinsics, + depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics + max_num_objects=max_objects, + ) + + # Initialize semantic segmentation + self.segmenter = None + if self.enable_segmentation: + self.segmenter = Sam2DSegmenter( + device="cuda", + use_tracker=False, # Disable tracker for simple segmentation + use_analyzer=False, # Disable analyzer for simple segmentation + ) + + # Initialize grasp generator if enabled + self.grasp_generator = None + if self.enable_grasp_generation: + try: + self.grasp_generator = AnyGraspGenerator(server_url=grasp_server_url) + logger.info("AnyGrasp generator initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize AnyGrasp generator: {e}") + self.grasp_generator = None + self.enable_grasp_generation = False + + logger.info( + f"Initialized ManipulationProcessor with confidence={min_confidence}, " + f"grasp_generation={enable_grasp_generation}" + ) + + def process_frame( + self, rgb_image: np.ndarray, depth_image: np.ndarray, generate_grasps: bool = None + ) -> Dict[str, Any]: + """ + Process a single RGB-D frame through the complete pipeline. + + Args: + rgb_image: RGB image (H, W, 3) + depth_image: Depth image (H, W) in meters + generate_grasps: Override grasp generation setting for this frame + + Returns: + Dictionary containing: + - detection_viz: Visualization of object detection + - pointcloud_viz: Visualization of point cloud overlay + - segmentation_viz: Visualization of semantic segmentation (if enabled) + - detection2d_objects: Raw detection results as ObjectData + - segmentation2d_objects: Raw segmentation results as ObjectData (if enabled) + - detected_objects: Detection (Object Detection) objects with point clouds filtered + - all_objects: Combined objects with intelligent duplicate removal + - full_pointcloud: Complete scene point cloud (if point cloud processing enabled) + - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) + - misc_voxel_grid: Open3D voxel grid approximating all misc/background points + - misc_pointcloud_viz: Visualization of misc/background cluster overlay + - grasps: Grasp results (AnyGrasp list of dictionaries, if enabled) + - grasp_overlay: Grasp visualization overlay (if enabled) + - processing_time: Total processing time + """ + start_time = time.time() + results = {} + + try: + # Step 1: Object Detection + step_start = time.time() + detection_results = self.run_object_detection(rgb_image) + results["detection2d_objects"] = detection_results.get("objects", []) + results["detection_viz"] = detection_results.get("viz_frame") + detection_time = time.time() - step_start + + # Step 2: Semantic Segmentation (if enabled) + segmentation_time = 0 + if self.enable_segmentation: + step_start = time.time() + segmentation_results = self.run_segmentation(rgb_image) + results["segmentation2d_objects"] = segmentation_results.get("objects", []) + results["segmentation_viz"] = segmentation_results.get("viz_frame") + segmentation_time = time.time() - step_start + + # Step 3: Point Cloud Processing + pointcloud_time = 0 + detection2d_objects = results.get("detection2d_objects", []) + segmentation2d_objects = results.get("segmentation2d_objects", []) + + # Process detection objects if available + detected_objects = [] + if detection2d_objects: + step_start = time.time() + detected_objects = self.run_pointcloud_filtering( + rgb_image, depth_image, detection2d_objects + ) + pointcloud_time += time.time() - step_start + + # Process segmentation objects if available + segmentation_filtered_objects = [] + if segmentation2d_objects: + step_start = time.time() + segmentation_filtered_objects = self.run_pointcloud_filtering( + rgb_image, depth_image, segmentation2d_objects + ) + pointcloud_time += time.time() - step_start + + # Combine all objects using intelligent duplicate removal + all_objects = combine_object_data( + detected_objects, segmentation_filtered_objects, overlap_threshold=0.8 + ) + + # Get full point cloud + full_pcd = self.pointcloud_filter.get_full_point_cloud() + + # Extract misc/background points and create voxel grid + misc_start = time.time() + misc_clusters, misc_voxel_grid = extract_and_cluster_misc_points( + full_pcd, + all_objects, + eps=0.03, + min_points=100, + enable_filtering=True, + voxel_size=0.02, + ) + misc_time = time.time() - misc_start + + # Store results + results.update( + { + "detected_objects": detected_objects, + "all_objects": all_objects, + "full_pointcloud": full_pcd, + "misc_clusters": misc_clusters, + "misc_voxel_grid": misc_voxel_grid, + } + ) + + # Create point cloud visualizations + base_image = colorize_depth(depth_image, max_depth=10.0) + + # Create visualizations + results["pointcloud_viz"] = ( + create_point_cloud_overlay_visualization( + base_image=base_image, + objects=all_objects, + intrinsics=self.camera_intrinsics, + ) + if all_objects + else base_image + ) + + results["detected_pointcloud_viz"] = ( + create_point_cloud_overlay_visualization( + base_image=base_image, + objects=detected_objects, + intrinsics=self.camera_intrinsics, + ) + if detected_objects + else base_image + ) + + if misc_clusters: + # Generate consistent colors for clusters + cluster_colors = [ + tuple((np.random.RandomState(i + 100).rand(3) * 255).astype(int)) + for i in range(len(misc_clusters)) + ] + results["misc_pointcloud_viz"] = overlay_point_clouds_on_image( + base_image=base_image, + point_clouds=misc_clusters, + camera_intrinsics=self.camera_intrinsics, + colors=cluster_colors, + point_size=2, + alpha=0.6, + ) + else: + results["misc_pointcloud_viz"] = base_image + + # Step 4: Grasp Generation (if enabled) + should_generate_grasps = ( + generate_grasps if generate_grasps is not None else self.enable_grasp_generation + ) + + if should_generate_grasps and all_objects and full_pcd: + grasps = self.run_grasp_generation(all_objects, full_pcd) + results["grasps"] = grasps + if grasps: + results["grasp_overlay"] = create_grasp_overlay( + rgb_image, grasps, self.camera_intrinsics + ) + + except Exception as e: + logger.error(f"Error processing frame: {e}") + results["error"] = str(e) + + # Add timing information + total_time = time.time() - start_time + results.update( + { + "processing_time": total_time, + "timing_breakdown": { + "detection": detection_time if "detection_time" in locals() else 0, + "segmentation": segmentation_time if "segmentation_time" in locals() else 0, + "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, + "misc_extraction": misc_time if "misc_time" in locals() else 0, + "total": total_time, + }, + } + ) + + return results + + def run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: + """Run object detection on RGB image.""" + try: + # Convert RGB to BGR for Detic detector + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Use process_image method from Detic detector + bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image( + bgr_image + ) + + # Convert to ObjectData format using utility function + objects = detection_results_to_object_data( + bboxes=bboxes, + track_ids=track_ids, + class_ids=class_ids, + confidences=confidences, + names=names, + masks=masks, + source="detection", + ) + + # Create visualization using detector's built-in method + viz_frame = self.detector.visualize_results( + rgb_image, bboxes, track_ids, class_ids, confidences, names + ) + + return {"objects": objects, "viz_frame": viz_frame} + + except Exception as e: + logger.error(f"Object detection failed: {e}") + return {"objects": [], "viz_frame": rgb_image.copy()} + + def run_pointcloud_filtering( + self, rgb_image: np.ndarray, depth_image: np.ndarray, objects: List[Dict] + ) -> List[Dict]: + """Run point cloud filtering on detected objects.""" + try: + filtered_objects = self.pointcloud_filter.process_images( + rgb_image, depth_image, objects + ) + return filtered_objects if filtered_objects else [] + except Exception as e: + logger.error(f"Point cloud filtering failed: {e}") + return [] + + def run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: + """Run semantic segmentation on RGB image.""" + if not self.segmenter: + return {"objects": [], "viz_frame": rgb_image.copy()} + + try: + # Convert RGB to BGR for segmenter + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Get segmentation results + masks, bboxes, track_ids, probs, names = self.segmenter.process_image(bgr_image) + + # Convert to ObjectData format using utility function + objects = detection_results_to_object_data( + bboxes=bboxes, + track_ids=track_ids, + class_ids=list(range(len(bboxes))), # Use indices as class IDs for segmentation + confidences=probs, + names=names, + masks=masks, + source="segmentation", + ) + + # Create visualization + if masks: + viz_bgr = self.segmenter.visualize_results( + bgr_image, masks, bboxes, track_ids, probs, names + ) + # Convert back to RGB + viz_frame = cv2.cvtColor(viz_bgr, cv2.COLOR_BGR2RGB) + else: + viz_frame = rgb_image.copy() + + return {"objects": objects, "viz_frame": viz_frame} + + except Exception as e: + logger.error(f"Segmentation failed: {e}") + return {"objects": [], "viz_frame": rgb_image.copy()} + + def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[List[Dict]]: + """Run grasp generation using the configured generator (AnyGrasp).""" + if not self.grasp_generator: + logger.warning("Grasp generation requested but no generator available") + return None + + try: + # Generate grasps using the configured generator + grasps = self.grasp_generator.generate_grasps_from_objects(filtered_objects, full_pcd) + + # Return parsed results directly (list of grasp dictionaries) + return grasps + + except Exception as e: + logger.error(f"AnyGrasp grasp generation failed: {e}") + return None + + def cleanup(self): + """Clean up resources.""" + if hasattr(self.detector, "cleanup"): + self.detector.cleanup() + if hasattr(self.pointcloud_filter, "cleanup"): + self.pointcloud_filter.cleanup() + if self.segmenter and hasattr(self.segmenter, "cleanup"): + self.segmenter.cleanup() + if self.grasp_generator and hasattr(self.grasp_generator, "cleanup"): + self.grasp_generator.cleanup() + logger.info("ManipulationProcessor cleaned up") diff --git a/build/lib/dimos/manipulation/manipulation_history.py b/build/lib/dimos/manipulation/manipulation_history.py new file mode 100644 index 0000000000..8404b225c1 --- /dev/null +++ b/build/lib/dimos/manipulation/manipulation_history.py @@ -0,0 +1,418 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for manipulation history tracking and search.""" + +from typing import Dict, List, Optional, Any, Tuple, Union, Set, Callable +from dataclasses import dataclass, field +import time +from datetime import datetime +import os +import json +import pickle +import uuid + +from dimos.types.manipulation import ( + ManipulationTask, + AbstractConstraint, + ManipulationTaskConstraint, + ManipulationMetadata, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.types.manipulation_history") + + +@dataclass +class ManipulationHistoryEntry: + """An entry in the manipulation history. + + Attributes: + task: The manipulation task executed + timestamp: When the manipulation was performed + result: Result of the manipulation (success/failure) + manipulation_response: Response from the motion planner/manipulation executor + """ + + task: ManipulationTask + timestamp: float = field(default_factory=time.time) + result: Dict[str, Any] = field(default_factory=dict) + manipulation_response: Optional[str] = ( + None # Any elaborative response from the motion planner / manipulation executor + ) + + def __str__(self) -> str: + status = self.result.get("status", "unknown") + return f"ManipulationHistoryEntry(task='{self.task.description}', status={status}, time={datetime.fromtimestamp(self.timestamp).strftime('%H:%M:%S')})" + + +class ManipulationHistory: + """A simplified, dictionary-based storage for manipulation history. + + This class provides an efficient way to store and query manipulation tasks, + focusing on quick lookups and flexible search capabilities. + """ + + def __init__(self, output_dir: str = None, new_memory: bool = False): + """Initialize a new manipulation history. + + Args: + output_dir: Directory to save history to + new_memory: If True, creates a new memory instead of loading existing one + """ + self._history: List[ManipulationHistoryEntry] = [] + self._output_dir = output_dir + + if output_dir and not new_memory: + self.load_from_dir(output_dir) + elif output_dir: + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Created new manipulation history at {output_dir}") + + def __len__(self) -> int: + """Return the number of entries in the history.""" + return len(self._history) + + def __str__(self) -> str: + """Return a string representation of the history.""" + if not self._history: + return "ManipulationHistory(empty)" + + return ( + f"ManipulationHistory(entries={len(self._history)}, " + f"time_range={datetime.fromtimestamp(self._history[0].timestamp).strftime('%Y-%m-%d %H:%M:%S')} to " + f"{datetime.fromtimestamp(self._history[-1].timestamp).strftime('%Y-%m-%d %H:%M:%S')})" + ) + + def clear(self) -> None: + """Clear all entries from the history.""" + self._history.clear() + logger.info("Cleared manipulation history") + + if self._output_dir: + self.save_history() + + def add_entry(self, entry: ManipulationHistoryEntry) -> None: + """Add an entry to the history. + + Args: + entry: The entry to add + """ + self._history.append(entry) + self._history.sort(key=lambda e: e.timestamp) + + if self._output_dir: + self.save_history() + + def save_history(self) -> None: + """Save the history to the output directory.""" + if not self._output_dir: + logger.warning("Cannot save history: no output directory specified") + return + + os.makedirs(self._output_dir, exist_ok=True) + history_path = os.path.join(self._output_dir, "manipulation_history.pickle") + + with open(history_path, "wb") as f: + pickle.dump(self._history, f) + + logger.info(f"Saved manipulation history to {history_path}") + + # Also save a JSON representation for easier inspection + json_path = os.path.join(self._output_dir, "manipulation_history.json") + try: + history_data = [ + { + "task": { + "description": entry.task.description, + "target_object": entry.task.target_object, + "target_point": entry.task.target_point, + "timestamp": entry.task.timestamp, + "task_id": entry.task.task_id, + "metadata": entry.task.metadata, + }, + "result": entry.result, + "timestamp": entry.timestamp, + "manipulation_response": entry.manipulation_response, + } + for entry in self._history + ] + + with open(json_path, "w") as f: + json.dump(history_data, f, indent=2) + + logger.info(f"Saved JSON representation to {json_path}") + except Exception as e: + logger.error(f"Failed to save JSON representation: {e}") + + def load_from_dir(self, directory: str) -> None: + """Load history from the specified directory. + + Args: + directory: Directory to load history from + """ + history_path = os.path.join(directory, "manipulation_history.pickle") + + if not os.path.exists(history_path): + logger.warning(f"No history found at {history_path}") + return + + try: + with open(history_path, "rb") as f: + self._history = pickle.load(f) + + logger.info( + f"Loaded manipulation history from {history_path} with {len(self._history)} entries" + ) + except Exception as e: + logger.error(f"Failed to load history: {e}") + + def get_all_entries(self) -> List[ManipulationHistoryEntry]: + """Get all entries in chronological order. + + Returns: + List of all manipulation history entries + """ + return self._history.copy() + + def get_entry_by_index(self, index: int) -> Optional[ManipulationHistoryEntry]: + """Get an entry by its index. + + Args: + index: Index of the entry to retrieve + + Returns: + The entry at the specified index or None if index is out of bounds + """ + if 0 <= index < len(self._history): + return self._history[index] + return None + + def get_entries_by_timerange( + self, start_time: float, end_time: float + ) -> List[ManipulationHistoryEntry]: + """Get entries within a specific time range. + + Args: + start_time: Start time (UNIX timestamp) + end_time: End time (UNIX timestamp) + + Returns: + List of entries within the specified time range + """ + return [entry for entry in self._history if start_time <= entry.timestamp <= end_time] + + def get_entries_by_object(self, object_name: str) -> List[ManipulationHistoryEntry]: + """Get entries related to a specific object. + + Args: + object_name: Name of the object to search for + + Returns: + List of entries related to the specified object + """ + return [entry for entry in self._history if entry.task.target_object == object_name] + + def create_task_entry( + self, task: ManipulationTask, result: Dict[str, Any] = None, agent_response: str = None + ) -> ManipulationHistoryEntry: + """Create a new manipulation history entry. + + Args: + task: The manipulation task + result: Result of the manipulation + agent_response: Response from the agent about this manipulation + + Returns: + The created history entry + """ + entry = ManipulationHistoryEntry( + task=task, result=result or {}, manipulation_response=agent_response + ) + self.add_entry(entry) + return entry + + def search(self, **kwargs) -> List[ManipulationHistoryEntry]: + """Flexible search method that can search by any field in ManipulationHistoryEntry using dot notation. + + This method supports dot notation to access nested fields. String values automatically use + substring matching (contains), while all other types use exact matching. + + Examples: + # Time-based searches: + - search(**{"task.metadata.timestamp": ('>', start_time)}) - entries after start_time + - search(**{"task.metadata.timestamp": ('>=', time - 1800)}) - entries in last 30 mins + + # Constraint searches: + - search(**{"task.constraints.*.reference_point.x": 2.5}) - tasks with x=2.5 reference point + - search(**{"task.constraints.*.end_angle.x": 90}) - tasks with 90-degree x rotation + - search(**{"task.constraints.*.lock_x": True}) - tasks with x-axis translation locked + + # Object and result searches: + - search(**{"task.metadata.objects.*.label": "cup"}) - tasks involving cups + - search(**{"result.status": "success"}) - successful tasks + - search(**{"result.error": "Collision"}) - tasks that had collisions + + Args: + **kwargs: Key-value pairs for searching using dot notation for field paths. + + Returns: + List of matching entries + """ + if not kwargs: + return self._history.copy() + + results = self._history.copy() + + for key, value in kwargs.items(): + # For all searches, automatically determine if we should use contains for strings + results = [e for e in results if self._check_field_match(e, key, value)] + + return results + + def _check_field_match(self, entry, field_path, value) -> bool: + """Check if a field matches the value, with special handling for strings, collections and comparisons. + + For string values, we automatically use substring matching (contains). + For collections (returned by * path), we check if any element matches. + For numeric values (like timestamps), supports >, <, >= and <= comparisons. + For all other types, we use exact matching. + + Args: + entry: The entry to check + field_path: Dot-separated path to the field + value: Value to match against. For comparisons, use tuples like: + ('>', timestamp) - greater than + ('<', timestamp) - less than + ('>=', timestamp) - greater or equal + ('<=', timestamp) - less or equal + + Returns: + True if the field matches the value, False otherwise + """ + try: + field_value = self._get_value_by_path(entry, field_path) + + # Handle comparison operators for timestamps and numbers + if isinstance(value, tuple) and len(value) == 2: + op, compare_value = value + if op == ">": + return field_value > compare_value + elif op == "<": + return field_value < compare_value + elif op == ">=": + return field_value >= compare_value + elif op == "<=": + return field_value <= compare_value + + # Handle lists (from collection searches) + if isinstance(field_value, list): + for item in field_value: + # String values use contains matching + if isinstance(item, str) and isinstance(value, str): + if value in item: + return True + # All other types use exact matching + elif item == value: + return True + return False + + # String values use contains matching + elif isinstance(field_value, str) and isinstance(value, str): + return value in field_value + # All other types use exact matching + else: + return field_value == value + + except (AttributeError, KeyError): + return False + + def _get_value_by_path(self, obj, path): + """Get a value from an object using a dot-separated path. + + This method handles three special cases: + 1. Regular attribute access (obj.attr) + 2. Dictionary key access (dict[key]) + 3. Collection search (dict.*.attr) - when * is used, it searches all values in the collection + + Args: + obj: Object to get value from + path: Dot-separated path to the field (e.g., "task.metadata.robot") + + Returns: + Value at the specified path or list of values for collection searches + + Raises: + AttributeError: If an attribute in the path doesn't exist + KeyError: If a dictionary key in the path doesn't exist + """ + current = obj + parts = path.split(".") + + for i, part in enumerate(parts): + # Collection search (*.attr) - search across all items in a collection + if part == "*": + # Get remaining path parts + remaining_path = ".".join(parts[i + 1 :]) + + # Handle different collection types + if isinstance(current, dict): + items = current.values() + if not remaining_path: # If * is the last part, return all values + return list(items) + elif isinstance(current, list): + items = current + if not remaining_path: # If * is the last part, return all items + return items + else: # Not a collection + raise AttributeError( + f"Cannot use wildcard on non-collection type: {type(current)}" + ) + + # Apply remaining path to each item in the collection + results = [] + for item in items: + try: + # Recursively get values from each item + value = self._get_value_by_path(item, remaining_path) + if isinstance(value, list): # Flatten nested lists + results.extend(value) + else: + results.append(value) + except (AttributeError, KeyError): + # Skip items that don't have the attribute + pass + return results + + # Regular attribute/key access + elif isinstance(current, dict): + current = current[part] + else: + current = getattr(current, part) + + return current diff --git a/build/lib/dimos/manipulation/manipulation_interface.py b/build/lib/dimos/manipulation/manipulation_interface.py new file mode 100644 index 0000000000..68d3924a99 --- /dev/null +++ b/build/lib/dimos/manipulation/manipulation_interface.py @@ -0,0 +1,292 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ManipulationInterface provides a unified interface for accessing manipulation history. + +This module defines the ManipulationInterface class, which serves as an access point +for the robot's manipulation history, agent-generated constraints, and manipulation +metadata streams. +""" + +from typing import Dict, List, Optional, Any, Tuple, Union +from dataclasses import dataclass +import os +import time +from datetime import datetime +from reactivex.disposable import Disposable +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.types.manipulation import ( + AbstractConstraint, + TranslationConstraint, + RotationConstraint, + ForceConstraint, + ManipulationTaskConstraint, + ManipulationTask, + ManipulationMetadata, + ObjectData, +) +from dimos.manipulation.manipulation_history import ( + ManipulationHistory, + ManipulationHistoryEntry, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.manipulation_interface") + + +class ManipulationInterface: + """ + Interface for accessing and managing robot manipulation data. + + This class provides a unified interface for managing manipulation tasks and constraints. + It maintains a list of constraints generated by the Agent and provides methods to + add and manage manipulation tasks. + """ + + def __init__( + self, + output_dir: str, + new_memory: bool = False, + perception_stream: ObjectDetectionStream = None, + ): + """ + Initialize a new ManipulationInterface instance. + + Args: + output_dir: Directory for storing manipulation data + new_memory: If True, creates a new manipulation history from scratch + perception_stream: ObjectDetectionStream instance for real-time object data + """ + self.output_dir = output_dir + + # Create manipulation history directory + manipulation_dir = os.path.join(output_dir, "manipulation_history") + os.makedirs(manipulation_dir, exist_ok=True) + + # Initialize manipulation history + self.manipulation_history: ManipulationHistory = ManipulationHistory( + output_dir=manipulation_dir, new_memory=new_memory + ) + + # List of constraints generated by the Agent via constraint generation skills + self.agent_constraints: List[AbstractConstraint] = [] + + # Initialize object detection stream and related properties + self.perception_stream = perception_stream + self.latest_objects: List[ObjectData] = [] + self.stream_subscription: Optional[Disposable] = None + + # Set up subscription to perception stream if available + self._setup_perception_subscription() + + logger.info("ManipulationInterface initialized") + + def add_constraint(self, constraint: AbstractConstraint) -> None: + """ + Add a constraint generated by the Agent via a constraint generation skill. + + Args: + constraint: The constraint to add to agent_constraints + """ + self.agent_constraints.append(constraint) + logger.info(f"Added agent constraint: {constraint}") + + def get_constraints(self) -> List[AbstractConstraint]: + """ + Get all constraints generated by the Agent via constraint generation skills. + + Returns: + List of all constraints created by the Agent + """ + return self.agent_constraints + + def get_constraint(self, constraint_id: str) -> Optional[AbstractConstraint]: + """ + Get a specific constraint by its ID. + + Args: + constraint_id: ID of the constraint to retrieve + + Returns: + The matching constraint or None if not found + """ + # Find constraint with matching ID + for constraint in self.agent_constraints: + if constraint.id == constraint_id: + return constraint + + logger.warning(f"Constraint with ID {constraint_id} not found") + return None + + def add_manipulation_task( + self, task: ManipulationTask, manipulation_response: Optional[str] = None + ) -> None: + """ + Add a manipulation task to ManipulationHistory. + + Args: + task: The ManipulationTask to add + manipulation_response: Optional response from the motion planner/executor + + """ + # Add task to history + self.manipulation_history.add_entry( + task=task, result=None, notes=None, manipulation_response=manipulation_response + ) + + def get_manipulation_task(self, task_id: str) -> Optional[ManipulationTask]: + """ + Get a manipulation task by its ID. + + Args: + task_id: ID of the task to retrieve + + Returns: + The task object or None if not found + """ + return self.history.get_manipulation_task(task_id) + + def get_all_manipulation_tasks(self) -> List[ManipulationTask]: + """ + Get all manipulation tasks. + + Returns: + List of all manipulation tasks + """ + return self.history.get_all_manipulation_tasks() + + def update_task_status( + self, task_id: str, status: str, result: Optional[Dict[str, Any]] = None + ) -> Optional[ManipulationTask]: + """ + Update the status and result of a manipulation task. + + Args: + task_id: ID of the task to update + status: New status for the task (e.g., 'completed', 'failed') + result: Optional dictionary with result data + + Returns: + The updated task or None if task not found + """ + return self.history.update_task_status(task_id, status, result) + + # === Perception stream methods === + + def _setup_perception_subscription(self): + """ + Set up subscription to perception stream if available. + """ + if self.perception_stream: + # Subscribe to the stream and update latest_objects + self.stream_subscription = self.perception_stream.get_stream().subscribe( + on_next=self._update_latest_objects, + on_error=lambda e: logger.error(f"Error in perception stream: {e}"), + ) + logger.info("Subscribed to perception stream") + + def _update_latest_objects(self, data): + """ + Update the latest detected objects. + + Args: + data: Data from the object detection stream + """ + if "objects" in data: + self.latest_objects = data["objects"] + + def get_latest_objects(self) -> List[ObjectData]: + """ + Get the latest detected objects from the stream. + + Returns: + List of the most recently detected objects + """ + return self.latest_objects + + def get_object_by_id(self, object_id: int) -> Optional[ObjectData]: + """ + Get a specific object by its tracking ID. + + Args: + object_id: Tracking ID of the object + + Returns: + The object data or None if not found + """ + for obj in self.latest_objects: + if obj["object_id"] == object_id: + return obj + return None + + def get_objects_by_label(self, label: str) -> List[ObjectData]: + """ + Get all objects with a specific label. + + Args: + label: Class label to filter objects by + + Returns: + List of objects matching the label + """ + return [obj for obj in self.latest_objects if obj["label"] == label] + + def set_perception_stream(self, perception_stream): + """ + Set or update the perception stream. + + Args: + perception_stream: The PerceptionStream instance + """ + # Clean up existing subscription if any + self.cleanup_perception_subscription() + + # Set new stream and subscribe + self.perception_stream = perception_stream + self._setup_perception_subscription() + + def cleanup_perception_subscription(self): + """ + Clean up the stream subscription. + """ + if self.stream_subscription: + self.stream_subscription.dispose() + self.stream_subscription = None + + # === Utility methods === + + def clear_history(self) -> None: + """ + Clear all manipulation history data and agent constraints. + """ + self.manipulation_history.clear() + self.agent_constraints.clear() + logger.info("Cleared manipulation history and agent constraints") + + def __str__(self) -> str: + """ + String representation of the manipulation interface. + + Returns: + String representation with key stats + """ + has_stream = self.perception_stream is not None + return f"ManipulationInterface(history={self.manipulation_history}, agent_constraints={len(self.agent_constraints)}, perception_stream={has_stream}, detected_objects={len(self.latest_objects)})" + + def __del__(self): + """ + Clean up resources on deletion. + """ + self.cleanup_perception_subscription() diff --git a/build/lib/dimos/manipulation/test_manipulation_history.py b/build/lib/dimos/manipulation/test_manipulation_history.py new file mode 100644 index 0000000000..239a04a86f --- /dev/null +++ b/build/lib/dimos/manipulation/test_manipulation_history.py @@ -0,0 +1,461 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import tempfile +import pytest +from typing import Dict, List, Optional, Any, Tuple + +from dimos.manipulation.manipulation_history import ManipulationHistory, ManipulationHistoryEntry +from dimos.types.manipulation import ( + ManipulationTask, + AbstractConstraint, + TranslationConstraint, + RotationConstraint, + ForceConstraint, + ManipulationTaskConstraint, + ManipulationMetadata, +) +from dimos.types.vector import Vector + + +@pytest.fixture +def sample_task(): + """Create a sample manipulation task for testing.""" + return ManipulationTask( + description="Pick up the cup", + target_object="cup", + target_point=(100, 200), + task_id="task1", + metadata={ + "timestamp": time.time(), + "objects": { + "cup1": { + "object_id": 1, + "label": "cup", + "confidence": 0.95, + "position": {"x": 1.5, "y": 2.0, "z": 0.5}, + }, + "table1": { + "object_id": 2, + "label": "table", + "confidence": 0.98, + "position": {"x": 0.0, "y": 0.0, "z": 0.0}, + }, + }, + }, + ) + + +@pytest.fixture +def sample_task_with_constraints(): + """Create a sample manipulation task with constraints for testing.""" + task = ManipulationTask( + description="Rotate the bottle", + target_object="bottle", + target_point=(150, 250), + task_id="task2", + metadata={ + "timestamp": time.time(), + "objects": { + "bottle1": { + "object_id": 3, + "label": "bottle", + "confidence": 0.92, + "position": {"x": 2.5, "y": 1.0, "z": 0.3}, + } + }, + }, + ) + + # Add rich translation constraint + translation_constraint = TranslationConstraint( + translation_axis="y", + reference_point=Vector(2.5, 1.0, 0.3), + bounds_min=Vector(2.0, 0.5, 0.3), + bounds_max=Vector(3.0, 1.5, 0.3), + target_point=Vector(2.7, 1.2, 0.3), + description="Constrained translation along Y-axis only", + ) + task.add_constraint(translation_constraint) + + # Add rich rotation constraint + rotation_constraint = RotationConstraint( + rotation_axis="roll", + start_angle=Vector(0, 0, 0), + end_angle=Vector(90, 0, 0), + pivot_point=Vector(2.5, 1.0, 0.3), + secondary_pivot_point=Vector(2.5, 1.0, 0.5), + description="Constrained rotation around X-axis (roll only)", + ) + task.add_constraint(rotation_constraint) + + # Add force constraint + force_constraint = ForceConstraint( + min_force=2.0, + max_force=5.0, + force_direction=Vector(0, 0, -1), + description="Apply moderate downward force during manipulation", + ) + task.add_constraint(force_constraint) + + return task + + +@pytest.fixture +def temp_output_dir(): + """Create a temporary directory for testing history saving/loading.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +@pytest.fixture +def populated_history(sample_task, sample_task_with_constraints): + """Create a populated history with multiple entries for testing.""" + history = ManipulationHistory() + + # Add first entry + entry1 = ManipulationHistoryEntry( + task=sample_task, + result={"status": "success", "execution_time": 2.5}, + manipulation_response="Successfully picked up the cup", + ) + history.add_entry(entry1) + + # Add second entry + entry2 = ManipulationHistoryEntry( + task=sample_task_with_constraints, + result={"status": "failure", "error": "Collision detected"}, + manipulation_response="Failed to rotate the bottle due to collision", + ) + history.add_entry(entry2) + + return history + + +def test_manipulation_history_init(): + """Test initialization of ManipulationHistory.""" + # Default initialization + history = ManipulationHistory() + assert len(history) == 0 + assert str(history) == "ManipulationHistory(empty)" + + # With output directory + with tempfile.TemporaryDirectory() as temp_dir: + history = ManipulationHistory(output_dir=temp_dir, new_memory=True) + assert len(history) == 0 + assert os.path.exists(temp_dir) + + +def test_manipulation_history_add_entry(sample_task): + """Test adding entries to ManipulationHistory.""" + history = ManipulationHistory() + + # Create and add entry + entry = ManipulationHistoryEntry( + task=sample_task, result={"status": "success"}, manipulation_response="Task completed" + ) + history.add_entry(entry) + + assert len(history) == 1 + assert history.get_entry_by_index(0) == entry + + +def test_manipulation_history_create_task_entry(sample_task): + """Test creating a task entry directly.""" + history = ManipulationHistory() + + entry = history.create_task_entry( + task=sample_task, result={"status": "success"}, agent_response="Task completed" + ) + + assert len(history) == 1 + assert entry.task == sample_task + assert entry.result["status"] == "success" + assert entry.manipulation_response == "Task completed" + + +def test_manipulation_history_save_load(temp_output_dir, sample_task): + """Test saving and loading history from disk.""" + # Create history and add entry + history = ManipulationHistory(output_dir=temp_output_dir) + entry = history.create_task_entry( + task=sample_task, result={"status": "success"}, agent_response="Task completed" + ) + + # Check that files were created + pickle_path = os.path.join(temp_output_dir, "manipulation_history.pickle") + json_path = os.path.join(temp_output_dir, "manipulation_history.json") + assert os.path.exists(pickle_path) + assert os.path.exists(json_path) + + # Create new history that loads from the saved files + loaded_history = ManipulationHistory(output_dir=temp_output_dir) + assert len(loaded_history) == 1 + assert loaded_history.get_entry_by_index(0).task.description == sample_task.description + + +def test_manipulation_history_clear(populated_history): + """Test clearing the history.""" + assert len(populated_history) > 0 + + populated_history.clear() + assert len(populated_history) == 0 + assert str(populated_history) == "ManipulationHistory(empty)" + + +def test_manipulation_history_get_methods(populated_history): + """Test various getter methods of ManipulationHistory.""" + # get_all_entries + entries = populated_history.get_all_entries() + assert len(entries) == 2 + + # get_entry_by_index + entry = populated_history.get_entry_by_index(0) + assert entry.task.task_id == "task1" + + # Out of bounds index + assert populated_history.get_entry_by_index(100) is None + + # get_entries_by_timerange + start_time = time.time() - 3600 # 1 hour ago + end_time = time.time() + 3600 # 1 hour from now + entries = populated_history.get_entries_by_timerange(start_time, end_time) + assert len(entries) == 2 + + # get_entries_by_object + cup_entries = populated_history.get_entries_by_object("cup") + assert len(cup_entries) == 1 + assert cup_entries[0].task.task_id == "task1" + + bottle_entries = populated_history.get_entries_by_object("bottle") + assert len(bottle_entries) == 1 + assert bottle_entries[0].task.task_id == "task2" + + +def test_manipulation_history_search_basic(populated_history): + """Test basic search functionality.""" + # Search by exact match on top-level fields + results = populated_history.search(timestamp=populated_history.get_entry_by_index(0).timestamp) + assert len(results) == 1 + + # Search by task fields + results = populated_history.search(**{"task.task_id": "task1"}) + assert len(results) == 1 + assert results[0].task.target_object == "cup" + + # Search by result fields + results = populated_history.search(**{"result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by manipulation_response (substring match for strings) + results = populated_history.search(manipulation_response="picked up") + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_nested(populated_history): + """Test search with nested field paths.""" + # Search by nested metadata fields + results = populated_history.search( + **{ + "task.metadata.timestamp": populated_history.get_entry_by_index(0).task.metadata[ + "timestamp" + ] + } + ) + assert len(results) == 1 + + # Search by nested object fields + results = populated_history.search(**{"task.metadata.objects.cup1.label": "cup"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by position values + results = populated_history.search(**{"task.metadata.objects.cup1.position.x": 1.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_wildcards(populated_history): + """Test search with wildcard patterns.""" + # Search for any object with label "cup" + results = populated_history.search(**{"task.metadata.objects.*.label": "cup"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for any object with confidence > 0.95 + results = populated_history.search(**{"task.metadata.objects.*.confidence": 0.98}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for any object position with x=2.5 + results = populated_history.search(**{"task.metadata.objects.*.position.x": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_constraints(populated_history): + """Test search by constraint properties.""" + # Find entries with any TranslationConstraint with y-axis + results = populated_history.search(**{"task.constraints.*.translation_axis": "y"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Find entries with any RotationConstraint with roll axis + results = populated_history.search(**{"task.constraints.*.rotation_axis": "roll"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_string_contains(populated_history): + """Test string contains searching.""" + # Basic string contains + results = populated_history.search(**{"task.description": "Pick"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Nested string contains + results = populated_history.search(manipulation_response="collision") + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_multiple_criteria(populated_history): + """Test search with multiple criteria.""" + # Multiple criteria - all must match + results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Multiple criteria with no matches + results = populated_history.search(**{"task.target_object": "cup", "result.status": "failure"}) + assert len(results) == 0 + + # Combination of direct and wildcard paths + results = populated_history.search( + **{"task.target_object": "bottle", "task.metadata.objects.*.position.z": 0.3} + ) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_nonexistent_fields(populated_history): + """Test search with fields that don't exist.""" + # Search by nonexistent field + results = populated_history.search(nonexistent_field="value") + assert len(results) == 0 + + # Search by nonexistent nested field + results = populated_history.search(**{"task.nonexistent_field": "value"}) + assert len(results) == 0 + + # Search by nonexistent object + results = populated_history.search(**{"task.metadata.objects.nonexistent_object": "value"}) + assert len(results) == 0 + + +def test_manipulation_history_search_timestamp_ranges(populated_history): + """Test searching by timestamp ranges.""" + # Get reference timestamps + entry1_time = populated_history.get_entry_by_index(0).task.metadata["timestamp"] + entry2_time = populated_history.get_entry_by_index(1).task.metadata["timestamp"] + mid_time = (entry1_time + entry2_time) / 2 + + # Search for timestamps before second entry + results = populated_history.search(**{"task.metadata.timestamp": ("<", entry2_time)}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for timestamps after first entry + results = populated_history.search(**{"task.metadata.timestamp": (">", entry1_time)}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search within a time window using >= and <= + results = populated_history.search(**{"task.metadata.timestamp": (">=", mid_time - 1800)}) + assert len(results) == 2 + assert results[0].task.task_id == "task1" + assert results[1].task.task_id == "task2" + + +def test_manipulation_history_search_vector_fields(populated_history): + """Test searching by vector components in constraints.""" + # Search by reference point components + results = populated_history.search(**{"task.constraints.*.reference_point.x": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by target point components + results = populated_history.search(**{"task.constraints.*.target_point.z": 0.3}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by rotation angles + results = populated_history.search(**{"task.constraints.*.end_angle.x": 90}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_execution_details(populated_history): + """Test searching by execution time and error patterns.""" + # Search by execution time + results = populated_history.search(**{"result.execution_time": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by error message pattern + results = populated_history.search(**{"result.error": "Collision"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by status + results = populated_history.search(**{"result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_multiple_criteria(populated_history): + """Test search with multiple criteria.""" + # Multiple criteria - all must match + results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Multiple criteria with no matches + results = populated_history.search(**{"task.target_object": "cup", "result.status": "failure"}) + assert len(results) == 0 + + # Combination of direct and wildcard paths + results = populated_history.search( + **{"task.target_object": "bottle", "task.metadata.objects.*.position.z": 0.3} + ) + assert len(results) == 1 + assert results[0].task.task_id == "task2" diff --git a/build/lib/dimos/models/__init__.py b/build/lib/dimos/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/models/depth/__init__.py b/build/lib/dimos/models/depth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/models/depth/metric3d.py b/build/lib/dimos/models/depth/metric3d.py new file mode 100644 index 0000000000..58cb63f640 --- /dev/null +++ b/build/lib/dimos/models/depth/metric3d.py @@ -0,0 +1,173 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from PIL import Image +import cv2 +import numpy as np + +# May need to add this back for import to work +# external_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'external', 'Metric3D')) +# if external_path not in sys.path: +# sys.path.append(external_path) + + +class Metric3D: + def __init__(self, gt_depth_scale=256.0): + # self.conf = get_config("zoedepth", "infer") + # self.depth_model = build_model(self.conf) + self.depth_model = torch.hub.load( + "yvanyin/metric3d", "metric3d_vit_small", pretrain=True + ).cuda() + if torch.cuda.device_count() > 1: + print(f"Using {torch.cuda.device_count()} GPUs!") + # self.depth_model = torch.nn.DataParallel(self.depth_model) + self.depth_model.eval() + + self.intrinsic = [707.0493, 707.0493, 604.0814, 180.5066] + self.intrinsic_scaled = None + self.gt_depth_scale = gt_depth_scale # And this + self.pad_info = None + self.rgb_origin = None + + """ + Input: Single image in RGB format + Output: Depth map + """ + + def update_intrinsic(self, intrinsic): + """ + Update the intrinsic parameters dynamically. + Ensure that the input intrinsic is valid. + """ + if len(intrinsic) != 4: + raise ValueError("Intrinsic must be a list or tuple with 4 values: [fx, fy, cx, cy]") + self.intrinsic = intrinsic + print(f"Intrinsics updated to: {self.intrinsic}") + + def infer_depth(self, img, debug=False): + if debug: + print(f"Input image: {img}") + try: + if isinstance(img, str): + print(f"Image type string: {type(img)}") + self.rgb_origin = cv2.imread(img)[:, :, ::-1] + else: + # print(f"Image type not string: {type(img)}, cv2 conversion assumed to be handled. If not, this will throw an error") + self.rgb_origin = img + except Exception as e: + print(f"Error parsing into infer_depth: {e}") + + img = self.rescale_input(img, self.rgb_origin) + + with torch.no_grad(): + pred_depth, confidence, output_dict = self.depth_model.inference({"input": img}) + + # Convert to PIL format + depth_image = self.unpad_transform_depth(pred_depth) + out_16bit_numpy = (depth_image.squeeze().cpu().numpy() * self.gt_depth_scale).astype( + np.uint16 + ) + depth_map_pil = Image.fromarray(out_16bit_numpy) + + return depth_map_pil + + def save_depth(self, pred_depth): + # Save the depth map to a file + pred_depth_np = pred_depth.cpu().numpy() + output_depth_file = "output_depth_map.png" + cv2.imwrite(output_depth_file, pred_depth_np) + print(f"Depth map saved to {output_depth_file}") + + # Adjusts input size to fit pretrained ViT model + def rescale_input(self, rgb, rgb_origin): + #### ajust input size to fit pretrained model + # keep ratio resize + input_size = (616, 1064) # for vit model + # input_size = (544, 1216) # for convnext model + h, w = rgb_origin.shape[:2] + scale = min(input_size[0] / h, input_size[1] / w) + rgb = cv2.resize( + rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR + ) + # remember to scale intrinsic, hold depth + self.intrinsic_scaled = [ + self.intrinsic[0] * scale, + self.intrinsic[1] * scale, + self.intrinsic[2] * scale, + self.intrinsic[3] * scale, + ] + # padding to input_size + padding = [123.675, 116.28, 103.53] + h, w = rgb.shape[:2] + pad_h = input_size[0] - h + pad_w = input_size[1] - w + pad_h_half = pad_h // 2 + pad_w_half = pad_w // 2 + rgb = cv2.copyMakeBorder( + rgb, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=padding, + ) + self.pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] + + #### normalize + mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] + std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None] + rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float() + rgb = torch.div((rgb - mean), std) + rgb = rgb[None, :, :, :].cuda() + return rgb + + def unpad_transform_depth(self, pred_depth): + # un pad + pred_depth = pred_depth.squeeze() + pred_depth = pred_depth[ + self.pad_info[0] : pred_depth.shape[0] - self.pad_info[1], + self.pad_info[2] : pred_depth.shape[1] - self.pad_info[3], + ] + + # upsample to original size + pred_depth = torch.nn.functional.interpolate( + pred_depth[None, None, :, :], self.rgb_origin.shape[:2], mode="bilinear" + ).squeeze() + ###################### canonical camera space ###################### + + #### de-canonical transform + canonical_to_real_scale = ( + self.intrinsic_scaled[0] / 1000.0 + ) # 1000.0 is the focal length of canonical camera + pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric + pred_depth = torch.clamp(pred_depth, 0, 1000) + return pred_depth + + """Set new intrinsic value.""" + + def update_intrinsic(self, intrinsic): + self.intrinsic = intrinsic + + def eval_predicted_depth(self, depth_file, pred_depth): + if depth_file is not None: + gt_depth = cv2.imread(depth_file, -1) + gt_depth = gt_depth / self.gt_depth_scale + gt_depth = torch.from_numpy(gt_depth).float().cuda() + assert gt_depth.shape == pred_depth.shape + + mask = gt_depth > 1e-8 + abs_rel_err = (torch.abs(pred_depth[mask] - gt_depth[mask]) / gt_depth[mask]).mean() + print("abs_rel_err:", abs_rel_err.item()) diff --git a/build/lib/dimos/models/labels/__init__.py b/build/lib/dimos/models/labels/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/models/labels/llava-34b.py b/build/lib/dimos/models/labels/llava-34b.py new file mode 100644 index 0000000000..c59a5c8aa9 --- /dev/null +++ b/build/lib/dimos/models/labels/llava-34b.py @@ -0,0 +1,92 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +# llava v1.6 +from llama_cpp import Llama +from llama_cpp.llama_chat_format import Llava15ChatHandler + +from vqasynth.datasets.utils import image_to_base64_data_uri + + +class Llava: + def __init__( + self, + mmproj=f"{os.getcwd()}/models/mmproj-model-f16.gguf", + model_path=f"{os.getcwd()}/models/llava-v1.6-34b.Q4_K_M.gguf", + gpu=True, + ): + chat_handler = Llava15ChatHandler(clip_model_path=mmproj, verbose=True) + n_gpu_layers = 0 + if gpu: + n_gpu_layers = -1 + self.llm = Llama( + model_path=model_path, + chat_handler=chat_handler, + n_ctx=2048, + logits_all=True, + n_gpu_layers=n_gpu_layers, + ) + + def run_inference(self, image, prompt, return_json=True): + data_uri = image_to_base64_data_uri(image) + res = self.llm.create_chat_completion( + messages=[ + { + "role": "system", + "content": "You are an assistant who perfectly describes images.", + }, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_uri}}, + {"type": "text", "text": prompt}, + ], + }, + ] + ) + if return_json: + return list( + set( + self.extract_descriptions_from_incomplete_json( + res["choices"][0]["message"]["content"] + ) + ) + ) + + return res["choices"][0]["message"]["content"] + + def extract_descriptions_from_incomplete_json(self, json_like_str): + last_object_idx = json_like_str.rfind(',"object') + + if last_object_idx != -1: + json_str = json_like_str[:last_object_idx] + "}" + else: + json_str = json_like_str.strip() + if not json_str.endswith("}"): + json_str += "}" + + try: + json_obj = json.loads(json_str) + descriptions = [ + details["description"].replace(".", "") + for key, details in json_obj.items() + if "description" in details + ] + + return descriptions + except json.JSONDecodeError as e: + raise ValueError(f"Error parsing JSON: {e}") diff --git a/build/lib/dimos/models/manipulation/__init__.py b/build/lib/dimos/models/manipulation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/models/pointcloud/__init__.py b/build/lib/dimos/models/pointcloud/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/models/pointcloud/pointcloud_utils.py b/build/lib/dimos/models/pointcloud/pointcloud_utils.py new file mode 100644 index 0000000000..c0951f44f2 --- /dev/null +++ b/build/lib/dimos/models/pointcloud/pointcloud_utils.py @@ -0,0 +1,214 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import open3d as o3d +import random + + +def save_pointcloud(pcd, file_path): + """ + Save a point cloud to a file using Open3D. + """ + o3d.io.write_point_cloud(file_path, pcd) + + +def restore_pointclouds(pointcloud_paths): + restored_pointclouds = [] + for path in pointcloud_paths: + restored_pointclouds.append(o3d.io.read_point_cloud(path)) + return restored_pointclouds + + +def create_point_cloud_from_rgbd(rgb_image, depth_image, intrinsic_parameters): + rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth( + o3d.geometry.Image(rgb_image), + o3d.geometry.Image(depth_image), + depth_scale=0.125, # 1000.0, + depth_trunc=10.0, # 10.0, + convert_rgb_to_intensity=False, + ) + intrinsic = o3d.camera.PinholeCameraIntrinsic() + intrinsic.set_intrinsics( + intrinsic_parameters["width"], + intrinsic_parameters["height"], + intrinsic_parameters["fx"], + intrinsic_parameters["fy"], + intrinsic_parameters["cx"], + intrinsic_parameters["cy"], + ) + pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic) + return pcd + + +def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3): + # Segment the largest plane, assumed to be the floor + plane_model, inliers = pcd.segment_plane( + distance_threshold=0.01, ransac_n=3, num_iterations=1000 + ) + + canonicalized = False + if len(inliers) / len(pcd.points) > canonicalize_threshold: + canonicalized = True + + # Ensure the plane normal points upwards + if np.dot(plane_model[:3], [0, 1, 0]) < 0: + plane_model = -plane_model + + # Normalize the plane normal vector + normal = plane_model[:3] / np.linalg.norm(plane_model[:3]) + + # Compute the new basis vectors + new_y = normal + new_x = np.cross(new_y, [0, 0, -1]) + new_x /= np.linalg.norm(new_x) + new_z = np.cross(new_x, new_y) + + # Create the transformation matrix + transformation = np.identity(4) + transformation[:3, :3] = np.vstack((new_x, new_y, new_z)).T + transformation[:3, 3] = -np.dot(transformation[:3, :3], pcd.points[inliers[0]]) + + # Apply the transformation + pcd.transform(transformation) + + # Additional 180-degree rotation around the Z-axis + rotation_z_180 = np.array( + [[np.cos(np.pi), -np.sin(np.pi), 0], [np.sin(np.pi), np.cos(np.pi), 0], [0, 0, 1]] + ) + pcd.rotate(rotation_z_180, center=(0, 0, 0)) + + return pcd, canonicalized, transformation + else: + return pcd, canonicalized, None + + +# Distance calculations +def human_like_distance(distance_meters): + # Define the choices with units included, focusing on the 0.1 to 10 meters range + if distance_meters < 1: # For distances less than 1 meter + choices = [ + ( + round(distance_meters * 100, 2), + "centimeters", + 0.2, + ), # Centimeters for very small distances + ( + round(distance_meters * 39.3701, 2), + "inches", + 0.8, + ), # Inches for the majority of cases under 1 meter + ] + elif distance_meters < 3: # For distances less than 3 meters + choices = [ + (round(distance_meters, 2), "meters", 0.5), + ( + round(distance_meters * 3.28084, 2), + "feet", + 0.5, + ), # Feet as a common unit within indoor spaces + ] + else: # For distances from 3 up to 10 meters + choices = [ + ( + round(distance_meters, 2), + "meters", + 0.7, + ), # Meters for clarity and international understanding + ( + round(distance_meters * 3.28084, 2), + "feet", + 0.3, + ), # Feet for additional context + ] + + # Normalize probabilities and make a selection + total_probability = sum(prob for _, _, prob in choices) + cumulative_distribution = [] + cumulative_sum = 0 + for value, unit, probability in choices: + cumulative_sum += probability / total_probability # Normalize probabilities + cumulative_distribution.append((cumulative_sum, value, unit)) + + # Randomly choose based on the cumulative distribution + r = random.random() + for cumulative_prob, value, unit in cumulative_distribution: + if r < cumulative_prob: + return f"{value} {unit}" + + # Fallback to the last choice if something goes wrong + return f"{choices[-1][0]} {choices[-1][1]}" + + +def calculate_distances_between_point_clouds(A, B): + dist_pcd1_to_pcd2 = np.asarray(A.compute_point_cloud_distance(B)) + dist_pcd2_to_pcd1 = np.asarray(B.compute_point_cloud_distance(A)) + combined_distances = np.concatenate((dist_pcd1_to_pcd2, dist_pcd2_to_pcd1)) + avg_dist = np.mean(combined_distances) + return human_like_distance(avg_dist) + + +def calculate_centroid(pcd): + """Calculate the centroid of a point cloud.""" + points = np.asarray(pcd.points) + centroid = np.mean(points, axis=0) + return centroid + + +def calculate_relative_positions(centroids): + """Calculate the relative positions between centroids of point clouds.""" + num_centroids = len(centroids) + relative_positions_info = [] + + for i in range(num_centroids): + for j in range(i + 1, num_centroids): + relative_vector = centroids[j] - centroids[i] + + distance = np.linalg.norm(relative_vector) + relative_positions_info.append( + {"pcd_pair": (i, j), "relative_vector": relative_vector, "distance": distance} + ) + + return relative_positions_info + + +def get_bounding_box_height(pcd): + """ + Compute the height of the bounding box for a given point cloud. + + Parameters: + pcd (open3d.geometry.PointCloud): The input point cloud. + + Returns: + float: The height of the bounding box. + """ + aabb = pcd.get_axis_aligned_bounding_box() + return aabb.get_extent()[1] # Assuming the Y-axis is the up-direction + + +def compare_bounding_box_height(pcd_i, pcd_j): + """ + Compare the bounding box heights of two point clouds. + + Parameters: + pcd_i (open3d.geometry.PointCloud): The first point cloud. + pcd_j (open3d.geometry.PointCloud): The second point cloud. + + Returns: + bool: True if the bounding box of pcd_i is taller than that of pcd_j, False otherwise. + """ + height_i = get_bounding_box_height(pcd_i) + height_j = get_bounding_box_height(pcd_j) + + return height_i > height_j diff --git a/build/lib/dimos/models/segmentation/__init__.py b/build/lib/dimos/models/segmentation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/models/segmentation/clipseg.py b/build/lib/dimos/models/segmentation/clipseg.py new file mode 100644 index 0000000000..043cd194b0 --- /dev/null +++ b/build/lib/dimos/models/segmentation/clipseg.py @@ -0,0 +1,32 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoProcessor, CLIPSegForImageSegmentation + + +class CLIPSeg: + def __init__(self, model_name="CIDAS/clipseg-rd64-refined"): + self.clipseg_processor = AutoProcessor.from_pretrained(model_name) + self.clipseg_model = CLIPSegForImageSegmentation.from_pretrained(model_name) + + def run_inference(self, image, text_descriptions): + inputs = self.clipseg_processor( + text=text_descriptions, + images=[image] * len(text_descriptions), + padding=True, + return_tensors="pt", + ) + outputs = self.clipseg_model(**inputs) + logits = outputs.logits + return logits.detach().unsqueeze(1) diff --git a/build/lib/dimos/models/segmentation/sam.py b/build/lib/dimos/models/segmentation/sam.py new file mode 100644 index 0000000000..1efb07c484 --- /dev/null +++ b/build/lib/dimos/models/segmentation/sam.py @@ -0,0 +1,35 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import SamModel, SamProcessor +import torch + + +class SAM: + def __init__(self, model_name="facebook/sam-vit-huge", device="cuda"): + self.device = device + self.sam_model = SamModel.from_pretrained(model_name).to(self.device) + self.sam_processor = SamProcessor.from_pretrained(model_name) + + def run_inference_from_points(self, image, points): + sam_inputs = self.sam_processor(image, input_points=points, return_tensors="pt").to( + self.device + ) + with torch.no_grad(): + sam_outputs = self.sam_model(**sam_inputs) + return self.sam_processor.image_processor.post_process_masks( + sam_outputs.pred_masks.cpu(), + sam_inputs["original_sizes"].cpu(), + sam_inputs["reshaped_input_sizes"].cpu(), + ) diff --git a/build/lib/dimos/models/segmentation/segment_utils.py b/build/lib/dimos/models/segmentation/segment_utils.py new file mode 100644 index 0000000000..9808f5d4e4 --- /dev/null +++ b/build/lib/dimos/models/segmentation/segment_utils.py @@ -0,0 +1,73 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + + +def find_medoid_and_closest_points(points, num_closest=5): + """ + Find the medoid from a collection of points and the closest points to the medoid. + + Parameters: + points (np.array): A numpy array of shape (N, D) where N is the number of points and D is the dimensionality. + num_closest (int): Number of closest points to return. + + Returns: + np.array: The medoid point. + np.array: The closest points to the medoid. + """ + distances = np.sqrt(((points[:, np.newaxis, :] - points[np.newaxis, :, :]) ** 2).sum(axis=-1)) + distance_sums = distances.sum(axis=1) + medoid_idx = np.argmin(distance_sums) + medoid = points[medoid_idx] + sorted_indices = np.argsort(distances[medoid_idx]) + closest_indices = sorted_indices[1 : num_closest + 1] + return medoid, points[closest_indices] + + +def sample_points_from_heatmap(heatmap, original_size, num_points=5, percentile=0.95): + """ + Sample points from the given heatmap, focusing on areas with higher values. + """ + width, height = original_size + threshold = np.percentile(heatmap.numpy(), percentile) + masked_heatmap = torch.where(heatmap > threshold, heatmap, torch.tensor(0.0)) + probabilities = torch.softmax(masked_heatmap.flatten(), dim=0) + + attn = torch.sigmoid(heatmap) + w = attn.shape[0] + sampled_indices = torch.multinomial( + torch.tensor(probabilities.ravel()), num_points, replacement=True + ) + + sampled_coords = np.array(np.unravel_index(sampled_indices, attn.shape)).T + medoid, sampled_coords = find_medoid_and_closest_points(sampled_coords) + pts = [] + for pt in sampled_coords.tolist(): + x, y = pt + x = height * x / w + y = width * y / w + pts.append([y, x]) + return pts + + +def apply_mask_to_image(image, mask): + """ + Apply a binary mask to an image. The mask should be a binary array where the regions to keep are True. + """ + masked_image = image.copy() + for c in range(masked_image.shape[2]): + masked_image[:, :, c] = masked_image[:, :, c] * mask + return masked_image diff --git a/build/lib/dimos/msgs/__init__.py b/build/lib/dimos/msgs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/msgs/geometry_msgs/Pose.py b/build/lib/dimos/msgs/geometry_msgs/Pose.py new file mode 100644 index 0000000000..74b534fefa --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/Pose.py @@ -0,0 +1,181 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +import traceback +from io import BytesIO +from typing import BinaryIO, TypeAlias + +from dimos_lcm.geometry_msgs import Pose as LCMPose +from plum import dispatch + +from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable + +# Types that can be converted to/from Pose +PoseConvertable: TypeAlias = ( + tuple[VectorConvertable, QuaternionConvertable] + | LCMPose + | dict[str, VectorConvertable | QuaternionConvertable] +) + + +class Pose(LCMPose): + position: Vector3 + orientation: Quaternion + msg_name = "geometry_msgs.Pose" + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO): + if not hasattr(data, "read"): + data = BytesIO(data) + if data.read(8) != cls._get_packed_fingerprint(): + traceback.print_exc() + raise ValueError("Decode error") + return cls._lcm_decode_one(data) + + @classmethod + def _lcm_decode_one(cls, buf): + return cls(Vector3._decode_one(buf), Quaternion._decode_one(buf)) + + def lcm_encode(self) -> bytes: + return super().encode() + + @dispatch + def __init__(self) -> None: + """Initialize a pose at origin with identity orientation.""" + self.position = Vector3(0.0, 0.0, 0.0) + self.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) + + @dispatch + def __init__(self, x: int | float, y: int | float, z: int | float) -> None: + """Initialize a pose with position and identity orientation.""" + self.position = Vector3(x, y, z) + self.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) + + @dispatch + def __init__( + self, + x: int | float, + y: int | float, + z: int | float, + qx: int | float, + qy: int | float, + qz: int | float, + qw: int | float, + ) -> None: + """Initialize a pose with position and orientation.""" + self.position = Vector3(x, y, z) + self.orientation = Quaternion(qx, qy, qz, qw) + + @dispatch + def __init__( + self, + position: VectorConvertable | Vector3 = [0, 0, 0], + orientation: QuaternionConvertable | Quaternion = [0, 0, 0, 1], + ) -> None: + """Initialize a pose with position and orientation.""" + self.position = Vector3(position) + self.orientation = Quaternion(orientation) + + @dispatch + def __init__(self, pose_tuple: tuple[VectorConvertable, QuaternionConvertable]) -> None: + """Initialize from a tuple of (position, orientation).""" + self.position = Vector3(pose_tuple[0]) + self.orientation = Quaternion(pose_tuple[1]) + + @dispatch + def __init__(self, pose_dict: dict[str, VectorConvertable | QuaternionConvertable]) -> None: + """Initialize from a dictionary with 'position' and 'orientation' keys.""" + self.position = Vector3(pose_dict["position"]) + self.orientation = Quaternion(pose_dict["orientation"]) + + @dispatch + def __init__(self, pose: Pose) -> None: + """Initialize from another Pose (copy constructor).""" + self.position = Vector3(pose.position) + self.orientation = Quaternion(pose.orientation) + + @dispatch + def __init__(self, lcm_pose: LCMPose) -> None: + """Initialize from an LCM Pose.""" + self.position = Vector3(lcm_pose.position.x, lcm_pose.position.y, lcm_pose.position.z) + self.orientation = Quaternion( + lcm_pose.orientation.x, + lcm_pose.orientation.y, + lcm_pose.orientation.z, + lcm_pose.orientation.w, + ) + + @property + def x(self) -> float: + """X coordinate of position.""" + return self.position.x + + @property + def y(self) -> float: + """Y coordinate of position.""" + return self.position.y + + @property + def z(self) -> float: + """Z coordinate of position.""" + return self.position.z + + @property + def roll(self) -> float: + """Roll angle in radians.""" + return self.orientation.to_euler().roll + + @property + def pitch(self) -> float: + """Pitch angle in radians.""" + return self.orientation.to_euler().pitch + + @property + def yaw(self) -> float: + """Yaw angle in radians.""" + return self.orientation.to_euler().yaw + + def __repr__(self) -> str: + return f"Pose(position={self.position!r}, orientation={self.orientation!r})" + + def __str__(self) -> str: + return ( + f"Pose(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}])" + ) + + def __eq__(self, other) -> bool: + """Check if two poses are equal.""" + if not isinstance(other, Pose): + return False + return self.position == other.position and self.orientation == other.orientation + + +@dispatch +def to_pose(value: "Pose") -> Pose: + """Pass through Pose objects.""" + return value + + +@dispatch +def to_pose(value: PoseConvertable | Pose) -> Pose: + """Convert a pose-compatible value to a Pose object.""" + return Pose(value) + + +PoseLike: TypeAlias = PoseConvertable | Pose diff --git a/build/lib/dimos/msgs/geometry_msgs/PoseStamped.py b/build/lib/dimos/msgs/geometry_msgs/PoseStamped.py new file mode 100644 index 0000000000..3871072d32 --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/PoseStamped.py @@ -0,0 +1,76 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +import time +from io import BytesIO +from typing import BinaryIO, TypeAlias + +from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped +from dimos_lcm.std_msgs import Header as LCMHeader +from dimos_lcm.std_msgs import Time as LCMTime +from plum import dispatch + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from Pose +PoseConvertable: TypeAlias = ( + tuple[VectorConvertable, QuaternionConvertable] + | LCMPoseStamped + | dict[str, VectorConvertable | QuaternionConvertable] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class PoseStamped(Pose, Timestamped): + msg_name = "geometry_msgs.PoseStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + def lcm_encode(self) -> bytes: + lcm_mgs = LCMPoseStamped() + lcm_mgs.pose = self + [lcm_mgs.header.stamp.sec, lcm_mgs.header.stamp.sec] = sec_nsec(self.ts) + lcm_mgs.header.frame_id = self.frame_id + return lcm_mgs.encode() + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> PoseStamped: + lcm_msg = LCMPoseStamped.decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + position=[lcm_msg.pose.position.x, lcm_msg.pose.position.y, lcm_msg.pose.position.z], + orientation=[ + lcm_msg.pose.orientation.x, + lcm_msg.pose.orientation.y, + lcm_msg.pose.orientation.z, + lcm_msg.pose.orientation.w, + ], # noqa: E501, + ) diff --git a/build/lib/dimos/msgs/geometry_msgs/Quaternion.py b/build/lib/dimos/msgs/geometry_msgs/Quaternion.py new file mode 100644 index 0000000000..ccb3328510 --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/Quaternion.py @@ -0,0 +1,167 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +from collections.abc import Sequence +from io import BytesIO +from typing import BinaryIO, TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion +from plum import dispatch + +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +# Types that can be converted to/from Quaternion +QuaternionConvertable: TypeAlias = Sequence[int | float] | LCMQuaternion | np.ndarray + + +class Quaternion(LCMQuaternion): + x: float = 0.0 + y: float = 0.0 + z: float = 0.0 + w: float = 1.0 + msg_name = "geometry_msgs.Quaternion" + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO): + if not hasattr(data, "read"): + data = BytesIO(data) + if data.read(8) != cls._get_packed_fingerprint(): + raise ValueError("Decode error") + return cls._lcm_decode_one(data) + + @classmethod + def _lcm_decode_one(cls, buf): + return cls(struct.unpack(">dddd", buf.read(32))) + + def lcm_encode(self): + return super().encode() + + @dispatch + def __init__(self) -> None: ... + + @dispatch + def __init__(self, x: int | float, y: int | float, z: int | float, w: int | float) -> None: + self.x = float(x) + self.y = float(y) + self.z = float(z) + self.w = float(w) + + @dispatch + def __init__(self, sequence: Sequence[int | float] | np.ndarray) -> None: + if isinstance(sequence, np.ndarray): + if sequence.size != 4: + raise ValueError("Quaternion requires exactly 4 components [x, y, z, w]") + else: + if len(sequence) != 4: + raise ValueError("Quaternion requires exactly 4 components [x, y, z, w]") + + self.x = sequence[0] + self.y = sequence[1] + self.z = sequence[2] + self.w = sequence[3] + + @dispatch + def __init__(self, quaternion: "Quaternion") -> None: + """Initialize from another Quaternion (copy constructor).""" + self.x, self.y, self.z, self.w = quaternion.x, quaternion.y, quaternion.z, quaternion.w + + @dispatch + def __init__(self, lcm_quaternion: LCMQuaternion) -> None: + """Initialize from an LCM Quaternion.""" + self.x, self.y, self.z, self.w = ( + lcm_quaternion.x, + lcm_quaternion.y, + lcm_quaternion.z, + lcm_quaternion.w, + ) + + def to_tuple(self) -> tuple[float, float, float, float]: + """Tuple representation of the quaternion (x, y, z, w).""" + return (self.x, self.y, self.z, self.w) + + def to_list(self) -> list[float]: + """List representation of the quaternion (x, y, z, w).""" + return [self.x, self.y, self.z, self.w] + + def to_numpy(self) -> np.ndarray: + """Numpy array representation of the quaternion (x, y, z, w).""" + return np.array([self.x, self.y, self.z, self.w]) + + @property + def euler(self) -> Vector3: + return self.to_euler() + + @property + def radians(self) -> Vector3: + return self.to_euler() + + def to_radians(self) -> Vector3: + """Radians representation of the quaternion (x, y, z, w).""" + return self.to_euler() + + def to_euler(self) -> Vector3: + """Convert quaternion to Euler angles (roll, pitch, yaw) in radians. + + Returns: + Vector3: Euler angles as (roll, pitch, yaw) in radians + """ + # Convert quaternion to Euler angles using ZYX convention (yaw, pitch, roll) + # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles + + # Roll (x-axis rotation) + sinr_cosp = 2 * (self.w * self.x + self.y * self.z) + cosr_cosp = 1 - 2 * (self.x * self.x + self.y * self.y) + roll = np.arctan2(sinr_cosp, cosr_cosp) + + # Pitch (y-axis rotation) + sinp = 2 * (self.w * self.y - self.z * self.x) + if abs(sinp) >= 1: + pitch = np.copysign(np.pi / 2, sinp) # Use 90 degrees if out of range + else: + pitch = np.arcsin(sinp) + + # Yaw (z-axis rotation) + siny_cosp = 2 * (self.w * self.z + self.x * self.y) + cosy_cosp = 1 - 2 * (self.y * self.y + self.z * self.z) + yaw = np.arctan2(siny_cosp, cosy_cosp) + + return Vector3(roll, pitch, yaw) + + def __getitem__(self, idx: int) -> float: + """Allow indexing into quaternion components: 0=x, 1=y, 2=z, 3=w.""" + if idx == 0: + return self.x + elif idx == 1: + return self.y + elif idx == 2: + return self.z + elif idx == 3: + return self.w + else: + raise IndexError(f"Quaternion index {idx} out of range [0-3]") + + def __repr__(self) -> str: + return f"Quaternion({self.x:.6f}, {self.y:.6f}, {self.z:.6f}, {self.w:.6f})" + + def __str__(self) -> str: + return self.__repr__() + + def __eq__(self, other) -> bool: + if not isinstance(other, Quaternion): + return False + return self.x == other.x and self.y == other.y and self.z == other.z and self.w == other.w diff --git a/build/lib/dimos/msgs/geometry_msgs/Twist.py b/build/lib/dimos/msgs/geometry_msgs/Twist.py new file mode 100644 index 0000000000..b9d9630716 --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/Twist.py @@ -0,0 +1,73 @@ +"""LCM type definitions +This file automatically generated by lcm. +DO NOT MODIFY BY HAND!!!! +""" + + +from io import BytesIO +import struct + +from . import * +from .Vector3 import Vector3 +class Twist(object): + + __slots__ = ["linear", "angular"] + + __typenames__ = ["Vector3", "Vector3"] + + __dimensions__ = [None, None] + + def __init__(self): + self.linear = Vector3() + """ LCM Type: Vector3 """ + self.angular = Vector3() + """ LCM Type: Vector3 """ + + def encode(self): + buf = BytesIO() + buf.write(Twist._get_packed_fingerprint()) + self._encode_one(buf) + return buf.getvalue() + + def _encode_one(self, buf): + assert self.linear._get_packed_fingerprint() == Vector3._get_packed_fingerprint() + self.linear._encode_one(buf) + assert self.angular._get_packed_fingerprint() == Vector3._get_packed_fingerprint() + self.angular._encode_one(buf) + + @classmethod + def decode(cls, data: bytes): + if hasattr(data, 'read'): + buf = data + else: + buf = BytesIO(data) + if buf.read(8) != cls._get_packed_fingerprint(): + raise ValueError("Decode error") + return cls._decode_one(buf) + + @classmethod + def _decode_one(cls, buf): + self = Twist() + self.linear = Vector3._decode_one(buf) + self.angular = Vector3._decode_one(buf) + return self + + @classmethod + def _get_hash_recursive(cls, parents): + if cls in parents: return 0 + newparents = parents + [cls] + tmphash = (0x3a4144772922add7+ Vector3._get_hash_recursive(newparents)+ Vector3._get_hash_recursive(newparents)) & 0xffffffffffffffff + tmphash = (((tmphash<<1)&0xffffffffffffffff) + (tmphash>>63)) & 0xffffffffffffffff + return tmphash + _packed_fingerprint = None + + @classmethod + def _get_packed_fingerprint(cls): + if cls._packed_fingerprint is None: + cls._packed_fingerprint = struct.pack(">Q", cls._get_hash_recursive([])) + return cls._packed_fingerprint + + def get_hash(self): + """Get the LCM hash of the struct""" + return struct.unpack(">Q", cls._get_packed_fingerprint())[0] + diff --git a/build/lib/dimos/msgs/geometry_msgs/Vector3.py b/build/lib/dimos/msgs/geometry_msgs/Vector3.py new file mode 100644 index 0000000000..7f839f2773 --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/Vector3.py @@ -0,0 +1,467 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +from collections.abc import Sequence +from io import BytesIO +from typing import BinaryIO, TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import Vector3 as LCMVector3 +from plum import dispatch + +# Types that can be converted to/from Vector +VectorConvertable: TypeAlias = Sequence[int | float] | LCMVector3 | np.ndarray + + +def _ensure_3d(data: np.ndarray) -> np.ndarray: + """Ensure the data array is exactly 3D by padding with zeros or raising an exception if too long.""" + if len(data) == 3: + return data + elif len(data) < 3: + padded = np.zeros(3, dtype=float) + padded[: len(data)] = data + return padded + else: + raise ValueError( + f"Vector3 cannot be initialized with more than 3 components. Got {len(data)} components." + ) + + +class Vector3(LCMVector3): + x: float = 0.0 + y: float = 0.0 + z: float = 0.0 + msg_name = "geometry_msgs.Vector3" + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO): + if not hasattr(data, "read"): + data = BytesIO(data) + if data.read(8) != cls._get_packed_fingerprint(): + raise ValueError("Decode error") + return cls._lcm_decode_one(data) + + @classmethod + def _lcm_decode_one(cls, buf): + return cls(struct.unpack(">ddd", buf.read(24))) + + def lcm_encode(self) -> bytes: + return super().encode() + + @dispatch + def __init__(self) -> None: + """Initialize a zero 3D vector.""" + self.x = 0.0 + self.y = 0.0 + self.z = 0.0 + + @dispatch + def __init__(self, x: int | float) -> None: + """Initialize a 3D vector from a single numeric value (x, 0, 0).""" + self.x = float(x) + self.y = 0.0 + self.z = 0.0 + + @dispatch + def __init__(self, x: int | float, y: int | float) -> None: + """Initialize a 3D vector from x, y components (z=0).""" + self.x = float(x) + self.y = float(y) + self.z = 0.0 + + @dispatch + def __init__(self, x: int | float, y: int | float, z: int | float) -> None: + """Initialize a 3D vector from x, y, z components.""" + self.x = float(x) + self.y = float(y) + self.z = float(z) + + @dispatch + def __init__(self, sequence: Sequence[int | float]) -> None: + """Initialize from a sequence (list, tuple) of numbers, ensuring 3D.""" + data = _ensure_3d(np.array(sequence, dtype=float)) + self.x = float(data[0]) + self.y = float(data[1]) + self.z = float(data[2]) + + @dispatch + def __init__(self, array: np.ndarray) -> None: + """Initialize from a numpy array, ensuring 3D.""" + data = _ensure_3d(np.array(array, dtype=float)) + self.x = float(data[0]) + self.y = float(data[1]) + self.z = float(data[2]) + + @dispatch + def __init__(self, vector: "Vector3") -> None: + """Initialize from another Vector3 (copy constructor).""" + self.x = vector.x + self.y = vector.y + self.z = vector.z + + @dispatch + def __init__(self, lcm_vector: LCMVector3) -> None: + """Initialize from an LCM Vector3.""" + self.x = float(lcm_vector.x) + self.y = float(lcm_vector.y) + self.z = float(lcm_vector.z) + + @property + def as_tuple(self) -> tuple[float, float, float]: + return (self.x, self.y, self.z) + + @property + def yaw(self) -> float: + return self.z + + @property + def pitch(self) -> float: + return self.y + + @property + def roll(self) -> float: + return self.x + + @property + def data(self) -> np.ndarray: + """Get the underlying numpy array.""" + return np.array([self.x, self.y, self.z], dtype=float) + + def __getitem__(self, idx): + if idx == 0: + return self.x + elif idx == 1: + return self.y + elif idx == 2: + return self.z + else: + raise IndexError(f"Vector3 index {idx} out of range [0-2]") + + def __repr__(self) -> str: + return f"Vector({self.data})" + + def __str__(self) -> str: + def getArrow(): + repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] + + if self.x == 0 and self.y == 0: + return "·" + + # Calculate angle in radians and convert to directional index + angle = np.arctan2(self.y, self.x) + # Map angle to 0-7 index (8 directions) with proper orientation + dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) + # Get directional arrow symbol + return repr[dir_index] + + return f"{getArrow()} Vector {self.__repr__()}" + + def serialize(self) -> dict: + """Serialize the vector to a tuple.""" + return {"type": "vector", "c": (self.x, self.y, self.z)} + + def __eq__(self, other) -> bool: + """Check if two vectors are equal using numpy's allclose for floating point comparison.""" + if not isinstance(other, Vector3): + return False + return np.allclose([self.x, self.y, self.z], [other.x, other.y, other.z]) + + def __add__(self, other: VectorConvertable | Vector3) -> Vector3: + other_vector: Vector3 = to_vector(other) + return self.__class__( + self.x + other_vector.x, self.y + other_vector.y, self.z + other_vector.z + ) + + def __sub__(self, other: VectorConvertable | Vector3) -> Vector3: + other_vector = to_vector(other) + return self.__class__( + self.x - other_vector.x, self.y - other_vector.y, self.z - other_vector.z + ) + + def __mul__(self, scalar: float) -> Vector3: + return self.__class__(self.x * scalar, self.y * scalar, self.z * scalar) + + def __rmul__(self, scalar: float) -> Vector3: + return self.__mul__(scalar) + + def __truediv__(self, scalar: float) -> Vector3: + return self.__class__(self.x / scalar, self.y / scalar, self.z / scalar) + + def __neg__(self) -> Vector3: + return self.__class__(-self.x, -self.y, -self.z) + + def dot(self, other: VectorConvertable | Vector3) -> float: + """Compute dot product.""" + other_vector = to_vector(other) + return self.x * other_vector.x + self.y * other_vector.y + self.z * other_vector.z + + def cross(self, other: VectorConvertable | Vector3) -> Vector3: + """Compute cross product (3D vectors only).""" + other_vector = to_vector(other) + return self.__class__( + self.y * other_vector.z - self.z * other_vector.y, + self.z * other_vector.x - self.x * other_vector.z, + self.x * other_vector.y - self.y * other_vector.x, + ) + + def length(self) -> float: + """Compute the Euclidean length (magnitude) of the vector.""" + return float(np.sqrt(self.x * self.x + self.y * self.y + self.z * self.z)) + + def length_squared(self) -> float: + """Compute the squared length of the vector (faster than length()).""" + return float(self.x * self.x + self.y * self.y + self.z * self.z) + + def normalize(self) -> Vector3: + """Return a normalized unit vector in the same direction.""" + length = self.length() + if length < 1e-10: # Avoid division by near-zero + return self.__class__(0.0, 0.0, 0.0) + return self.__class__(self.x / length, self.y / length, self.z / length) + + def to_2d(self) -> Vector3: + """Convert a vector to a 2D vector by taking only the x and y components (z=0).""" + return self.__class__(self.x, self.y, 0.0) + + def distance(self, other: VectorConvertable | Vector3) -> float: + """Compute Euclidean distance to another vector.""" + other_vector = to_vector(other) + dx = self.x - other_vector.x + dy = self.y - other_vector.y + dz = self.z - other_vector.z + return float(np.sqrt(dx * dx + dy * dy + dz * dz)) + + def distance_squared(self, other: VectorConvertable | Vector3) -> float: + """Compute squared Euclidean distance to another vector (faster than distance()).""" + other_vector = to_vector(other) + dx = self.x - other_vector.x + dy = self.y - other_vector.y + dz = self.z - other_vector.z + return float(dx * dx + dy * dy + dz * dz) + + def angle(self, other: VectorConvertable | Vector3) -> float: + """Compute the angle (in radians) between this vector and another.""" + other_vector = to_vector(other) + this_length = self.length() + other_length = other_vector.length() + + if this_length < 1e-10 or other_length < 1e-10: + return 0.0 + + cos_angle = np.clip( + self.dot(other_vector) / (this_length * other_length), + -1.0, + 1.0, + ) + return float(np.arccos(cos_angle)) + + def project(self, onto: VectorConvertable | Vector3) -> Vector3: + """Project this vector onto another vector.""" + onto_vector = to_vector(onto) + onto_length_sq = ( + onto_vector.x * onto_vector.x + + onto_vector.y * onto_vector.y + + onto_vector.z * onto_vector.z + ) + if onto_length_sq < 1e-10: + return self.__class__(0.0, 0.0, 0.0) + + scalar_projection = self.dot(onto_vector) / onto_length_sq + return self.__class__( + scalar_projection * onto_vector.x, + scalar_projection * onto_vector.y, + scalar_projection * onto_vector.z, + ) + + # this is here to test ros_observable_topic + # doesn't happen irl afaik that we want a vector from ros message + @classmethod + def from_msg(cls, msg) -> Vector3: + return cls(*msg) + + @classmethod + def zeros(cls) -> Vector3: + """Create a zero 3D vector.""" + return cls() + + @classmethod + def ones(cls) -> Vector3: + """Create a 3D vector of ones.""" + return cls(1.0, 1.0, 1.0) + + @classmethod + def unit_x(cls) -> Vector3: + """Create a unit vector in the x direction.""" + return cls(1.0, 0.0, 0.0) + + @classmethod + def unit_y(cls) -> Vector3: + """Create a unit vector in the y direction.""" + return cls(0.0, 1.0, 0.0) + + @classmethod + def unit_z(cls) -> Vector3: + """Create a unit vector in the z direction.""" + return cls(0.0, 0.0, 1.0) + + def to_list(self) -> list[float]: + """Convert the vector to a list.""" + return [self.x, self.y, self.z] + + def to_tuple(self) -> tuple[float, float, float]: + """Convert the vector to a tuple.""" + return (self.x, self.y, self.z) + + def to_numpy(self) -> np.ndarray: + """Convert the vector to a numpy array.""" + return np.array([self.x, self.y, self.z], dtype=float) + + def is_zero(self) -> bool: + """Check if this is a zero vector (all components are zero). + + Returns: + True if all components are zero, False otherwise + """ + return np.allclose([self.x, self.y, self.z], 0.0) + + @property + def quaternion(self): + return self.to_quaternion() + + def to_quaternion(self): + """Convert Vector3 representing Euler angles (roll, pitch, yaw) to a Quaternion. + + Assumes this Vector3 contains Euler angles in radians: + - x component: roll (rotation around x-axis) + - y component: pitch (rotation around y-axis) + - z component: yaw (rotation around z-axis) + + Returns: + Quaternion: The equivalent quaternion representation + """ + # Import here to avoid circular imports + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + + # Extract Euler angles + roll = self.x + pitch = self.y + yaw = self.z + + # Convert Euler angles to quaternion using ZYX convention + # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles + + # Compute half angles + cy = np.cos(yaw * 0.5) + sy = np.sin(yaw * 0.5) + cp = np.cos(pitch * 0.5) + sp = np.sin(pitch * 0.5) + cr = np.cos(roll * 0.5) + sr = np.sin(roll * 0.5) + + # Compute quaternion components + w = cr * cp * cy + sr * sp * sy + x = sr * cp * cy - cr * sp * sy + y = cr * sp * cy + sr * cp * sy + z = cr * cp * sy - sr * sp * cy + + return Quaternion(x, y, z, w) + + def __bool__(self) -> bool: + """Boolean conversion for Vector. + + A Vector is considered False if it's a zero vector (all components are zero), + and True otherwise. + + Returns: + False if vector is zero, True otherwise + """ + return not self.is_zero() + + +@dispatch +def to_numpy(value: "Vector3") -> np.ndarray: + """Convert a Vector3 to a numpy array.""" + return value.to_numpy() + + +@dispatch +def to_numpy(value: np.ndarray) -> np.ndarray: + """Pass through numpy arrays.""" + return value + + +@dispatch +def to_numpy(value: Sequence[int | float]) -> np.ndarray: + """Convert a sequence to a numpy array.""" + return np.array(value, dtype=float) + + +@dispatch +def to_vector(value: "Vector3") -> Vector3: + """Pass through Vector3 objects.""" + return value + + +@dispatch +def to_vector(value: VectorConvertable | Vector3) -> Vector3: + """Convert a vector-compatible value to a Vector3 object.""" + return Vector3(value) + + +@dispatch +def to_tuple(value: Vector3) -> tuple[float, float, float]: + """Convert a Vector3 to a tuple.""" + return value.to_tuple() + + +@dispatch +def to_tuple(value: np.ndarray) -> tuple[float, ...]: + """Convert a numpy array to a tuple.""" + return tuple(value.tolist()) + + +@dispatch +def to_tuple(value: Sequence[int | float]) -> tuple[float, ...]: + """Convert a sequence to a tuple.""" + if isinstance(value, tuple): + return value + else: + return tuple(value) + + +@dispatch +def to_list(value: Vector3) -> list[float]: + """Convert a Vector3 to a list.""" + return value.to_list() + + +@dispatch +def to_list(value: np.ndarray) -> list[float]: + """Convert a numpy array to a list.""" + return value.tolist() + + +@dispatch +def to_list(value: Sequence[int | float]) -> list[float]: + """Convert a sequence to a list.""" + if isinstance(value, list): + return value + else: + return list(value) + + +VectorLike: TypeAlias = VectorConvertable | Vector3 diff --git a/build/lib/dimos/msgs/geometry_msgs/__init__.py b/build/lib/dimos/msgs/geometry_msgs/__init__.py new file mode 100644 index 0000000000..2af44a7ff5 --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/__init__.py @@ -0,0 +1,4 @@ +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 diff --git a/build/lib/dimos/msgs/geometry_msgs/test_Pose.py b/build/lib/dimos/msgs/geometry_msgs/test_Pose.py new file mode 100644 index 0000000000..590a17549c --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/test_Pose.py @@ -0,0 +1,555 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle + +import numpy as np +import pytest +from dimos_lcm.geometry_msgs import Pose as LCMPose + +from dimos.msgs.geometry_msgs.Pose import Pose, to_pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_pose_default_init(): + """Test that default initialization creates a pose at origin with identity orientation.""" + pose = Pose() + + # Position should be at origin + assert pose.position.x == 0.0 + assert pose.position.y == 0.0 + assert pose.position.z == 0.0 + + # Orientation should be identity quaternion + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + # Test convenience properties + assert pose.x == 0.0 + assert pose.y == 0.0 + assert pose.z == 0.0 + + +def test_pose_position_init(): + """Test initialization with position coordinates only (identity orientation).""" + pose = Pose(1.0, 2.0, 3.0) + + # Position should be as specified + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should be identity quaternion + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + # Test convenience properties + assert pose.x == 1.0 + assert pose.y == 2.0 + assert pose.z == 3.0 + + +def test_pose_full_init(): + """Test initialization with position and orientation coordinates.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + + # Position should be as specified + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should be as specified + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + # Test convenience properties + assert pose.x == 1.0 + assert pose.y == 2.0 + assert pose.z == 3.0 + + +def test_pose_vector_position_init(): + """Test initialization with Vector3 position (identity orientation).""" + position = Vector3(4.0, 5.0, 6.0) + pose = Pose(position) + + # Position should match the vector + assert pose.position.x == 4.0 + assert pose.position.y == 5.0 + assert pose.position.z == 6.0 + + # Orientation should be identity + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + +def test_pose_vector_quaternion_init(): + """Test initialization with Vector3 position and Quaternion orientation.""" + position = Vector3(1.0, 2.0, 3.0) + orientation = Quaternion(0.1, 0.2, 0.3, 0.9) + pose = Pose(position, orientation) + + # Position should match the vector + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match the quaternion + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_list_init(): + """Test initialization with lists for position and orientation.""" + position_list = [1.0, 2.0, 3.0] + orientation_list = [0.1, 0.2, 0.3, 0.9] + pose = Pose(position_list, orientation_list) + + # Position should match the list + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match the list + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_tuple_init(): + """Test initialization from a tuple of (position, orientation).""" + position = [1.0, 2.0, 3.0] + orientation = [0.1, 0.2, 0.3, 0.9] + pose_tuple = (position, orientation) + pose = Pose(pose_tuple) + + # Position should match + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_dict_init(): + """Test initialization from a dictionary with 'position' and 'orientation' keys.""" + pose_dict = {"position": [1.0, 2.0, 3.0], "orientation": [0.1, 0.2, 0.3, 0.9]} + pose = Pose(pose_dict) + + # Position should match + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_copy_init(): + """Test initialization from another Pose (copy constructor).""" + original = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + copy = Pose(original) + + # Position should match + assert copy.position.x == 1.0 + assert copy.position.y == 2.0 + assert copy.position.z == 3.0 + + # Orientation should match + assert copy.orientation.x == 0.1 + assert copy.orientation.y == 0.2 + assert copy.orientation.z == 0.3 + assert copy.orientation.w == 0.9 + + # Should be a copy, not the same object + assert copy is not original + assert copy == original + + +def test_pose_lcm_init(): + """Test initialization from an LCM Pose.""" + # Create LCM pose + lcm_pose = LCMPose() + lcm_pose.position.x = 1.0 + lcm_pose.position.y = 2.0 + lcm_pose.position.z = 3.0 + lcm_pose.orientation.x = 0.1 + lcm_pose.orientation.y = 0.2 + lcm_pose.orientation.z = 0.3 + lcm_pose.orientation.w = 0.9 + + pose = Pose(lcm_pose) + + # Position should match + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_properties(): + """Test pose property access.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + + # Test position properties + assert pose.x == 1.0 + assert pose.y == 2.0 + assert pose.z == 3.0 + + # Test orientation properties (through quaternion's to_euler method) + euler = pose.orientation.to_euler() + assert pose.roll == euler.x + assert pose.pitch == euler.y + assert pose.yaw == euler.z + + +def test_pose_euler_properties_identity(): + """Test pose Euler angle properties with identity orientation.""" + pose = Pose(1.0, 2.0, 3.0) # Identity orientation + + # Identity quaternion should give zero Euler angles + assert np.isclose(pose.roll, 0.0, atol=1e-10) + assert np.isclose(pose.pitch, 0.0, atol=1e-10) + assert np.isclose(pose.yaw, 0.0, atol=1e-10) + + # Euler property should also be zeros + assert np.isclose(pose.orientation.euler.x, 0.0, atol=1e-10) + assert np.isclose(pose.orientation.euler.y, 0.0, atol=1e-10) + assert np.isclose(pose.orientation.euler.z, 0.0, atol=1e-10) + + +def test_pose_repr(): + """Test pose string representation.""" + pose = Pose(1.234, 2.567, 3.891, 0.1, 0.2, 0.3, 0.9) + + repr_str = repr(pose) + + # Should contain position and orientation info + assert "Pose" in repr_str + assert "position" in repr_str + assert "orientation" in repr_str + + # Should contain the actual values (approximately) + assert "1.234" in repr_str or "1.23" in repr_str + assert "2.567" in repr_str or "2.57" in repr_str + + +def test_pose_str(): + """Test pose string formatting.""" + pose = Pose(1.234, 2.567, 3.891, 0.1, 0.2, 0.3, 0.9) + + str_repr = str(pose) + + # Should contain position coordinates + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + + # Should contain Euler angles + assert "euler" in str_repr + + # Should be formatted with specified precision + assert str_repr.count("Pose") == 1 + + +def test_pose_equality(): + """Test pose equality comparison.""" + pose1 = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose2 = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose3 = Pose(1.1, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) # Different position + pose4 = Pose(1.0, 2.0, 3.0, 0.11, 0.2, 0.3, 0.9) # Different orientation + + # Equal poses + assert pose1 == pose2 + assert pose2 == pose1 + + # Different poses + assert pose1 != pose3 + assert pose1 != pose4 + assert pose3 != pose4 + + # Different types + assert pose1 != "not a pose" + assert pose1 != [1.0, 2.0, 3.0] + assert pose1 != None + + +def test_pose_with_numpy_arrays(): + """Test pose initialization with numpy arrays.""" + position_array = np.array([1.0, 2.0, 3.0]) + orientation_array = np.array([0.1, 0.2, 0.3, 0.9]) + + pose = Pose(position_array, orientation_array) + + # Position should match + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_with_mixed_types(): + """Test pose initialization with mixed input types.""" + # Position as tuple, orientation as list + pose1 = Pose((1.0, 2.0, 3.0), [0.1, 0.2, 0.3, 0.9]) + + # Position as numpy array, orientation as Vector3/Quaternion + position = np.array([1.0, 2.0, 3.0]) + orientation = Quaternion(0.1, 0.2, 0.3, 0.9) + pose2 = Pose(position, orientation) + + # Both should result in the same pose + assert pose1.position.x == pose2.position.x + assert pose1.position.y == pose2.position.y + assert pose1.position.z == pose2.position.z + assert pose1.orientation.x == pose2.orientation.x + assert pose1.orientation.y == pose2.orientation.y + assert pose1.orientation.z == pose2.orientation.z + assert pose1.orientation.w == pose2.orientation.w + + +def test_to_pose_passthrough(): + """Test to_pose function with Pose input (passthrough).""" + original = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + result = to_pose(original) + + # Should be the same object (passthrough) + assert result is original + + +def test_to_pose_conversion(): + """Test to_pose function with convertible inputs.""" + # Note: The to_pose conversion function has type checking issues in the current implementation + # Test direct construction instead to verify the intended functionality + + # Test the intended functionality by creating poses directly + pose_tuple = ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3, 0.9]) + result1 = Pose(pose_tuple) + + assert isinstance(result1, Pose) + assert result1.position.x == 1.0 + assert result1.position.y == 2.0 + assert result1.position.z == 3.0 + assert result1.orientation.x == 0.1 + assert result1.orientation.y == 0.2 + assert result1.orientation.z == 0.3 + assert result1.orientation.w == 0.9 + + # Test with dictionary + pose_dict = {"position": [1.0, 2.0, 3.0], "orientation": [0.1, 0.2, 0.3, 0.9]} + result2 = Pose(pose_dict) + + assert isinstance(result2, Pose) + assert result2.position.x == 1.0 + assert result2.position.y == 2.0 + assert result2.position.z == 3.0 + assert result2.orientation.x == 0.1 + assert result2.orientation.y == 0.2 + assert result2.orientation.z == 0.3 + assert result2.orientation.w == 0.9 + + +def test_pose_euler_roundtrip(): + """Test conversion from Euler angles to quaternion and back.""" + # Start with known Euler angles (small angles to avoid gimbal lock) + roll = 0.1 + pitch = 0.2 + yaw = 0.3 + + # Create quaternion from Euler angles + euler_vector = Vector3(roll, pitch, yaw) + quaternion = euler_vector.to_quaternion() + + # Create pose with this quaternion + pose = Pose(Vector3(0, 0, 0), quaternion) + + # Convert back to Euler angles + result_euler = pose.orientation.euler + + # Should get back the original Euler angles (within tolerance) + assert np.isclose(result_euler.x, roll, atol=1e-6) + assert np.isclose(result_euler.y, pitch, atol=1e-6) + assert np.isclose(result_euler.z, yaw, atol=1e-6) + + +def test_pose_zero_position(): + """Test pose with zero position vector.""" + # Use manual construction since Vector3.zeros has signature issues + pose = Pose(0.0, 0.0, 0.0) # Position at origin with identity orientation + + assert pose.x == 0.0 + assert pose.y == 0.0 + assert pose.z == 0.0 + assert np.isclose(pose.roll, 0.0, atol=1e-10) + assert np.isclose(pose.pitch, 0.0, atol=1e-10) + assert np.isclose(pose.yaw, 0.0, atol=1e-10) + + +def test_pose_unit_vectors(): + """Test pose with unit vector positions.""" + # Test unit x vector position + pose_x = Pose(Vector3.unit_x()) + assert pose_x.x == 1.0 + assert pose_x.y == 0.0 + assert pose_x.z == 0.0 + + # Test unit y vector position + pose_y = Pose(Vector3.unit_y()) + assert pose_y.x == 0.0 + assert pose_y.y == 1.0 + assert pose_y.z == 0.0 + + # Test unit z vector position + pose_z = Pose(Vector3.unit_z()) + assert pose_z.x == 0.0 + assert pose_z.y == 0.0 + assert pose_z.z == 1.0 + + +def test_pose_negative_coordinates(): + """Test pose with negative coordinates.""" + pose = Pose(-1.0, -2.0, -3.0, -0.1, -0.2, -0.3, 0.9) + + # Position should be negative + assert pose.x == -1.0 + assert pose.y == -2.0 + assert pose.z == -3.0 + + # Orientation should be as specified + assert pose.orientation.x == -0.1 + assert pose.orientation.y == -0.2 + assert pose.orientation.z == -0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_large_coordinates(): + """Test pose with large coordinate values.""" + large_value = 1000.0 + pose = Pose(large_value, large_value, large_value) + + assert pose.x == large_value + assert pose.y == large_value + assert pose.z == large_value + + # Orientation should still be identity + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + +@pytest.mark.parametrize( + "x,y,z", + [(0.0, 0.0, 0.0), (1.0, 2.0, 3.0), (-1.0, -2.0, -3.0), (0.5, -0.5, 1.5), (100.0, -100.0, 0.0)], +) +def test_pose_parametrized_positions(x, y, z): + """Parametrized test for various position values.""" + pose = Pose(x, y, z) + + assert pose.x == x + assert pose.y == y + assert pose.z == z + + # Should have identity orientation + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + +@pytest.mark.parametrize( + "qx,qy,qz,qw", + [ + (0.0, 0.0, 0.0, 1.0), # Identity + (1.0, 0.0, 0.0, 0.0), # 180° around x + (0.0, 1.0, 0.0, 0.0), # 180° around y + (0.0, 0.0, 1.0, 0.0), # 180° around z + (0.5, 0.5, 0.5, 0.5), # Equal components + ], +) +def test_pose_parametrized_orientations(qx, qy, qz, qw): + """Parametrized test for various orientation values.""" + pose = Pose(0.0, 0.0, 0.0, qx, qy, qz, qw) + + # Position should be at origin + assert pose.x == 0.0 + assert pose.y == 0.0 + assert pose.z == 0.0 + + # Orientation should match + assert pose.orientation.x == qx + assert pose.orientation.y == qy + assert pose.orientation.z == qz + assert pose.orientation.w == qw + + +def test_lcm_encode_decode(): + """Test encoding and decoding of Pose to/from binary LCM format.""" + + def encodepass(): + pose_source = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + binary_msg = pose_source.lcm_encode() + pose_dest = Pose.lcm_decode(binary_msg) + assert isinstance(pose_dest, Pose) + assert pose_dest is not pose_source + assert pose_dest == pose_source + + import timeit + + print(f"{timeit.timeit(encodepass, number=1000)} ms per cycle") + + +def test_pickle_encode_decode(): + """Test encoding and decoding of Pose to/from binary LCM format.""" + + def encodepass(): + pose_source = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + binary_msg = pickle.dumps(pose_source) + pose_dest = pickle.loads(binary_msg) + assert isinstance(pose_dest, Pose) + assert pose_dest is not pose_source + assert pose_dest == pose_source + + import timeit + + print(f"{timeit.timeit(encodepass, number=1000)} ms per cycle") diff --git a/build/lib/dimos/msgs/geometry_msgs/test_Quaternion.py b/build/lib/dimos/msgs/geometry_msgs/test_Quaternion.py new file mode 100644 index 0000000000..ab049f809f --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/test_Quaternion.py @@ -0,0 +1,210 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion + +from dimos.msgs.geometry_msgs.Quaternion import Quaternion + + +def test_quaternion_default_init(): + """Test that default initialization creates an identity quaternion (w=1, x=y=z=0).""" + q = Quaternion() + assert q.x == 0.0 + assert q.y == 0.0 + assert q.z == 0.0 + assert q.w == 1.0 + assert q.to_tuple() == (0.0, 0.0, 0.0, 1.0) + + +def test_quaternion_component_init(): + """Test initialization with four float components (x, y, z, w).""" + q = Quaternion(0.5, 0.5, 0.5, 0.5) + assert q.x == 0.5 + assert q.y == 0.5 + assert q.z == 0.5 + assert q.w == 0.5 + + # Test with different values + q2 = Quaternion(1.0, 2.0, 3.0, 4.0) + assert q2.x == 1.0 + assert q2.y == 2.0 + assert q2.z == 3.0 + assert q2.w == 4.0 + + # Test with negative values + q3 = Quaternion(-1.0, -2.0, -3.0, -4.0) + assert q3.x == -1.0 + assert q3.y == -2.0 + assert q3.z == -3.0 + assert q3.w == -4.0 + + # Test with integers (should convert to float) + q4 = Quaternion(1, 2, 3, 4) + assert q4.x == 1.0 + assert q4.y == 2.0 + assert q4.z == 3.0 + assert q4.w == 4.0 + assert isinstance(q4.x, float) + + +def test_quaternion_sequence_init(): + """Test initialization from sequence (list, tuple) of 4 numbers.""" + # From list + q1 = Quaternion([0.1, 0.2, 0.3, 0.4]) + assert q1.x == 0.1 + assert q1.y == 0.2 + assert q1.z == 0.3 + assert q1.w == 0.4 + + # From tuple + q2 = Quaternion((0.5, 0.6, 0.7, 0.8)) + assert q2.x == 0.5 + assert q2.y == 0.6 + assert q2.z == 0.7 + assert q2.w == 0.8 + + # Test with integers in sequence + q3 = Quaternion([1, 2, 3, 4]) + assert q3.x == 1.0 + assert q3.y == 2.0 + assert q3.z == 3.0 + assert q3.w == 4.0 + + # Test error with wrong length + with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): + Quaternion([1, 2, 3]) # Only 3 components + + with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): + Quaternion([1, 2, 3, 4, 5]) # Too many components + + +def test_quaternion_numpy_init(): + """Test initialization from numpy array.""" + # From numpy array + arr = np.array([0.1, 0.2, 0.3, 0.4]) + q1 = Quaternion(arr) + assert q1.x == 0.1 + assert q1.y == 0.2 + assert q1.z == 0.3 + assert q1.w == 0.4 + + # Test with different dtypes + arr_int = np.array([1, 2, 3, 4], dtype=int) + q2 = Quaternion(arr_int) + assert q2.x == 1.0 + assert q2.y == 2.0 + assert q2.z == 3.0 + assert q2.w == 4.0 + + # Test error with wrong size + with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): + Quaternion(np.array([1, 2, 3])) # Only 3 elements + + with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): + Quaternion(np.array([1, 2, 3, 4, 5])) # Too many elements + + +def test_quaternion_copy_init(): + """Test initialization from another Quaternion (copy constructor).""" + original = Quaternion(0.1, 0.2, 0.3, 0.4) + copy = Quaternion(original) + + assert copy.x == 0.1 + assert copy.y == 0.2 + assert copy.z == 0.3 + assert copy.w == 0.4 + + # Verify it's a copy, not the same object + assert copy is not original + assert copy == original + + +def test_quaternion_lcm_init(): + """Test initialization from LCM Quaternion.""" + lcm_quat = LCMQuaternion() + lcm_quat.x = 0.1 + lcm_quat.y = 0.2 + lcm_quat.z = 0.3 + lcm_quat.w = 0.4 + + q = Quaternion(lcm_quat) + assert q.x == 0.1 + assert q.y == 0.2 + assert q.z == 0.3 + assert q.w == 0.4 + + +def test_quaternion_properties(): + """Test quaternion component properties.""" + q = Quaternion(1.0, 2.0, 3.0, 4.0) + + # Test property access + assert q.x == 1.0 + assert q.y == 2.0 + assert q.z == 3.0 + assert q.w == 4.0 + + # Test as_tuple property + assert q.to_tuple() == (1.0, 2.0, 3.0, 4.0) + + +def test_quaternion_indexing(): + """Test quaternion indexing support.""" + q = Quaternion(1.0, 2.0, 3.0, 4.0) + + # Test indexing + assert q[0] == 1.0 + assert q[1] == 2.0 + assert q[2] == 3.0 + assert q[3] == 4.0 + + +def test_quaternion_euler(): + """Test quaternion to Euler angles conversion.""" + + # Test identity quaternion (should give zero angles) + q_identity = Quaternion() + angles = q_identity.to_euler() + assert np.isclose(angles.x, 0.0, atol=1e-10) # roll + assert np.isclose(angles.y, 0.0, atol=1e-10) # pitch + assert np.isclose(angles.z, 0.0, atol=1e-10) # yaw + + # Test 90 degree rotation around Z-axis (yaw) + q_z90 = Quaternion(0, 0, np.sin(np.pi / 4), np.cos(np.pi / 4)) + angles_z90 = q_z90.to_euler() + assert np.isclose(angles_z90.roll, 0.0, atol=1e-10) # roll should be 0 + assert np.isclose(angles_z90.pitch, 0.0, atol=1e-10) # pitch should be 0 + assert np.isclose(angles_z90.yaw, np.pi / 2, atol=1e-10) # yaw should be π/2 (90 degrees) + + # Test 90 degree rotation around X-axis (roll) + q_x90 = Quaternion(np.sin(np.pi / 4), 0, 0, np.cos(np.pi / 4)) + angles_x90 = q_x90.to_euler() + assert np.isclose(angles_x90.x, np.pi / 2, atol=1e-10) # roll should be π/2 + assert np.isclose(angles_x90.y, 0.0, atol=1e-10) # pitch should be 0 + assert np.isclose(angles_x90.z, 0.0, atol=1e-10) # yaw should be 0 + + +def test_lcm_encode_decode(): + """Test encoding and decoding of Quaternion to/from binary LCM format.""" + q_source = Quaternion(1.0, 2.0, 3.0, 4.0) + + binary_msg = q_source.lcm_encode() + + q_dest = Quaternion.lcm_decode(binary_msg) + + assert isinstance(q_dest, Quaternion) + assert q_dest is not q_source + assert q_dest == q_source diff --git a/build/lib/dimos/msgs/geometry_msgs/test_Vector3.py b/build/lib/dimos/msgs/geometry_msgs/test_Vector3.py new file mode 100644 index 0000000000..81325286f9 --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/test_Vector3.py @@ -0,0 +1,462 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_vector_default_init(): + """Test that default initialization of Vector() has x,y,z components all zero.""" + v = Vector3() + assert v.x == 0.0 + assert v.y == 0.0 + assert v.z == 0.0 + assert len(v.data) == 3 + assert v.to_list() == [0.0, 0.0, 0.0] + assert v.is_zero() == True # Zero vector should be considered zero + + +def test_vector_specific_init(): + """Test initialization with specific values and different input types.""" + + v1 = Vector3(1.0, 2.0) # 2D vector (now becomes 3D with z=0) + assert v1.x == 1.0 + assert v1.y == 2.0 + assert v1.z == 0.0 + + v2 = Vector3(3.0, 4.0, 5.0) # 3D vector + assert v2.x == 3.0 + assert v2.y == 4.0 + assert v2.z == 5.0 + + v3 = Vector3([6.0, 7.0, 8.0]) + assert v3.x == 6.0 + assert v3.y == 7.0 + assert v3.z == 8.0 + + v4 = Vector3((9.0, 10.0, 11.0)) + assert v4.x == 9.0 + assert v4.y == 10.0 + assert v4.z == 11.0 + + v5 = Vector3(np.array([12.0, 13.0, 14.0])) + assert v5.x == 12.0 + assert v5.y == 13.0 + assert v5.z == 14.0 + + original = Vector3([15.0, 16.0, 17.0]) + v6 = Vector3(original) + assert v6.x == 15.0 + assert v6.y == 16.0 + assert v6.z == 17.0 + + assert v6 is not original + assert v6 == original + + +def test_vector_addition(): + """Test vector addition.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 5.0, 6.0) + + v_add = v1 + v2 + assert v_add.x == 5.0 + assert v_add.y == 7.0 + assert v_add.z == 9.0 + + +def test_vector_subtraction(): + """Test vector subtraction.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 5.0, 6.0) + + v_sub = v2 - v1 + assert v_sub.x == 3.0 + assert v_sub.y == 3.0 + assert v_sub.z == 3.0 + + +def test_vector_scalar_multiplication(): + """Test vector multiplication by a scalar.""" + v1 = Vector3(1.0, 2.0, 3.0) + + v_mul = v1 * 2.0 + assert v_mul.x == 2.0 + assert v_mul.y == 4.0 + assert v_mul.z == 6.0 + + # Test right multiplication + v_rmul = 2.0 * v1 + assert v_rmul.x == 2.0 + assert v_rmul.y == 4.0 + assert v_rmul.z == 6.0 + + +def test_vector_scalar_division(): + """Test vector division by a scalar.""" + v2 = Vector3(4.0, 5.0, 6.0) + + v_div = v2 / 2.0 + assert v_div.x == 2.0 + assert v_div.y == 2.5 + assert v_div.z == 3.0 + + +def test_vector_dot_product(): + """Test vector dot product.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 5.0, 6.0) + + dot = v1.dot(v2) + assert dot == 32.0 + + +def test_vector_length(): + """Test vector length calculation.""" + # 2D vector with length 5 (now 3D with z=0) + v1 = Vector3(3.0, 4.0) + assert v1.length() == 5.0 + + # 3D vector + v2 = Vector3(2.0, 3.0, 6.0) + assert v2.length() == pytest.approx(7.0, 0.001) + + # Test length_squared + assert v1.length_squared() == 25.0 + assert v2.length_squared() == 49.0 + + +def test_vector_normalize(): + """Test vector normalization.""" + v = Vector3(2.0, 3.0, 6.0) + assert v.is_zero() == False + + v_norm = v.normalize() + length = v.length() + expected_x = 2.0 / length + expected_y = 3.0 / length + expected_z = 6.0 / length + + assert np.isclose(v_norm.x, expected_x) + assert np.isclose(v_norm.y, expected_y) + assert np.isclose(v_norm.z, expected_z) + assert np.isclose(v_norm.length(), 1.0) + assert v_norm.is_zero() == False + + # Test normalizing a zero vector + v_zero = Vector3(0.0, 0.0, 0.0) + assert v_zero.is_zero() == True + v_zero_norm = v_zero.normalize() + assert v_zero_norm.x == 0.0 + assert v_zero_norm.y == 0.0 + assert v_zero_norm.z == 0.0 + assert v_zero_norm.is_zero() == True + + +def test_vector_to_2d(): + """Test conversion to 2D vector.""" + v = Vector3(2.0, 3.0, 6.0) + + v_2d = v.to_2d() + assert v_2d.x == 2.0 + assert v_2d.y == 3.0 + assert v_2d.z == 0.0 # z should be 0 for 2D conversion + + # Already 2D vector (z=0) + v2 = Vector3(4.0, 5.0) + v2_2d = v2.to_2d() + assert v2_2d.x == 4.0 + assert v2_2d.y == 5.0 + assert v2_2d.z == 0.0 + + +def test_vector_distance(): + """Test distance calculations between vectors.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 6.0, 8.0) + + # Distance + dist = v1.distance(v2) + expected_dist = np.sqrt(9.0 + 16.0 + 25.0) # sqrt((4-1)² + (6-2)² + (8-3)²) + assert dist == pytest.approx(expected_dist) + + # Distance squared + dist_sq = v1.distance_squared(v2) + assert dist_sq == 50.0 # 9 + 16 + 25 + + +def test_vector_cross_product(): + """Test vector cross product.""" + v1 = Vector3(1.0, 0.0, 0.0) # Unit x vector + v2 = Vector3(0.0, 1.0, 0.0) # Unit y vector + + # v1 × v2 should be unit z vector + cross = v1.cross(v2) + assert cross.x == 0.0 + assert cross.y == 0.0 + assert cross.z == 1.0 + + # Test with more complex vectors + a = Vector3(2.0, 3.0, 4.0) + b = Vector3(5.0, 6.0, 7.0) + c = a.cross(b) + + # Cross product manually calculated: + # (3*7-4*6, 4*5-2*7, 2*6-3*5) + assert c.x == -3.0 + assert c.y == 6.0 + assert c.z == -3.0 + + # Test with vectors that have z=0 (still works as they're 3D) + v_2d1 = Vector3(1.0, 2.0) # (1, 2, 0) + v_2d2 = Vector3(3.0, 4.0) # (3, 4, 0) + cross_2d = v_2d1.cross(v_2d2) + # (2*0-0*4, 0*3-1*0, 1*4-2*3) = (0, 0, -2) + assert cross_2d.x == 0.0 + assert cross_2d.y == 0.0 + assert cross_2d.z == -2.0 + + +def test_vector_zeros(): + """Test Vector3.zeros class method.""" + # 3D zero vector + v_zeros = Vector3.zeros() + assert v_zeros.x == 0.0 + assert v_zeros.y == 0.0 + assert v_zeros.z == 0.0 + assert v_zeros.is_zero() == True + + +def test_vector_ones(): + """Test Vector3.ones class method.""" + # 3D ones vector + v_ones = Vector3.ones() + assert v_ones.x == 1.0 + assert v_ones.y == 1.0 + assert v_ones.z == 1.0 + + +def test_vector_conversion_methods(): + """Test vector conversion methods (to_list, to_tuple, to_numpy).""" + v = Vector3(1.0, 2.0, 3.0) + + # to_list + assert v.to_list() == [1.0, 2.0, 3.0] + + # to_tuple + assert v.to_tuple() == (1.0, 2.0, 3.0) + + # to_numpy + np_array = v.to_numpy() + assert isinstance(np_array, np.ndarray) + assert np.array_equal(np_array, np.array([1.0, 2.0, 3.0])) + + +def test_vector_equality(): + """Test vector equality.""" + v1 = Vector3(1, 2, 3) + v2 = Vector3(1, 2, 3) + v3 = Vector3(4, 5, 6) + + assert v1 == v2 + assert v1 != v3 + assert v1 != Vector3(1, 2) # Now (1, 2, 0) vs (1, 2, 3) + assert v1 != Vector3(1.1, 2, 3) # Different values + assert v1 != [1, 2, 3] + + +def test_vector_is_zero(): + """Test is_zero method for vectors.""" + # Default zero vector + v0 = Vector3() + assert v0.is_zero() == True + + # Explicit zero vector + v1 = Vector3(0.0, 0.0, 0.0) + assert v1.is_zero() == True + + # Zero vector with different initialization (now always 3D) + v2 = Vector3(0.0, 0.0) # Becomes (0, 0, 0) + assert v2.is_zero() == True + + # Non-zero vectors + v3 = Vector3(1.0, 0.0, 0.0) + assert v3.is_zero() == False + + v4 = Vector3(0.0, 2.0, 0.0) + assert v4.is_zero() == False + + v5 = Vector3(0.0, 0.0, 3.0) + assert v5.is_zero() == False + + # Almost zero (within tolerance) + v6 = Vector3(1e-10, 1e-10, 1e-10) + assert v6.is_zero() == True + + # Almost zero (outside tolerance) + v7 = Vector3(1e-6, 1e-6, 1e-6) + assert v7.is_zero() == False + + +def test_vector_bool_conversion(): + """Test boolean conversion of vectors.""" + # Zero vectors should be False + v0 = Vector3() + assert bool(v0) == False + + v1 = Vector3(0.0, 0.0, 0.0) + assert bool(v1) == False + + # Almost zero vectors should be False + v2 = Vector3(1e-10, 1e-10, 1e-10) + assert bool(v2) == False + + # Non-zero vectors should be True + v3 = Vector3(1.0, 0.0, 0.0) + assert bool(v3) == True + + v4 = Vector3(0.0, 2.0, 0.0) + assert bool(v4) == True + + v5 = Vector3(0.0, 0.0, 3.0) + assert bool(v5) == True + + # Direct use in if statements + if v0: + assert False, "Zero vector should be False in boolean context" + else: + pass # Expected path + + if v3: + pass # Expected path + else: + assert False, "Non-zero vector should be True in boolean context" + + +def test_vector_add(): + """Test vector addition operator.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 5.0, 6.0) + + # Using __add__ method + v_add = v1.__add__(v2) + assert v_add.x == 5.0 + assert v_add.y == 7.0 + assert v_add.z == 9.0 + + # Using + operator + v_add_op = v1 + v2 + assert v_add_op.x == 5.0 + assert v_add_op.y == 7.0 + assert v_add_op.z == 9.0 + + # Adding zero vector should return original vector + v_zero = Vector3.zeros() + assert (v1 + v_zero) == v1 + + +def test_vector_add_dim_mismatch(): + """Test vector addition with different input dimensions (now all vectors are 3D).""" + v1 = Vector3(1.0, 2.0) # Becomes (1, 2, 0) + v2 = Vector3(4.0, 5.0, 6.0) # (4, 5, 6) + + # Using + operator - should work fine now since both are 3D + v_add_op = v1 + v2 + assert v_add_op.x == 5.0 # 1 + 4 + assert v_add_op.y == 7.0 # 2 + 5 + assert v_add_op.z == 6.0 # 0 + 6 + + +def test_yaw_pitch_roll_accessors(): + """Test yaw, pitch, and roll accessor properties.""" + # Test with a 3D vector + v = Vector3(1.0, 2.0, 3.0) + + # According to standard convention: + # roll = rotation around x-axis = x component + # pitch = rotation around y-axis = y component + # yaw = rotation around z-axis = z component + assert v.roll == 1.0 # Should return x component + assert v.pitch == 2.0 # Should return y component + assert v.yaw == 3.0 # Should return z component + + # Test with a 2D vector (z should be 0.0) + v_2d = Vector3(4.0, 5.0) + assert v_2d.roll == 4.0 # Should return x component + assert v_2d.pitch == 5.0 # Should return y component + assert v_2d.yaw == 0.0 # Should return z component (defaults to 0 for 2D) + + # Test with empty vector (all should be 0.0) + v_empty = Vector3() + assert v_empty.roll == 0.0 + assert v_empty.pitch == 0.0 + assert v_empty.yaw == 0.0 + + # Test with negative values + v_neg = Vector3(-1.5, -2.5, -3.5) + assert v_neg.roll == -1.5 + assert v_neg.pitch == -2.5 + assert v_neg.yaw == -3.5 + + +def test_vector_to_quaternion(): + """Test vector to quaternion conversion.""" + # Test with zero Euler angles (should produce identity quaternion) + v_zero = Vector3(0.0, 0.0, 0.0) + q_identity = v_zero.to_quaternion() + + # Identity quaternion should have w=1, x=y=z=0 + assert np.isclose(q_identity.x, 0.0, atol=1e-10) + assert np.isclose(q_identity.y, 0.0, atol=1e-10) + assert np.isclose(q_identity.z, 0.0, atol=1e-10) + assert np.isclose(q_identity.w, 1.0, atol=1e-10) + + # Test with small angles (to avoid gimbal lock issues) + v_small = Vector3(0.1, 0.2, 0.3) # Small roll, pitch, yaw + q_small = v_small.to_quaternion() + + # Quaternion should be normalized (magnitude = 1) + magnitude = np.sqrt(q_small.x**2 + q_small.y**2 + q_small.z**2 + q_small.w**2) + assert np.isclose(magnitude, 1.0, atol=1e-10) + + # Test conversion back to Euler (should be close to original) + v_back = q_small.to_euler() + assert np.isclose(v_back.x, 0.1, atol=1e-6) + assert np.isclose(v_back.y, 0.2, atol=1e-6) + assert np.isclose(v_back.z, 0.3, atol=1e-6) + + # Test with π/2 rotation around x-axis + v_x_90 = Vector3(np.pi / 2, 0.0, 0.0) + q_x_90 = v_x_90.to_quaternion() + + # Should be approximately (sin(π/4), 0, 0, cos(π/4)) = (√2/2, 0, 0, √2/2) + expected = np.sqrt(2) / 2 + assert np.isclose(q_x_90.x, expected, atol=1e-10) + assert np.isclose(q_x_90.y, 0.0, atol=1e-10) + assert np.isclose(q_x_90.z, 0.0, atol=1e-10) + assert np.isclose(q_x_90.w, expected, atol=1e-10) + + +def test_lcm_encode_decode(): + v_source = Vector3(1.0, 2.0, 3.0) + + binary_msg = v_source.lcm_encode() + + v_dest = Vector3.lcm_decode(binary_msg) + + assert isinstance(v_dest, Vector3) + assert v_dest is not v_source + assert v_dest == v_source diff --git a/build/lib/dimos/msgs/geometry_msgs/test_publish.py b/build/lib/dimos/msgs/geometry_msgs/test_publish.py new file mode 100644 index 0000000000..4e364dc19a --- /dev/null +++ b/build/lib/dimos/msgs/geometry_msgs/test_publish.py @@ -0,0 +1,54 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import lcm +import pytest + +from dimos.msgs.geometry_msgs import Vector3 + + +@pytest.mark.tool +def test_runpublish(): + for i in range(10): + msg = Vector3(-5 + i, -5 + i, i) + lc = lcm.LCM() + lc.publish("thing1_vector3#geometry_msgs.Vector3", msg.encode()) + time.sleep(0.1) + print(f"Published: {msg}") + + +@pytest.mark.tool +def test_receive(): + lc = lcm.LCM() + + def receive(bla, msg): + # print("receive", bla, msg) + print(Vector3.decode(msg)) + + lc.subscribe("thing1_vector3#geometry_msgs.Vector3", receive) + + def _loop(): + while True: + """LCM message handling loop""" + try: + lc.handle() + # loop 10000 times + for _ in range(10000000): + 3 + 3 + except Exception as e: + print(f"Error in LCM handling: {e}") + + _loop() diff --git a/build/lib/dimos/msgs/sensor_msgs/Image.py b/build/lib/dimos/msgs/sensor_msgs/Image.py new file mode 100644 index 0000000000..2ac53a2fd7 --- /dev/null +++ b/build/lib/dimos/msgs/sensor_msgs/Image.py @@ -0,0 +1,372 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional, Tuple + +import cv2 +import numpy as np + +# Import LCM types +from dimos_lcm.sensor_msgs.Image import Image as LCMImage +from dimos_lcm.std_msgs.Header import Header + +from dimos.types.timestamped import Timestamped + + +class ImageFormat(Enum): + """Supported image formats.""" + + BGR = "bgr8" + RGB = "rgb8" + RGBA = "rgba8" + BGRA = "bgra8" + GRAY = "mono8" + GRAY16 = "mono16" + + +@dataclass +class Image(Timestamped): + """Standardized image type with LCM integration.""" + + msg_name = "sensor_msgs.Image" + data: np.ndarray + format: ImageFormat = field(default=ImageFormat.BGR) + frame_id: str = field(default="") + ts: float = field(default_factory=time.time) + + def __post_init__(self): + """Validate image data and format.""" + if self.data is None: + raise ValueError("Image data cannot be None") + + if not isinstance(self.data, np.ndarray): + raise ValueError("Image data must be a numpy array") + + if len(self.data.shape) < 2: + raise ValueError("Image data must be at least 2D") + + # Ensure data is contiguous for efficient operations + if not self.data.flags["C_CONTIGUOUS"]: + self.data = np.ascontiguousarray(self.data) + + @property + def height(self) -> int: + """Get image height.""" + return self.data.shape[0] + + @property + def width(self) -> int: + """Get image width.""" + return self.data.shape[1] + + @property + def channels(self) -> int: + """Get number of channels.""" + if len(self.data.shape) == 2: + return 1 + elif len(self.data.shape) == 3: + return self.data.shape[2] + else: + raise ValueError("Invalid image dimensions") + + @property + def shape(self) -> Tuple[int, ...]: + """Get image shape.""" + return self.data.shape + + @property + def dtype(self) -> np.dtype: + """Get image data type.""" + return self.data.dtype + + def copy(self) -> "Image": + """Create a deep copy of the image.""" + return self.__class__( + data=self.data.copy(), + format=self.format, + frame_id=self.frame_id, + ts=self.ts, + ) + + @classmethod + def from_opencv( + cls, cv_image: np.ndarray, format: ImageFormat = ImageFormat.BGR, **kwargs + ) -> "Image": + """Create Image from OpenCV image array.""" + return cls(data=cv_image, format=format, **kwargs) + + @classmethod + def from_numpy( + cls, np_image: np.ndarray, format: ImageFormat = ImageFormat.BGR, **kwargs + ) -> "Image": + """Create Image from numpy array.""" + return cls(data=np_image, format=format, **kwargs) + + @classmethod + def from_file(cls, filepath: str, format: ImageFormat = ImageFormat.BGR) -> "Image": + """Load image from file.""" + # OpenCV loads as BGR by default + cv_image = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) + if cv_image is None: + raise ValueError(f"Could not load image from {filepath}") + + # Detect format based on channels + if len(cv_image.shape) == 2: + detected_format = ImageFormat.GRAY + elif cv_image.shape[2] == 3: + detected_format = ImageFormat.BGR # OpenCV default + elif cv_image.shape[2] == 4: + detected_format = ImageFormat.BGRA + else: + detected_format = format + + return cls(data=cv_image, format=detected_format) + + def to_opencv(self) -> np.ndarray: + """Convert to OpenCV-compatible array (BGR format).""" + if self.format == ImageFormat.BGR: + return self.data + elif self.format == ImageFormat.RGB: + return cv2.cvtColor(self.data, cv2.COLOR_RGB2BGR) + elif self.format == ImageFormat.RGBA: + return cv2.cvtColor(self.data, cv2.COLOR_RGBA2BGR) + elif self.format == ImageFormat.BGRA: + return cv2.cvtColor(self.data, cv2.COLOR_BGRA2BGR) + elif self.format == ImageFormat.GRAY: + return self.data + elif self.format == ImageFormat.GRAY16: + return self.data + else: + raise ValueError(f"Unsupported format conversion: {self.format}") + + def to_rgb(self) -> "Image": + """Convert image to RGB format.""" + if self.format == ImageFormat.RGB: + return self.copy() + elif self.format == ImageFormat.BGR: + rgb_data = cv2.cvtColor(self.data, cv2.COLOR_BGR2RGB) + elif self.format == ImageFormat.RGBA: + return self.copy() # Already RGB with alpha + elif self.format == ImageFormat.BGRA: + rgb_data = cv2.cvtColor(self.data, cv2.COLOR_BGRA2RGBA) + elif self.format == ImageFormat.GRAY: + rgb_data = cv2.cvtColor(self.data, cv2.COLOR_GRAY2RGB) + elif self.format == ImageFormat.GRAY16: + # Convert 16-bit grayscale to 8-bit then to RGB + gray8 = (self.data / 256).astype(np.uint8) + rgb_data = cv2.cvtColor(gray8, cv2.COLOR_GRAY2RGB) + else: + raise ValueError(f"Unsupported format conversion from {self.format} to RGB") + + return self.__class__( + data=rgb_data, + format=ImageFormat.RGB if self.format != ImageFormat.BGRA else ImageFormat.RGBA, + frame_id=self.frame_id, + ts=self.ts, + ) + + def to_bgr(self) -> "Image": + """Convert image to BGR format.""" + if self.format == ImageFormat.BGR: + return self.copy() + elif self.format == ImageFormat.RGB: + bgr_data = cv2.cvtColor(self.data, cv2.COLOR_RGB2BGR) + elif self.format == ImageFormat.RGBA: + bgr_data = cv2.cvtColor(self.data, cv2.COLOR_RGBA2BGR) + elif self.format == ImageFormat.BGRA: + bgr_data = cv2.cvtColor(self.data, cv2.COLOR_BGRA2BGR) + elif self.format == ImageFormat.GRAY: + bgr_data = cv2.cvtColor(self.data, cv2.COLOR_GRAY2BGR) + elif self.format == ImageFormat.GRAY16: + # Convert 16-bit grayscale to 8-bit then to BGR + gray8 = (self.data / 256).astype(np.uint8) + bgr_data = cv2.cvtColor(gray8, cv2.COLOR_GRAY2BGR) + else: + raise ValueError(f"Unsupported format conversion from {self.format} to BGR") + + return self.__class__( + data=bgr_data, + format=ImageFormat.BGR, + frame_id=self.frame_id, + ts=self.ts, + ) + + def to_grayscale(self) -> "Image": + """Convert image to grayscale.""" + if self.format == ImageFormat.GRAY: + return self.copy() + elif self.format == ImageFormat.GRAY16: + return self.copy() + elif self.format == ImageFormat.BGR: + gray_data = cv2.cvtColor(self.data, cv2.COLOR_BGR2GRAY) + elif self.format == ImageFormat.RGB: + gray_data = cv2.cvtColor(self.data, cv2.COLOR_RGB2GRAY) + elif self.format == ImageFormat.RGBA: + gray_data = cv2.cvtColor(self.data, cv2.COLOR_RGBA2GRAY) + elif self.format == ImageFormat.BGRA: + gray_data = cv2.cvtColor(self.data, cv2.COLOR_BGRA2GRAY) + else: + raise ValueError(f"Unsupported format conversion from {self.format} to grayscale") + + return self.__class__( + data=gray_data, + format=ImageFormat.GRAY, + frame_id=self.frame_id, + ts=self.ts, + ) + + def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> "Image": + """Resize the image to the specified dimensions.""" + resized_data = cv2.resize(self.data, (width, height), interpolation=interpolation) + + return self.__class__( + data=resized_data, + format=self.format, + frame_id=self.frame_id, + ts=self.ts, + ) + + def crop(self, x: int, y: int, width: int, height: int) -> "Image": + """Crop the image to the specified region.""" + # Ensure crop region is within image bounds + x = max(0, min(x, self.width)) + y = max(0, min(y, self.height)) + x2 = min(x + width, self.width) + y2 = min(y + height, self.height) + + cropped_data = self.data[y:y2, x:x2] + + return self.__class__( + data=cropped_data, + format=self.format, + frame_id=self.frame_id, + ts=self.ts, + ) + + def save(self, filepath: str) -> bool: + """Save image to file.""" + # Convert to OpenCV format for saving + cv_image = self.to_opencv() + return cv2.imwrite(filepath, cv_image) + + def lcm_encode(self, frame_id: Optional[str] = None) -> LCMImage: + """Convert to LCM Image message.""" + msg = LCMImage() + + # Header + msg.header = Header() + msg.header.seq = 0 # Initialize sequence number + msg.header.frame_id = frame_id or self.frame_id + + # Set timestamp properly as Time object + if self.ts is not None: + msg.header.stamp.sec = int(self.ts) + msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) + else: + current_time = time.time() + msg.header.stamp.sec = int(current_time) + msg.header.stamp.nsec = int((current_time - int(current_time)) * 1e9) + + # Image properties + msg.height = self.height + msg.width = self.width + msg.encoding = self.format.value + msg.is_bigendian = False # Use little endian + msg.step = self._get_row_step() + + # Image data + image_bytes = self.data.tobytes() + msg.data_length = len(image_bytes) + msg.data = image_bytes + + return msg.encode() + + @classmethod + def lcm_decode(cls, data: bytes, **kwargs) -> "Image": + """Create Image from LCM Image message.""" + # Parse encoding to determine format and data type + msg = LCMImage.decode(data) + format_info = cls._parse_encoding(msg.encoding) + + # Convert bytes back to numpy array + data = np.frombuffer(msg.data, dtype=format_info["dtype"]) + + # Reshape to image dimensions + if format_info["channels"] == 1: + data = data.reshape((msg.height, msg.width)) + else: + data = data.reshape((msg.height, msg.width, format_info["channels"])) + + return cls( + data=data, + format=format_info["format"], + frame_id=msg.header.frame_id if hasattr(msg, "header") else "", + ts=msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 + if hasattr(msg, "header") and msg.header.stamp.sec > 0 + else time.time(), + **kwargs, + ) + + def _get_row_step(self) -> int: + """Calculate row step (bytes per row).""" + bytes_per_pixel = self._get_bytes_per_pixel() + return self.width * bytes_per_pixel + + def _get_bytes_per_pixel(self) -> int: + """Calculate bytes per pixel based on format and data type.""" + bytes_per_element = self.data.dtype.itemsize + return self.channels * bytes_per_element + + @staticmethod + def _parse_encoding(encoding: str) -> dict: + """Parse LCM image encoding string to determine format and data type.""" + encoding_map = { + "mono8": {"format": ImageFormat.GRAY, "dtype": np.uint8, "channels": 1}, + "mono16": {"format": ImageFormat.GRAY16, "dtype": np.uint16, "channels": 1}, + "rgb8": {"format": ImageFormat.RGB, "dtype": np.uint8, "channels": 3}, + "rgba8": {"format": ImageFormat.RGBA, "dtype": np.uint8, "channels": 4}, + "bgr8": {"format": ImageFormat.BGR, "dtype": np.uint8, "channels": 3}, + "bgra8": {"format": ImageFormat.BGRA, "dtype": np.uint8, "channels": 4}, + } + + if encoding not in encoding_map: + raise ValueError(f"Unsupported encoding: {encoding}") + + return encoding_map[encoding] + + def __repr__(self) -> str: + """String representation.""" + return ( + f"Image(shape={self.shape}, format={self.format.value}, " + f"dtype={self.dtype}, frame_id='{self.frame_id}', ts={self.ts})" + ) + + def __eq__(self, other) -> bool: + """Check equality with another Image.""" + if not isinstance(other, Image): + return False + + return ( + np.array_equal(self.data, other.data) + and self.format == other.format + and self.frame_id == other.frame_id + and abs(self.ts - other.ts) < 1e-6 + ) + + def __len__(self) -> int: + """Return total number of pixels.""" + return self.height * self.width diff --git a/build/lib/dimos/msgs/sensor_msgs/PointCloud2.py b/build/lib/dimos/msgs/sensor_msgs/PointCloud2.py new file mode 100644 index 0000000000..4c4455a473 --- /dev/null +++ b/build/lib/dimos/msgs/sensor_msgs/PointCloud2.py @@ -0,0 +1,213 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +import time +from typing import Optional + +import numpy as np +import open3d as o3d + +# Import LCM types +from dimos_lcm.sensor_msgs.PointCloud2 import ( + PointCloud2 as LCMPointCloud2, +) +from dimos_lcm.sensor_msgs.PointField import PointField +from dimos_lcm.std_msgs.Header import Header + +from dimos.types.timestamped import Timestamped + + +# TODO: encode/decode need to be updated to work with full spectrum of pointcloud2 fields +class PointCloud2(Timestamped): + msg_name = "sensor_msgs.PointCloud2" + + def __init__( + self, + pointcloud: o3d.geometry.PointCloud = None, + frame_id: str = "", + ts: Optional[float] = None, + ): + self.ts = ts if ts is not None else time.time() + self.pointcloud = pointcloud if pointcloud is not None else o3d.geometry.PointCloud() + self.frame_id = frame_id + + # TODO what's the usual storage here? is it already numpy? + def as_numpy(self) -> np.ndarray: + """Get points as numpy array.""" + return np.asarray(self.pointcloud.points) + + def lcm_encode(self, frame_id: Optional[str] = None) -> bytes: + """Convert to LCM PointCloud2 message.""" + msg = LCMPointCloud2() + + # Header + msg.header = Header() + msg.header.seq = 0 # Initialize sequence number + msg.header.frame_id = frame_id or self.frame_id + + msg.header.stamp.sec = int(self.ts) + msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) + + points = self.as_numpy() + if len(points) == 0: + # Empty point cloud + msg.height = 0 + msg.width = 0 + msg.point_step = 16 # 4 floats * 4 bytes (x, y, z, intensity) + msg.row_step = 0 + msg.data_length = 0 + msg.data = b"" + msg.is_dense = True + msg.is_bigendian = False + msg.fields_length = 4 # x, y, z, intensity + msg.fields = self._create_xyz_field() + return msg.encode() + + # Point cloud dimensions + msg.height = 1 # Unorganized point cloud + msg.width = len(points) + + # Define fields (X, Y, Z, intensity as float32) + msg.fields_length = 4 # x, y, z, intensity + msg.fields = self._create_xyz_field() + + # Point step and row step + msg.point_step = 16 # 4 floats * 4 bytes each (x, y, z, intensity) + msg.row_step = msg.point_step * msg.width + + # Convert points to bytes with intensity padding (little endian float32) + # Add intensity column (zeros) to make it 4 columns: x, y, z, intensity + points_with_intensity = np.column_stack( + [ + points, # x, y, z columns + np.zeros(len(points), dtype=np.float32), # intensity column (padding) + ] + ) + data_bytes = points_with_intensity.astype(np.float32).tobytes() + msg.data_length = len(data_bytes) + msg.data = data_bytes + + # Properties + msg.is_dense = True # No invalid points + msg.is_bigendian = False # Little endian + + return msg.encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> "PointCloud2": + msg = LCMPointCloud2.decode(data) + + if msg.width == 0 or msg.height == 0: + # Empty point cloud + pc = o3d.geometry.PointCloud() + return cls( + pointcloud=pc, + frame_id=msg.header.frame_id if hasattr(msg, "header") else "", + ts=msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 + if hasattr(msg, "header") and msg.header.stamp.sec > 0 + else None, + ) + + # Parse field information to find X, Y, Z offsets + x_offset = y_offset = z_offset = None + for msgfield in msg.fields: + if msgfield.name == "x": + x_offset = msgfield.offset + elif msgfield.name == "y": + y_offset = msgfield.offset + elif msgfield.name == "z": + z_offset = msgfield.offset + + if any(offset is None for offset in [x_offset, y_offset, z_offset]): + raise ValueError("PointCloud2 message missing X, Y, or Z msgfields") + + # Extract points from binary data + num_points = msg.width * msg.height + points = np.zeros((num_points, 3), dtype=np.float32) + + data = msg.data + point_step = msg.point_step + + for i in range(num_points): + base_offset = i * point_step + + # Extract X, Y, Z (assuming float32, little endian) + x_bytes = data[base_offset + x_offset : base_offset + x_offset + 4] + y_bytes = data[base_offset + y_offset : base_offset + y_offset + 4] + z_bytes = data[base_offset + z_offset : base_offset + z_offset + 4] + + points[i, 0] = struct.unpack(" 0 + else None, + ) + + def _create_xyz_field(self) -> list: + """Create standard X, Y, Z field definitions for LCM PointCloud2.""" + fields = [] + + # X field + x_field = PointField() + x_field.name = "x" + x_field.offset = 0 + x_field.datatype = 7 # FLOAT32 + x_field.count = 1 + fields.append(x_field) + + # Y field + y_field = PointField() + y_field.name = "y" + y_field.offset = 4 + y_field.datatype = 7 # FLOAT32 + y_field.count = 1 + fields.append(y_field) + + # Z field + z_field = PointField() + z_field.name = "z" + z_field.offset = 8 + z_field.datatype = 7 # FLOAT32 + z_field.count = 1 + fields.append(z_field) + + # I field + i_field = PointField() + i_field.name = "intensity" + i_field.offset = 12 + i_field.datatype = 7 # FLOAT32 + i_field.count = 1 + fields.append(i_field) + + return fields + + def __len__(self) -> int: + """Return number of points.""" + return len(self.pointcloud.points) + + def __repr__(self) -> str: + """String representation.""" + return f"PointCloud(points={len(self)}, frame_id='{self.frame_id}', ts={self.ts})" diff --git a/build/lib/dimos/msgs/sensor_msgs/__init__.py b/build/lib/dimos/msgs/sensor_msgs/__init__.py new file mode 100644 index 0000000000..170587e286 --- /dev/null +++ b/build/lib/dimos/msgs/sensor_msgs/__init__.py @@ -0,0 +1,2 @@ +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 diff --git a/build/lib/dimos/msgs/sensor_msgs/test_PointCloud2.py b/build/lib/dimos/msgs/sensor_msgs/test_PointCloud2.py new file mode 100644 index 0000000000..eee1778680 --- /dev/null +++ b/build/lib/dimos/msgs/sensor_msgs/test_PointCloud2.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils.testing import SensorReplay + + +def test_lcm_encode_decode(): + """Test LCM encode/decode preserves pointcloud data.""" + replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + lidar_msg: LidarMessage = replay.load_one("lidar_data_021") + + binary_msg = lidar_msg.lcm_encode() + decoded = PointCloud2.lcm_decode(binary_msg) + + # 1. Check number of points + original_points = lidar_msg.as_numpy() + decoded_points = decoded.as_numpy() + + print(f"Original points: {len(original_points)}") + print(f"Decoded points: {len(decoded_points)}") + assert len(original_points) == len(decoded_points), ( + f"Point count mismatch: {len(original_points)} vs {len(decoded_points)}" + ) + + # 2. Check point coordinates are preserved (within floating point tolerance) + if len(original_points) > 0: + np.testing.assert_allclose( + original_points, + decoded_points, + rtol=1e-6, + atol=1e-6, + err_msg="Point coordinates don't match between original and decoded", + ) + print(f"✓ All {len(original_points)} point coordinates match within tolerance") + + # 3. Check frame_id is preserved + assert lidar_msg.frame_id == decoded.frame_id, ( + f"Frame ID mismatch: '{lidar_msg.frame_id}' vs '{decoded.frame_id}'" + ) + print(f"✓ Frame ID preserved: '{decoded.frame_id}'") + + # 4. Check timestamp is preserved (within reasonable tolerance for float precision) + if lidar_msg.ts is not None and decoded.ts is not None: + assert abs(lidar_msg.ts - decoded.ts) < 1e-6, ( + f"Timestamp mismatch: {lidar_msg.ts} vs {decoded.ts}" + ) + print(f"✓ Timestamp preserved: {decoded.ts}") + + # 5. Check pointcloud properties + assert len(lidar_msg.pointcloud.points) == len(decoded.pointcloud.points), ( + "Open3D pointcloud size mismatch" + ) + + # 6. Additional detailed checks + print("✓ Original pointcloud summary:") + print(f" - Points: {len(original_points)}") + print(f" - Bounds: {original_points.min(axis=0)} to {original_points.max(axis=0)}") + print(f" - Mean: {original_points.mean(axis=0)}") + + print("✓ Decoded pointcloud summary:") + print(f" - Points: {len(decoded_points)}") + print(f" - Bounds: {decoded_points.min(axis=0)} to {decoded_points.max(axis=0)}") + print(f" - Mean: {decoded_points.mean(axis=0)}") + + print("✓ LCM encode/decode test passed - all properties preserved!") diff --git a/build/lib/dimos/msgs/sensor_msgs/test_image.py b/build/lib/dimos/msgs/sensor_msgs/test_image.py new file mode 100644 index 0000000000..8e4e0a413f --- /dev/null +++ b/build/lib/dimos/msgs/sensor_msgs/test_image.py @@ -0,0 +1,63 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.utils.data import get_data + + +@pytest.fixture +def img(): + image_file_path = get_data("cafe.jpg") + return Image.from_file(str(image_file_path)) + + +def test_file_load(img: Image): + assert isinstance(img.data, np.ndarray) + assert img.width == 1024 + assert img.height == 771 + assert img.channels == 3 + assert img.shape == (771, 1024, 3) + assert img.data.dtype == np.uint8 + assert img.format == ImageFormat.BGR + assert img.frame_id == "" + assert isinstance(img.ts, float) + assert img.ts > 0 + assert img.data.flags["C_CONTIGUOUS"] + + +def test_lcm_encode_decode(img: Image): + binary_msg = img.lcm_encode() + decoded_img = Image.lcm_decode(binary_msg) + + assert isinstance(decoded_img, Image) + assert decoded_img is not img + assert decoded_img == img + + +def test_rgb_bgr_conversion(img: Image): + rgb = img.to_rgb() + assert not rgb == img + assert rgb.to_bgr() == img + + +def test_opencv_conversion(img: Image): + ocv = img.to_opencv() + decoded_img = Image.from_opencv(ocv) + + # artificially patch timestamp + decoded_img.ts = img.ts + assert decoded_img == img diff --git a/build/lib/dimos/perception/__init__.py b/build/lib/dimos/perception/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/perception/common/__init__.py b/build/lib/dimos/perception/common/__init__.py new file mode 100644 index 0000000000..ad815d3f46 --- /dev/null +++ b/build/lib/dimos/perception/common/__init__.py @@ -0,0 +1,3 @@ +from .detection2d_tracker import target2dTracker, get_tracked_results +from .cuboid_fit import * +from .ibvs import * diff --git a/build/lib/dimos/perception/common/cuboid_fit.py b/build/lib/dimos/perception/common/cuboid_fit.py new file mode 100644 index 0000000000..9848332c06 --- /dev/null +++ b/build/lib/dimos/perception/common/cuboid_fit.py @@ -0,0 +1,331 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import cv2 + + +def depth_to_point_cloud(depth_image, camera_matrix, subsample_factor=4): + """ + Convert depth image to point cloud using camera intrinsics. + Subsamples points to reduce density. + + Args: + depth_image: HxW depth image in meters + camera_matrix: 3x3 camera intrinsic matrix + subsample_factor: Factor to subsample points (higher = fewer points) + + Returns: + Nx3 array of 3D points + """ + # Get focal length and principal point from camera matrix + fx = camera_matrix[0, 0] + fy = camera_matrix[1, 1] + cx = camera_matrix[0, 2] + cy = camera_matrix[1, 2] + + # Create pixel coordinate grid + rows, cols = depth_image.shape + x_grid, y_grid = np.meshgrid( + np.arange(0, cols, subsample_factor), np.arange(0, rows, subsample_factor) + ) + + # Flatten grid and depth + x = x_grid.flatten() + y = y_grid.flatten() + z = depth_image[y_grid, x_grid].flatten() + + # Remove points with invalid depth + valid = z > 0 + x = x[valid] + y = y[valid] + z = z[valid] + + # Convert to 3D points + X = (x - cx) * z / fx + Y = (y - cy) * z / fy + Z = z + + return np.column_stack([X, Y, Z]) + + +def fit_cuboid(points, n_iterations=5, inlier_thresh=2.0): + """ + Fit a cuboid to a point cloud using iteratively refined PCA. + + Args: + points: Nx3 array of points + n_iterations: Number of refinement iterations + inlier_thresh: Threshold for inlier detection in standard deviations + + Returns: + dict containing: + - center: 3D center point + - dimensions: 3D dimensions + - rotation: 3x3 rotation matrix + - error: fitting error + """ + points = np.asarray(points) + if len(points) < 4: + return None + + # Initial center estimate using median for robustness + best_error = float("inf") + best_params = None + center = np.median(points, axis=0) + current_points = points - center + + for iteration in range(n_iterations): + if len(current_points) < 4: # Need at least 4 points for PCA + break + + # Perform PCA + pca = PCA(n_components=3) + pca.fit(current_points) + + # Get rotation matrix from PCA + rotation = pca.components_ + + # Transform points to PCA space + local_points = current_points @ rotation.T + + # Initialize mask for this iteration + inlier_mask = np.ones(len(current_points), dtype=bool) + dimensions = np.zeros(3) + + # Filter points along each dimension + for dim in range(3): + points_1d = local_points[inlier_mask, dim] + if len(points_1d) < 4: + break + + median = np.median(points_1d) + mad = np.median(np.abs(points_1d - median)) + sigma = mad * 1.4826 # Convert MAD to standard deviation estimate + + # Avoid issues with constant values + if sigma < 1e-6: + continue + + # Update mask for this dimension + dim_inliers = np.abs(points_1d - median) < (inlier_thresh * sigma) + inlier_mask[inlier_mask] = dim_inliers + + # Calculate dimension based on robust statistics + valid_points = points_1d[dim_inliers] + if len(valid_points) > 0: + dimensions[dim] = np.max(valid_points) - np.min(valid_points) + + # Skip if we don't have enough inliers + if np.sum(inlier_mask) < 4: + continue + + # Calculate error for this iteration + # Mean squared distance from points to cuboid surface + half_dims = dimensions / 2 + dx = np.abs(local_points[:, 0]) - half_dims[0] + dy = np.abs(local_points[:, 1]) - half_dims[1] + dz = np.abs(local_points[:, 2]) - half_dims[2] + + outside_dist = np.sqrt( + np.maximum(dx, 0) ** 2 + np.maximum(dy, 0) ** 2 + np.maximum(dz, 0) ** 2 + ) + inside_dist = np.minimum(np.maximum(np.maximum(dx, dy), dz), 0) + distances = outside_dist + inside_dist + error = np.mean(distances**2) + + if error < best_error: + best_error = error + best_params = { + "center": center, + "rotation": rotation, + "dimensions": dimensions, + "error": error, + } + + # Update points for next iteration + current_points = current_points[inlier_mask] + + return best_params + + +def compute_fitting_error(local_points, dimensions): + """Compute mean squared distance from points to cuboid surface.""" + half_dims = dimensions / 2 + dx = np.abs(local_points[:, 0]) - half_dims[0] + dy = np.abs(local_points[:, 1]) - half_dims[1] + dz = np.abs(local_points[:, 2]) - half_dims[2] + + outside_dist = np.sqrt(np.maximum(dx, 0) ** 2 + np.maximum(dy, 0) ** 2 + np.maximum(dz, 0) ** 2) + inside_dist = np.minimum(np.maximum(np.maximum(dx, dy), dz), 0) + + distances = outside_dist + inside_dist + return np.mean(distances**2) + + +def get_cuboid_corners(center, dimensions, rotation): + """Get the 8 corners of a cuboid.""" + half_dims = dimensions / 2 + corners_local = ( + np.array( + [ + [-1, -1, -1], # 0: left bottom back + [-1, -1, 1], # 1: left bottom front + [-1, 1, -1], # 2: left top back + [-1, 1, 1], # 3: left top front + [1, -1, -1], # 4: right bottom back + [1, -1, 1], # 5: right bottom front + [1, 1, -1], # 6: right top back + [1, 1, 1], # 7: right top front + ] + ) + * half_dims + ) + + return corners_local @ rotation + center + + +def visualize_fit(image, cuboid_params, camera_matrix, R=None, t=None): + """ + Draw the fitted cuboid on the image. + """ + # Get corners in world coordinates + corners = get_cuboid_corners( + cuboid_params["center"], cuboid_params["dimensions"], cuboid_params["rotation"] + ) + + # Transform corners if R and t are provided + if R is not None and t is not None: + corners = (R @ corners.T).T + t + + # Project corners to image space + corners_img = ( + cv2.projectPoints( + corners, + np.zeros(3), + np.zeros(3), # Already in camera frame + camera_matrix, + None, + )[0] + .reshape(-1, 2) + .astype(int) + ) + + # Define edges for visualization + edges = [ + # Bottom face + (0, 1), + (1, 5), + (5, 4), + (4, 0), + # Top face + (2, 3), + (3, 7), + (7, 6), + (6, 2), + # Vertical edges + (0, 2), + (1, 3), + (5, 7), + (4, 6), + ] + + # Draw edges + vis_img = image.copy() + for i, j in edges: + cv2.line(vis_img, tuple(corners_img[i]), tuple(corners_img[j]), (0, 255, 0), 2) + + # Add text with dimensions + dims = cuboid_params["dimensions"] + dim_text = f"Dims: {dims[0]:.3f} x {dims[1]:.3f} x {dims[2]:.3f}" + cv2.putText(vis_img, dim_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + + return vis_img + + +def plot_3d_fit(points, cuboid_params, title="3D Cuboid Fit"): + """Plot points and fitted cuboid in 3D.""" + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(111, projection="3d") + + # Plot points + ax.scatter( + points[:, 0], points[:, 1], points[:, 2], c="b", marker=".", alpha=0.1, label="Points" + ) + + # Plot fitted cuboid + corners = get_cuboid_corners( + cuboid_params["center"], cuboid_params["dimensions"], cuboid_params["rotation"] + ) + + # Define edges + edges = [ + # Bottom face + (0, 1), + (1, 5), + (5, 4), + (4, 0), + # Top face + (2, 3), + (3, 7), + (7, 6), + (6, 2), + # Vertical edges + (0, 2), + (1, 3), + (5, 7), + (4, 6), + ] + + # Plot edges + for i, j in edges: + ax.plot3D( + [corners[i, 0], corners[j, 0]], + [corners[i, 1], corners[j, 1]], + [corners[i, 2], corners[j, 2]], + "r-", + ) + + # Set labels and title + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + ax.set_title(title) + + # Make scaling uniform + all_points = np.vstack([points, corners]) + max_range = ( + np.array( + [ + all_points[:, 0].max() - all_points[:, 0].min(), + all_points[:, 1].max() - all_points[:, 1].min(), + all_points[:, 2].max() - all_points[:, 2].min(), + ] + ).max() + / 2.0 + ) + + mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5 + mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5 + mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5 + + ax.set_xlim(mid_x - max_range, mid_x + max_range) + ax.set_ylim(mid_y - max_range, mid_y + max_range) + ax.set_zlim(mid_z - max_range, mid_z + max_range) + + ax.set_box_aspect([1, 1, 1]) + plt.legend() + return fig, ax diff --git a/build/lib/dimos/perception/common/detection2d_tracker.py b/build/lib/dimos/perception/common/detection2d_tracker.py new file mode 100644 index 0000000000..2e4582cc00 --- /dev/null +++ b/build/lib/dimos/perception/common/detection2d_tracker.py @@ -0,0 +1,385 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from collections import deque + + +def compute_iou(bbox1, bbox2): + """ + Compute Intersection over Union (IoU) of two bounding boxes. + Each bbox is [x1, y1, x2, y2]. + """ + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) + + inter_area = max(0, x2 - x1) * max(0, y2 - y1) + area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + + union_area = area1 + area2 - inter_area + if union_area == 0: + return 0 + return inter_area / union_area + + +def get_tracked_results(tracked_targets): + """ + Extract tracked results from a list of target2d objects. + + Args: + tracked_targets (list[target2d]): List of target2d objects (published targets) + returned by the tracker's update() function. + + Returns: + tuple: (tracked_masks, tracked_bboxes, tracked_track_ids, tracked_probs, tracked_names) + where each is a list of the corresponding attribute from each target. + """ + tracked_masks = [] + tracked_bboxes = [] + tracked_track_ids = [] + tracked_probs = [] + tracked_names = [] + + for target in tracked_targets: + # Extract the latest values stored in each target. + tracked_masks.append(target.latest_mask) + tracked_bboxes.append(target.latest_bbox) + # Here we use the most recent detection's track ID. + tracked_track_ids.append(target.target_id) + # Use the latest probability from the history. + tracked_probs.append(target.score) + # Use the stored name (if any). If not available, you can use a default value. + tracked_names.append(target.name) + + return tracked_masks, tracked_bboxes, tracked_track_ids, tracked_probs, tracked_names + + +class target2d: + """ + Represents a tracked 2D target. + Stores the latest bounding box and mask along with a short history of track IDs, + detection probabilities, and computed texture values. + """ + + def __init__( + self, + initial_mask, + initial_bbox, + track_id, + prob, + name, + texture_value, + target_id, + history_size=10, + ): + """ + Args: + initial_mask (torch.Tensor): Latest segmentation mask. + initial_bbox (list): Bounding box in [x1, y1, x2, y2] format. + track_id (int): Detection’s track ID (may be -1 if not provided). + prob (float): Detection probability. + name (str): Object class name. + texture_value (float): Computed average texture value for this detection. + target_id (int): Unique identifier assigned by the tracker. + history_size (int): Maximum number of frames to keep in the history. + """ + self.target_id = target_id + self.latest_mask = initial_mask + self.latest_bbox = initial_bbox + self.name = name + self.score = 1.0 + + self.track_id = track_id + self.probs_history = deque(maxlen=history_size) + self.texture_history = deque(maxlen=history_size) + + self.frame_count = deque(maxlen=history_size) # Total frames this target has been seen. + self.missed_frames = 0 # Consecutive frames when no detection was assigned. + self.history_size = history_size + + def update(self, mask, bbox, track_id, prob, name, texture_value): + """ + Update the target with a new detection. + """ + self.latest_mask = mask + self.latest_bbox = bbox + self.name = name + + self.track_id = track_id + self.probs_history.append(prob) + self.texture_history.append(texture_value) + + self.frame_count.append(1) + self.missed_frames = 0 + + def mark_missed(self): + """ + Increment the count of consecutive frames where this target was not updated. + """ + self.missed_frames += 1 + self.frame_count.append(0) + + def compute_score( + self, + frame_shape, + min_area_ratio, + max_area_ratio, + texture_range=(0.0, 1.0), + border_safe_distance=50, + weights=None, + ): + """ + Compute a combined score for the target based on several factors. + + Factors: + - **Detection probability:** Average over recent frames. + - **Temporal stability:** How consistently the target has appeared. + - **Texture quality:** Normalized using the provided min and max values. + - **Border proximity:** Computed from the minimum distance from the bbox to the frame edges. + - **Size:** How the object's area (relative to the frame) compares to acceptable bounds. + + Args: + frame_shape (tuple): (height, width) of the frame. + min_area_ratio (float): Minimum acceptable ratio (bbox area / frame area). + max_area_ratio (float): Maximum acceptable ratio. + texture_range (tuple): (min_texture, max_texture) expected values. + border_safe_distance (float): Distance (in pixels) considered safe from the border. + weights (dict): Weights for each component. Expected keys: + 'prob', 'temporal', 'texture', 'border', and 'size'. + + Returns: + float: The combined (normalized) score in the range [0, 1]. + """ + # Default weights if none provided. + if weights is None: + weights = {"prob": 1.0, "temporal": 1.0, "texture": 1.0, "border": 1.0, "size": 1.0} + + h, w = frame_shape + x1, y1, x2, y2 = self.latest_bbox + bbox_area = (x2 - x1) * (y2 - y1) + frame_area = w * h + area_ratio = bbox_area / frame_area + + # Detection probability factor. + avg_prob = np.mean(self.probs_history) + # Temporal stability factor: normalized by history size. + temporal_stability = np.mean(self.frame_count) + # Texture factor: normalize average texture using the provided range. + avg_texture = np.mean(self.texture_history) if self.texture_history else 0.0 + min_texture, max_texture = texture_range + if max_texture == min_texture: + normalized_texture = avg_texture + else: + normalized_texture = (avg_texture - min_texture) / (max_texture - min_texture) + normalized_texture = max(0.0, min(normalized_texture, 1.0)) + + # Border factor: compute the minimum distance from the bbox to any frame edge. + left_dist = x1 + top_dist = y1 + right_dist = w - x2 + min_border_dist = min(left_dist, top_dist, right_dist) + # Normalize the border distance: full score (1.0) if at least border_safe_distance away. + border_factor = min(1.0, min_border_dist / border_safe_distance) + + # Size factor: penalize objects that are too small or too big. + if area_ratio < min_area_ratio: + size_factor = area_ratio / min_area_ratio + elif area_ratio > max_area_ratio: + # Here we compute a linear penalty if the area exceeds max_area_ratio. + if 1 - max_area_ratio > 0: + size_factor = max(0, (1 - area_ratio) / (1 - max_area_ratio)) + else: + size_factor = 0.0 + else: + size_factor = 1.0 + + # Combine factors using a weighted sum (each factor is assumed in [0, 1]). + w_prob = weights.get("prob", 1.0) + w_temporal = weights.get("temporal", 1.0) + w_texture = weights.get("texture", 1.0) + w_border = weights.get("border", 1.0) + w_size = weights.get("size", 1.0) + total_weight = w_prob + w_temporal + w_texture + w_border + w_size + + # print(f"track_id: {self.target_id}, avg_prob: {avg_prob:.2f}, temporal_stability: {temporal_stability:.2f}, normalized_texture: {normalized_texture:.2f}, border_factor: {border_factor:.2f}, size_factor: {size_factor:.2f}") + + final_score = ( + w_prob * avg_prob + + w_temporal * temporal_stability + + w_texture * normalized_texture + + w_border * border_factor + + w_size * size_factor + ) / total_weight + + self.score = final_score + + return final_score + + +class target2dTracker: + """ + Tracker that maintains a history of targets across frames. + New segmentation detections (frame, masks, bboxes, track_ids, probabilities, + and computed texture values) are matched to existing targets or used to create new ones. + + The tracker uses a scoring system that incorporates: + - **Detection probability** + - **Temporal stability** + - **Texture quality** (normalized within a specified range) + - **Proximity to image borders** (a continuous penalty based on the distance) + - **Object size** relative to the frame + + Targets are published if their score exceeds the start threshold and are removed if their score + falls below the stop threshold or if they are missed for too many consecutive frames. + """ + + def __init__( + self, + history_size=10, + score_threshold_start=0.5, + score_threshold_stop=0.3, + min_frame_count=10, + max_missed_frames=3, + min_area_ratio=0.001, + max_area_ratio=0.1, + texture_range=(0.0, 1.0), + border_safe_distance=50, + weights=None, + ): + """ + Args: + history_size (int): Maximum history length (number of frames) per target. + score_threshold_start (float): Minimum score for a target to be published. + score_threshold_stop (float): If a target’s score falls below this, it is removed. + min_frame_count (int): Minimum number of frames a target must be seen to be published. + max_missed_frames (int): Maximum consecutive frames a target can be missing before deletion. + min_area_ratio (float): Minimum acceptable bbox area relative to the frame. + max_area_ratio (float): Maximum acceptable bbox area relative to the frame. + texture_range (tuple): (min_texture, max_texture) expected values. + border_safe_distance (float): Distance (in pixels) considered safe from the border. + weights (dict): Weights for the scoring components (keys: 'prob', 'temporal', + 'texture', 'border', 'size'). + """ + self.history_size = history_size + self.score_threshold_start = score_threshold_start + self.score_threshold_stop = score_threshold_stop + self.min_frame_count = min_frame_count + self.max_missed_frames = max_missed_frames + self.min_area_ratio = min_area_ratio + self.max_area_ratio = max_area_ratio + self.texture_range = texture_range + self.border_safe_distance = border_safe_distance + # Default weights if none are provided. + if weights is None: + weights = {"prob": 1.0, "temporal": 1.0, "texture": 1.0, "border": 1.0, "size": 1.0} + self.weights = weights + + self.targets = {} # Dictionary mapping target_id -> target2d instance. + self.next_target_id = 0 + + def update(self, frame, masks, bboxes, track_ids, probs, names, texture_values): + """ + Update the tracker with new detections from the current frame. + + Args: + frame (np.ndarray): Current BGR frame. + masks (list[torch.Tensor]): List of segmentation masks. + bboxes (list): List of bounding boxes [x1, y1, x2, y2]. + track_ids (list): List of detection track IDs. + probs (list): List of detection probabilities. + names (list): List of class names. + texture_values (list): List of computed texture values. + + Returns: + published_targets (list[target2d]): Targets that are active and have scores above + the start threshold. + """ + updated_target_ids = set() + frame_shape = frame.shape[:2] # (height, width) + + # For each detection, try to match with an existing target. + for mask, bbox, det_tid, prob, name, texture in zip( + masks, bboxes, track_ids, probs, names, texture_values + ): + matched_target = None + + # First, try matching by detection track ID if valid. + if det_tid != -1: + for target in self.targets.values(): + if target.track_id == det_tid: + matched_target = target + break + + # Otherwise, try matching using IoU. + if matched_target is None: + best_iou = 0 + for target in self.targets.values(): + iou = compute_iou(bbox, target.latest_bbox) + if iou > 0.5 and iou > best_iou: + best_iou = iou + matched_target = target + + # Update existing target or create a new one. + if matched_target is not None: + matched_target.update(mask, bbox, det_tid, prob, name, texture) + updated_target_ids.add(matched_target.target_id) + else: + new_target = target2d( + mask, bbox, det_tid, prob, name, texture, self.next_target_id, self.history_size + ) + self.targets[self.next_target_id] = new_target + updated_target_ids.add(self.next_target_id) + self.next_target_id += 1 + + # Mark targets that were not updated. + for target_id, target in list(self.targets.items()): + if target_id not in updated_target_ids: + target.mark_missed() + if target.missed_frames > self.max_missed_frames: + del self.targets[target_id] + continue # Skip further checks for this target. + # Remove targets whose score falls below the stop threshold. + score = target.compute_score( + frame_shape, + self.min_area_ratio, + self.max_area_ratio, + texture_range=self.texture_range, + border_safe_distance=self.border_safe_distance, + weights=self.weights, + ) + if score < self.score_threshold_stop: + del self.targets[target_id] + + # Publish targets with scores above the start threshold. + published_targets = [] + for target in self.targets.values(): + score = target.compute_score( + frame_shape, + self.min_area_ratio, + self.max_area_ratio, + texture_range=self.texture_range, + border_safe_distance=self.border_safe_distance, + weights=self.weights, + ) + if ( + score >= self.score_threshold_start + and sum(target.frame_count) >= self.min_frame_count + and target.missed_frames <= 5 + ): + published_targets.append(target) + + return published_targets diff --git a/build/lib/dimos/perception/common/export_tensorrt.py b/build/lib/dimos/perception/common/export_tensorrt.py new file mode 100644 index 0000000000..9c021eb0a0 --- /dev/null +++ b/build/lib/dimos/perception/common/export_tensorrt.py @@ -0,0 +1,57 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from ultralytics import YOLO, FastSAM + + +def parse_args(): + parser = argparse.ArgumentParser(description="Export YOLO/FastSAM models to different formats") + parser.add_argument("--model_path", type=str, required=True, help="Path to the model weights") + parser.add_argument( + "--model_type", + type=str, + choices=["yolo", "fastsam"], + required=True, + help="Type of model to export", + ) + parser.add_argument( + "--precision", + type=str, + choices=["fp32", "fp16", "int8"], + default="fp32", + help="Precision for export", + ) + parser.add_argument( + "--format", type=str, choices=["onnx", "engine"], default="onnx", help="Export format" + ) + return parser.parse_args() + + +def main(): + args = parse_args() + half = args.precision == "fp16" + int8 = args.precision == "int8" + # Load the appropriate model + if args.model_type == "yolo": + model = YOLO(args.model_path) + else: + model = FastSAM(args.model_path) + + # Export the model + model.export(format=args.format, half=half, int8=int8) + + +if __name__ == "__main__": + main() diff --git a/build/lib/dimos/perception/common/ibvs.py b/build/lib/dimos/perception/common/ibvs.py new file mode 100644 index 0000000000..d580c71b23 --- /dev/null +++ b/build/lib/dimos/perception/common/ibvs.py @@ -0,0 +1,280 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +class PersonDistanceEstimator: + def __init__(self, K, camera_pitch, camera_height): + """ + Initialize the distance estimator using ground plane constraint. + + Args: + K: 3x3 Camera intrinsic matrix in OpenCV format + (Assumed to be already for an undistorted image) + camera_pitch: Upward pitch of the camera (in radians), in the robot frame + Positive means looking up, negative means looking down + camera_height: Height of the camera above the ground (in meters) + """ + self.K = K + self.camera_height = camera_height + + # Precompute the inverse intrinsic matrix + self.K_inv = np.linalg.inv(K) + + # Transform from camera to robot frame (z-forward to x-forward) + self.T = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) + + # Pitch rotation matrix (positive is upward) + theta = -camera_pitch # Negative since positive pitch is negative rotation about robot Y + self.R_pitch = np.array( + [[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]] + ) + + # Combined transform from camera to robot frame + self.A = self.R_pitch @ self.T + + # Store focal length and principal point for angle calculation + self.fx = K[0, 0] + self.cx = K[0, 2] + + def estimate_distance_angle(self, bbox: tuple, robot_pitch: float = None): + """ + Estimate distance and angle to person using ground plane constraint. + + Args: + bbox: tuple (x_min, y_min, x_max, y_max) + where y_max represents the feet position + robot_pitch: Current pitch of the robot body (in radians) + If provided, this will be combined with the camera's fixed pitch + + Returns: + depth: distance to person along camera's z-axis (meters) + angle: horizontal angle in camera frame (radians, positive right) + """ + x_min, _, x_max, y_max = bbox + + # Get center point of feet + u_c = (x_min + x_max) / 2.0 + v_feet = y_max + + # Create homogeneous feet point and get ray direction + p_feet = np.array([u_c, v_feet, 1.0]) + d_feet_cam = self.K_inv @ p_feet + + # If robot_pitch is provided, recalculate the transformation matrix + if robot_pitch is not None: + # Combined pitch (fixed camera pitch + current robot pitch) + total_pitch = -camera_pitch - robot_pitch # Both negated for correct rotation direction + R_total_pitch = np.array( + [ + [np.cos(total_pitch), 0, np.sin(total_pitch)], + [0, 1, 0], + [-np.sin(total_pitch), 0, np.cos(total_pitch)], + ] + ) + # Use the updated transformation matrix + A = R_total_pitch @ self.T + else: + # Use the precomputed transformation matrix + A = self.A + + # Convert ray to robot frame using appropriate transformation + d_feet_robot = A @ d_feet_cam + + # Ground plane intersection (z=0) + # camera_height + t * d_feet_robot[2] = 0 + if abs(d_feet_robot[2]) < 1e-6: + raise ValueError("Feet ray is parallel to ground plane") + + # Solve for scaling factor t + t = -self.camera_height / d_feet_robot[2] + + # Get 3D feet position in robot frame + p_feet_robot = t * d_feet_robot + + # Convert back to camera frame + p_feet_cam = self.A.T @ p_feet_robot + + # Extract depth (z-coordinate in camera frame) + depth = p_feet_cam[2] + + # Calculate horizontal angle from image center + angle = np.arctan((u_c - self.cx) / self.fx) + + return depth, angle + + +class ObjectDistanceEstimator: + """ + Estimate distance to an object using the ground plane constraint. + This class assumes the camera is mounted on a robot and uses the + camera's intrinsic parameters to estimate the distance to a detected object. + """ + + def __init__(self, K, camera_pitch, camera_height): + """ + Initialize the distance estimator using ground plane constraint. + + Args: + K: 3x3 Camera intrinsic matrix in OpenCV format + (Assumed to be already for an undistorted image) + camera_pitch: Upward pitch of the camera (in radians) + Positive means looking up, negative means looking down + camera_height: Height of the camera above the ground (in meters) + """ + self.K = K + self.camera_height = camera_height + + # Precompute the inverse intrinsic matrix + self.K_inv = np.linalg.inv(K) + + # Transform from camera to robot frame (z-forward to x-forward) + self.T = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) + + # Pitch rotation matrix (positive is upward) + theta = -camera_pitch # Negative since positive pitch is negative rotation about robot Y + self.R_pitch = np.array( + [[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]] + ) + + # Combined transform from camera to robot frame + self.A = self.R_pitch @ self.T + + # Store focal length and principal point for angle calculation + self.fx = K[0, 0] + self.fy = K[1, 1] + self.cx = K[0, 2] + self.estimated_object_size = None + + def estimate_object_size(self, bbox: tuple, distance: float): + """ + Estimate the physical size of an object based on its bbox and known distance. + + Args: + bbox: tuple (x_min, y_min, x_max, y_max) bounding box in the image + distance: Known distance to the object (in meters) + robot_pitch: Current pitch of the robot body (in radians), if any + + Returns: + estimated_size: Estimated physical height of the object (in meters) + """ + x_min, y_min, x_max, y_max = bbox + + # Calculate object height in pixels + object_height_px = y_max - y_min + + # Calculate the physical height using the known distance and focal length + estimated_size = object_height_px * distance / self.fy + self.estimated_object_size = estimated_size + + return estimated_size + + def set_estimated_object_size(self, size: float): + """ + Set the estimated object size for future distance calculations. + + Args: + size: Estimated physical size of the object (in meters) + """ + self.estimated_object_size = size + + def estimate_distance_angle(self, bbox: tuple): + """ + Estimate distance and angle to object using size-based estimation. + + Args: + bbox: tuple (x_min, y_min, x_max, y_max) + where y_max represents the bottom of the object + robot_pitch: Current pitch of the robot body (in radians) + If provided, this will be combined with the camera's fixed pitch + initial_distance: Initial distance estimate for the object (in meters) + Used to calibrate object size if not previously known + + Returns: + depth: distance to object along camera's z-axis (meters) + angle: horizontal angle in camera frame (radians, positive right) + or None, None if estimation not possible + """ + # If we don't have estimated object size and no initial distance is provided, + # we can't estimate the distance + if self.estimated_object_size is None: + return None, None + + x_min, y_min, x_max, y_max = bbox + + # Calculate center of the object for angle calculation + u_c = (x_min + x_max) / 2.0 + + # If we have an initial distance estimate and no object size yet, + # calculate and store the object size using the initial distance + object_height_px = y_max - y_min + depth = self.estimated_object_size * self.fy / object_height_px + + # Calculate horizontal angle from image center + angle = np.arctan((u_c - self.cx) / self.fx) + + return depth, angle + + +# Example usage: +if __name__ == "__main__": + # Example camera calibration + K = np.array([[600, 0, 320], [0, 600, 240], [0, 0, 1]], dtype=np.float32) + + # Camera mounted 1.2m high, pitched down 10 degrees + camera_pitch = np.deg2rad(0) # negative for downward pitch + camera_height = 1.0 # meters + + estimator = PersonDistanceEstimator(K, camera_pitch, camera_height) + object_estimator = ObjectDistanceEstimator(K, camera_pitch, camera_height) + + # Example detection + bbox = (300, 100, 380, 400) # x1, y1, x2, y2 + + depth, angle = estimator.estimate_distance_angle(bbox) + # Estimate object size based on the known distance + object_size = object_estimator.estimate_object_size(bbox, depth) + depth_obj, angle_obj = object_estimator.estimate_distance_angle(bbox) + + print(f"Estimated person depth: {depth:.2f} m") + print(f"Estimated person angle: {np.rad2deg(angle):.1f}°") + print(f"Estimated object depth: {depth_obj:.2f} m") + print(f"Estimated object angle: {np.rad2deg(angle_obj):.1f}°") + + # Shrink the bbox by 30 pixels while keeping the same center + x_min, y_min, x_max, y_max = bbox + width = x_max - x_min + height = y_max - y_min + center_x = (x_min + x_max) // 2 + center_y = (y_min + y_max) // 2 + + new_width = max(width - 20, 2) # Ensure width is at least 2 pixels + new_height = max(height - 20, 2) # Ensure height is at least 2 pixels + + x_min = center_x - new_width // 2 + x_max = center_x + new_width // 2 + y_min = center_y - new_height // 2 + y_max = center_y + new_height // 2 + + bbox = (x_min, y_min, x_max, y_max) + + # Re-estimate distance and angle with the new bbox + depth, angle = estimator.estimate_distance_angle(bbox) + depth_obj, angle_obj = object_estimator.estimate_distance_angle(bbox) + + print(f"New estimated person depth: {depth:.2f} m") + print(f"New estimated person angle: {np.rad2deg(angle):.1f}°") + print(f"New estimated object depth: {depth_obj:.2f} m") + print(f"New estimated object angle: {np.rad2deg(angle_obj):.1f}°") diff --git a/build/lib/dimos/perception/common/utils.py b/build/lib/dimos/perception/common/utils.py new file mode 100644 index 0000000000..fc50e042ad --- /dev/null +++ b/build/lib/dimos/perception/common/utils.py @@ -0,0 +1,364 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +from typing import List, Tuple, Optional, Any +from dimos.types.manipulation import ObjectData +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger +import torch + +logger = setup_logger("dimos.perception.common.utils") + + +def colorize_depth(depth_img: np.ndarray, max_depth: float = 5.0) -> Optional[np.ndarray]: + """ + Normalize and colorize depth image using COLORMAP_JET. + + Args: + depth_img: Depth image (H, W) in meters + max_depth: Maximum depth value for normalization + + Returns: + Colorized depth image (H, W, 3) in RGB format, or None if input is None + """ + if depth_img is None: + return None + + valid_mask = np.isfinite(depth_img) & (depth_img > 0) + depth_norm = np.zeros_like(depth_img) + depth_norm[valid_mask] = np.clip(depth_img[valid_mask] / max_depth, 0, 1) + depth_colored = cv2.applyColorMap((depth_norm * 255).astype(np.uint8), cv2.COLORMAP_JET) + depth_rgb = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB) + + # Make the depth image less bright by scaling down the values + depth_rgb = (depth_rgb * 0.6).astype(np.uint8) + + return depth_rgb + + +def draw_bounding_box( + image: np.ndarray, + bbox: List[float], + color: Tuple[int, int, int] = (0, 255, 0), + thickness: int = 2, + label: Optional[str] = None, + confidence: Optional[float] = None, + object_id: Optional[int] = None, + font_scale: float = 0.6, +) -> np.ndarray: + """ + Draw a bounding box with optional label on an image. + + Args: + image: Image to draw on (H, W, 3) + bbox: Bounding box [x1, y1, x2, y2] + color: RGB color tuple for the box + thickness: Line thickness for the box + label: Optional class label + confidence: Optional confidence score + object_id: Optional object ID + font_scale: Font scale for text + + Returns: + Image with bounding box drawn + """ + x1, y1, x2, y2 = map(int, bbox) + + # Draw bounding box + cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness) + + # Create label text + text_parts = [] + if label is not None: + text_parts.append(str(label)) + if object_id is not None: + text_parts.append(f"ID: {object_id}") + if confidence is not None: + text_parts.append(f"({confidence:.2f})") + + if text_parts: + text = ", ".join(text_parts) + + # Draw text background + text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1)[0] + cv2.rectangle( + image, + (x1, y1 - text_size[1] - 5), + (x1 + text_size[0], y1), + (0, 0, 0), + -1, + ) + + # Draw text + cv2.putText( + image, + text, + (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + (255, 255, 255), + 1, + ) + + return image + + +def draw_segmentation_mask( + image: np.ndarray, + mask: np.ndarray, + color: Tuple[int, int, int] = (0, 200, 200), + alpha: float = 0.5, + draw_contours: bool = True, + contour_thickness: int = 2, +) -> np.ndarray: + """ + Draw segmentation mask overlay on an image. + + Args: + image: Image to draw on (H, W, 3) + mask: Segmentation mask (H, W) - boolean or uint8 + color: RGB color for the mask + alpha: Transparency factor (0.0 = transparent, 1.0 = opaque) + draw_contours: Whether to draw mask contours + contour_thickness: Thickness of contour lines + + Returns: + Image with mask overlay drawn + """ + if mask is None: + return image + + try: + # Ensure mask is uint8 + mask = mask.astype(np.uint8) + + # Create colored mask overlay + colored_mask = np.zeros_like(image) + colored_mask[mask > 0] = color + + # Apply the mask with transparency + mask_area = mask > 0 + image[mask_area] = cv2.addWeighted( + image[mask_area], 1 - alpha, colored_mask[mask_area], alpha, 0 + ) + + # Draw mask contours if requested + if draw_contours: + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + cv2.drawContours(image, contours, -1, color, contour_thickness) + + except Exception as e: + logger.warning(f"Error drawing segmentation mask: {e}") + + return image + + +def draw_object_detection_visualization( + image: np.ndarray, + objects: List[ObjectData], + draw_masks: bool = False, + bbox_color: Tuple[int, int, int] = (0, 255, 0), + mask_color: Tuple[int, int, int] = (0, 200, 200), + font_scale: float = 0.6, +) -> np.ndarray: + """ + Create object detection visualization with bounding boxes and optional masks. + + Args: + image: Base image to draw on (H, W, 3) + objects: List of ObjectData with detection information + draw_masks: Whether to draw segmentation masks + bbox_color: Default color for bounding boxes + mask_color: Default color for segmentation masks + font_scale: Font scale for text labels + + Returns: + Image with detection visualization + """ + viz_image = image.copy() + + for obj in objects: + try: + # Draw segmentation mask first (if enabled and available) + if draw_masks and "segmentation_mask" in obj and obj["segmentation_mask"] is not None: + viz_image = draw_segmentation_mask( + viz_image, obj["segmentation_mask"], color=mask_color, alpha=0.5 + ) + + # Draw bounding box + if "bbox" in obj and obj["bbox"] is not None: + # Use object's color if available, otherwise default + color = bbox_color + if "color" in obj and obj["color"] is not None: + obj_color = obj["color"] + if isinstance(obj_color, np.ndarray): + color = tuple(int(c) for c in obj_color) + elif isinstance(obj_color, (list, tuple)): + color = tuple(int(c) for c in obj_color[:3]) + + viz_image = draw_bounding_box( + viz_image, + obj["bbox"], + color=color, + label=obj.get("label"), + confidence=obj.get("confidence"), + object_id=obj.get("object_id"), + font_scale=font_scale, + ) + + except Exception as e: + logger.warning(f"Error drawing object visualization: {e}") + + return viz_image + + +def detection_results_to_object_data( + bboxes: List[List[float]], + track_ids: List[int], + class_ids: List[int], + confidences: List[float], + names: List[str], + masks: Optional[List[np.ndarray]] = None, + source: str = "detection", +) -> List[ObjectData]: + """ + Convert detection/segmentation results to ObjectData format. + + Args: + bboxes: List of bounding boxes [x1, y1, x2, y2] + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + masks: Optional list of segmentation masks + source: Source type ("detection" or "segmentation") + + Returns: + List of ObjectData dictionaries + """ + objects = [] + + for i in range(len(bboxes)): + # Calculate basic properties from bbox + bbox = bboxes[i] + width = bbox[2] - bbox[0] + height = bbox[3] - bbox[1] + center_x = bbox[0] + width / 2 + center_y = bbox[1] + height / 2 + + # Create ObjectData + object_data: ObjectData = { + "object_id": track_ids[i] if i < len(track_ids) else i, + "bbox": bbox, + "depth": -1.0, # Will be populated by depth estimation or point cloud processing + "confidence": confidences[i] if i < len(confidences) else 1.0, + "class_id": class_ids[i] if i < len(class_ids) else 0, + "label": names[i] if i < len(names) else f"{source}_object", + "movement_tolerance": 1.0, # Default to freely movable + "segmentation_mask": masks[i].cpu().numpy() + if masks and i < len(masks) and isinstance(masks[i], torch.Tensor) + else masks[i] + if masks and i < len(masks) + else None, + # Initialize 3D properties (will be populated by point cloud processing) + "position": Vector(0, 0, 0), + "rotation": Vector(0, 0, 0), + "size": { + "width": 0.0, + "height": 0.0, + "depth": 0.0, + }, + } + objects.append(object_data) + + return objects + + +def combine_object_data( + list1: List[ObjectData], list2: List[ObjectData], overlap_threshold: float = 0.8 +) -> List[ObjectData]: + """ + Combine two ObjectData lists, removing duplicates based on segmentation mask overlap. + """ + combined = list1.copy() + used_ids = set(obj.get("object_id", 0) for obj in list1) + next_id = max(used_ids) + 1 if used_ids else 1 + + for obj2 in list2: + obj_copy = obj2.copy() + + # Handle duplicate object_id + if obj_copy.get("object_id", 0) in used_ids: + obj_copy["object_id"] = next_id + next_id += 1 + used_ids.add(obj_copy["object_id"]) + + # Check mask overlap + mask2 = obj2.get("segmentation_mask") + if mask2 is None or np.sum(mask2 > 0) == 0: + combined.append(obj_copy) + continue + + mask2_area = np.sum(mask2 > 0) + is_duplicate = False + + for obj1 in list1: + mask1 = obj1.get("segmentation_mask") + if mask1 is None: + continue + + intersection = np.sum((mask1 > 0) & (mask2 > 0)) + if intersection / mask2_area >= overlap_threshold: + is_duplicate = True + break + + if not is_duplicate: + combined.append(obj_copy) + + return combined + + +def point_in_bbox(point: Tuple[int, int], bbox: List[float]) -> bool: + """ + Check if a point is inside a bounding box. + + Args: + point: (x, y) coordinates + bbox: Bounding box [x1, y1, x2, y2] + + Returns: + True if point is inside bbox + """ + x, y = point + x1, y1, x2, y2 = bbox + return x1 <= x <= x2 and y1 <= y <= y2 + + +def find_clicked_object(click_point: Tuple[int, int], objects: List[Any]) -> Optional[Any]: + """ + Find which object was clicked based on bounding boxes. + + Args: + click_point: (x, y) coordinates of mouse click + objects: List of objects with 'bbox' field + + Returns: + Clicked object or None + """ + for obj in objects: + if "bbox" in obj and point_in_bbox(click_point, obj["bbox"]): + return obj + return None diff --git a/build/lib/dimos/perception/detection2d/__init__.py b/build/lib/dimos/perception/detection2d/__init__.py new file mode 100644 index 0000000000..a43c5da6ce --- /dev/null +++ b/build/lib/dimos/perception/detection2d/__init__.py @@ -0,0 +1,2 @@ +from .utils import * +from .yolo_2d_det import * diff --git a/build/lib/dimos/perception/detection2d/detic_2d_det.py b/build/lib/dimos/perception/detection2d/detic_2d_det.py new file mode 100644 index 0000000000..fc81526ad2 --- /dev/null +++ b/build/lib/dimos/perception/detection2d/detic_2d_det.py @@ -0,0 +1,414 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import os +import sys + +# Add Detic to Python path +detic_path = os.path.join(os.path.dirname(__file__), "..", "..", "models", "Detic") +if detic_path not in sys.path: + sys.path.append(detic_path) + sys.path.append(os.path.join(detic_path, "third_party/CenterNet2")) + +# PIL patch for compatibility +import PIL.Image + +if not hasattr(PIL.Image, "LINEAR") and hasattr(PIL.Image, "BILINEAR"): + PIL.Image.LINEAR = PIL.Image.BILINEAR + +# Detectron2 imports +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog + + +# Simple tracking implementation +class SimpleTracker: + """Simple IOU-based tracker implementation without external dependencies""" + + def __init__(self, iou_threshold=0.3, max_age=5): + self.iou_threshold = iou_threshold + self.max_age = max_age + self.next_id = 1 + self.tracks = {} # id -> {bbox, class_id, age, mask, etc} + + def _calculate_iou(self, bbox1, bbox2): + """Calculate IoU between two bboxes in format [x1,y1,x2,y2]""" + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) + + if x2 < x1 or y2 < y1: + return 0.0 + + intersection = (x2 - x1) * (y2 - y1) + area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0 + + def update(self, detections, masks): + """Update tracker with new detections + + Args: + detections: List of [x1,y1,x2,y2,score,class_id] + masks: List of segmentation masks corresponding to detections + + Returns: + List of [track_id, bbox, score, class_id, mask] + """ + if len(detections) == 0: + # Age existing tracks + for track_id in list(self.tracks.keys()): + self.tracks[track_id]["age"] += 1 + # Remove old tracks + if self.tracks[track_id]["age"] > self.max_age: + del self.tracks[track_id] + return [] + + # Convert to numpy for easier handling + if not isinstance(detections, np.ndarray): + detections = np.array(detections) + + result = [] + matched_indices = set() + + # Update existing tracks + for track_id, track in list(self.tracks.items()): + track["age"] += 1 + + if track["age"] > self.max_age: + del self.tracks[track_id] + continue + + # Find best matching detection for this track + best_iou = self.iou_threshold + best_idx = -1 + + for i, det in enumerate(detections): + if i in matched_indices: + continue + + # Check class match + if det[5] != track["class_id"]: + continue + + iou = self._calculate_iou(track["bbox"], det[:4]) + if iou > best_iou: + best_iou = iou + best_idx = i + + # If we found a match, update the track + if best_idx >= 0: + self.tracks[track_id]["bbox"] = detections[best_idx][:4] + self.tracks[track_id]["score"] = detections[best_idx][4] + self.tracks[track_id]["age"] = 0 + self.tracks[track_id]["mask"] = masks[best_idx] + matched_indices.add(best_idx) + + # Add to results with mask + result.append( + [ + track_id, + detections[best_idx][:4], + detections[best_idx][4], + int(detections[best_idx][5]), + self.tracks[track_id]["mask"], + ] + ) + + # Create new tracks for unmatched detections + for i, det in enumerate(detections): + if i in matched_indices: + continue + + # Create new track + new_id = self.next_id + self.next_id += 1 + + self.tracks[new_id] = { + "bbox": det[:4], + "score": det[4], + "class_id": int(det[5]), + "age": 0, + "mask": masks[i], + } + + # Add to results with mask directly from the track + result.append([new_id, det[:4], det[4], int(det[5]), masks[i]]) + + return result + + +class Detic2DDetector: + def __init__(self, model_path=None, device="cuda", vocabulary=None, threshold=0.5): + """ + Initialize the Detic detector with open vocabulary support. + + Args: + model_path (str): Path to a custom Detic model weights (optional) + device (str): Device to run inference on ('cuda' or 'cpu') + vocabulary (list): Custom vocabulary (list of class names) or 'lvis', 'objects365', 'openimages', 'coco' + threshold (float): Detection confidence threshold + """ + self.device = device + self.threshold = threshold + + # Set up Detic paths - already added to sys.path at module level + + # Import Detic modules + from centernet.config import add_centernet_config + from detic.config import add_detic_config + from detic.modeling.utils import reset_cls_test + from detic.modeling.text.text_encoder import build_text_encoder + + # Keep reference to these functions for later use + self.reset_cls_test = reset_cls_test + self.build_text_encoder = build_text_encoder + + # Setup model configuration + self.cfg = get_cfg() + add_centernet_config(self.cfg) + add_detic_config(self.cfg) + + # Use default Detic config + self.cfg.merge_from_file( + os.path.join( + detic_path, "configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml" + ) + ) + + # Set default weights if not provided + if model_path is None: + self.cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth" + else: + self.cfg.MODEL.WEIGHTS = model_path + + # Set device + if device == "cpu": + self.cfg.MODEL.DEVICE = "cpu" + + # Set detection threshold + self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold + self.cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand" + self.cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True + + # Built-in datasets for Detic - use absolute paths with detic_path + self.builtin_datasets = { + "lvis": { + "metadata": "lvis_v1_val", + "classifier": os.path.join( + detic_path, "datasets/metadata/lvis_v1_clip_a+cname.npy" + ), + }, + "objects365": { + "metadata": "objects365_v2_val", + "classifier": os.path.join( + detic_path, "datasets/metadata/o365_clip_a+cnamefix.npy" + ), + }, + "openimages": { + "metadata": "oid_val_expanded", + "classifier": os.path.join(detic_path, "datasets/metadata/oid_clip_a+cname.npy"), + }, + "coco": { + "metadata": "coco_2017_val", + "classifier": os.path.join(detic_path, "datasets/metadata/coco_clip_a+cname.npy"), + }, + } + + # Override config paths to use absolute paths + self.cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = os.path.join( + detic_path, "datasets/metadata/lvis_v1_train_cat_info.json" + ) + + # Initialize model + self.predictor = None + + # Setup with initial vocabulary + vocabulary = vocabulary or "lvis" + self.setup_vocabulary(vocabulary) + + # Initialize our simple tracker + self.tracker = SimpleTracker(iou_threshold=0.5, max_age=5) + + def setup_vocabulary(self, vocabulary): + """ + Setup the model's vocabulary. + + Args: + vocabulary: Either a string ('lvis', 'objects365', 'openimages', 'coco') + or a list of class names for custom vocabulary. + """ + if self.predictor is None: + # Initialize the model + from detectron2.engine import DefaultPredictor + + self.predictor = DefaultPredictor(self.cfg) + + if isinstance(vocabulary, str) and vocabulary in self.builtin_datasets: + # Use built-in dataset + dataset = vocabulary + metadata = MetadataCatalog.get(self.builtin_datasets[dataset]["metadata"]) + classifier = self.builtin_datasets[dataset]["classifier"] + num_classes = len(metadata.thing_classes) + self.class_names = metadata.thing_classes + else: + # Use custom vocabulary + if isinstance(vocabulary, str): + # If it's a string but not a built-in dataset, treat as a file + try: + with open(vocabulary, "r") as f: + class_names = [line.strip() for line in f if line.strip()] + except: + # Default to LVIS if there's an issue + print(f"Error loading vocabulary from {vocabulary}, using LVIS") + return self.setup_vocabulary("lvis") + else: + # Assume it's a list of class names + class_names = vocabulary + + # Create classifier from text embeddings + metadata = MetadataCatalog.get("__unused") + metadata.thing_classes = class_names + self.class_names = class_names + + # Generate CLIP embeddings for custom vocabulary + classifier = self._get_clip_embeddings(class_names) + num_classes = len(class_names) + + # Reset model with new vocabulary + self.reset_cls_test(self.predictor.model, classifier, num_classes) + return self.class_names + + def _get_clip_embeddings(self, vocabulary, prompt="a "): + """ + Generate CLIP embeddings for a vocabulary list. + + Args: + vocabulary (list): List of class names + prompt (str): Prompt prefix to use for CLIP + + Returns: + torch.Tensor: Tensor of embeddings + """ + text_encoder = self.build_text_encoder(pretrain=True) + text_encoder.eval() + texts = [prompt + x for x in vocabulary] + emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() + return emb + + def process_image(self, image): + """ + Process an image and return detection results. + + Args: + image: Input image in BGR format (OpenCV) + + Returns: + tuple: (bboxes, track_ids, class_ids, confidences, names, masks) + - bboxes: list of [x1, y1, x2, y2] coordinates + - track_ids: list of tracking IDs (or -1 if no tracking) + - class_ids: list of class indices + - confidences: list of detection confidences + - names: list of class names + - masks: list of segmentation masks (numpy arrays) + """ + # Run inference with Detic + outputs = self.predictor(image) + instances = outputs["instances"].to("cpu") + + # Extract bounding boxes, classes, scores, and masks + if len(instances) == 0: + return [], [], [], [], [], [] + + boxes = instances.pred_boxes.tensor.numpy() + class_ids = instances.pred_classes.numpy() + scores = instances.scores.numpy() + masks = instances.pred_masks.numpy() + + # Convert boxes to [x1, y1, x2, y2] format + bboxes = [] + for box in boxes: + x1, y1, x2, y2 = box.tolist() + bboxes.append([x1, y1, x2, y2]) + + # Get class names + names = [self.class_names[class_id] for class_id in class_ids] + + # Apply tracking + detections = [] + filtered_masks = [] + for i, bbox in enumerate(bboxes): + if scores[i] >= self.threshold: + # Format for tracker: [x1, y1, x2, y2, score, class_id] + detections.append(bbox + [scores[i], class_ids[i]]) + filtered_masks.append(masks[i]) + + if not detections: + return [], [], [], [], [], [] + + # Update tracker with detections and correctly aligned masks + track_results = self.tracker.update(detections, filtered_masks) + + # Process tracking results + track_ids = [] + tracked_bboxes = [] + tracked_class_ids = [] + tracked_scores = [] + tracked_names = [] + tracked_masks = [] + + for track_id, bbox, score, class_id, mask in track_results: + track_ids.append(int(track_id)) + tracked_bboxes.append(bbox.tolist() if isinstance(bbox, np.ndarray) else bbox) + tracked_class_ids.append(int(class_id)) + tracked_scores.append(score) + tracked_names.append(self.class_names[int(class_id)]) + tracked_masks.append(mask) + + return ( + tracked_bboxes, + track_ids, + tracked_class_ids, + tracked_scores, + tracked_names, + tracked_masks, + ) + + def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, names): + """ + Generate visualization of detection results. + + Args: + image: Original input image + bboxes: List of bounding boxes + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + + Returns: + Image with visualized detections + """ + from dimos.perception.detection2d.utils import plot_results + + return plot_results(image, bboxes, track_ids, class_ids, confidences, names) + + def cleanup(self): + """Clean up resources.""" + # Nothing specific to clean up for Detic + pass diff --git a/build/lib/dimos/perception/detection2d/test_yolo_2d_det.py b/build/lib/dimos/perception/detection2d/test_yolo_2d_det.py new file mode 100644 index 0000000000..4240625744 --- /dev/null +++ b/build/lib/dimos/perception/detection2d/test_yolo_2d_det.py @@ -0,0 +1,177 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import cv2 +import numpy as np +import pytest +import reactivex as rx +from reactivex import operators as ops + +from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector +from dimos.stream.video_provider import VideoProvider + + +class TestYolo2DDetector: + def test_yolo_detector_initialization(self): + """Test YOLO detector initializes correctly with default model path.""" + try: + detector = Yolo2DDetector() + assert detector is not None + assert detector.model is not None + except Exception as e: + # If the model file doesn't exist, the test should still pass with a warning + pytest.skip(f"Skipping test due to model initialization error: {e}") + + def test_yolo_detector_process_image(self): + """Test YOLO detector can process video frames and return detection results.""" + try: + # Import data inside method to avoid pytest fixture confusion + from dimos.utils.data import get_data + + detector = Yolo2DDetector() + + video_path = get_data("assets") / "trimmed_video_office.mov" + + # Create video provider and directly get a video stream observable + assert os.path.exists(video_path), f"Test video not found: {video_path}" + video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + # Process more frames for thorough testing + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) + + # Use ReactiveX operators to process the stream + def process_frame(frame): + try: + # Process frame with YOLO + bboxes, track_ids, class_ids, confidences, names = detector.process_image(frame) + print( + f"YOLO results - boxes: {(bboxes)}, tracks: {len(track_ids)}, classes: {(class_ids)}, confidences: {(confidences)}, names: {(names)}" + ) + + return { + "frame": frame, + "bboxes": bboxes, + "track_ids": track_ids, + "class_ids": class_ids, + "confidences": confidences, + "names": names, + } + except Exception as e: + print(f"Exception in process_frame: {e}") + return {} + + # Create the detection stream using pipe and map operator + detection_stream = video_stream.pipe(ops.map(process_frame)) + + # Collect results from the stream + results = [] + + frames_processed = 0 + target_frames = 10 + + def on_next(result): + nonlocal frames_processed + if not result: + return + + results.append(result) + frames_processed += 1 + + # Stop after processing target number of frames + if frames_processed >= target_frames: + subscription.dispose() + + def on_error(error): + pytest.fail(f"Error in detection stream: {error}") + + def on_completed(): + pass + + # Subscribe and wait for results + subscription = detection_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + timeout = 10.0 + start_time = time.time() + while frames_processed < target_frames and time.time() - start_time < timeout: + time.sleep(0.5) + + # Clean up subscription + subscription.dispose() + video_provider.dispose_all() + # Check that we got detection results + if len(results) == 0: + pytest.skip("Skipping test due to error: Failed to get any detection results") + + # Verify we have detection results with expected properties + assert len(results) > 0, "No detection results were received" + + # Print statistics about detections + total_detections = sum(len(r["bboxes"]) for r in results if r.get("bboxes")) + avg_detections = total_detections / len(results) if results else 0 + print(f"Total detections: {total_detections}, Average per frame: {avg_detections:.2f}") + + # Print most common detected objects + object_counts = {} + for r in results: + if r.get("names"): + for name in r["names"]: + if name: + object_counts[name] = object_counts.get(name, 0) + 1 + + if object_counts: + print("Detected objects:") + for obj, count in sorted(object_counts.items(), key=lambda x: x[1], reverse=True)[ + :5 + ]: + print(f" - {obj}: {count} times") + + # Analyze the first result + result = results[0] + + # Check that we have a frame + assert "frame" in result, "Result doesn't contain a frame" + assert isinstance(result["frame"], np.ndarray), "Frame is not a numpy array" + + # Check that detection results are valid + assert isinstance(result["bboxes"], list) + assert isinstance(result["track_ids"], list) + assert isinstance(result["class_ids"], list) + assert isinstance(result["confidences"], list) + assert isinstance(result["names"], list) + + # All result lists should be the same length + assert ( + len(result["bboxes"]) + == len(result["track_ids"]) + == len(result["class_ids"]) + == len(result["confidences"]) + == len(result["names"]) + ) + + # If we have detections, check that bbox format is valid + if result["bboxes"]: + assert len(result["bboxes"][0]) == 4, ( + "Bounding boxes should be in [x1, y1, x2, y2] format" + ) + + except Exception as e: + pytest.skip(f"Skipping test due to error: {e}") + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/build/lib/dimos/perception/detection2d/utils.py b/build/lib/dimos/perception/detection2d/utils.py new file mode 100644 index 0000000000..dbe19baf30 --- /dev/null +++ b/build/lib/dimos/perception/detection2d/utils.py @@ -0,0 +1,338 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 +from dimos.types.vector import Vector +from dimos.utils.transform_utils import distance_angle_to_goal_xy + + +def filter_detections( + bboxes, + track_ids, + class_ids, + confidences, + names, + class_filter=None, + name_filter=None, + track_id_filter=None, +): + """ + Filter detection results based on class IDs, names, and/or tracking IDs. + + Args: + bboxes: List of bounding boxes [x1, y1, x2, y2] + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + class_filter: List/set of class IDs to keep, or None to keep all + name_filter: List/set of class names to keep, or None to keep all + track_id_filter: List/set of track IDs to keep, or None to keep all + + Returns: + tuple: (filtered_bboxes, filtered_track_ids, filtered_class_ids, + filtered_confidences, filtered_names) + """ + # Convert filters to sets for efficient lookup + if class_filter is not None: + class_filter = set(class_filter) + if name_filter is not None: + name_filter = set(name_filter) + if track_id_filter is not None: + track_id_filter = set(track_id_filter) + + # Initialize lists for filtered results + filtered_bboxes = [] + filtered_track_ids = [] + filtered_class_ids = [] + filtered_confidences = [] + filtered_names = [] + + # Filter detections + for bbox, track_id, class_id, conf, name in zip( + bboxes, track_ids, class_ids, confidences, names + ): + # Check if detection passes all specified filters + keep = True + + if class_filter is not None: + keep = keep and (class_id in class_filter) + + if name_filter is not None: + keep = keep and (name in name_filter) + + if track_id_filter is not None: + keep = keep and (track_id in track_id_filter) + + # If detection passes all filters, add it to results + if keep: + filtered_bboxes.append(bbox) + filtered_track_ids.append(track_id) + filtered_class_ids.append(class_id) + filtered_confidences.append(conf) + filtered_names.append(name) + + return ( + filtered_bboxes, + filtered_track_ids, + filtered_class_ids, + filtered_confidences, + filtered_names, + ) + + +def extract_detection_results(result, class_filter=None, name_filter=None, track_id_filter=None): + """ + Extract and optionally filter detection information from a YOLO result object. + + Args: + result: Ultralytics result object + class_filter: List/set of class IDs to keep, or None to keep all + name_filter: List/set of class names to keep, or None to keep all + track_id_filter: List/set of track IDs to keep, or None to keep all + + Returns: + tuple: (bboxes, track_ids, class_ids, confidences, names) + - bboxes: list of [x1, y1, x2, y2] coordinates + - track_ids: list of tracking IDs + - class_ids: list of class indices + - confidences: list of detection confidences + - names: list of class names + """ + bboxes = [] + track_ids = [] + class_ids = [] + confidences = [] + names = [] + + if result.boxes is None: + return bboxes, track_ids, class_ids, confidences, names + + for box in result.boxes: + # Extract bounding box coordinates + x1, y1, x2, y2 = box.xyxy[0].tolist() + + # Extract tracking ID if available + track_id = -1 + if hasattr(box, "id") and box.id is not None: + track_id = int(box.id[0].item()) + + # Extract class information + cls_idx = int(box.cls[0]) + name = result.names[cls_idx] + + # Extract confidence + conf = float(box.conf[0]) + + # Check filters before adding to results + keep = True + if class_filter is not None: + keep = keep and (cls_idx in class_filter) + if name_filter is not None: + keep = keep and (name in name_filter) + if track_id_filter is not None: + keep = keep and (track_id in track_id_filter) + + if keep: + bboxes.append([x1, y1, x2, y2]) + track_ids.append(track_id) + class_ids.append(cls_idx) + confidences.append(conf) + names.append(name) + + return bboxes, track_ids, class_ids, confidences, names + + +def plot_results(image, bboxes, track_ids, class_ids, confidences, names, alpha=0.5): + """ + Draw bounding boxes and labels on the image. + + Args: + image: Original input image + bboxes: List of bounding boxes [x1, y1, x2, y2] + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + alpha: Transparency of the overlay + + Returns: + Image with visualized detections + """ + vis_img = image.copy() + + for bbox, track_id, conf, name in zip(bboxes, track_ids, confidences, names): + # Generate consistent color based on track_id or class name + if track_id != -1: + np.random.seed(track_id) + else: + np.random.seed(hash(name) % 100000) + color = np.random.randint(0, 255, (3,), dtype=np.uint8) + np.random.seed(None) + + # Draw bounding box + x1, y1, x2, y2 = map(int, bbox) + cv2.rectangle(vis_img, (x1, y1), (x2, y2), color.tolist(), 2) + + # Prepare label text + if track_id != -1: + label = f"ID:{track_id} {name} {conf:.2f}" + else: + label = f"{name} {conf:.2f}" + + # Calculate text size for background rectangle + (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + # Draw background rectangle for text + cv2.rectangle(vis_img, (x1, y1 - text_h - 8), (x1 + text_w + 4, y1), color.tolist(), -1) + + # Draw text with white color for better visibility + cv2.putText( + vis_img, label, (x1 + 2, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1 + ) + + return vis_img + + +def calculate_depth_from_bbox(depth_map, bbox): + """ + Calculate the average depth of an object within a bounding box. + Uses the 25th to 75th percentile range to filter outliers. + + Args: + depth_map: The depth map + bbox: Bounding box in format [x1, y1, x2, y2] + + Returns: + float: Average depth in meters, or None if depth estimation fails + """ + try: + # Extract region of interest from the depth map + x1, y1, x2, y2 = map(int, bbox) + roi_depth = depth_map[y1:y2, x1:x2] + + if roi_depth.size == 0: + return None + + # Calculate 25th and 75th percentile to filter outliers + p25 = np.percentile(roi_depth, 25) + p75 = np.percentile(roi_depth, 75) + + # Filter depth values within this range + filtered_depth = roi_depth[(roi_depth >= p25) & (roi_depth <= p75)] + + # Calculate average depth (convert to meters) + if filtered_depth.size > 0: + return np.mean(filtered_depth) / 1000.0 # Convert mm to meters + + return None + except Exception as e: + print(f"Error calculating depth from bbox: {e}") + return None + + +def calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics): + """ + Calculate distance and angle to object center based on bbox and depth. + + Args: + bbox: Bounding box [x1, y1, x2, y2] + depth: Depth value in meters + camera_intrinsics: List [fx, fy, cx, cy] with camera parameters + + Returns: + tuple: (distance, angle) in meters and radians + """ + if camera_intrinsics is None: + raise ValueError("Camera intrinsics required for distance calculation") + + # Extract camera parameters + fx, fy, cx, cy = camera_intrinsics + + # Calculate center of bounding box in pixels + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + + # Calculate normalized image coordinates + x_norm = (center_x - cx) / fx + + # Calculate angle (positive to the right) + angle = np.arctan(x_norm) + + # Calculate distance using depth and angle + distance = depth / np.cos(angle) if np.cos(angle) != 0 else depth + + return distance, angle + + +def calculate_object_size_from_bbox(bbox, depth, camera_intrinsics): + """ + Estimate physical width and height of object in meters. + + Args: + bbox: Bounding box [x1, y1, x2, y2] + depth: Depth value in meters + camera_intrinsics: List [fx, fy, cx, cy] with camera parameters + + Returns: + tuple: (width, height) in meters + """ + if camera_intrinsics is None: + return 0.0, 0.0 + + fx, fy, _, _ = camera_intrinsics + + # Calculate bbox dimensions in pixels + x1, y1, x2, y2 = bbox + width_px = x2 - x1 + height_px = y2 - y1 + + # Convert to meters using similar triangles and depth + width_m = (width_px * depth) / fx + height_m = (height_px * depth) / fy + + return width_m, height_m + + +def calculate_position_rotation_from_bbox(bbox, depth, camera_intrinsics): + """ + Calculate position (xyz) and rotation (roll, pitch, yaw) for an object + based on its bounding box and depth. + + Args: + bbox: Bounding box [x1, y1, x2, y2] + depth: Depth value in meters + camera_intrinsics: List [fx, fy, cx, cy] with camera parameters + + Returns: + Vector: position + Vector: rotation + """ + # Calculate distance and angle to object + distance, angle = calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics) + + # Convert distance and angle to x,y coordinates (in camera frame) + # Note: We negate the angle since positive angle means object is to the right, + # but we want positive y to be to the left in the standard coordinate system + x, y = distance_angle_to_goal_xy(distance, -angle) + + # For now, rotation is only in yaw (around z-axis) + # We can use the negative of the angle as an estimate of the object's yaw + # assuming objects tend to face the camera + position = Vector([x, y, 0.0]) + rotation = Vector([0.0, 0.0, -angle]) + + return position, rotation diff --git a/build/lib/dimos/perception/detection2d/yolo_2d_det.py b/build/lib/dimos/perception/detection2d/yolo_2d_det.py new file mode 100644 index 0000000000..b9b04165cd --- /dev/null +++ b/build/lib/dimos/perception/detection2d/yolo_2d_det.py @@ -0,0 +1,157 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import cv2 +import onnxruntime +from ultralytics import YOLO + +from dimos.perception.detection2d.utils import ( + extract_detection_results, + filter_detections, + plot_results, +) +from dimos.utils.data import get_data +from dimos.utils.gpu_utils import is_cuda_available +from dimos.utils.logging_config import setup_logger +from dimos.utils.path_utils import get_project_root + +logger = setup_logger("dimos.perception.detection2d.yolo_2d_det") + + +class Yolo2DDetector: + def __init__(self, model_path="models_yolo", model_name="yolo11n.onnx", device="cpu"): + """ + Initialize the YOLO detector. + + Args: + model_path (str): Path to the YOLO model weights in tests/data LFS directory + model_name (str): Name of the YOLO model weights file + device (str): Device to run inference on ('cuda' or 'cpu') + """ + self.device = device + self.model = YOLO(get_data(model_path) / model_name) + + module_dir = os.path.dirname(__file__) + self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") + if is_cuda_available(): + if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18 + onnxruntime.preload_dlls(cuda=True, cudnn=True) + self.device = "cuda" + logger.info("Using CUDA for YOLO 2d detector") + else: + self.device = "cpu" + logger.info("Using CPU for YOLO 2d detector") + + def process_image(self, image): + """ + Process an image and return detection results. + + Args: + image: Input image in BGR format (OpenCV) + + Returns: + tuple: (bboxes, track_ids, class_ids, confidences, names) + - bboxes: list of [x1, y1, x2, y2] coordinates + - track_ids: list of tracking IDs (or -1 if no tracking) + - class_ids: list of class indices + - confidences: list of detection confidences + - names: list of class names + """ + results = self.model.track( + source=image, + device=self.device, + conf=0.5, + iou=0.6, + persist=True, + verbose=False, + tracker=self.tracker_config, + ) + + if len(results) > 0: + # Extract detection results + bboxes, track_ids, class_ids, confidences, names = extract_detection_results(results[0]) + return bboxes, track_ids, class_ids, confidences, names + + return [], [], [], [], [] + + def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, names): + """ + Generate visualization of detection results. + + Args: + image: Original input image + bboxes: List of bounding boxes + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + + Returns: + Image with visualized detections + """ + return plot_results(image, bboxes, track_ids, class_ids, confidences, names) + + +def main(): + """Example usage of the Yolo2DDetector class.""" + # Initialize video capture + cap = cv2.VideoCapture(0) + + # Initialize detector + detector = Yolo2DDetector() + + enable_person_filter = True + + try: + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # Process frame + bboxes, track_ids, class_ids, confidences, names = detector.process_image(frame) + + # Apply person filtering if enabled + if enable_person_filter and len(bboxes) > 0: + # Person is class_id 0 in COCO dataset + bboxes, track_ids, class_ids, confidences, names = filter_detections( + bboxes, + track_ids, + class_ids, + confidences, + names, + class_filter=[0], # 0 is the class_id for person + name_filter=["person"], + ) + + # Visualize results + if len(bboxes) > 0: + frame = detector.visualize_results( + frame, bboxes, track_ids, class_ids, confidences, names + ) + + # Display results + cv2.imshow("YOLO Detection", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + finally: + cap.release() + cv2.destroyAllWindows() + + +if __name__ == "__main__": + main() diff --git a/build/lib/dimos/perception/grasp_generation/__init__.py b/build/lib/dimos/perception/grasp_generation/__init__.py new file mode 100644 index 0000000000..16281fe0b6 --- /dev/null +++ b/build/lib/dimos/perception/grasp_generation/__init__.py @@ -0,0 +1 @@ +from .utils import * diff --git a/build/lib/dimos/perception/grasp_generation/grasp_generation.py b/build/lib/dimos/perception/grasp_generation/grasp_generation.py new file mode 100644 index 0000000000..947a3bcd96 --- /dev/null +++ b/build/lib/dimos/perception/grasp_generation/grasp_generation.py @@ -0,0 +1,228 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +AnyGrasp-based grasp generation for manipulation pipeline. +""" + +import asyncio +import numpy as np +import open3d as o3d +from typing import Dict, List, Optional + +from dimos.types.manipulation import ObjectData +from dimos.utils.logging_config import setup_logger +from dimos.perception.grasp_generation.utils import parse_anygrasp_results + +logger = setup_logger("dimos.perception.grasp_generation") + + +class AnyGraspGenerator: + """ + AnyGrasp-based grasp generator using WebSocket communication. + """ + + def __init__(self, server_url: str): + """ + Initialize AnyGrasp generator. + + Args: + server_url: WebSocket URL for AnyGrasp server + """ + self.server_url = server_url + logger.info(f"Initialized AnyGrasp generator with server: {server_url}") + + def generate_grasps_from_objects( + self, objects: List[ObjectData], full_pcd: o3d.geometry.PointCloud + ) -> List[Dict]: + """ + Generate grasps from ObjectData objects using AnyGrasp. + + Args: + objects: List of ObjectData with point clouds + full_pcd: Open3D point cloud of full scene + + Returns: + Parsed grasp results as list of dictionaries + """ + try: + # Combine all point clouds + all_points = [] + all_colors = [] + valid_objects = 0 + + for obj in objects: + if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: + continue + + points = obj["point_cloud_numpy"] + if not isinstance(points, np.ndarray) or points.size == 0: + continue + + if len(points.shape) != 2 or points.shape[1] != 3: + continue + + colors = None + if "colors_numpy" in obj and obj["colors_numpy"] is not None: + colors = obj["colors_numpy"] + if isinstance(colors, np.ndarray) and colors.size > 0: + if ( + colors.shape[0] != points.shape[0] + or len(colors.shape) != 2 + or colors.shape[1] != 3 + ): + colors = None + + all_points.append(points) + if colors is not None: + all_colors.append(colors) + valid_objects += 1 + + if not all_points: + return [] + + # Combine point clouds + combined_points = np.vstack(all_points) + combined_colors = None + if len(all_colors) == valid_objects and len(all_colors) > 0: + combined_colors = np.vstack(all_colors) + + # Send grasp request + grasps = self._send_grasp_request_sync(combined_points, combined_colors) + + if not grasps: + return [] + + # Parse and return results in list of dictionaries format + return parse_anygrasp_results(grasps) + + except Exception as e: + logger.error(f"AnyGrasp generation failed: {e}") + return [] + + def _send_grasp_request_sync( + self, points: np.ndarray, colors: Optional[np.ndarray] + ) -> Optional[List[Dict]]: + """Send synchronous grasp request to AnyGrasp server.""" + + try: + # Prepare colors + colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 + + # Ensure correct data types + points = points.astype(np.float32) + colors = colors.astype(np.float32) + + # Validate ranges + if np.any(np.isnan(points)) or np.any(np.isinf(points)): + logger.error("Points contain NaN or Inf values") + return None + if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): + logger.error("Colors contain NaN or Inf values") + return None + + colors = np.clip(colors, 0.0, 1.0) + + # Run async request in sync context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete(self._async_grasp_request(points, colors)) + return result + finally: + loop.close() + + except Exception as e: + logger.error(f"Error in synchronous grasp request: {e}") + return None + + async def _async_grasp_request( + self, points: np.ndarray, colors: np.ndarray + ) -> Optional[List[Dict]]: + """Async grasp request helper.""" + import json + import websockets + + try: + async with websockets.connect(self.server_url) as websocket: + request = { + "points": points.tolist(), + "colors": colors.tolist(), + "lims": [-1.0, 1.0, -1.0, 1.0, 0.0, 2.0], + } + + await websocket.send(json.dumps(request)) + response = await websocket.recv() + grasps = json.loads(response) + + if isinstance(grasps, dict) and "error" in grasps: + logger.error(f"Server returned error: {grasps['error']}") + return None + elif isinstance(grasps, (int, float)) and grasps == 0: + return None + elif not isinstance(grasps, list): + logger.error(f"Server returned unexpected response type: {type(grasps)}") + return None + elif len(grasps) == 0: + return None + + return self._convert_grasp_format(grasps) + + except Exception as e: + logger.error(f"Async grasp request failed: {e}") + return None + + def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: + """Convert AnyGrasp format to visualization format.""" + converted = [] + + for i, grasp in enumerate(anygrasp_grasps): + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) + euler_angles = self._rotation_matrix_to_euler(rotation_matrix) + + converted_grasp = { + "id": f"grasp_{i}", + "score": grasp.get("score", 0.0), + "width": grasp.get("width", 0.0), + "height": grasp.get("height", 0.0), + "depth": grasp.get("depth", 0.0), + "translation": grasp.get("translation", [0, 0, 0]), + "rotation_matrix": rotation_matrix.tolist(), + "euler_angles": euler_angles, + } + converted.append(converted_grasp) + + converted.sort(key=lambda x: x["score"], reverse=True) + return converted + + def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: + """Convert rotation matrix to Euler angles (in radians).""" + sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) + + singular = sy < 1e-6 + + if not singular: + x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + else: + x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = 0 + + return {"roll": x, "pitch": y, "yaw": z} + + def cleanup(self): + """Clean up resources.""" + logger.info("AnyGrasp generator cleaned up") diff --git a/build/lib/dimos/perception/grasp_generation/utils.py b/build/lib/dimos/perception/grasp_generation/utils.py new file mode 100644 index 0000000000..ba461f9d90 --- /dev/null +++ b/build/lib/dimos/perception/grasp_generation/utils.py @@ -0,0 +1,621 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for grasp generation and visualization.""" + +import numpy as np +import open3d as o3d +import cv2 +from typing import List, Dict, Tuple, Optional, Union + + +def project_3d_points_to_2d( + points_3d: np.ndarray, camera_intrinsics: Union[List[float], np.ndarray] +) -> np.ndarray: + """ + Project 3D points to 2D image coordinates using camera intrinsics. + + Args: + points_3d: Nx3 array of 3D points (X, Y, Z) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + + Returns: + Nx2 array of 2D image coordinates (u, v) + """ + if len(points_3d) == 0: + return np.zeros((0, 2), dtype=np.int32) + + # Filter out points with zero or negative depth + valid_mask = points_3d[:, 2] > 0 + if not np.any(valid_mask): + return np.zeros((0, 2), dtype=np.int32) + + valid_points = points_3d[valid_mask] + + # Extract camera parameters + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + else: + camera_matrix = np.array(camera_intrinsics) + fx = camera_matrix[0, 0] + fy = camera_matrix[1, 1] + cx = camera_matrix[0, 2] + cy = camera_matrix[1, 2] + + # Project to image coordinates + u = (valid_points[:, 0] * fx / valid_points[:, 2]) + cx + v = (valid_points[:, 1] * fy / valid_points[:, 2]) + cy + + # Round to integer pixel coordinates + points_2d = np.column_stack([u, v]).astype(np.int32) + + return points_2d + + +def euler_to_rotation_matrix(roll: float, pitch: float, yaw: float) -> np.ndarray: + """ + Convert Euler angles to rotation matrix. + + Args: + roll: Roll angle in radians + pitch: Pitch angle in radians + yaw: Yaw angle in radians + + Returns: + 3x3 rotation matrix + """ + Rx = np.array([[1, 0, 0], [0, np.cos(roll), -np.sin(roll)], [0, np.sin(roll), np.cos(roll)]]) + + Ry = np.array( + [[np.cos(pitch), 0, np.sin(pitch)], [0, 1, 0], [-np.sin(pitch), 0, np.cos(pitch)]] + ) + + Rz = np.array([[np.cos(yaw), -np.sin(yaw), 0], [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]]) + + # Combined rotation matrix + R = Rz @ Ry @ Rx + + return R + + +def create_gripper_geometry( + grasp_data: dict, + finger_length: float = 0.08, + finger_thickness: float = 0.004, +) -> List[o3d.geometry.TriangleMesh]: + """ + Create a simple fork-like gripper geometry from grasp data. + + Args: + grasp_data: Dictionary containing grasp parameters + - translation: 3D position list + - rotation_matrix: 3x3 rotation matrix defining gripper coordinate system + * X-axis: gripper width direction (opening/closing) + * Y-axis: finger length direction + * Z-axis: approach direction (toward object) + - width: Gripper opening width + finger_length: Length of gripper fingers (longer) + finger_thickness: Thickness of gripper fingers + base_height: Height of gripper base (longer) + color: RGB color for the gripper (solid blue) + + Returns: + List of Open3D TriangleMesh geometries for the gripper + """ + + translation = np.array(grasp_data["translation"]) + rotation_matrix = np.array(grasp_data["rotation_matrix"]) + + width = grasp_data.get("width", 0.04) + + # Create transformation matrix + transform = np.eye(4) + transform[:3, :3] = rotation_matrix + transform[:3, 3] = translation + + geometries = [] + + # Gripper dimensions + finger_width = 0.006 # Thickness of each finger + handle_length = 0.05 # Length of handle extending backward + + # Build gripper in local coordinate system: + # X-axis = width direction (left/right finger separation) + # Y-axis = finger length direction (fingers extend along +Y) + # Z-axis = approach direction (toward object, handle extends along -Z) + # IMPORTANT: Fingertips should be at origin (translation point) + + # Create left finger extending along +Y, positioned at +X + left_finger = o3d.geometry.TriangleMesh.create_box( + width=finger_width, # Thin finger + height=finger_length, # Extends along Y (finger length direction) + depth=finger_thickness, # Thin in Z direction + ) + left_finger.translate( + [ + width / 2 - finger_width / 2, # Position at +X (half width from center) + -finger_length, # Shift so fingertips are at origin + -finger_thickness / 2, # Center in Z + ] + ) + + # Create right finger extending along +Y, positioned at -X + right_finger = o3d.geometry.TriangleMesh.create_box( + width=finger_width, # Thin finger + height=finger_length, # Extends along Y (finger length direction) + depth=finger_thickness, # Thin in Z direction + ) + right_finger.translate( + [ + -width / 2 - finger_width / 2, # Position at -X (half width from center) + -finger_length, # Shift so fingertips are at origin + -finger_thickness / 2, # Center in Z + ] + ) + + # Create base connecting fingers - flat like a stickman body + base = o3d.geometry.TriangleMesh.create_box( + width=width + finger_width, # Full width plus finger thickness + height=finger_thickness, # Flat like fingers (stickman style) + depth=finger_thickness, # Thin like fingers + ) + base.translate( + [ + -width / 2 - finger_width / 2, # Start from left finger position + -finger_length - finger_thickness, # Behind fingers, adjusted for fingertips at origin + -finger_thickness / 2, # Center in Z + ] + ) + + # Create handle extending backward - flat stick like stickman arm + handle = o3d.geometry.TriangleMesh.create_box( + width=finger_width, # Same width as fingers + height=handle_length, # Extends backward along Y direction (same plane) + depth=finger_thickness, # Thin like fingers (same plane) + ) + handle.translate( + [ + -finger_width / 2, # Center in X + -finger_length + - finger_thickness + - handle_length, # Extend backward from base, adjusted for fingertips at origin + -finger_thickness / 2, # Same Z plane as other components + ] + ) + + # Use solid red color for all parts (user changed to red) + solid_color = [1.0, 0.0, 0.0] # Red color + + left_finger.paint_uniform_color(solid_color) + right_finger.paint_uniform_color(solid_color) + base.paint_uniform_color(solid_color) + handle.paint_uniform_color(solid_color) + + # Apply transformation to all parts + left_finger.transform(transform) + right_finger.transform(transform) + base.transform(transform) + handle.transform(transform) + + geometries.extend([left_finger, right_finger, base, handle]) + + return geometries + + +def create_all_gripper_geometries( + grasp_list: List[dict], max_grasps: int = -1 +) -> List[o3d.geometry.TriangleMesh]: + """ + Create gripper geometries for multiple grasps. + + Args: + grasp_list: List of grasp dictionaries + max_grasps: Maximum number of grasps to visualize (-1 for all) + + Returns: + List of all gripper geometries + """ + all_geometries = [] + + grasps_to_show = grasp_list if max_grasps < 0 else grasp_list[:max_grasps] + + for grasp in grasps_to_show: + gripper_parts = create_gripper_geometry(grasp) + all_geometries.extend(gripper_parts) + + return all_geometries + + +def draw_grasps_on_image( + image: np.ndarray, + grasp_data: Union[dict, Dict[Union[int, str], List[dict]], List[dict]], + camera_intrinsics: Union[List[float], np.ndarray], # [fx, fy, cx, cy] or 3x3 matrix + max_grasps: int = -1, # -1 means show all grasps + finger_length: float = 0.08, # Match 3D gripper + finger_thickness: float = 0.004, # Match 3D gripper +) -> np.ndarray: + """ + Draw fork-like gripper visualizations on the image matching 3D gripper design. + + Args: + image: Base image to draw on + grasp_data: Can be: + - A single grasp dict + - A list of grasp dicts + - A dictionary mapping object IDs or "scene" to list of grasps + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + max_grasps: Maximum number of grasps to visualize (-1 for all) + finger_length: Length of gripper fingers (matches 3D design) + finger_thickness: Thickness of gripper fingers (matches 3D design) + + Returns: + Image with grasps drawn + """ + result = image.copy() + + # Convert camera intrinsics to 3x3 matrix if needed + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + else: + camera_matrix = np.array(camera_intrinsics) + + # Convert input to standard format + if isinstance(grasp_data, dict) and not any( + key in grasp_data for key in ["scene", 0, 1, 2, 3, 4, 5] + ): + # Single grasp + grasps_to_draw = [(grasp_data, 0)] + elif isinstance(grasp_data, list): + # List of grasps + grasps_to_draw = [(grasp, i) for i, grasp in enumerate(grasp_data)] + else: + # Dictionary of grasps by object ID + grasps_to_draw = [] + for obj_id, grasps in grasp_data.items(): + for i, grasp in enumerate(grasps): + grasps_to_draw.append((grasp, i)) + + # Limit number of grasps if specified + if max_grasps > 0: + grasps_to_draw = grasps_to_draw[:max_grasps] + + # Define grasp colors (solid red to match 3D design) + def get_grasp_color(index: int) -> tuple: + # Use solid red color for all grasps to match 3D design + return (0, 0, 255) # Red in BGR format for OpenCV + + # Draw each grasp + for grasp, index in grasps_to_draw: + try: + color = get_grasp_color(index) + thickness = max(1, 4 - index // 3) + + # Extract grasp parameters (using translation and rotation_matrix) + if "translation" not in grasp or "rotation_matrix" not in grasp: + continue + + translation = np.array(grasp["translation"]) + rotation_matrix = np.array(grasp["rotation_matrix"]) + width = grasp.get("width", 0.04) + + # Match 3D gripper dimensions + finger_width = 0.006 # Thickness of each finger (matches 3D) + handle_length = 0.05 # Length of handle extending backward (matches 3D) + + # Create gripper geometry in local coordinate system matching 3D design: + # X-axis = width direction (left/right finger separation) + # Y-axis = finger length direction (fingers extend along +Y) + # Z-axis = approach direction (toward object, handle extends along -Z) + # IMPORTANT: Fingertips should be at origin (translation point) + + # Left finger extending along +Y, positioned at +X + left_finger_points = np.array( + [ + [ + width / 2 - finger_width / 2, + -finger_length, + -finger_thickness / 2, + ], # Back left + [ + width / 2 + finger_width / 2, + -finger_length, + -finger_thickness / 2, + ], # Back right + [ + width / 2 + finger_width / 2, + 0, + -finger_thickness / 2, + ], # Front right (at origin) + [ + width / 2 - finger_width / 2, + 0, + -finger_thickness / 2, + ], # Front left (at origin) + ] + ) + + # Right finger extending along +Y, positioned at -X + right_finger_points = np.array( + [ + [ + -width / 2 - finger_width / 2, + -finger_length, + -finger_thickness / 2, + ], # Back left + [ + -width / 2 + finger_width / 2, + -finger_length, + -finger_thickness / 2, + ], # Back right + [ + -width / 2 + finger_width / 2, + 0, + -finger_thickness / 2, + ], # Front right (at origin) + [ + -width / 2 - finger_width / 2, + 0, + -finger_thickness / 2, + ], # Front left (at origin) + ] + ) + + # Base connecting fingers - flat rectangle behind fingers + base_points = np.array( + [ + [ + -width / 2 - finger_width / 2, + -finger_length - finger_thickness, + -finger_thickness / 2, + ], # Back left + [ + width / 2 + finger_width / 2, + -finger_length - finger_thickness, + -finger_thickness / 2, + ], # Back right + [ + width / 2 + finger_width / 2, + -finger_length, + -finger_thickness / 2, + ], # Front right + [ + -width / 2 - finger_width / 2, + -finger_length, + -finger_thickness / 2, + ], # Front left + ] + ) + + # Handle extending backward - thin rectangle + handle_points = np.array( + [ + [ + -finger_width / 2, + -finger_length - finger_thickness - handle_length, + -finger_thickness / 2, + ], # Back left + [ + finger_width / 2, + -finger_length - finger_thickness - handle_length, + -finger_thickness / 2, + ], # Back right + [ + finger_width / 2, + -finger_length - finger_thickness, + -finger_thickness / 2, + ], # Front right + [ + -finger_width / 2, + -finger_length - finger_thickness, + -finger_thickness / 2, + ], # Front left + ] + ) + + # Transform all points to world frame + def transform_points(points): + # Apply rotation and translation + world_points = (rotation_matrix @ points.T).T + translation + return world_points + + left_finger_world = transform_points(left_finger_points) + right_finger_world = transform_points(right_finger_points) + base_world = transform_points(base_points) + handle_world = transform_points(handle_points) + + # Project to 2D + left_finger_2d = project_3d_points_to_2d(left_finger_world, camera_matrix) + right_finger_2d = project_3d_points_to_2d(right_finger_world, camera_matrix) + base_2d = project_3d_points_to_2d(base_world, camera_matrix) + handle_2d = project_3d_points_to_2d(handle_world, camera_matrix) + + # Draw left finger + pts = left_finger_2d.astype(np.int32) + cv2.polylines(result, [pts], True, color, thickness) + + # Draw right finger + pts = right_finger_2d.astype(np.int32) + cv2.polylines(result, [pts], True, color, thickness) + + # Draw base + pts = base_2d.astype(np.int32) + cv2.polylines(result, [pts], True, color, thickness) + + # Draw handle + pts = handle_2d.astype(np.int32) + cv2.polylines(result, [pts], True, color, thickness) + + # Draw grasp center (fingertips at origin) + center_2d = project_3d_points_to_2d(translation.reshape(1, -1), camera_matrix)[0] + cv2.circle(result, tuple(center_2d.astype(int)), 3, color, -1) + + except Exception as e: + # Skip this grasp if there's an error + continue + + return result + + +def get_standard_coordinate_transform(): + """ + Get a standard coordinate transformation matrix for consistent visualization. + + This transformation ensures that: + - X (red) axis points right + - Y (green) axis points up + - Z (blue) axis points toward viewer + + Returns: + 4x4 transformation matrix + """ + # Standard transformation matrix to ensure consistent coordinate frame orientation + transform = np.array( + [ + [1, 0, 0, 0], # X points right + [0, -1, 0, 0], # Y points up (flip from OpenCV to standard) + [0, 0, -1, 0], # Z points toward viewer (flip depth) + [0, 0, 0, 1], + ] + ) + return transform + + +def visualize_grasps_3d( + point_cloud: o3d.geometry.PointCloud, + grasp_list: List[dict], + max_grasps: int = -1, +): + """ + Visualize grasps in 3D with point cloud. + + Args: + point_cloud: Open3D point cloud + grasp_list: List of grasp dictionaries + max_grasps: Maximum number of grasps to visualize + """ + # Apply standard coordinate transformation + transform = get_standard_coordinate_transform() + + # Transform point cloud + pc_copy = o3d.geometry.PointCloud(point_cloud) + pc_copy.transform(transform) + geometries = [pc_copy] + + # Transform gripper geometries + gripper_geometries = create_all_gripper_geometries(grasp_list, max_grasps) + for geom in gripper_geometries: + geom.transform(transform) + geometries.extend(gripper_geometries) + + # Add transformed coordinate frame + origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + origin_frame.transform(transform) + geometries.append(origin_frame) + + o3d.visualization.draw_geometries(geometries, window_name="3D Grasp Visualization") + + +def rotation_matrix_to_euler(rotation_matrix: np.ndarray) -> Tuple[float, float, float]: + """ + Convert 3x3 rotation matrix to Euler angles (roll, pitch, yaw). + + Args: + rotation_matrix: 3x3 rotation matrix + + Returns: + Tuple of (roll, pitch, yaw) in radians + """ + sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) + singular = sy < 1e-6 + + if not singular: + x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) # roll + y = np.arctan2(-rotation_matrix[2, 0], sy) # pitch + z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) # yaw + else: + x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) # roll + y = np.arctan2(-rotation_matrix[2, 0], sy) # pitch + z = 0 # yaw + + return x, y, z + + +def parse_anygrasp_results(grasps: List[Dict]) -> List[Dict]: + """ + Parse AnyGrasp results into visualization format. + + Args: + grasps: List of AnyGrasp grasp dictionaries + + Returns: + List of dictionaries containing: + - id: Unique grasp identifier + - score: Confidence score (float) + - width: Gripper opening width (float) + - translation: 3D position [x, y, z] + - rotation_matrix: 3x3 rotation matrix as nested list + """ + if not grasps: + return [] + + parsed_grasps = [] + + for i, grasp in enumerate(grasps): + # Extract data from each grasp + translation = grasp.get("translation", [0, 0, 0]) + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) + score = float(grasp.get("score", 0.0)) + width = float(grasp.get("width", 0.08)) + + parsed_grasp = { + "id": f"grasp_{i}", + "score": score, + "width": width, + "translation": translation, + "rotation_matrix": rotation_matrix.tolist(), + } + parsed_grasps.append(parsed_grasp) + + return parsed_grasps + + +def create_grasp_overlay( + rgb_image: np.ndarray, + grasps: List[Dict], + camera_intrinsics: Union[List[float], np.ndarray], +) -> np.ndarray: + """ + Create grasp visualization overlay on RGB image. + + Args: + rgb_image: RGB input image + grasps: List of grasp dictionaries in viz format + camera_intrinsics: Camera parameters + + Returns: + RGB image with grasp overlay + """ + try: + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + result_bgr = draw_grasps_on_image( + bgr_image, + grasps, + camera_intrinsics, + max_grasps=-1, + ) + return cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) + except Exception as e: + return rgb_image.copy() diff --git a/build/lib/dimos/perception/object_detection_stream.py b/build/lib/dimos/perception/object_detection_stream.py new file mode 100644 index 0000000000..3284d99f8b --- /dev/null +++ b/build/lib/dimos/perception/object_detection_stream.py @@ -0,0 +1,373 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import time +import numpy as np +from reactivex import Observable +from reactivex import operators as ops + +from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector + +try: + from dimos.perception.detection2d.detic_2d_det import Detic2DDetector + + DETIC_AVAILABLE = True +except (ModuleNotFoundError, ImportError): + DETIC_AVAILABLE = False + Detic2DDetector = None +from dimos.models.depth.metric3d import Metric3D +from dimos.perception.detection2d.utils import ( + calculate_depth_from_bbox, + calculate_object_size_from_bbox, + calculate_position_rotation_from_bbox, +) +from dimos.types.vector import Vector +from typing import Optional, Union, Callable +from dimos.types.manipulation import ObjectData +from dimos.utils.transform_utils import transform_robot_to_map + +from dimos.utils.logging_config import setup_logger + +# Initialize logger for the ObjectDetectionStream +logger = setup_logger("dimos.perception.object_detection_stream") + + +class ObjectDetectionStream: + """ + A stream processor that: + 1. Detects objects using a Detector (Detic or Yolo) + 2. Estimates depth using Metric3D + 3. Calculates 3D position and dimensions using camera intrinsics + 4. Transforms coordinates to map frame + 5. Draws bounding boxes and segmentation masks on the frame + + Provides a stream of structured object data with position and rotation information. + """ + + def __init__( + self, + camera_intrinsics=None, # [fx, fy, cx, cy] + device="cuda", + gt_depth_scale=1000.0, + min_confidence=0.7, + class_filter=None, # Optional list of class names to filter (e.g., ["person", "car"]) + get_pose: Callable = None, # Optional function to transform coordinates to map frame + detector: Optional[Union[Detic2DDetector, Yolo2DDetector]] = None, + video_stream: Observable = None, + disable_depth: bool = False, # Flag to disable monocular Metric3D depth estimation + draw_masks: bool = False, # Flag to enable drawing segmentation masks + ): + """ + Initialize the ObjectDetectionStream. + + Args: + camera_intrinsics: List [fx, fy, cx, cy] with camera parameters + device: Device to run inference on ("cuda" or "cpu") + gt_depth_scale: Ground truth depth scale for Metric3D + min_confidence: Minimum confidence for detections + class_filter: Optional list of class names to filter + get_pose: Optional function to transform pose to map coordinates + detector: Optional detector instance (Detic or Yolo) + video_stream: Observable of video frames to process (if provided, returns a stream immediately) + disable_depth: Flag to disable monocular Metric3D depth estimation + draw_masks: Flag to enable drawing segmentation masks + """ + self.min_confidence = min_confidence + self.class_filter = class_filter + self.get_pose = get_pose + self.disable_depth = disable_depth + self.draw_masks = draw_masks + # Initialize object detector + if detector is not None: + self.detector = detector + else: + if DETIC_AVAILABLE: + try: + self.detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) + logger.info("Using Detic2DDetector") + except Exception as e: + logger.warning( + f"Failed to initialize Detic2DDetector: {e}. Falling back to Yolo2DDetector." + ) + self.detector = Yolo2DDetector() + else: + logger.info("Detic not available. Using Yolo2DDetector.") + self.detector = Yolo2DDetector() + # Set up camera intrinsics + self.camera_intrinsics = camera_intrinsics + + # Initialize depth estimation model + self.depth_model = None + if not disable_depth: + try: + self.depth_model = Metric3D(gt_depth_scale) + + if camera_intrinsics is not None: + self.depth_model.update_intrinsic(camera_intrinsics) + + # Create 3x3 camera matrix for calculations + fx, fy, cx, cy = camera_intrinsics + self.camera_matrix = np.array( + [[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32 + ) + else: + raise ValueError("camera_intrinsics must be provided") + + logger.info("Depth estimation enabled with Metric3D") + except Exception as e: + logger.warning(f"Failed to initialize Metric3D depth model: {e}") + logger.warning("Falling back to disable_depth=True mode") + self.disable_depth = True + self.depth_model = None + else: + logger.info("Depth estimation disabled") + + # If video_stream is provided, create and store the stream immediately + self.stream = None + if video_stream is not None: + self.stream = self.create_stream(video_stream) + + def create_stream(self, video_stream: Observable) -> Observable: + """ + Create an Observable stream of object data from a video stream. + + Args: + video_stream: Observable that emits video frames + + Returns: + Observable that emits dictionaries containing object data + with position and rotation information + """ + + def process_frame(frame): + # TODO: More modular detector output interface + bboxes, track_ids, class_ids, confidences, names, *mask_data = ( + self.detector.process_image(frame) + ([],) + ) + + masks = ( + mask_data[0] + if mask_data and len(mask_data[0]) == len(bboxes) + else [None] * len(bboxes) + ) + + # Create visualization + viz_frame = frame.copy() + + # Process detections + objects = [] + if not self.disable_depth: + depth_map = self.depth_model.infer_depth(frame) + depth_map = np.array(depth_map) + else: + depth_map = None + + for i, bbox in enumerate(bboxes): + # Skip if confidence is too low + if i < len(confidences) and confidences[i] < self.min_confidence: + continue + + # Skip if class filter is active and class not in filter + class_name = names[i] if i < len(names) else None + if self.class_filter and class_name not in self.class_filter: + continue + + if not self.disable_depth and depth_map is not None: + # Get depth for this object + depth = calculate_depth_from_bbox(depth_map, bbox) + if depth is None: + # Skip objects with invalid depth + continue + # Calculate object position and rotation + position, rotation = calculate_position_rotation_from_bbox( + bbox, depth, self.camera_intrinsics + ) + # Get object dimensions + width, height = calculate_object_size_from_bbox( + bbox, depth, self.camera_intrinsics + ) + + # Transform to map frame if a transform function is provided + try: + if self.get_pose: + # position and rotation are already Vector objects, no need to convert + robot_pose = self.get_pose() + position, rotation = transform_robot_to_map( + robot_pose["position"], robot_pose["rotation"], position, rotation + ) + except Exception as e: + logger.error(f"Error transforming to map frame: {e}") + position, rotation = position, rotation + + else: + depth = -1 + position = Vector(0, 0, 0) + rotation = Vector(0, 0, 0) + width = -1 + height = -1 + + # Create a properly typed ObjectData instance + object_data: ObjectData = { + "object_id": track_ids[i] if i < len(track_ids) else -1, + "bbox": bbox, + "depth": depth, + "confidence": confidences[i] if i < len(confidences) else None, + "class_id": class_ids[i] if i < len(class_ids) else None, + "label": class_name, + "position": position, + "rotation": rotation, + "size": {"width": width, "height": height}, + "segmentation_mask": masks[i], + } + + objects.append(object_data) + + # Add visualization + x1, y1, x2, y2 = map(int, bbox) + color = (0, 255, 0) # Green for detected objects + mask_color = (0, 200, 200) # Yellow-green for masks + + # Draw segmentation mask if available and valid + try: + if self.draw_masks and object_data["segmentation_mask"] is not None: + # Create a colored mask overlay + mask = object_data["segmentation_mask"].astype(np.uint8) + colored_mask = np.zeros_like(viz_frame) + colored_mask[mask > 0] = mask_color + + # Apply the mask with transparency + alpha = 0.5 # transparency factor + mask_area = mask > 0 + viz_frame[mask_area] = cv2.addWeighted( + viz_frame[mask_area], 1 - alpha, colored_mask[mask_area], alpha, 0 + ) + + # Draw mask contour + contours, _ = cv2.findContours( + mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + cv2.drawContours(viz_frame, contours, -1, mask_color, 2) + except Exception as e: + logger.warning(f"Error drawing segmentation mask: {e}") + + # Draw bounding box with metadata + try: + cv2.rectangle(viz_frame, (x1, y1), (x2, y2), color, 1) + + # Add text for class only (removed position data) + # Handle possible None values for class_name or track_ids[i] + class_text = class_name if class_name is not None else "Unknown" + id_text = ( + track_ids[i] if i < len(track_ids) and track_ids[i] is not None else "?" + ) + text = f"{class_text}, ID: {id_text}" + + # Draw text background with smaller font + text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.3, 1)[0] + cv2.rectangle( + viz_frame, + (x1, y1 - text_size[1] - 5), + (x1 + text_size[0], y1), + (0, 0, 0), + -1, + ) + + # Draw text with smaller font + cv2.putText( + viz_frame, + text, + (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.3, + (255, 255, 255), + 1, + ) + except Exception as e: + logger.warning(f"Error drawing bounding box or text: {e}") + + return {"frame": frame, "viz_frame": viz_frame, "objects": objects} + + self.stream = video_stream.pipe(ops.map(process_frame)) + + return self.stream + + def get_stream(self): + """ + Returns the current detection stream if available. + Creates a new one with the provided video_stream if not already created. + + Returns: + Observable: The reactive stream of detection results + """ + if self.stream is None: + raise ValueError( + "Stream not initialized. Either provide a video_stream during initialization or call create_stream first." + ) + return self.stream + + def get_formatted_stream(self): + """ + Returns a formatted stream of object detection data for better readability. + This is especially useful for LLMs like Claude that need structured text input. + + Returns: + Observable: A stream of formatted string representations of object data + """ + if self.stream is None: + raise ValueError( + "Stream not initialized. Either provide a video_stream during initialization or call create_stream first." + ) + + def format_detection_data(result): + # Extract objects from result + objects = result.get("objects", []) + + if not objects: + return "No objects detected." + + formatted_data = "[DETECTED OBJECTS]\n" + try: + for i, obj in enumerate(objects): + pos = obj["position"] + rot = obj["rotation"] + size = obj["size"] + bbox = obj["bbox"] + + # Format each object with a multiline f-string for better readability + bbox_str = f"[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]" + formatted_data += ( + f"Object {i + 1}: {obj['label']}\n" + f" ID: {obj['object_id']}\n" + f" Confidence: {obj['confidence']:.2f}\n" + f" Position: x={pos.x:.2f}m, y={pos.y:.2f}m, z={pos.z:.2f}m\n" + f" Rotation: yaw={rot.z:.2f} rad\n" + f" Size: width={size['width']:.2f}m, height={size['height']:.2f}m\n" + f" Depth: {obj['depth']:.2f}m\n" + f" Bounding box: {bbox_str}\n" + "----------------------------------\n" + ) + except Exception as e: + logger.warning(f"Error formatting object {i}: {e}") + formatted_data += f"Object {i + 1}: [Error formatting data]" + formatted_data += "\n----------------------------------\n" + + return formatted_data + + # Return a new stream with the formatter applied + return self.stream.pipe(ops.map(format_detection_data)) + + def cleanup(self): + """Clean up resources.""" + pass diff --git a/build/lib/dimos/perception/object_tracker.py b/build/lib/dimos/perception/object_tracker.py new file mode 100644 index 0000000000..010dbb9f3e --- /dev/null +++ b/build/lib/dimos/perception/object_tracker.py @@ -0,0 +1,357 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +from reactivex import Observable +from reactivex import operators as ops +import numpy as np +from dimos.perception.common.ibvs import ObjectDistanceEstimator +from dimos.models.depth.metric3d import Metric3D +from dimos.perception.detection2d.utils import calculate_depth_from_bbox + + +class ObjectTrackingStream: + def __init__( + self, + camera_intrinsics=None, + camera_pitch=0.0, + camera_height=1.0, + reid_threshold=5, + reid_fail_tolerance=10, + gt_depth_scale=1000.0, + ): + """ + Initialize an object tracking stream using OpenCV's CSRT tracker with ORB re-ID. + + Args: + camera_intrinsics: List in format [fx, fy, cx, cy] where: + - fx: Focal length in x direction (pixels) + - fy: Focal length in y direction (pixels) + - cx: Principal point x-coordinate (pixels) + - cy: Principal point y-coordinate (pixels) + camera_pitch: Camera pitch angle in radians (positive is up) + camera_height: Height of the camera from the ground in meters + reid_threshold: Minimum good feature matches needed to confirm re-ID. + reid_fail_tolerance: Number of consecutive frames Re-ID can fail before + tracking is stopped. + gt_depth_scale: Ground truth depth scale factor for Metric3D model + """ + self.tracker = None + self.tracking_bbox = None # Stores (x, y, w, h) for tracker initialization + self.tracking_initialized = False + self.orb = cv2.ORB_create() + self.bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False) + self.original_des = None # Store original ORB descriptors + self.reid_threshold = reid_threshold + self.reid_fail_tolerance = reid_fail_tolerance + self.reid_fail_count = 0 # Counter for consecutive re-id failures + + # Initialize distance estimator if camera parameters are provided + self.distance_estimator = None + if camera_intrinsics is not None: + # Convert [fx, fy, cx, cy] to 3x3 camera matrix + fx, fy, cx, cy = camera_intrinsics + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + self.distance_estimator = ObjectDistanceEstimator( + K=K, camera_pitch=camera_pitch, camera_height=camera_height + ) + + # Initialize depth model + self.depth_model = Metric3D(gt_depth_scale) + if camera_intrinsics is not None: + self.depth_model.update_intrinsic(camera_intrinsics) + + def track(self, bbox, frame=None, distance=None, size=None): + """ + Set the initial bounding box for tracking. Features are extracted later. + + Args: + bbox: Bounding box in format [x1, y1, x2, y2] + frame: Optional - Current frame for depth estimation and feature extraction + distance: Optional - Known distance to object (meters) + size: Optional - Known size of object (meters) + + Returns: + bool: True if intention to track is set (bbox is valid) + """ + x1, y1, x2, y2 = map(int, bbox) + w, h = x2 - x1, y2 - y1 + if w <= 0 or h <= 0: + print(f"Warning: Invalid initial bbox provided: {bbox}. Tracking not started.") + self.stop_track() # Ensure clean state + return False + + self.tracking_bbox = (x1, y1, w, h) # Store in (x, y, w, h) format + self.tracker = cv2.legacy.TrackerCSRT_create() + self.tracking_initialized = False # Reset flag + self.original_des = None # Clear previous descriptors + self.reid_fail_count = 0 # Reset counter on new track + print(f"Tracking target set with bbox: {self.tracking_bbox}") + + # Calculate depth only if distance and size not provided + if frame is not None and distance is None and size is None: + depth_map = self.depth_model.infer_depth(frame) + depth_map = np.array(depth_map) + depth_estimate = calculate_depth_from_bbox(depth_map, bbox) + if depth_estimate is not None: + print(f"Estimated depth for object: {depth_estimate:.2f}m") + + # Update distance estimator if needed + if self.distance_estimator is not None: + if size is not None: + self.distance_estimator.set_estimated_object_size(size) + elif distance is not None: + self.distance_estimator.estimate_object_size(bbox, distance) + elif depth_estimate is not None: + self.distance_estimator.estimate_object_size(bbox, depth_estimate) + else: + print("No distance or size provided. Cannot estimate object size.") + + return True # Indicate intention to track is set + + def calculate_depth_from_bbox(self, frame, bbox): + """ + Calculate the average depth of an object within a bounding box. + Uses the 25th to 75th percentile range to filter outliers. + + Args: + frame: The image frame + bbox: Bounding box in format [x1, y1, x2, y2] + + Returns: + float: Average depth in meters, or None if depth estimation fails + """ + try: + # Get depth map for the entire frame + depth_map = self.depth_model.infer_depth(frame) + depth_map = np.array(depth_map) + + # Extract region of interest from the depth map + x1, y1, x2, y2 = map(int, bbox) + roi_depth = depth_map[y1:y2, x1:x2] + + if roi_depth.size == 0: + return None + + # Calculate 25th and 75th percentile to filter outliers + p25 = np.percentile(roi_depth, 25) + p75 = np.percentile(roi_depth, 75) + + # Filter depth values within this range + filtered_depth = roi_depth[(roi_depth >= p25) & (roi_depth <= p75)] + + # Calculate average depth (convert to meters) + if filtered_depth.size > 0: + return np.mean(filtered_depth) / 1000.0 # Convert mm to meters + + return None + except Exception as e: + print(f"Error calculating depth from bbox: {e}") + return None + + def reid(self, frame, current_bbox) -> bool: + """Check if features in current_bbox match stored original features.""" + if self.original_des is None: + return True # Cannot re-id if no original features + x1, y1, x2, y2 = map(int, current_bbox) + roi = frame[y1:y2, x1:x2] + if roi.size == 0: + return False # Empty ROI cannot match + + _, des_current = self.orb.detectAndCompute(roi, None) + if des_current is None or len(des_current) < 2: + return False # Need at least 2 descriptors for knnMatch + + # Handle case where original_des has only 1 descriptor (cannot use knnMatch with k=2) + if len(self.original_des) < 2: + matches = self.bf.match(self.original_des, des_current) + good_matches = len(matches) + else: + matches = self.bf.knnMatch(self.original_des, des_current, k=2) + # Apply Lowe's ratio test robustly + good_matches = 0 + for match_pair in matches: + if len(match_pair) == 2: + m, n = match_pair + if m.distance < 0.75 * n.distance: + good_matches += 1 + + # print(f"ReID: Good Matches={good_matches}, Threshold={self.reid_threshold}") # Debug + return good_matches >= self.reid_threshold + + def stop_track(self): + """ + Stop tracking the current object. + This resets the tracker and all tracking state. + + Returns: + bool: True if tracking was successfully stopped + """ + self.tracker = None + self.tracking_bbox = None + self.tracking_initialized = False + self.original_des = None + self.reid_fail_count = 0 # Reset counter + return True + + def create_stream(self, video_stream: Observable) -> Observable: + """ + Create an Observable stream of object tracking results from a video stream. + + Args: + video_stream: Observable that emits video frames + + Returns: + Observable that emits dictionaries containing tracking results and visualizations + """ + + def process_frame(frame): + viz_frame = frame.copy() + tracker_succeeded = False # Success from tracker.update() + reid_confirmed_this_frame = False # Result of reid() call for this frame + final_success = False # Overall success considering re-id tolerance + target_data = None + current_bbox_x1y1x2y2 = None # Store current bbox if tracking succeeds + + if self.tracker is not None and self.tracking_bbox is not None: + if not self.tracking_initialized: + # Extract initial features and initialize tracker on first frame + x_init, y_init, w_init, h_init = self.tracking_bbox + roi = frame[y_init : y_init + h_init, x_init : x_init + w_init] + + if roi.size > 0: + _, self.original_des = self.orb.detectAndCompute(roi, None) + if self.original_des is None: + print( + "Warning: No ORB features found in initial ROI during stream processing." + ) + else: + print(f"Initial ORB features extracted: {len(self.original_des)}") + + # Initialize the tracker + init_success = self.tracker.init(frame, self.tracking_bbox) + if init_success: + self.tracking_initialized = True + tracker_succeeded = True + reid_confirmed_this_frame = True # Assume re-id true on init + current_bbox_x1y1x2y2 = [ + x_init, + y_init, + x_init + w_init, + y_init + h_init, + ] + print("Tracker initialized successfully.") + else: + print("Error: Tracker initialization failed in stream.") + self.stop_track() # Reset if init fails + else: + print("Error: Empty ROI during tracker initialization in stream.") + self.stop_track() # Reset if ROI is bad + + else: # Tracker already initialized, perform update and re-id + tracker_succeeded, bbox_cv = self.tracker.update(frame) + if tracker_succeeded: + x, y, w, h = map(int, bbox_cv) + current_bbox_x1y1x2y2 = [x, y, x + w, y + h] + # Perform re-ID check + reid_confirmed_this_frame = self.reid(frame, current_bbox_x1y1x2y2) + + if reid_confirmed_this_frame: + self.reid_fail_count = 0 # Reset counter on success + else: + self.reid_fail_count += 1 # Increment counter on failure + print( + f"Re-ID failed ({self.reid_fail_count}/{self.reid_fail_tolerance}). Continuing track..." + ) + + # --- Determine final success and stop tracking if needed --- + if tracker_succeeded: + if self.reid_fail_count >= self.reid_fail_tolerance: + print(f"Re-ID failed consecutively {self.reid_fail_count} times. Target lost.") + final_success = False # Stop tracking + else: + final_success = True # Tracker ok, Re-ID ok or within tolerance + else: # Tracker update failed + final_success = False + if self.tracking_initialized: + print("Tracker update failed. Stopping track.") + + # --- Post-processing based on final_success --- + if final_success and current_bbox_x1y1x2y2 is not None: + # Tracking is considered successful (tracker ok, re-id ok or within tolerance) + x1, y1, x2, y2 = current_bbox_x1y1x2y2 + # Visualize based on *this frame's* re-id result + viz_color = ( + (0, 255, 0) if reid_confirmed_this_frame else (0, 165, 255) + ) # Green if confirmed, Orange if failed but tolerated + cv2.rectangle(viz_frame, (x1, y1), (x2, y2), viz_color, 2) + + target_data = { + "target_id": 0, + "bbox": current_bbox_x1y1x2y2, + "confidence": 1.0, + "reid_confirmed": reid_confirmed_this_frame, # Report actual re-id status + } + + dist_text = "Object Tracking" + if not reid_confirmed_this_frame: + dist_text += " (Re-ID Failed - Tolerated)" + + if ( + self.distance_estimator is not None + and self.distance_estimator.estimated_object_size is not None + ): + distance, angle = self.distance_estimator.estimate_distance_angle( + current_bbox_x1y1x2y2 + ) + if distance is not None: + target_data["distance"] = distance + target_data["angle"] = angle + dist_text = f"Object: {distance:.2f}m, {np.rad2deg(angle):.1f} deg" + if not reid_confirmed_this_frame: + dist_text += " (Re-ID Failed - Tolerated)" + + text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] + label_bg_y = max(y1 - text_size[1] - 5, 0) + cv2.rectangle(viz_frame, (x1, label_bg_y), (x1 + text_size[0], y1), (0, 0, 0), -1) + cv2.putText( + viz_frame, + dist_text, + (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 1, + ) + + elif ( + self.tracking_initialized + ): # Tracking stopped this frame (either tracker fail or re-id tolerance exceeded) + self.stop_track() # Reset tracker state and counter + + # else: # Not tracking or initialization failed, do nothing, return empty result + # pass + + return { + "frame": frame, + "viz_frame": viz_frame, + "targets": [target_data] if target_data else [], + } + + return video_stream.pipe(ops.map(process_frame)) + + def cleanup(self): + """Clean up resources.""" + self.stop_track() diff --git a/build/lib/dimos/perception/person_tracker.py b/build/lib/dimos/perception/person_tracker.py new file mode 100644 index 0000000000..0a2f9cc7b7 --- /dev/null +++ b/build/lib/dimos/perception/person_tracker.py @@ -0,0 +1,154 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector +from dimos.perception.detection2d.utils import filter_detections +from dimos.perception.common.ibvs import PersonDistanceEstimator +from reactivex import Observable +from reactivex import operators as ops +import numpy as np +import cv2 + + +class PersonTrackingStream: + def __init__( + self, + camera_intrinsics=None, + camera_pitch=0.0, + camera_height=1.0, + ): + """ + Initialize a person tracking stream using Yolo2DDetector and PersonDistanceEstimator. + + Args: + camera_intrinsics: List in format [fx, fy, cx, cy] where: + - fx: Focal length in x direction (pixels) + - fy: Focal length in y direction (pixels) + - cx: Principal point x-coordinate (pixels) + - cy: Principal point y-coordinate (pixels) + camera_pitch: Camera pitch angle in radians (positive is up) + camera_height: Height of the camera from the ground in meters + """ + self.detector = Yolo2DDetector() + + # Initialize distance estimator + if camera_intrinsics is None: + raise ValueError("Camera intrinsics are required for distance estimation") + + # Validate camera intrinsics format [fx, fy, cx, cy] + if ( + not isinstance(camera_intrinsics, (list, tuple, np.ndarray)) + or len(camera_intrinsics) != 4 + ): + raise ValueError("Camera intrinsics must be provided as [fx, fy, cx, cy]") + + # Convert [fx, fy, cx, cy] to 3x3 camera matrix + fx, fy, cx, cy = camera_intrinsics + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + self.distance_estimator = PersonDistanceEstimator( + K=K, camera_pitch=camera_pitch, camera_height=camera_height + ) + + def create_stream(self, video_stream: Observable) -> Observable: + """ + Create an Observable stream of person tracking results from a video stream. + + Args: + video_stream: Observable that emits video frames + + Returns: + Observable that emits dictionaries containing tracking results and visualizations + """ + + def process_frame(frame): + # Detect people in the frame + bboxes, track_ids, class_ids, confidences, names = self.detector.process_image(frame) + + # Filter to keep only person detections using filter_detections + ( + filtered_bboxes, + filtered_track_ids, + filtered_class_ids, + filtered_confidences, + filtered_names, + ) = filter_detections( + bboxes, + track_ids, + class_ids, + confidences, + names, + class_filter=[0], # 0 is the class_id for person + name_filter=["person"], + ) + + # Create visualization + viz_frame = self.detector.visualize_results( + frame, + filtered_bboxes, + filtered_track_ids, + filtered_class_ids, + filtered_confidences, + filtered_names, + ) + + # Calculate distance and angle for each person + targets = [] + for i, bbox in enumerate(filtered_bboxes): + target_data = { + "target_id": filtered_track_ids[i] if i < len(filtered_track_ids) else -1, + "bbox": bbox, + "confidence": filtered_confidences[i] + if i < len(filtered_confidences) + else None, + } + + distance, angle = self.distance_estimator.estimate_distance_angle(bbox) + target_data["distance"] = distance + target_data["angle"] = angle + + # Add text to visualization + x1, y1, x2, y2 = map(int, bbox) + dist_text = f"{distance:.2f}m, {np.rad2deg(angle):.1f} deg" + + # Add black background for better visibility + text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] + # Position at top-right corner + cv2.rectangle( + viz_frame, (x2 - text_size[0], y1 - text_size[1] - 5), (x2, y1), (0, 0, 0), -1 + ) + + # Draw text in white at top-right + cv2.putText( + viz_frame, + dist_text, + (x2 - text_size[0], y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 2, + ) + + targets.append(target_data) + + # Create the result dictionary + result = {"frame": frame, "viz_frame": viz_frame, "targets": targets} + + return result + + return video_stream.pipe(ops.map(process_frame)) + + def cleanup(self): + """Clean up resources.""" + pass # No specific cleanup needed for now diff --git a/build/lib/dimos/perception/pointcloud/__init__.py b/build/lib/dimos/perception/pointcloud/__init__.py new file mode 100644 index 0000000000..1f282bb738 --- /dev/null +++ b/build/lib/dimos/perception/pointcloud/__init__.py @@ -0,0 +1,3 @@ +from .utils import * +from .cuboid_fit import * +from .pointcloud_filtering import * diff --git a/build/lib/dimos/perception/pointcloud/cuboid_fit.py b/build/lib/dimos/perception/pointcloud/cuboid_fit.py new file mode 100644 index 0000000000..d567f40395 --- /dev/null +++ b/build/lib/dimos/perception/pointcloud/cuboid_fit.py @@ -0,0 +1,414 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import open3d as o3d +import cv2 +from typing import Dict, Optional, Union, Tuple + + +def fit_cuboid( + points: Union[np.ndarray, o3d.geometry.PointCloud], method: str = "minimal" +) -> Optional[Dict]: + """ + Fit a cuboid to a point cloud using Open3D's built-in methods. + + Args: + points: Nx3 array of points or Open3D PointCloud + method: Fitting method: + - 'minimal': Minimal oriented bounding box (best fit) + - 'oriented': PCA-based oriented bounding box + - 'axis_aligned': Axis-aligned bounding box + + Returns: + Dictionary containing: + - center: 3D center point + - dimensions: 3D dimensions (extent) + - rotation: 3x3 rotation matrix + - error: Fitting error + - bounding_box: Open3D OrientedBoundingBox object + Returns None if insufficient points or fitting fails. + + Raises: + ValueError: If method is invalid or inputs are malformed + """ + # Validate method + valid_methods = ["minimal", "oriented", "axis_aligned"] + if method not in valid_methods: + raise ValueError(f"method must be one of {valid_methods}, got '{method}'") + + # Convert to point cloud if needed + if isinstance(points, np.ndarray): + points = np.asarray(points) + if len(points.shape) != 2 or points.shape[1] != 3: + raise ValueError(f"points array must be Nx3, got shape {points.shape}") + if len(points) < 4: + return None + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + elif isinstance(points, o3d.geometry.PointCloud): + pcd = points + points = np.asarray(pcd.points) + if len(points) < 4: + return None + else: + raise ValueError(f"points must be numpy array or Open3D PointCloud, got {type(points)}") + + try: + # Get bounding box based on method + if method == "minimal": + obb = pcd.get_minimal_oriented_bounding_box(robust=True) + elif method == "oriented": + obb = pcd.get_oriented_bounding_box(robust=True) + elif method == "axis_aligned": + # Convert axis-aligned to oriented format for consistency + aabb = pcd.get_axis_aligned_bounding_box() + obb = o3d.geometry.OrientedBoundingBox() + obb.center = aabb.get_center() + obb.extent = aabb.get_extent() + obb.R = np.eye(3) # Identity rotation for axis-aligned + + # Extract parameters + center = np.asarray(obb.center) + dimensions = np.asarray(obb.extent) + rotation = np.asarray(obb.R) + + # Calculate fitting error + error = _compute_fitting_error(points, center, dimensions, rotation) + + return { + "center": center, + "dimensions": dimensions, + "rotation": rotation, + "error": error, + "bounding_box": obb, + "method": method, + } + + except Exception as e: + # Log error but don't crash - return None for graceful handling + print(f"Warning: Cuboid fitting failed with method '{method}': {e}") + return None + + +def fit_cuboid_simple(points: Union[np.ndarray, o3d.geometry.PointCloud]) -> Optional[Dict]: + """ + Simple wrapper for minimal oriented bounding box fitting. + + Args: + points: Nx3 array of points or Open3D PointCloud + + Returns: + Dictionary with center, dimensions, rotation, and bounding_box, + or None if insufficient points + """ + return fit_cuboid(points, method="minimal") + + +def _compute_fitting_error( + points: np.ndarray, center: np.ndarray, dimensions: np.ndarray, rotation: np.ndarray +) -> float: + """ + Compute fitting error as mean squared distance from points to cuboid surface. + + Args: + points: Nx3 array of points + center: 3D center point + dimensions: 3D dimensions + rotation: 3x3 rotation matrix + + Returns: + Mean squared error + """ + if len(points) == 0: + return 0.0 + + # Transform points to local coordinates + local_points = (points - center) @ rotation + half_dims = dimensions / 2 + + # Calculate distance to cuboid surface + dx = np.abs(local_points[:, 0]) - half_dims[0] + dy = np.abs(local_points[:, 1]) - half_dims[1] + dz = np.abs(local_points[:, 2]) - half_dims[2] + + # Points outside: distance to nearest face + # Points inside: negative distance to nearest face + outside_dist = np.sqrt(np.maximum(dx, 0) ** 2 + np.maximum(dy, 0) ** 2 + np.maximum(dz, 0) ** 2) + inside_dist = np.minimum(np.minimum(dx, dy), dz) + distances = np.where((dx > 0) | (dy > 0) | (dz > 0), outside_dist, -inside_dist) + + return float(np.mean(distances**2)) + + +def get_cuboid_corners( + center: np.ndarray, dimensions: np.ndarray, rotation: np.ndarray +) -> np.ndarray: + """ + Get the 8 corners of a cuboid. + + Args: + center: 3D center point + dimensions: 3D dimensions + rotation: 3x3 rotation matrix + + Returns: + 8x3 array of corner coordinates + """ + half_dims = dimensions / 2 + corners_local = ( + np.array( + [ + [-1, -1, -1], # 0: left bottom back + [-1, -1, 1], # 1: left bottom front + [-1, 1, -1], # 2: left top back + [-1, 1, 1], # 3: left top front + [1, -1, -1], # 4: right bottom back + [1, -1, 1], # 5: right bottom front + [1, 1, -1], # 6: right top back + [1, 1, 1], # 7: right top front + ] + ) + * half_dims + ) + + # Apply rotation and translation + return corners_local @ rotation.T + center + + +def visualize_cuboid_on_image( + image: np.ndarray, + cuboid_params: Dict, + camera_matrix: np.ndarray, + extrinsic_rotation: Optional[np.ndarray] = None, + extrinsic_translation: Optional[np.ndarray] = None, + color: Tuple[int, int, int] = (0, 255, 0), + thickness: int = 2, + show_dimensions: bool = True, +) -> np.ndarray: + """ + Draw a fitted cuboid on an image using camera projection. + + Args: + image: Input image to draw on + cuboid_params: Dictionary containing cuboid parameters + camera_matrix: Camera intrinsic matrix (3x3) + extrinsic_rotation: Optional external rotation (3x3) + extrinsic_translation: Optional external translation (3x1) + color: Line color as (B, G, R) tuple + thickness: Line thickness + show_dimensions: Whether to display dimension text + + Returns: + Image with cuboid visualization + + Raises: + ValueError: If required parameters are missing or invalid + """ + # Validate inputs + required_keys = ["center", "dimensions", "rotation"] + if not all(key in cuboid_params for key in required_keys): + raise ValueError(f"cuboid_params must contain keys: {required_keys}") + + if camera_matrix.shape != (3, 3): + raise ValueError(f"camera_matrix must be 3x3, got {camera_matrix.shape}") + + # Get corners in world coordinates + corners = get_cuboid_corners( + cuboid_params["center"], cuboid_params["dimensions"], cuboid_params["rotation"] + ) + + # Transform corners if extrinsic parameters are provided + if extrinsic_rotation is not None and extrinsic_translation is not None: + if extrinsic_rotation.shape != (3, 3): + raise ValueError(f"extrinsic_rotation must be 3x3, got {extrinsic_rotation.shape}") + if extrinsic_translation.shape not in [(3,), (3, 1)]: + raise ValueError( + f"extrinsic_translation must be (3,) or (3,1), got {extrinsic_translation.shape}" + ) + + extrinsic_translation = extrinsic_translation.flatten() + corners = (extrinsic_rotation @ corners.T).T + extrinsic_translation + + try: + # Project 3D corners to image coordinates + corners_img, _ = cv2.projectPoints( + corners.astype(np.float32), + np.zeros(3), + np.zeros(3), # No additional rotation/translation + camera_matrix.astype(np.float32), + None, # No distortion + ) + corners_img = corners_img.reshape(-1, 2).astype(int) + + # Check if corners are within image bounds + h, w = image.shape[:2] + valid_corners = ( + (corners_img[:, 0] >= 0) + & (corners_img[:, 0] < w) + & (corners_img[:, 1] >= 0) + & (corners_img[:, 1] < h) + ) + + if not np.any(valid_corners): + print("Warning: All cuboid corners are outside image bounds") + return image.copy() + + except Exception as e: + print(f"Warning: Failed to project cuboid corners: {e}") + return image.copy() + + # Define edges for wireframe visualization + edges = [ + # Bottom face + (0, 1), + (1, 5), + (5, 4), + (4, 0), + # Top face + (2, 3), + (3, 7), + (7, 6), + (6, 2), + # Vertical edges + (0, 2), + (1, 3), + (5, 7), + (4, 6), + ] + + # Draw edges + vis_img = image.copy() + for i, j in edges: + # Only draw edge if both corners are valid + if valid_corners[i] and valid_corners[j]: + cv2.line(vis_img, tuple(corners_img[i]), tuple(corners_img[j]), color, thickness) + + # Add dimension text if requested + if show_dimensions and np.any(valid_corners): + dims = cuboid_params["dimensions"] + dim_text = f"Dims: {dims[0]:.3f} x {dims[1]:.3f} x {dims[2]:.3f}" + + # Find a good position for text (top-left of image) + text_pos = (10, 30) + font_scale = 0.7 + + # Add background rectangle for better readability + text_size = cv2.getTextSize(dim_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2)[0] + cv2.rectangle( + vis_img, + (text_pos[0] - 5, text_pos[1] - text_size[1] - 5), + (text_pos[0] + text_size[0] + 5, text_pos[1] + 5), + (0, 0, 0), + -1, + ) + + cv2.putText(vis_img, dim_text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, 2) + + return vis_img + + +def compute_cuboid_volume(cuboid_params: Dict) -> float: + """ + Compute the volume of a cuboid. + + Args: + cuboid_params: Dictionary containing cuboid parameters + + Returns: + Volume in cubic units + """ + if "dimensions" not in cuboid_params: + raise ValueError("cuboid_params must contain 'dimensions' key") + + dims = cuboid_params["dimensions"] + return float(np.prod(dims)) + + +def compute_cuboid_surface_area(cuboid_params: Dict) -> float: + """ + Compute the surface area of a cuboid. + + Args: + cuboid_params: Dictionary containing cuboid parameters + + Returns: + Surface area in square units + """ + if "dimensions" not in cuboid_params: + raise ValueError("cuboid_params must contain 'dimensions' key") + + dims = cuboid_params["dimensions"] + return 2.0 * (dims[0] * dims[1] + dims[1] * dims[2] + dims[2] * dims[0]) + + +def check_cuboid_quality(cuboid_params: Dict, points: np.ndarray) -> Dict: + """ + Assess the quality of a cuboid fit. + + Args: + cuboid_params: Dictionary containing cuboid parameters + points: Original points used for fitting + + Returns: + Dictionary with quality metrics + """ + if len(points) == 0: + return {"error": "No points provided"} + + # Basic metrics + volume = compute_cuboid_volume(cuboid_params) + surface_area = compute_cuboid_surface_area(cuboid_params) + error = cuboid_params.get("error", 0.0) + + # Aspect ratio analysis + dims = cuboid_params["dimensions"] + aspect_ratios = [ + dims[0] / dims[1] if dims[1] > 0 else float("inf"), + dims[1] / dims[2] if dims[2] > 0 else float("inf"), + dims[2] / dims[0] if dims[0] > 0 else float("inf"), + ] + max_aspect_ratio = max(aspect_ratios) + + # Volume ratio (cuboid volume vs convex hull volume) + try: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + hull, _ = pcd.compute_convex_hull() + hull_volume = hull.get_volume() + volume_ratio = volume / hull_volume if hull_volume > 0 else float("inf") + except: + volume_ratio = None + + return { + "fitting_error": error, + "volume": volume, + "surface_area": surface_area, + "max_aspect_ratio": max_aspect_ratio, + "volume_ratio": volume_ratio, + "num_points": len(points), + "method": cuboid_params.get("method", "unknown"), + } + + +# Backward compatibility +def visualize_fit(image, cuboid_params, camera_matrix, R=None, t=None): + """ + Legacy function for backward compatibility. + Use visualize_cuboid_on_image instead. + """ + return visualize_cuboid_on_image( + image, cuboid_params, camera_matrix, R, t, show_dimensions=True + ) diff --git a/build/lib/dimos/perception/pointcloud/pointcloud_filtering.py b/build/lib/dimos/perception/pointcloud/pointcloud_filtering.py new file mode 100644 index 0000000000..ef033bff3f --- /dev/null +++ b/build/lib/dimos/perception/pointcloud/pointcloud_filtering.py @@ -0,0 +1,674 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 +import os +import torch +import open3d as o3d +import argparse +import pickle +from typing import Dict, List, Optional, Union +import time +from dimos.types.manipulation import ObjectData +from dimos.types.vector import Vector +from dimos.perception.pointcloud.utils import ( + load_camera_matrix_from_yaml, + create_point_cloud_and_extract_masks, + o3d_point_cloud_to_numpy, +) +from dimos.perception.pointcloud.cuboid_fit import fit_cuboid + + +class PointcloudFiltering: + """ + A production-ready point cloud filtering pipeline for segmented objects. + + This class takes segmentation results and produces clean, filtered point clouds + for each object with consistent coloring and optional outlier removal. + """ + + def __init__( + self, + color_intrinsics: Optional[Union[str, List[float], np.ndarray]] = None, + depth_intrinsics: Optional[Union[str, List[float], np.ndarray]] = None, + color_weight: float = 0.3, + enable_statistical_filtering: bool = True, + statistical_neighbors: int = 20, + statistical_std_ratio: float = 1.5, + enable_radius_filtering: bool = True, + radius_filtering_radius: float = 0.015, + radius_filtering_min_neighbors: int = 25, + enable_subsampling: bool = True, + voxel_size: float = 0.005, + max_num_objects: int = 10, + min_points_for_cuboid: int = 10, + cuboid_method: str = "oriented", + max_bbox_size_percent: float = 30.0, + ): + """ + Initialize the point cloud filtering pipeline. + + Args: + color_intrinsics: Camera intrinsics for color image + depth_intrinsics: Camera intrinsics for depth image + color_weight: Weight for blending generated color with original (0.0-1.0) + enable_statistical_filtering: Enable/disable statistical outlier filtering + statistical_neighbors: Number of neighbors for statistical filtering + statistical_std_ratio: Standard deviation ratio for statistical filtering + enable_radius_filtering: Enable/disable radius outlier filtering + radius_filtering_radius: Search radius for radius filtering (meters) + radius_filtering_min_neighbors: Min neighbors within radius + enable_subsampling: Enable/disable point cloud subsampling + voxel_size: Voxel size for downsampling (meters, when subsampling enabled) + max_num_objects: Maximum number of objects to process (top N by confidence) + min_points_for_cuboid: Minimum points required for cuboid fitting + cuboid_method: Method for cuboid fitting ('minimal', 'oriented', 'axis_aligned') + max_bbox_size_percent: Maximum percentage of image size for object bboxes (0-100) + + Raises: + ValueError: If invalid parameters are provided + """ + # Validate parameters + if not 0.0 <= color_weight <= 1.0: + raise ValueError(f"color_weight must be between 0.0 and 1.0, got {color_weight}") + if not 0.0 <= max_bbox_size_percent <= 100.0: + raise ValueError( + f"max_bbox_size_percent must be between 0.0 and 100.0, got {max_bbox_size_percent}" + ) + + # Store settings + self.color_weight = color_weight + self.enable_statistical_filtering = enable_statistical_filtering + self.statistical_neighbors = statistical_neighbors + self.statistical_std_ratio = statistical_std_ratio + self.enable_radius_filtering = enable_radius_filtering + self.radius_filtering_radius = radius_filtering_radius + self.radius_filtering_min_neighbors = radius_filtering_min_neighbors + self.enable_subsampling = enable_subsampling + self.voxel_size = voxel_size + self.max_num_objects = max_num_objects + self.min_points_for_cuboid = min_points_for_cuboid + self.cuboid_method = cuboid_method + self.max_bbox_size_percent = max_bbox_size_percent + + # Load camera matrices + self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) + self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) + + # Store the full point cloud + self.full_pcd = None + + def generate_color_from_id(self, object_id: int) -> np.ndarray: + """Generate a consistent color for a given object ID.""" + np.random.seed(object_id) + color = np.random.randint(0, 255, 3, dtype=np.uint8) + np.random.seed(None) + return color + + def _validate_inputs( + self, color_img: np.ndarray, depth_img: np.ndarray, objects: List[ObjectData] + ): + """Validate input parameters.""" + if color_img.shape[:2] != depth_img.shape: + raise ValueError("Color and depth image dimensions don't match") + + def _prepare_masks(self, masks: List[np.ndarray], target_shape: tuple) -> List[np.ndarray]: + """Prepare and validate masks to match target shape.""" + processed_masks = [] + for mask in masks: + # Convert mask to numpy if it's a tensor + if hasattr(mask, "cpu"): + mask = mask.cpu().numpy() + + mask = mask.astype(bool) + + # Handle shape mismatches + if mask.shape != target_shape: + if len(mask.shape) > 2: + mask = mask[:, :, 0] + + if mask.shape != target_shape: + mask = cv2.resize( + mask.astype(np.uint8), + (target_shape[1], target_shape[0]), + interpolation=cv2.INTER_NEAREST, + ).astype(bool) + + processed_masks.append(mask) + + return processed_masks + + def _apply_color_mask( + self, pcd: o3d.geometry.PointCloud, rgb_color: np.ndarray + ) -> o3d.geometry.PointCloud: + """Apply weighted color mask to point cloud.""" + if len(np.asarray(pcd.colors)) > 0: + original_colors = np.asarray(pcd.colors) + generated_color = rgb_color.astype(np.float32) / 255.0 + colored_mask = ( + 1.0 - self.color_weight + ) * original_colors + self.color_weight * generated_color + colored_mask = np.clip(colored_mask, 0.0, 1.0) + pcd.colors = o3d.utility.Vector3dVector(colored_mask) + return pcd + + def _apply_filtering(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.PointCloud: + """Apply optional filtering to point cloud based on enabled flags.""" + current_pcd = pcd + + # Apply statistical filtering if enabled + if self.enable_statistical_filtering: + current_pcd, _ = current_pcd.remove_statistical_outlier( + nb_neighbors=self.statistical_neighbors, std_ratio=self.statistical_std_ratio + ) + + # Apply radius filtering if enabled + if self.enable_radius_filtering: + current_pcd, _ = current_pcd.remove_radius_outlier( + nb_points=self.radius_filtering_min_neighbors, radius=self.radius_filtering_radius + ) + + return current_pcd + + def _apply_subsampling(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.PointCloud: + """Apply subsampling to limit point cloud size using Open3D's voxel downsampling.""" + if self.enable_subsampling: + return pcd.voxel_down_sample(self.voxel_size) + return pcd + + def _extract_masks_from_objects(self, objects: List[ObjectData]) -> List[np.ndarray]: + """Extract segmentation masks from ObjectData objects.""" + return [obj["segmentation_mask"] for obj in objects] + + def get_full_point_cloud(self) -> o3d.geometry.PointCloud: + """Get the full point cloud.""" + return self._apply_subsampling(self.full_pcd) + + def process_images( + self, color_img: np.ndarray, depth_img: np.ndarray, objects: List[ObjectData] + ) -> List[ObjectData]: + """ + Process color and depth images with object detection results to create filtered point clouds. + + Args: + color_img: RGB image as numpy array (H, W, 3) + depth_img: Depth image as numpy array (H, W) in meters + objects: List of ObjectData from object detection stream + + Returns: + List of updated ObjectData with pointcloud and 3D information. Each ObjectData + dictionary is enhanced with the following new fields: + + **3D Spatial Information** (added when sufficient points for cuboid fitting): + - "position": Vector(x, y, z) - 3D center position in world coordinates (meters) + - "rotation": Vector(roll, pitch, yaw) - 3D orientation as Euler angles (radians) + - "size": {"width": float, "height": float, "depth": float} - 3D bounding box dimensions (meters) + + **Point Cloud Data**: + - "point_cloud": o3d.geometry.PointCloud - Filtered Open3D point cloud with colors + - "color": np.ndarray - Consistent RGB color [R,G,B] (0-255) generated from object_id + + **Grasp Generation Arrays** (AnyGrasp format): + - "point_cloud_numpy": np.ndarray - Nx3 XYZ coordinates as float32 (meters) + - "colors_numpy": np.ndarray - Nx3 RGB colors as float32 (0.0-1.0 range) + + Raises: + ValueError: If inputs are invalid + RuntimeError: If processing fails + """ + # Validate inputs + self._validate_inputs(color_img, depth_img, objects) + + if not objects: + return [] + + # Filter to top N objects by confidence + if len(objects) > self.max_num_objects: + # Sort objects by confidence (highest first), handle None confidences + sorted_objects = sorted( + objects, + key=lambda obj: obj.get("confidence", 0.0) + if obj.get("confidence") is not None + else 0.0, + reverse=True, + ) + objects = sorted_objects[: self.max_num_objects] + + # Filter out objects with bboxes too large + image_area = color_img.shape[0] * color_img.shape[1] + max_bbox_area = image_area * (self.max_bbox_size_percent / 100.0) + + filtered_objects = [] + for obj in objects: + if "bbox" in obj and obj["bbox"] is not None: + bbox = obj["bbox"] + # Calculate bbox area (assuming bbox format [x1, y1, x2, y2]) + bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + if bbox_area <= max_bbox_area: + filtered_objects.append(obj) + else: + filtered_objects.append(obj) + + objects = filtered_objects + + # Extract masks from ObjectData + masks = self._extract_masks_from_objects(objects) + + # Prepare masks + processed_masks = self._prepare_masks(masks, depth_img.shape) + + # Create point clouds efficiently + self.full_pcd, masked_pcds = create_point_cloud_and_extract_masks( + color_img, depth_img, processed_masks, self.depth_camera_matrix, depth_scale=1.0 + ) + + # Process each object and update ObjectData + updated_objects = [] + + for i, (obj, mask, pcd) in enumerate(zip(objects, processed_masks, masked_pcds)): + # Skip empty point clouds + if len(np.asarray(pcd.points)) == 0: + continue + + # Create a copy of the object data to avoid modifying the original + updated_obj = obj.copy() + + # Generate consistent color + object_id = obj.get("object_id", i) + rgb_color = self.generate_color_from_id(object_id) + + # Apply color mask + pcd = self._apply_color_mask(pcd, rgb_color) + + # Apply subsampling to control point cloud size + pcd = self._apply_subsampling(pcd) + + # Apply filtering (optional based on flags) + pcd_filtered = self._apply_filtering(pcd) + + # Fit cuboid and extract 3D information + points = np.asarray(pcd_filtered.points) + if len(points) >= self.min_points_for_cuboid: + cuboid_params = fit_cuboid(points, method=self.cuboid_method) + if cuboid_params is not None: + # Update position, rotation, and size from cuboid + center = cuboid_params["center"] + dimensions = cuboid_params["dimensions"] + rotation_matrix = cuboid_params["rotation"] + + # Convert rotation matrix to euler angles (roll, pitch, yaw) + sy = np.sqrt( + rotation_matrix[0, 0] * rotation_matrix[0, 0] + + rotation_matrix[1, 0] * rotation_matrix[1, 0] + ) + singular = sy < 1e-6 + + if not singular: + roll = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + pitch = np.arctan2(-rotation_matrix[2, 0], sy) + yaw = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + else: + roll = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) + pitch = np.arctan2(-rotation_matrix[2, 0], sy) + yaw = 0 + + # Update position, rotation, and size from cuboid + updated_obj["position"] = Vector(center[0], center[1], center[2]) + updated_obj["rotation"] = Vector(roll, pitch, yaw) + updated_obj["size"] = { + "width": float(dimensions[0]), + "height": float(dimensions[1]), + "depth": float(dimensions[2]), + } + + # Add point cloud data to ObjectData + updated_obj["point_cloud"] = pcd_filtered + updated_obj["color"] = rgb_color + + # Extract numpy arrays for grasp generation (anygrasp format) + points_array = np.asarray(pcd_filtered.points).astype(np.float32) # Nx3 XYZ coordinates + if pcd_filtered.has_colors(): + colors_array = np.asarray(pcd_filtered.colors).astype( + np.float32 + ) # Nx3 RGB (0-1 range) + else: + # If no colors, create array of zeros + colors_array = np.zeros((len(points_array), 3), dtype=np.float32) + + updated_obj["point_cloud_numpy"] = points_array + updated_obj["colors_numpy"] = colors_array + + updated_objects.append(updated_obj) + + return updated_objects + + def cleanup(self): + """Clean up resources.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def create_test_pipeline(data_dir: str) -> tuple: + """ + Create a test pipeline with default settings. + + Args: + data_dir: Directory containing camera info files + + Returns: + Tuple of (filter_pipeline, color_info_path, depth_info_path) + """ + color_info_path = os.path.join(data_dir, "color_camera_info.yaml") + depth_info_path = os.path.join(data_dir, "depth_camera_info.yaml") + + # Default pipeline with subsampling disabled by default + filter_pipeline = PointcloudFiltering( + color_intrinsics=color_info_path, + depth_intrinsics=depth_info_path, + ) + + return filter_pipeline, color_info_path, depth_info_path + + +def load_test_images(data_dir: str) -> tuple: + """ + Load the first available test images from data directory. + + Args: + data_dir: Directory containing color and depth subdirectories + + Returns: + Tuple of (color_img, depth_img) or raises FileNotFoundError + """ + + def find_first_image(directory): + """Find the first image file in the given directory.""" + if not os.path.exists(directory): + return None + + image_extensions = [".jpg", ".jpeg", ".png", ".bmp"] + for filename in sorted(os.listdir(directory)): + if any(filename.lower().endswith(ext) for ext in image_extensions): + return os.path.join(directory, filename) + return None + + color_dir = os.path.join(data_dir, "color") + depth_dir = os.path.join(data_dir, "depth") + + color_img_path = find_first_image(color_dir) + depth_img_path = find_first_image(depth_dir) + + if not color_img_path or not depth_img_path: + raise FileNotFoundError(f"Could not find color or depth images in {data_dir}") + + # Load color image + color_img = cv2.imread(color_img_path) + if color_img is None: + raise FileNotFoundError(f"Could not load color image from {color_img_path}") + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) + + # Load depth image + depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) + if depth_img is None: + raise FileNotFoundError(f"Could not load depth image from {depth_img_path}") + + # Convert depth to meters if needed + if depth_img.dtype == np.uint16: + depth_img = depth_img.astype(np.float32) / 1000.0 + + return color_img, depth_img + + +def run_segmentation(color_img: np.ndarray, device: str = "auto") -> List[ObjectData]: + """ + Run segmentation on color image and return ObjectData objects. + + Args: + color_img: RGB color image + device: Device to use ('auto', 'cuda', or 'cpu') + + Returns: + List of ObjectData objects with segmentation results + """ + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Import here to avoid circular imports + from dimos.perception.segmentation import Sam2DSegmenter + + segmenter = Sam2DSegmenter( + model_path="FastSAM-s.pt", device=device, use_tracker=False, use_analyzer=False + ) + + try: + masks, bboxes, target_ids, probs, names = segmenter.process_image(np.array(color_img)) + + # Create ObjectData objects + objects = [] + for i in range(len(bboxes)): + obj_data: ObjectData = { + "object_id": target_ids[i] if i < len(target_ids) else i, + "bbox": bboxes[i], + "depth": -1.0, # Will be populated by pointcloud filtering + "confidence": probs[i] if i < len(probs) else 1.0, + "class_id": i, + "label": names[i] if i < len(names) else "", + "segmentation_mask": masks[i].cpu().numpy() + if hasattr(masks[i], "cpu") + else masks[i], + "position": Vector(0, 0, 0), # Will be populated by pointcloud filtering + "rotation": Vector(0, 0, 0), # Will be populated by pointcloud filtering + "size": { + "width": 0.0, + "height": 0.0, + "depth": 0.0, + }, # Will be populated by pointcloud filtering + } + objects.append(obj_data) + + return objects + + finally: + segmenter.cleanup() + + +def visualize_results(objects: List[ObjectData]): + """ + Visualize point cloud filtering results with 3D bounding boxes. + + Args: + objects: List of ObjectData with point clouds + """ + all_pcds = [] + + for obj in objects: + if "point_cloud" in obj and obj["point_cloud"] is not None: + pcd = obj["point_cloud"] + all_pcds.append(pcd) + + # Draw 3D bounding box if position, rotation, and size are available + if ( + "position" in obj + and "rotation" in obj + and "size" in obj + and obj["position"] is not None + and obj["rotation"] is not None + and obj["size"] is not None + ): + try: + position = obj["position"] + rotation = obj["rotation"] + size = obj["size"] + + # Convert position to numpy array + if hasattr(position, "x"): # Vector object + center = np.array([position.x, position.y, position.z]) + else: # Dictionary + center = np.array([position["x"], position["y"], position["z"]]) + + # Convert rotation (euler angles) to rotation matrix + if hasattr(rotation, "x"): # Vector object (roll, pitch, yaw) + roll, pitch, yaw = rotation.x, rotation.y, rotation.z + else: # Dictionary + roll, pitch, yaw = rotation["roll"], rotation["pitch"], rotation["yaw"] + + # Create rotation matrix from euler angles (ZYX order) + # Roll (X), Pitch (Y), Yaw (Z) + cos_r, sin_r = np.cos(roll), np.sin(roll) + cos_p, sin_p = np.cos(pitch), np.sin(pitch) + cos_y, sin_y = np.cos(yaw), np.sin(yaw) + + # Rotation matrix for ZYX euler angles + R = np.array( + [ + [ + cos_y * cos_p, + cos_y * sin_p * sin_r - sin_y * cos_r, + cos_y * sin_p * cos_r + sin_y * sin_r, + ], + [ + sin_y * cos_p, + sin_y * sin_p * sin_r + cos_y * cos_r, + sin_y * sin_p * cos_r - cos_y * sin_r, + ], + [-sin_p, cos_p * sin_r, cos_p * cos_r], + ] + ) + + # Get dimensions + width = size.get("width", 0.1) + height = size.get("height", 0.1) + depth = size.get("depth", 0.1) + extent = np.array([width, height, depth]) + + # Create oriented bounding box + obb = o3d.geometry.OrientedBoundingBox(center=center, R=R, extent=extent) + obb.color = [1, 0, 0] # Red bounding boxes + all_pcds.append(obb) + + except Exception as e: + print( + f"Warning: Failed to create bounding box for object {obj.get('object_id', 'unknown')}: {e}" + ) + + # Add coordinate frame + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + all_pcds.append(coordinate_frame) + + # Visualize + if all_pcds: + o3d.visualization.draw_geometries( + all_pcds, + window_name="Filtered Point Clouds with 3D Bounding Boxes", + width=1280, + height=720, + ) + + +def main(): + """Main function to demonstrate the PointcloudFiltering pipeline.""" + parser = argparse.ArgumentParser(description="Point cloud filtering pipeline demonstration") + parser.add_argument( + "--save-pickle", + type=str, + help="Save generated ObjectData to pickle file (provide filename)", + ) + parser.add_argument( + "--data-dir", type=str, help="Directory containing RGBD data (default: auto-detect)" + ) + args = parser.parse_args() + + try: + # Setup paths + if args.data_dir: + data_dir = args.data_dir + else: + script_dir = os.path.dirname(os.path.abspath(__file__)) + dimos_dir = os.path.abspath(os.path.join(script_dir, "../../../")) + data_dir = os.path.join(dimos_dir, "assets/rgbd_data") + + # Load test data + print("Loading test images...") + color_img, depth_img = load_test_images(data_dir) + print(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") + + # Run segmentation + print("Running segmentation...") + objects = run_segmentation(color_img) + print(f"Found {len(objects)} objects") + + # Create filtering pipeline + print("Creating filtering pipeline...") + filter_pipeline, _, _ = create_test_pipeline(data_dir) + + # Process images + print("Processing point clouds...") + updated_objects = filter_pipeline.process_images(color_img, depth_img, objects) + + # Print results + print(f"Processing complete:") + print(f" Objects processed: {len(updated_objects)}/{len(objects)}") + + # Print per-object stats + for i, obj in enumerate(updated_objects): + if "point_cloud" in obj and obj["point_cloud"] is not None: + num_points = len(np.asarray(obj["point_cloud"].points)) + position = obj.get("position", Vector(0, 0, 0)) + size = obj.get("size", {}) + print(f" Object {i + 1} (ID: {obj['object_id']}): {num_points} points") + print(f" Position: ({position.x:.2f}, {position.y:.2f}, {position.z:.2f})") + print( + f" Size: {size.get('width', 0):.3f} x {size.get('height', 0):.3f} x {size.get('depth', 0):.3f}" + ) + + # Save to pickle file if requested + if args.save_pickle: + pickle_path = args.save_pickle + if not pickle_path.endswith(".pkl"): + pickle_path += ".pkl" + + print(f"Saving ObjectData to {pickle_path}...") + + # Create serializable objects (exclude Open3D point clouds) + serializable_objects = [] + for obj in updated_objects: + obj_copy = obj.copy() + # Remove the Open3D point cloud object (can't be pickled) + if "point_cloud" in obj_copy: + del obj_copy["point_cloud"] + serializable_objects.append(obj_copy) + + with open(pickle_path, "wb") as f: + pickle.dump(serializable_objects, f) + + print(f"Successfully saved {len(serializable_objects)} objects to {pickle_path}") + print("To load: objects = pickle.load(open('filename.pkl', 'rb'))") + print( + "Note: Open3D point clouds excluded - use point_cloud_numpy and colors_numpy for processing" + ) + + # Visualize results + print("Visualizing results...") + visualize_results(updated_objects) + + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/build/lib/dimos/perception/pointcloud/utils.py b/build/lib/dimos/perception/pointcloud/utils.py new file mode 100644 index 0000000000..b1174253e3 --- /dev/null +++ b/build/lib/dimos/perception/pointcloud/utils.py @@ -0,0 +1,1451 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Point cloud utilities for RGBD data processing. + +This module provides efficient utilities for creating and manipulating point clouds +from RGBD images using Open3D. +""" + +import numpy as np +import yaml +import os +import cv2 +import open3d as o3d +from typing import List, Optional, Tuple, Union, Dict, Any +from scipy.spatial import cKDTree + + +def depth_to_point_cloud(depth_image, camera_intrinsics, subsample_factor=4): + """ + Convert depth image to point cloud using camera intrinsics. + Subsamples points to reduce density. + + Args: + depth_image: HxW depth image in meters + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + subsample_factor: Factor to subsample points (higher = fewer points) + + Returns: + Nx3 array of 3D points + """ + # Filter out inf and nan values from depth image + depth_filtered = depth_image.copy() + + # Create mask for valid depth values (finite, positive, non-zero) + valid_mask = np.isfinite(depth_filtered) & (depth_filtered > 0) + + # Set invalid values to 0 + depth_filtered[~valid_mask] = 0.0 + + # Extract camera parameters + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + else: + fx = camera_intrinsics[0, 0] + fy = camera_intrinsics[1, 1] + cx = camera_intrinsics[0, 2] + cy = camera_intrinsics[1, 2] + + # Create pixel coordinate grid + rows, cols = depth_filtered.shape + x_grid, y_grid = np.meshgrid( + np.arange(0, cols, subsample_factor), np.arange(0, rows, subsample_factor) + ) + + # Flatten grid and depth + x = x_grid.flatten() + y = y_grid.flatten() + z = depth_filtered[y_grid, x_grid].flatten() + + # Remove points with invalid depth (after filtering, this catches zeros) + valid = z > 0 + x = x[valid] + y = y[valid] + z = z[valid] + + # Convert to 3D points + X = (x - cx) * z / fx + Y = (y - cy) * z / fy + Z = z + + return np.column_stack([X, Y, Z]) + + +def load_camera_matrix_from_yaml( + camera_info: Optional[Union[str, List[float], np.ndarray, dict]], +) -> Optional[np.ndarray]: + """ + Load camera intrinsic matrix from various input formats. + + Args: + camera_info: Can be: + - Path to YAML file containing camera parameters + - List of [fx, fy, cx, cy] + - 3x3 numpy array (returned as-is) + - Dict with camera parameters + - None (returns None) + + Returns: + 3x3 camera intrinsic matrix or None if input is None + + Raises: + ValueError: If camera_info format is invalid or file cannot be read + FileNotFoundError: If YAML file path doesn't exist + """ + if camera_info is None: + return None + + # Handle case where camera_info is already a matrix + if isinstance(camera_info, np.ndarray) and camera_info.shape == (3, 3): + return camera_info.astype(np.float32) + + # Handle case where camera_info is [fx, fy, cx, cy] format + if isinstance(camera_info, list) and len(camera_info) == 4: + fx, fy, cx, cy = camera_info + return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + # Handle case where camera_info is a dict + if isinstance(camera_info, dict): + return _extract_matrix_from_dict(camera_info) + + # Handle case where camera_info is a path to a YAML file + if isinstance(camera_info, str): + if not os.path.isfile(camera_info): + raise FileNotFoundError(f"Camera info file not found: {camera_info}") + + try: + with open(camera_info, "r") as f: + data = yaml.safe_load(f) + return _extract_matrix_from_dict(data) + except Exception as e: + raise ValueError(f"Failed to read camera info from {camera_info}: {e}") + + raise ValueError( + f"Invalid camera_info format. Expected str, list, dict, or numpy array, got {type(camera_info)}" + ) + + +def _extract_matrix_from_dict(data: dict) -> np.ndarray: + """Extract camera matrix from dictionary with various formats.""" + # ROS format with 'K' field (most common) + if "K" in data: + k_data = data["K"] + if len(k_data) == 9: + return np.array(k_data, dtype=np.float32).reshape(3, 3) + + # Standard format with 'camera_matrix' + if "camera_matrix" in data: + if "data" in data["camera_matrix"]: + matrix_data = data["camera_matrix"]["data"] + if len(matrix_data) == 9: + return np.array(matrix_data, dtype=np.float32).reshape(3, 3) + + # Explicit intrinsics format + if all(k in data for k in ["fx", "fy", "cx", "cy"]): + fx, fy = float(data["fx"]), float(data["fy"]) + cx, cy = float(data["cx"]), float(data["cy"]) + return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + # Error case - provide helpful debug info + available_keys = list(data.keys()) + if "K" in data: + k_info = f"K field length: {len(data['K']) if hasattr(data['K'], '__len__') else 'unknown'}" + else: + k_info = "K field not found" + + raise ValueError( + f"Cannot extract camera matrix from data. " + f"Available keys: {available_keys}. {k_info}. " + f"Expected formats: 'K' (9 elements), 'camera_matrix.data' (9 elements), " + f"or individual 'fx', 'fy', 'cx', 'cy' fields." + ) + + +def create_o3d_point_cloud_from_rgbd( + color_img: np.ndarray, + depth_img: np.ndarray, + intrinsic: np.ndarray, + depth_scale: float = 1.0, + depth_trunc: float = 3.0, +) -> o3d.geometry.PointCloud: + """ + Create an Open3D point cloud from RGB and depth images. + + Args: + color_img: RGB image as numpy array (H, W, 3) + depth_img: Depth image as numpy array (H, W) + intrinsic: Camera intrinsic matrix (3x3 numpy array) + depth_scale: Scale factor to convert depth to meters + depth_trunc: Maximum depth in meters + + Returns: + Open3D point cloud object + + Raises: + ValueError: If input dimensions are invalid + """ + # Validate inputs + if len(color_img.shape) != 3 or color_img.shape[2] != 3: + raise ValueError(f"color_img must be (H, W, 3), got {color_img.shape}") + if len(depth_img.shape) != 2: + raise ValueError(f"depth_img must be (H, W), got {depth_img.shape}") + if color_img.shape[:2] != depth_img.shape: + raise ValueError( + f"Color and depth image dimensions don't match: {color_img.shape[:2]} vs {depth_img.shape}" + ) + if intrinsic.shape != (3, 3): + raise ValueError(f"intrinsic must be (3, 3), got {intrinsic.shape}") + + # Convert to Open3D format + color_o3d = o3d.geometry.Image(color_img.astype(np.uint8)) + + # Filter out inf and nan values from depth image + depth_filtered = depth_img.copy() + + # Create mask for valid depth values (finite, positive, non-zero) + valid_mask = np.isfinite(depth_filtered) & (depth_filtered > 0) + + # Set invalid values to 0 (which Open3D treats as no depth) + depth_filtered[~valid_mask] = 0.0 + + depth_o3d = o3d.geometry.Image(depth_filtered.astype(np.float32)) + + # Create Open3D intrinsic object + height, width = color_img.shape[:2] + fx, fy = intrinsic[0, 0], intrinsic[1, 1] + cx, cy = intrinsic[0, 2], intrinsic[1, 2] + intrinsic_o3d = o3d.camera.PinholeCameraIntrinsic( + width, + height, + fx, + fy, # fx, fy + cx, + cy, # cx, cy + ) + + # Create RGBD image + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, + depth_o3d, + depth_scale=depth_scale, + depth_trunc=depth_trunc, + convert_rgb_to_intensity=False, + ) + + # Create point cloud + pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, intrinsic_o3d) + + return pcd + + +def o3d_point_cloud_to_numpy(pcd: o3d.geometry.PointCloud) -> np.ndarray: + """ + Convert Open3D point cloud to numpy array of XYZRGB points. + + Args: + pcd: Open3D point cloud object + + Returns: + Nx6 array of XYZRGB points (empty array if no points) + """ + points = np.asarray(pcd.points) + if len(points) == 0: + return np.zeros((0, 6), dtype=np.float32) + + # Get colors if available + if pcd.has_colors(): + colors = np.asarray(pcd.colors) * 255.0 # Convert from [0,1] to [0,255] + return np.column_stack([points, colors]).astype(np.float32) + else: + # No colors available, return points with zero colors + zeros = np.zeros((len(points), 3), dtype=np.float32) + return np.column_stack([points, zeros]).astype(np.float32) + + +def numpy_to_o3d_point_cloud(points_rgb: np.ndarray) -> o3d.geometry.PointCloud: + """ + Convert numpy array of XYZRGB points to Open3D point cloud. + + Args: + points_rgb: Nx6 array of XYZRGB points or Nx3 array of XYZ points + + Returns: + Open3D point cloud object + + Raises: + ValueError: If array shape is invalid + """ + if len(points_rgb) == 0: + return o3d.geometry.PointCloud() + + if points_rgb.shape[1] < 3: + raise ValueError( + f"points_rgb must have at least 3 columns (XYZ), got {points_rgb.shape[1]}" + ) + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points_rgb[:, :3]) + + # Add colors if available + if points_rgb.shape[1] >= 6: + colors = points_rgb[:, 3:6] / 255.0 # Convert from [0,255] to [0,1] + colors = np.clip(colors, 0.0, 1.0) # Ensure valid range + pcd.colors = o3d.utility.Vector3dVector(colors) + + return pcd + + +def create_masked_point_cloud(color_img, depth_img, mask, intrinsic, depth_scale=1.0): + """ + Create a point cloud for a masked region of RGBD data using Open3D. + + Args: + color_img: RGB image (H, W, 3) + depth_img: Depth image (H, W) + mask: Boolean mask of the same size as color_img and depth_img + intrinsic: Camera intrinsic matrix (3x3 numpy array) + depth_scale: Scale factor to convert depth to meters + + Returns: + Open3D point cloud object for the masked region + """ + # Filter out inf and nan values from depth image + depth_filtered = depth_img.copy() + + # Create mask for valid depth values (finite, positive, non-zero) + valid_depth_mask = np.isfinite(depth_filtered) & (depth_filtered > 0) + + # Set invalid values to 0 + depth_filtered[~valid_depth_mask] = 0.0 + + # Create masked color and depth images + masked_color = color_img.copy() + masked_depth = depth_filtered.copy() + + # Apply mask + if not mask.shape[:2] == color_img.shape[:2]: + raise ValueError(f"Mask shape {mask.shape} doesn't match image shape {color_img.shape[:2]}") + + # Create a boolean mask that is properly expanded for the RGB channels + # For RGB image, we need to properly broadcast the mask to all 3 channels + if len(color_img.shape) == 3 and color_img.shape[2] == 3: + # Properly broadcast mask to match the RGB dimensions + mask_rgb = np.broadcast_to(mask[:, :, np.newaxis], color_img.shape) + masked_color[~mask_rgb] = 0 + else: + # For grayscale images + masked_color[~mask] = 0 + + # Apply mask to depth image + masked_depth[~mask] = 0 + + # Create point cloud + pcd = create_o3d_point_cloud_from_rgbd(masked_color, masked_depth, intrinsic, depth_scale) + + # Remove points with coordinates at origin (0,0,0) which are likely from masked out regions + points = np.asarray(pcd.points) + if len(points) > 0: + # Find points that are not at origin + dist_from_origin = np.sum(points**2, axis=1) + valid_indices = dist_from_origin > 1e-6 + + # Filter points and colors + pcd = pcd.select_by_index(np.where(valid_indices)[0]) + + return pcd + + +def create_point_cloud_and_extract_masks( + color_img: np.ndarray, + depth_img: np.ndarray, + masks: List[np.ndarray], + intrinsic: np.ndarray, + depth_scale: float = 1.0, + depth_trunc: float = 3.0, +) -> Tuple[o3d.geometry.PointCloud, List[o3d.geometry.PointCloud]]: + """ + Efficiently create a point cloud once and extract multiple masked regions. + + Args: + color_img: RGB image (H, W, 3) + depth_img: Depth image (H, W) + masks: List of boolean masks, each of shape (H, W) + intrinsic: Camera intrinsic matrix (3x3 numpy array) + depth_scale: Scale factor to convert depth to meters + depth_trunc: Maximum depth in meters + + Returns: + Tuple of (full_point_cloud, list_of_masked_point_clouds) + """ + if not masks: + return o3d.geometry.PointCloud(), [] + + # Create the full point cloud + full_pcd = create_o3d_point_cloud_from_rgbd( + color_img, depth_img, intrinsic, depth_scale, depth_trunc + ) + + if len(np.asarray(full_pcd.points)) == 0: + return full_pcd, [o3d.geometry.PointCloud() for _ in masks] + + # Create pixel-to-point mapping + valid_depth_mask = np.isfinite(depth_img) & (depth_img > 0) & (depth_img <= depth_trunc) + + valid_depth = valid_depth_mask.flatten() + if not np.any(valid_depth): + return full_pcd, [o3d.geometry.PointCloud() for _ in masks] + + pixel_to_point = np.full(len(valid_depth), -1, dtype=np.int32) + pixel_to_point[valid_depth] = np.arange(np.sum(valid_depth)) + + # Extract point clouds for each mask + masked_pcds = [] + max_points = len(np.asarray(full_pcd.points)) + + for mask in masks: + if mask.shape != depth_img.shape: + masked_pcds.append(o3d.geometry.PointCloud()) + continue + + mask_flat = mask.flatten() + valid_mask_indices = mask_flat & valid_depth + point_indices = pixel_to_point[valid_mask_indices] + valid_point_indices = point_indices[point_indices >= 0] + + if len(valid_point_indices) > 0: + valid_point_indices = np.clip(valid_point_indices, 0, max_points - 1) + valid_point_indices = np.unique(valid_point_indices) + masked_pcd = full_pcd.select_by_index(valid_point_indices.tolist()) + else: + masked_pcd = o3d.geometry.PointCloud() + + masked_pcds.append(masked_pcd) + + return full_pcd, masked_pcds + + +def extract_masked_point_cloud_efficient( + full_pcd: o3d.geometry.PointCloud, depth_img: np.ndarray, mask: np.ndarray +) -> o3d.geometry.PointCloud: + """ + Extract a masked region from an existing point cloud efficiently. + + This assumes the point cloud was created from the given depth image. + Use this when you have a pre-computed full point cloud and want to extract + individual masked regions. + + Args: + full_pcd: Complete Open3D point cloud + depth_img: Depth image used to create the point cloud (H, W) + mask: Boolean mask (H, W) + + Returns: + Open3D point cloud for the masked region + + Raises: + ValueError: If mask shape doesn't match depth image + """ + if mask.shape != depth_img.shape: + raise ValueError( + f"Mask shape {mask.shape} doesn't match depth image shape {depth_img.shape}" + ) + + # Early return if no points in full point cloud + if len(np.asarray(full_pcd.points)) == 0: + return o3d.geometry.PointCloud() + + # Get valid depth mask + valid_depth = depth_img.flatten() > 0 + mask_flat = mask.flatten() + + # Find pixels that are both valid and in the mask + valid_mask_indices = mask_flat & valid_depth + + # Get indices of valid points + point_indices = np.where(valid_mask_indices[valid_depth])[0] + + # Extract the masked point cloud + if len(point_indices) > 0: + return full_pcd.select_by_index(point_indices) + else: + return o3d.geometry.PointCloud() + + +def segment_and_remove_plane(pcd, distance_threshold=0.02, ransac_n=3, num_iterations=1000): + """ + Segment the dominant plane from a point cloud using RANSAC and remove it. + Often used to remove table tops, floors, walls, or other planar surfaces. + + Args: + pcd: Open3D point cloud object + distance_threshold: Maximum distance a point can be from the plane to be considered an inlier (in meters) + ransac_n: Number of points to sample for each RANSAC iteration + num_iterations: Number of RANSAC iterations + + Returns: + Open3D point cloud with the dominant plane removed + """ + # Make a copy of the input point cloud to avoid modifying the original + pcd_filtered = o3d.geometry.PointCloud() + pcd_filtered.points = o3d.utility.Vector3dVector(np.asarray(pcd.points)) + if pcd.has_colors(): + pcd_filtered.colors = o3d.utility.Vector3dVector(np.asarray(pcd.colors)) + if pcd.has_normals(): + pcd_filtered.normals = o3d.utility.Vector3dVector(np.asarray(pcd.normals)) + + # Check if point cloud has enough points + if len(pcd_filtered.points) < ransac_n: + return pcd_filtered + + # Run RANSAC to find the largest plane + _, inliers = pcd_filtered.segment_plane( + distance_threshold=distance_threshold, ransac_n=ransac_n, num_iterations=num_iterations + ) + + # Remove the dominant plane (regardless of orientation) + pcd_without_dominant_plane = pcd_filtered.select_by_index(inliers, invert=True) + return pcd_without_dominant_plane + + +def filter_point_cloud_statistical( + pcd: o3d.geometry.PointCloud, nb_neighbors: int = 20, std_ratio: float = 2.0 +) -> Tuple[o3d.geometry.PointCloud, np.ndarray]: + """ + Apply statistical outlier filtering to point cloud. + + Args: + pcd: Input point cloud + nb_neighbors: Number of neighbors to analyze for each point + std_ratio: Threshold level based on standard deviation + + Returns: + Tuple of (filtered_point_cloud, outlier_indices) + """ + if len(np.asarray(pcd.points)) == 0: + return pcd, np.array([]) + + return pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) + + +def filter_point_cloud_radius( + pcd: o3d.geometry.PointCloud, nb_points: int = 16, radius: float = 0.05 +) -> Tuple[o3d.geometry.PointCloud, np.ndarray]: + """ + Apply radius-based outlier filtering to point cloud. + + Args: + pcd: Input point cloud + nb_points: Minimum number of points within radius + radius: Search radius in meters + + Returns: + Tuple of (filtered_point_cloud, outlier_indices) + """ + if len(np.asarray(pcd.points)) == 0: + return pcd, np.array([]) + + return pcd.remove_radius_outlier(nb_points=nb_points, radius=radius) + + +def compute_point_cloud_bounds(pcd: o3d.geometry.PointCloud) -> dict: + """ + Compute bounding box information for a point cloud. + + Args: + pcd: Input point cloud + + Returns: + Dictionary with bounds information + """ + points = np.asarray(pcd.points) + if len(points) == 0: + return { + "min": np.array([0, 0, 0]), + "max": np.array([0, 0, 0]), + "center": np.array([0, 0, 0]), + "size": np.array([0, 0, 0]), + "volume": 0.0, + } + + min_bound = points.min(axis=0) + max_bound = points.max(axis=0) + center = (min_bound + max_bound) / 2 + size = max_bound - min_bound + volume = np.prod(size) + + return {"min": min_bound, "max": max_bound, "center": center, "size": size, "volume": volume} + + +def project_3d_points_to_2d( + points_3d: np.ndarray, camera_intrinsics: Union[List[float], np.ndarray] +) -> np.ndarray: + """ + Project 3D points to 2D image coordinates using camera intrinsics. + + Args: + points_3d: Nx3 array of 3D points (X, Y, Z) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + + Returns: + Nx2 array of 2D image coordinates (u, v) + """ + if len(points_3d) == 0: + return np.zeros((0, 2), dtype=np.int32) + + # Filter out points with zero or negative depth + valid_mask = points_3d[:, 2] > 0 + if not np.any(valid_mask): + return np.zeros((0, 2), dtype=np.int32) + + valid_points = points_3d[valid_mask] + + # Extract camera parameters + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + else: + fx = camera_intrinsics[0, 0] + fy = camera_intrinsics[1, 1] + cx = camera_intrinsics[0, 2] + cy = camera_intrinsics[1, 2] + + # Project to image coordinates + u = (valid_points[:, 0] * fx / valid_points[:, 2]) + cx + v = (valid_points[:, 1] * fy / valid_points[:, 2]) + cy + + # Round to integer pixel coordinates + points_2d = np.column_stack([u, v]).astype(np.int32) + + return points_2d + + +def overlay_point_clouds_on_image( + base_image: np.ndarray, + point_clouds: List[o3d.geometry.PointCloud], + camera_intrinsics: Union[List[float], np.ndarray], + colors: List[Tuple[int, int, int]], + point_size: int = 2, + alpha: float = 0.7, +) -> np.ndarray: + """ + Overlay multiple colored point clouds onto an image. + + Args: + base_image: Base image to overlay onto (H, W, 3) - assumed to be RGB + point_clouds: List of Open3D point cloud objects + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + colors: List of RGB color tuples for each point cloud. If None, generates distinct colors. + point_size: Size of points to draw (in pixels) + alpha: Blending factor for overlay (0.0 = fully transparent, 1.0 = fully opaque) + + Returns: + Image with overlaid point clouds (H, W, 3) + """ + if len(point_clouds) == 0: + return base_image.copy() + + # Create overlay image + overlay = base_image.copy() + height, width = base_image.shape[:2] + + # Process each point cloud + for i, pcd in enumerate(point_clouds): + if pcd is None: + continue + + points_3d = np.asarray(pcd.points) + if len(points_3d) == 0: + continue + + # Project 3D points to 2D + points_2d = project_3d_points_to_2d(points_3d, camera_intrinsics) + + if len(points_2d) == 0: + continue + + # Filter points within image bounds + valid_mask = ( + (points_2d[:, 0] >= 0) + & (points_2d[:, 0] < width) + & (points_2d[:, 1] >= 0) + & (points_2d[:, 1] < height) + ) + valid_points_2d = points_2d[valid_mask] + + if len(valid_points_2d) == 0: + continue + + # Get color for this point cloud + color = colors[i % len(colors)] + + # Ensure color is a tuple of integers for OpenCV + if isinstance(color, (list, tuple, np.ndarray)): + color = tuple(int(c) for c in color[:3]) + else: + color = (255, 255, 255) + + # Draw points on overlay + for point in valid_points_2d: + u, v = point + # Draw a small filled circle for each point + cv2.circle(overlay, (u, v), point_size, color, -1) + + # Blend overlay with base image + result = cv2.addWeighted(base_image, 1 - alpha, overlay, alpha, 0) + + return result + + +def create_point_cloud_overlay_visualization( + base_image: np.ndarray, + objects: List[dict], + intrinsics: np.ndarray, +) -> np.ndarray: + """ + Create a visualization showing object point clouds and bounding boxes overlaid on a base image. + + Args: + base_image: Base image to overlay onto (H, W, 3) + objects: List of object dictionaries containing 'point_cloud', 'color', 'position', 'rotation', 'size' keys + intrinsics: Camera intrinsics as [fx, fy, cx, cy] or 3x3 matrix + + Returns: + Visualization image with overlaid point clouds and bounding boxes (H, W, 3) + """ + # Extract point clouds and colors from objects + point_clouds = [] + colors = [] + for obj in objects: + if "point_cloud" in obj and obj["point_cloud"] is not None: + point_clouds.append(obj["point_cloud"]) + + # Convert color to tuple + color = obj["color"] + if isinstance(color, np.ndarray): + color = tuple(int(c) for c in color) + elif isinstance(color, (list, tuple)): + color = tuple(int(c) for c in color[:3]) + colors.append(color) + + # Create visualization + if point_clouds: + result = overlay_point_clouds_on_image( + base_image=base_image, + point_clouds=point_clouds, + camera_intrinsics=intrinsics, + colors=colors, + point_size=3, + alpha=0.8, + ) + else: + result = base_image.copy() + + # Draw 3D bounding boxes + height_img, width_img = result.shape[:2] + for i, obj in enumerate(objects): + if all(key in obj and obj[key] is not None for key in ["position", "rotation", "size"]): + try: + # Create and project 3D bounding box + corners_3d = create_3d_bounding_box_corners( + obj["position"], obj["rotation"], obj["size"] + ) + corners_2d = project_3d_points_to_2d(corners_3d, intrinsics) + + # Check if any corners are visible + valid_mask = ( + (corners_2d[:, 0] >= 0) + & (corners_2d[:, 0] < width_img) + & (corners_2d[:, 1] >= 0) + & (corners_2d[:, 1] < height_img) + ) + + if np.any(valid_mask): + # Get color + bbox_color = colors[i] if i < len(colors) else (255, 255, 255) + draw_3d_bounding_box_on_image(result, corners_2d, bbox_color, thickness=2) + except: + continue + + return result + + +def create_3d_bounding_box_corners(position, rotation, size): + """ + Create 8 corners of a 3D bounding box from position, rotation, and size. + + Args: + position: Vector or dict with x, y, z coordinates + rotation: Vector or dict with roll, pitch, yaw angles + size: Dict with width, height, depth + + Returns: + 8x3 numpy array of corner coordinates + """ + # Convert position to numpy array + if hasattr(position, "x"): # Vector object + center = np.array([position.x, position.y, position.z]) + else: # Dictionary + center = np.array([position["x"], position["y"], position["z"]]) + + # Convert rotation (euler angles) to rotation matrix + if hasattr(rotation, "x"): # Vector object (roll, pitch, yaw) + roll, pitch, yaw = rotation.x, rotation.y, rotation.z + else: # Dictionary + roll, pitch, yaw = rotation["roll"], rotation["pitch"], rotation["yaw"] + + # Create rotation matrix from euler angles (ZYX order) + cos_r, sin_r = np.cos(roll), np.sin(roll) + cos_p, sin_p = np.cos(pitch), np.sin(pitch) + cos_y, sin_y = np.cos(yaw), np.sin(yaw) + + # Rotation matrix for ZYX euler angles + R = np.array( + [ + [ + cos_y * cos_p, + cos_y * sin_p * sin_r - sin_y * cos_r, + cos_y * sin_p * cos_r + sin_y * sin_r, + ], + [ + sin_y * cos_p, + sin_y * sin_p * sin_r + cos_y * cos_r, + sin_y * sin_p * cos_r - cos_y * sin_r, + ], + [-sin_p, cos_p * sin_r, cos_p * cos_r], + ] + ) + + # Get dimensions + width = size.get("width", 0.1) + height = size.get("height", 0.1) + depth = size.get("depth", 0.1) + + # Create 8 corners of the bounding box (before rotation) + corners = np.array( + [ + [-width / 2, -height / 2, -depth / 2], # 0 + [width / 2, -height / 2, -depth / 2], # 1 + [width / 2, height / 2, -depth / 2], # 2 + [-width / 2, height / 2, -depth / 2], # 3 + [-width / 2, -height / 2, depth / 2], # 4 + [width / 2, -height / 2, depth / 2], # 5 + [width / 2, height / 2, depth / 2], # 6 + [-width / 2, height / 2, depth / 2], # 7 + ] + ) + + # Apply rotation and translation + rotated_corners = corners @ R.T + center + + return rotated_corners + + +def draw_3d_bounding_box_on_image(image, corners_2d, color, thickness=2): + """ + Draw a 3D bounding box on an image using projected 2D corners. + + Args: + image: Image to draw on + corners_2d: 8x2 array of 2D corner coordinates + color: RGB color tuple + thickness: Line thickness + """ + # Define the 12 edges of a cube (connecting corner indices) + edges = [ + (0, 1), + (1, 2), + (2, 3), + (3, 0), # Bottom face + (4, 5), + (5, 6), + (6, 7), + (7, 4), # Top face + (0, 4), + (1, 5), + (2, 6), + (3, 7), # Vertical edges + ] + + # Draw each edge + for start_idx, end_idx in edges: + start_point = tuple(corners_2d[start_idx].astype(int)) + end_point = tuple(corners_2d[end_idx].astype(int)) + cv2.line(image, start_point, end_point, color, thickness) + + +def extract_and_cluster_misc_points( + full_pcd: o3d.geometry.PointCloud, + all_objects: List[dict], + eps: float = 0.03, + min_points: int = 100, + enable_filtering: bool = True, + voxel_size: float = 0.02, +) -> Tuple[List[o3d.geometry.PointCloud], o3d.geometry.VoxelGrid]: + """ + Extract miscellaneous/background points and cluster them using DBSCAN. + + Args: + full_pcd: Complete scene point cloud + all_objects: List of objects with point clouds to subtract + eps: DBSCAN epsilon parameter (max distance between points in cluster) + min_points: DBSCAN min_samples parameter (min points to form cluster) + enable_filtering: Whether to apply statistical and radius filtering + voxel_size: Size of voxels for voxel grid generation + + Returns: + Tuple of (clustered_point_clouds, voxel_grid) + """ + if full_pcd is None or len(np.asarray(full_pcd.points)) == 0: + return [], o3d.geometry.VoxelGrid() + + if not all_objects: + # If no objects detected, cluster the full point cloud + clusters = _cluster_point_cloud_dbscan(full_pcd, eps, min_points) + voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) + return clusters, voxel_grid + + try: + # Start with a copy of the full point cloud + misc_pcd = o3d.geometry.PointCloud(full_pcd) + + # Remove object points by combining all object point clouds + all_object_points = [] + for obj in all_objects: + if "point_cloud" in obj and obj["point_cloud"] is not None: + obj_points = np.asarray(obj["point_cloud"].points) + if len(obj_points) > 0: + all_object_points.append(obj_points) + + if not all_object_points: + # No object points to remove, cluster full point cloud + clusters = _cluster_point_cloud_dbscan(misc_pcd, eps, min_points) + voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) + return clusters, voxel_grid + + # Combine all object points + combined_obj_points = np.vstack(all_object_points) + + # For efficiency, downsample both point clouds + misc_downsampled = misc_pcd.voxel_down_sample(voxel_size=0.005) + + # Create object point cloud for efficient operations + obj_pcd = o3d.geometry.PointCloud() + obj_pcd.points = o3d.utility.Vector3dVector(combined_obj_points) + obj_downsampled = obj_pcd.voxel_down_sample(voxel_size=0.005) + + misc_points = np.asarray(misc_downsampled.points) + obj_points_down = np.asarray(obj_downsampled.points) + + if len(misc_points) == 0 or len(obj_points_down) == 0: + clusters = _cluster_point_cloud_dbscan(misc_downsampled, eps, min_points) + voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) + return clusters, voxel_grid + + # Build tree for object points + obj_tree = cKDTree(obj_points_down) + + # Find distances from misc points to nearest object points + distances, _ = obj_tree.query(misc_points, k=1) + + # Keep points that are far enough from any object point + threshold = 0.015 # 1.5cm threshold + keep_mask = distances > threshold + + if not np.any(keep_mask): + return [], o3d.geometry.VoxelGrid() + + # Filter misc points + misc_indices = np.where(keep_mask)[0] + final_misc_pcd = misc_downsampled.select_by_index(misc_indices) + + if len(np.asarray(final_misc_pcd.points)) == 0: + return [], o3d.geometry.VoxelGrid() + + # Apply additional filtering if enabled + if enable_filtering: + # Apply statistical outlier filtering + filtered_misc_pcd, _ = filter_point_cloud_statistical( + final_misc_pcd, nb_neighbors=30, std_ratio=2.0 + ) + + if len(np.asarray(filtered_misc_pcd.points)) == 0: + return [], o3d.geometry.VoxelGrid() + + # Apply radius outlier filtering + final_filtered_misc_pcd, _ = filter_point_cloud_radius( + filtered_misc_pcd, + nb_points=20, + radius=0.03, # 3cm radius + ) + + if len(np.asarray(final_filtered_misc_pcd.points)) == 0: + return [], o3d.geometry.VoxelGrid() + + final_misc_pcd = final_filtered_misc_pcd + + # Cluster the misc points using DBSCAN + clusters = _cluster_point_cloud_dbscan(final_misc_pcd, eps, min_points) + + # Create voxel grid from all misc points (before clustering) + voxel_grid = _create_voxel_grid_from_point_cloud(final_misc_pcd, voxel_size) + + return clusters, voxel_grid + + except Exception as e: + print(f"Error in misc point extraction and clustering: {e}") + # Fallback: return downsampled full point cloud as single cluster + try: + downsampled = full_pcd.voxel_down_sample(voxel_size=0.02) + if len(np.asarray(downsampled.points)) > 0: + voxel_grid = _create_voxel_grid_from_point_cloud(downsampled, voxel_size) + return [downsampled], voxel_grid + else: + return [], o3d.geometry.VoxelGrid() + except: + return [], o3d.geometry.VoxelGrid() + + +def _create_voxel_grid_from_point_cloud( + pcd: o3d.geometry.PointCloud, voxel_size: float = 0.02 +) -> o3d.geometry.VoxelGrid: + """ + Create a voxel grid from a point cloud. + + Args: + pcd: Input point cloud + voxel_size: Size of each voxel + + Returns: + Open3D VoxelGrid object + """ + if len(np.asarray(pcd.points)) == 0: + return o3d.geometry.VoxelGrid() + + try: + # Create voxel grid from point cloud + voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size) + + # Color the voxels with a semi-transparent gray + for voxel in voxel_grid.get_voxels(): + voxel.color = [0.5, 0.5, 0.5] # Gray color + + print( + f"Created voxel grid with {len(voxel_grid.get_voxels())} voxels (voxel_size={voxel_size})" + ) + return voxel_grid + + except Exception as e: + print(f"Error creating voxel grid: {e}") + return o3d.geometry.VoxelGrid() + + +def _create_voxel_grid_from_clusters( + clusters: List[o3d.geometry.PointCloud], voxel_size: float = 0.02 +) -> o3d.geometry.VoxelGrid: + """ + Create a voxel grid from multiple clustered point clouds. + + Args: + clusters: List of clustered point clouds + voxel_size: Size of each voxel + + Returns: + Open3D VoxelGrid object + """ + if not clusters: + return o3d.geometry.VoxelGrid() + + # Combine all clusters into one point cloud + combined_points = [] + for cluster in clusters: + points = np.asarray(cluster.points) + if len(points) > 0: + combined_points.append(points) + + if not combined_points: + return o3d.geometry.VoxelGrid() + + # Create combined point cloud + all_points = np.vstack(combined_points) + combined_pcd = o3d.geometry.PointCloud() + combined_pcd.points = o3d.utility.Vector3dVector(all_points) + + return _create_voxel_grid_from_point_cloud(combined_pcd, voxel_size) + + +def _cluster_point_cloud_dbscan( + pcd: o3d.geometry.PointCloud, eps: float = 0.05, min_points: int = 50 +) -> List[o3d.geometry.PointCloud]: + """ + Cluster a point cloud using DBSCAN and return list of clustered point clouds. + + Args: + pcd: Point cloud to cluster + eps: DBSCAN epsilon parameter + min_points: DBSCAN min_samples parameter + + Returns: + List of point clouds, one for each cluster + """ + if len(np.asarray(pcd.points)) == 0: + return [] + + try: + # Apply DBSCAN clustering + labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_points)) + + # Get unique cluster labels (excluding noise points labeled as -1) + unique_labels = np.unique(labels) + cluster_pcds = [] + + for label in unique_labels: + if label == -1: # Skip noise points + continue + + # Get indices for this cluster + cluster_indices = np.where(labels == label)[0] + + if len(cluster_indices) > 0: + # Create point cloud for this cluster + cluster_pcd = pcd.select_by_index(cluster_indices) + + # Assign a random color to this cluster + cluster_color = np.random.rand(3) # Random RGB color + cluster_pcd.paint_uniform_color(cluster_color) + + cluster_pcds.append(cluster_pcd) + + print( + f"DBSCAN clustering found {len(cluster_pcds)} clusters from {len(np.asarray(pcd.points))} points" + ) + return cluster_pcds + + except Exception as e: + print(f"Error in DBSCAN clustering: {e}") + return [pcd] # Return original point cloud as fallback + + +def get_standard_coordinate_transform(): + """ + Get a standard coordinate transformation matrix for consistent visualization. + + This transformation ensures that: + - X (red) axis points right + - Y (green) axis points up + - Z (blue) axis points toward viewer + + Returns: + 4x4 transformation matrix + """ + # Standard transformation matrix to ensure consistent coordinate frame orientation + transform = np.array( + [ + [1, 0, 0, 0], # X points right + [0, -1, 0, 0], # Y points up (flip from OpenCV to standard) + [0, 0, -1, 0], # Z points toward viewer (flip depth) + [0, 0, 0, 1], + ] + ) + return transform + + +def visualize_clustered_point_clouds( + clustered_pcds: List[o3d.geometry.PointCloud], + window_name: str = "Clustered Point Clouds", + point_size: float = 2.0, + show_coordinate_frame: bool = True, + coordinate_frame_size: float = 0.1, +) -> None: + """ + Visualize multiple clustered point clouds with different colors. + + Args: + clustered_pcds: List of point clouds (already colored) + window_name: Name of the visualization window + point_size: Size of points in the visualization + show_coordinate_frame: Whether to show coordinate frame + coordinate_frame_size: Size of the coordinate frame + """ + if not clustered_pcds: + print("Warning: No clustered point clouds to visualize") + return + + # Apply standard coordinate transformation + transform = get_standard_coordinate_transform() + geometries = [] + for pcd in clustered_pcds: + pcd_copy = o3d.geometry.PointCloud(pcd) + pcd_copy.transform(transform) + geometries.append(pcd_copy) + + # Add coordinate frame + if show_coordinate_frame: + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=coordinate_frame_size + ) + coordinate_frame.transform(transform) + geometries.append(coordinate_frame) + + total_points = sum(len(np.asarray(pcd.points)) for pcd in clustered_pcds) + print(f"Visualizing {len(clustered_pcds)} clusters with {total_points} total points") + + try: + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=window_name, width=1280, height=720) + for geom in geometries: + vis.add_geometry(geom) + render_option = vis.get_render_option() + render_option.point_size = point_size + vis.run() + vis.destroy_window() + except Exception as e: + print(f"Failed to create interactive visualization: {e}") + o3d.visualization.draw_geometries( + geometries, window_name=window_name, width=1280, height=720 + ) + + +def visualize_pcd( + pcd: o3d.geometry.PointCloud, + window_name: str = "Point Cloud Visualization", + point_size: float = 1.0, + show_coordinate_frame: bool = True, + coordinate_frame_size: float = 0.1, +) -> None: + """ + Visualize an Open3D point cloud using Open3D's visualization window. + + Args: + pcd: Open3D point cloud to visualize + window_name: Name of the visualization window + point_size: Size of points in the visualization + show_coordinate_frame: Whether to show coordinate frame + coordinate_frame_size: Size of the coordinate frame + """ + if pcd is None: + print("Warning: Point cloud is None, nothing to visualize") + return + + if len(np.asarray(pcd.points)) == 0: + print("Warning: Point cloud is empty, nothing to visualize") + return + + # Apply standard coordinate transformation + transform = get_standard_coordinate_transform() + pcd_copy = o3d.geometry.PointCloud(pcd) + pcd_copy.transform(transform) + geometries = [pcd_copy] + + # Add coordinate frame + if show_coordinate_frame: + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=coordinate_frame_size + ) + coordinate_frame.transform(transform) + geometries.append(coordinate_frame) + + print(f"Visualizing point cloud with {len(np.asarray(pcd.points))} points") + + try: + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=window_name, width=1280, height=720) + for geom in geometries: + vis.add_geometry(geom) + render_option = vis.get_render_option() + render_option.point_size = point_size + vis.run() + vis.destroy_window() + except Exception as e: + print(f"Failed to create interactive visualization: {e}") + o3d.visualization.draw_geometries( + geometries, window_name=window_name, width=1280, height=720 + ) + + +def visualize_voxel_grid( + voxel_grid: o3d.geometry.VoxelGrid, + window_name: str = "Voxel Grid Visualization", + show_coordinate_frame: bool = True, + coordinate_frame_size: float = 0.1, +) -> None: + """ + Visualize an Open3D voxel grid using Open3D's visualization window. + + Args: + voxel_grid: Open3D voxel grid to visualize + window_name: Name of the visualization window + show_coordinate_frame: Whether to show coordinate frame + coordinate_frame_size: Size of the coordinate frame + """ + if voxel_grid is None: + print("Warning: Voxel grid is None, nothing to visualize") + return + + if len(voxel_grid.get_voxels()) == 0: + print("Warning: Voxel grid is empty, nothing to visualize") + return + + # VoxelGrid doesn't support transform, so we need to transform the source points instead + # For now, just visualize as-is with transformed coordinate frame + geometries = [voxel_grid] + + # Add coordinate frame + if show_coordinate_frame: + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=coordinate_frame_size + ) + coordinate_frame.transform(get_standard_coordinate_transform()) + geometries.append(coordinate_frame) + + print(f"Visualizing voxel grid with {len(voxel_grid.get_voxels())} voxels") + + try: + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=window_name, width=1280, height=720) + for geom in geometries: + vis.add_geometry(geom) + vis.run() + vis.destroy_window() + except Exception as e: + print(f"Failed to create interactive visualization: {e}") + o3d.visualization.draw_geometries( + geometries, window_name=window_name, width=1280, height=720 + ) + + +def combine_object_pointclouds( + point_clouds: Union[List[np.ndarray], List[o3d.geometry.PointCloud]], + colors: Optional[List[np.ndarray]] = None, +) -> o3d.geometry.PointCloud: + """ + Combine multiple point clouds into a single Open3D point cloud. + + Args: + point_clouds: List of point clouds as numpy arrays or Open3D point clouds + colors: List of colors as numpy arrays + Returns: + Combined Open3D point cloud + """ + all_points = [] + all_colors = [] + + for i, pcd in enumerate(point_clouds): + if isinstance(pcd, np.ndarray): + points = pcd[:, :3] + all_points.append(points) + if colors: + all_colors.append(colors[i]) + + elif isinstance(pcd, o3d.geometry.PointCloud): + points = np.asarray(pcd.points) + all_points.append(points) + if pcd.has_colors(): + colors = np.asarray(pcd.colors) + all_colors.append(colors) + + if not all_points: + return o3d.geometry.PointCloud() + + combined_pcd = o3d.geometry.PointCloud() + combined_pcd.points = o3d.utility.Vector3dVector(np.vstack(all_points)) + + if all_colors: + combined_pcd.colors = o3d.utility.Vector3dVector(np.vstack(all_colors)) + + return combined_pcd + + +def extract_centroids_from_masks( + rgb_image: np.ndarray, + depth_image: np.ndarray, + masks: List[np.ndarray], + camera_intrinsics: Union[List[float], np.ndarray], + min_points: int = 10, + max_depth: float = 10.0, +) -> List[Dict[str, Any]]: + """ + Extract 3D centroids and orientations from segmentation masks. + + Args: + rgb_image: RGB image (H, W, 3) + depth_image: Depth image (H, W) in meters + masks: List of boolean masks (H, W) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] or 3x3 matrix + min_points: Minimum number of valid 3D points required for a detection + max_depth: Maximum valid depth in meters + + Returns: + List of dictionaries containing: + - centroid: 3D centroid position [x, y, z] in camera frame + - orientation: Normalized direction vector from camera to centroid + - num_points: Number of valid 3D points + - mask_idx: Index of the mask in the input list + """ + # Extract camera parameters + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + else: + fx = camera_intrinsics[0, 0] + fy = camera_intrinsics[1, 1] + cx = camera_intrinsics[0, 2] + cy = camera_intrinsics[1, 2] + + results = [] + + for mask_idx, mask in enumerate(masks): + if mask is None or mask.sum() == 0: + continue + + # Get pixel coordinates where mask is True + y_coords, x_coords = np.where(mask) + + # Get depth values at mask locations + depths = depth_image[y_coords, x_coords] + + # Filter valid depths + valid_mask = (depths > 0) & (depths < max_depth) & np.isfinite(depths) + if valid_mask.sum() < min_points: + continue + + # Get valid coordinates and depths + valid_x = x_coords[valid_mask] + valid_y = y_coords[valid_mask] + valid_z = depths[valid_mask] + + # Convert to 3D points in camera frame + X = (valid_x - cx) * valid_z / fx + Y = (valid_y - cy) * valid_z / fy + Z = valid_z + + # Calculate centroid + centroid_x = np.mean(X) + centroid_y = np.mean(Y) + centroid_z = np.mean(Z) + centroid = np.array([centroid_x, centroid_y, centroid_z]) + + # Calculate orientation as normalized direction from camera origin to centroid + # Camera origin is at (0, 0, 0) + orientation = centroid / np.linalg.norm(centroid) + + results.append( + { + "centroid": centroid, + "orientation": orientation, + "num_points": int(valid_mask.sum()), + "mask_idx": mask_idx, + } + ) + + return results diff --git a/build/lib/dimos/perception/segmentation/__init__.py b/build/lib/dimos/perception/segmentation/__init__.py new file mode 100644 index 0000000000..a8f9a291ce --- /dev/null +++ b/build/lib/dimos/perception/segmentation/__init__.py @@ -0,0 +1,2 @@ +from .utils import * +from .sam_2d_seg import * diff --git a/build/lib/dimos/perception/segmentation/image_analyzer.py b/build/lib/dimos/perception/segmentation/image_analyzer.py new file mode 100644 index 0000000000..1260e41fe7 --- /dev/null +++ b/build/lib/dimos/perception/segmentation/image_analyzer.py @@ -0,0 +1,161 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +from openai import OpenAI +import cv2 +import os + +NORMAL_PROMPT = "What are in these images? Give a short word answer with at most two words, \ + if not sure, give a description of its shape or color like 'small tube', 'blue item'. \" \ + if does not look like an object, say 'unknown'. Export objects as a list of strings \ + in this exact format '['object 1', 'object 2', '...']'." + +RICH_PROMPT = ( + "What are in these images? Give a detailed description of each item, the first n images will be \ + cropped patches of the original image detected by the object detection model. \ + The last image will be the original image. Use the last image only for context, \ + do not describe objects in the last image. \ + Export the objects as a list of strings in this exact format, '['description of object 1', '...', '...']', \ + don't include anything else. " +) + + +class ImageAnalyzer: + def __init__(self): + """ + Initializes the ImageAnalyzer with OpenAI API credentials. + """ + self.client = OpenAI() + + def encode_image(self, image): + """ + Encodes an image to Base64. + + Parameters: + image (numpy array): Image array (BGR format). + + Returns: + str: Base64 encoded string of the image. + """ + _, buffer = cv2.imencode(".jpg", image) + return base64.b64encode(buffer).decode("utf-8") + + def analyze_images(self, images, detail="auto", prompt_type="normal"): + """ + Takes a list of cropped images and returns descriptions from OpenAI's Vision model. + + Parameters: + images (list of numpy arrays): Cropped images from the original frame. + detail (str): "low", "high", or "auto" to set image processing detail. + prompt_type (str): "normal" or "rich" to set the prompt type. + + Returns: + list of str: Descriptions of objects in each image. + """ + image_data = [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{self.encode_image(img)}", + "detail": detail, + }, + } + for img in images + ] + + if prompt_type == "normal": + prompt = NORMAL_PROMPT + elif prompt_type == "rich": + prompt = RICH_PROMPT + else: + raise ValueError(f"Invalid prompt type: {prompt_type}") + + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": [{"type": "text", "text": prompt}] + image_data, + } + ], + max_tokens=300, + timeout=5, + ) + + # Accessing the content of the response using dot notation + return [choice.message.content for choice in response.choices][0] + + +def main(): + # Define the directory containing cropped images + cropped_images_dir = "cropped_images" + if not os.path.exists(cropped_images_dir): + print(f"Directory '{cropped_images_dir}' does not exist.") + return + + # Load all images from the directory + images = [] + for filename in os.listdir(cropped_images_dir): + if filename.endswith(".jpg") or filename.endswith(".png"): + image_path = os.path.join(cropped_images_dir, filename) + image = cv2.imread(image_path) + if image is not None: + images.append(image) + else: + print(f"Warning: Could not read image {image_path}") + + if not images: + print("No valid images found in the directory.") + return + + # Initialize ImageAnalyzer + analyzer = ImageAnalyzer() + + # Analyze images + results = analyzer.analyze_images(images) + + # Split results into a list of items + object_list = [item.strip()[2:] for item in results.split("\n")] + + # Overlay text on images and display them + for i, (img, obj) in enumerate(zip(images, object_list)): + if obj: # Only process non-empty lines + # Add text to image + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + thickness = 2 + text = obj.strip() + + # Get text size + (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) + + # Position text at top of image + x = 10 + y = text_height + 10 + + # Add white background for text + cv2.rectangle( + img, (x - 5, y - text_height - 5), (x + text_width + 5, y + 5), (255, 255, 255), -1 + ) + # Add text + cv2.putText(img, text, (x, y), font, font_scale, (0, 0, 0), thickness) + + # Save or display the image + cv2.imwrite(f"annotated_image_{i}.jpg", img) + print(f"Detected object: {obj}") + + +if __name__ == "__main__": + main() diff --git a/build/lib/dimos/perception/segmentation/sam_2d_seg.py b/build/lib/dimos/perception/segmentation/sam_2d_seg.py new file mode 100644 index 0000000000..d33c7faa0d --- /dev/null +++ b/build/lib/dimos/perception/segmentation/sam_2d_seg.py @@ -0,0 +1,335 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor + +import cv2 +import onnxruntime +from ultralytics import FastSAM + +from dimos.perception.common.detection2d_tracker import get_tracked_results, target2dTracker +from dimos.perception.segmentation.image_analyzer import ImageAnalyzer +from dimos.perception.segmentation.utils import ( + crop_images_from_bboxes, + extract_masks_bboxes_probs_names, + filter_segmentation_results, + plot_results, +) +from dimos.utils.data import get_data +from dimos.utils.gpu_utils import is_cuda_available +from dimos.utils.logging_config import setup_logger +from dimos.utils.path_utils import get_project_root + +logger = setup_logger("dimos.perception.segmentation.sam_2d_seg") + + +class Sam2DSegmenter: + def __init__( + self, + model_path="models_fastsam", + model_name="FastSAM-s.onnx", + device="cpu", + min_analysis_interval=5.0, + use_tracker=True, + use_analyzer=True, + use_rich_labeling=False, + ): + self.device = device + if is_cuda_available(): + logger.info("Using CUDA for SAM 2d segmenter") + if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18 + onnxruntime.preload_dlls(cuda=True, cudnn=True) + self.device = "cuda" + else: + logger.info("Using CPU for SAM 2d segmenter") + self.device = "cpu" + # Core components + self.model = FastSAM(get_data(model_path) / model_name) + self.use_tracker = use_tracker + self.use_analyzer = use_analyzer + self.use_rich_labeling = use_rich_labeling + + module_dir = os.path.dirname(__file__) + self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") + + # Initialize tracker if enabled + if self.use_tracker: + self.tracker = target2dTracker( + history_size=80, + score_threshold_start=0.7, + score_threshold_stop=0.05, + min_frame_count=10, + max_missed_frames=50, + min_area_ratio=0.05, + max_area_ratio=0.4, + texture_range=(0.0, 0.35), + border_safe_distance=100, + weights={"prob": 1.0, "temporal": 3.0, "texture": 2.0, "border": 3.0, "size": 1.0}, + ) + + # Initialize analyzer components if enabled + if self.use_analyzer: + self.image_analyzer = ImageAnalyzer() + self.min_analysis_interval = min_analysis_interval + self.last_analysis_time = 0 + self.to_be_analyzed = deque() + self.object_names = {} + self.analysis_executor = ThreadPoolExecutor(max_workers=1) + self.current_future = None + self.current_queue_ids = None + + def process_image(self, image): + """Process an image and return segmentation results.""" + results = self.model.track( + source=image, + device=self.device, + retina_masks=True, + conf=0.6, + iou=0.9, + persist=True, + verbose=False, + tracker=self.tracker_config, + ) + + if len(results) > 0: + # Get initial segmentation results + masks, bboxes, track_ids, probs, names, areas = extract_masks_bboxes_probs_names( + results[0] + ) + + # Filter results + ( + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + filtered_texture_values, + ) = filter_segmentation_results(image, masks, bboxes, track_ids, probs, names, areas) + + if self.use_tracker: + # Update tracker with filtered results + tracked_targets = self.tracker.update( + image, + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + filtered_texture_values, + ) + + # Get tracked results + tracked_masks, tracked_bboxes, tracked_target_ids, tracked_probs, tracked_names = ( + get_tracked_results(tracked_targets) + ) + + if self.use_analyzer: + # Update analysis queue with tracked IDs + target_id_set = set(tracked_target_ids) + + # Remove untracked objects from object_names + all_target_ids = list(self.tracker.targets.keys()) + self.object_names = { + track_id: name + for track_id, name in self.object_names.items() + if track_id in all_target_ids + } + + # Remove untracked objects from queue and results + self.to_be_analyzed = deque( + [track_id for track_id in self.to_be_analyzed if track_id in target_id_set] + ) + + # Filter out any IDs being analyzed from the to_be_analyzed queue + if self.current_queue_ids: + self.to_be_analyzed = deque( + [ + tid + for tid in self.to_be_analyzed + if tid not in self.current_queue_ids + ] + ) + + # Add new track_ids to analysis queue + for track_id in tracked_target_ids: + if ( + track_id not in self.object_names + and track_id not in self.to_be_analyzed + ): + self.to_be_analyzed.append(track_id) + + return ( + tracked_masks, + tracked_bboxes, + tracked_target_ids, + tracked_probs, + tracked_names, + ) + else: + # Return filtered results directly if tracker is disabled + return ( + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + ) + return [], [], [], [], [] + + def check_analysis_status(self, tracked_target_ids): + """Check if analysis is complete and prepare new queue if needed.""" + if not self.use_analyzer: + return None, None + + current_time = time.time() + + # Check if current queue analysis is complete + if self.current_future and self.current_future.done(): + try: + results = self.current_future.result() + if results is not None: + # Map results to track IDs + object_list = eval(results) + for track_id, result in zip(self.current_queue_ids, object_list): + self.object_names[track_id] = result + except Exception as e: + print(f"Queue analysis failed: {e}") + self.current_future = None + self.current_queue_ids = None + self.last_analysis_time = current_time + + # If enough time has passed and we have items to analyze, start new analysis + if ( + not self.current_future + and self.to_be_analyzed + and current_time - self.last_analysis_time >= self.min_analysis_interval + ): + queue_indices = [] + queue_ids = [] + + # Collect all valid track IDs from the queue + while self.to_be_analyzed: + track_id = self.to_be_analyzed[0] + if track_id in tracked_target_ids: + bbox_idx = tracked_target_ids.index(track_id) + queue_indices.append(bbox_idx) + queue_ids.append(track_id) + self.to_be_analyzed.popleft() + + if queue_indices: + return queue_indices, queue_ids + return None, None + + def run_analysis(self, frame, tracked_bboxes, tracked_target_ids): + """Run queue image analysis in background.""" + if not self.use_analyzer: + return + + queue_indices, queue_ids = self.check_analysis_status(tracked_target_ids) + if queue_indices: + selected_bboxes = [tracked_bboxes[i] for i in queue_indices] + cropped_images = crop_images_from_bboxes(frame, selected_bboxes) + if cropped_images: + self.current_queue_ids = queue_ids + print(f"Analyzing objects with track_ids: {queue_ids}") + + if self.use_rich_labeling: + prompt_type = "rich" + cropped_images.append(frame) + else: + prompt_type = "normal" + + self.current_future = self.analysis_executor.submit( + self.image_analyzer.analyze_images, cropped_images, prompt_type=prompt_type + ) + + def get_object_names(self, track_ids, tracked_names): + """Get object names for the given track IDs, falling back to tracked names.""" + if not self.use_analyzer: + return tracked_names + + return [ + self.object_names.get(track_id, tracked_name) + for track_id, tracked_name in zip(track_ids, tracked_names) + ] + + def visualize_results(self, image, masks, bboxes, track_ids, probs, names): + """Generate an overlay visualization with segmentation results and object names.""" + return plot_results(image, masks, bboxes, track_ids, probs, names) + + def cleanup(self): + """Cleanup resources.""" + if self.use_analyzer: + self.analysis_executor.shutdown() + + +def main(): + # Example usage with different configurations + cap = cv2.VideoCapture(0) + + # Example 1: Full functionality with rich labeling + segmenter = Sam2DSegmenter( + min_analysis_interval=4.0, + use_tracker=True, + use_analyzer=True, + use_rich_labeling=True, # Enable rich labeling + ) + + # Example 2: Full functionality with normal labeling + # segmenter = Sam2DSegmenter(min_analysis_interval=4.0, use_tracker=True, use_analyzer=True) + + # Example 3: Tracker only (analyzer disabled) + # segmenter = Sam2DSegmenter(use_analyzer=False) + + # Example 4: Basic segmentation only (both tracker and analyzer disabled) + # segmenter = Sam2DSegmenter(use_tracker=False, use_analyzer=False) + + try: + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + start_time = time.time() + + # Process image and get results + masks, bboxes, target_ids, probs, names = segmenter.process_image(frame) + + # Run analysis if enabled + if segmenter.use_tracker and segmenter.use_analyzer: + segmenter.run_analysis(frame, bboxes, target_ids) + names = segmenter.get_object_names(target_ids, names) + + # processing_time = time.time() - start_time + # print(f"Processing time: {processing_time:.2f}s") + + overlay = segmenter.visualize_results(frame, masks, bboxes, target_ids, probs, names) + + cv2.imshow("Segmentation", overlay) + key = cv2.waitKey(1) + if key & 0xFF == ord("q"): + break + + finally: + segmenter.cleanup() + cap.release() + cv2.destroyAllWindows() + + +if __name__ == "__main__": + main() diff --git a/build/lib/dimos/perception/segmentation/test_sam_2d_seg.py b/build/lib/dimos/perception/segmentation/test_sam_2d_seg.py new file mode 100644 index 0000000000..297b265415 --- /dev/null +++ b/build/lib/dimos/perception/segmentation/test_sam_2d_seg.py @@ -0,0 +1,214 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import cv2 +import numpy as np +import pytest +import reactivex as rx +from reactivex import operators as ops + +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.perception.segmentation.utils import extract_masks_bboxes_probs_names +from dimos.stream import video_provider +from dimos.stream.video_provider import VideoProvider + + +@pytest.mark.heavy +class TestSam2DSegmenter: + def test_sam_segmenter_initialization(self): + """Test FastSAM segmenter initializes correctly with default model path.""" + try: + # Try to initialize with the default model path and existing device setting + segmenter = Sam2DSegmenter(use_analyzer=False) + assert segmenter is not None + assert segmenter.model is not None + except Exception as e: + # If the model file doesn't exist, the test should still pass with a warning + pytest.skip(f"Skipping test due to model initialization error: {e}") + + def test_sam_segmenter_process_image(self): + """Test FastSAM segmenter can process video frames and return segmentation masks.""" + # Import get data inside method to avoid pytest fixture confusion + from dimos.utils.data import get_data + + # Get test video path directly + video_path = get_data("assets") / "trimmed_video_office.mov" + try: + # Initialize segmenter without analyzer for faster testing + segmenter = Sam2DSegmenter(use_analyzer=False) + + # Note: conf and iou are parameters for process_image, not constructor + # We'll monkey patch the process_image method to use lower thresholds + original_process_image = segmenter.process_image + + def patched_process_image(image): + results = segmenter.model.track( + source=image, + device=segmenter.device, + retina_masks=True, + conf=0.1, # Lower confidence threshold for testing + iou=0.5, # Lower IoU threshold + persist=True, + verbose=False, + tracker=segmenter.tracker_config + if hasattr(segmenter, "tracker_config") + else None, + ) + + if len(results) > 0: + masks, bboxes, track_ids, probs, names, areas = ( + extract_masks_bboxes_probs_names(results[0]) + ) + return masks, bboxes, track_ids, probs, names + return [], [], [], [], [] + + # Replace the method + segmenter.process_image = patched_process_image + + # Create video provider and directly get a video stream observable + assert os.path.exists(video_path), f"Test video not found: {video_path}" + video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=1) + + # Use ReactiveX operators to process the stream + def process_frame(frame): + try: + # Process frame with FastSAM + masks, bboxes, track_ids, probs, names = segmenter.process_image(frame) + print( + f"SAM results - masks: {len(masks)}, bboxes: {len(bboxes)}, track_ids: {len(track_ids)}, names: {len(names)}" + ) + + return { + "frame": frame, + "masks": masks, + "bboxes": bboxes, + "track_ids": track_ids, + "probs": probs, + "names": names, + } + except Exception as e: + print(f"Error in process_frame: {e}") + return {} + + # Create the segmentation stream using pipe and map operator + segmentation_stream = video_stream.pipe(ops.map(process_frame)) + + # Collect results from the stream + results = [] + frames_processed = 0 + target_frames = 5 + + def on_next(result): + nonlocal frames_processed, results + if not result: + return + + results.append(result) + frames_processed += 1 + + # Stop processing after target frames + if frames_processed >= target_frames: + subscription.dispose() + + def on_error(error): + pytest.fail(f"Error in segmentation stream: {error}") + + def on_completed(): + pass + + # Subscribe and wait for results + subscription = segmentation_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + # Wait for frames to be processed + timeout = 30.0 # seconds + start_time = time.time() + while frames_processed < target_frames and time.time() - start_time < timeout: + time.sleep(0.5) + + # Clean up subscription + subscription.dispose() + video_provider.dispose_all() + + # Check if we have results + if len(results) == 0: + pytest.skip( + "No segmentation results found, but test connection established correctly" + ) + return + + print(f"Processed {len(results)} frames with segmentation results") + + # Analyze the first result + result = results[0] + + # Check that we have a frame + assert "frame" in result, "Result doesn't contain a frame" + assert isinstance(result["frame"], np.ndarray), "Frame is not a numpy array" + + # Check that segmentation results are valid + assert isinstance(result["masks"], list) + assert isinstance(result["bboxes"], list) + assert isinstance(result["track_ids"], list) + assert isinstance(result["probs"], list) + assert isinstance(result["names"], list) + + # All result lists should be the same length + assert ( + len(result["masks"]) + == len(result["bboxes"]) + == len(result["track_ids"]) + == len(result["probs"]) + == len(result["names"]) + ) + + # If we have masks, check that they have valid shape + if result.get("masks") and len(result["masks"]) > 0: + assert result["masks"][0].shape == ( + result["frame"].shape[0], + result["frame"].shape[1], + ), "Mask shape should match image dimensions" + print(f"Found {len(result['masks'])} masks in first frame") + else: + print("No masks found in first frame, but test connection established correctly") + + # Test visualization function + if result["masks"]: + vis_frame = segmenter.visualize_results( + result["frame"], + result["masks"], + result["bboxes"], + result["track_ids"], + result["probs"], + result["names"], + ) + assert isinstance(vis_frame, np.ndarray), "Visualization output should be an image" + assert vis_frame.shape == result["frame"].shape, ( + "Visualization should have same dimensions as input frame" + ) + + # We've already tested visualization above, so no need for a duplicate test + + except Exception as e: + pytest.skip(f"Skipping test due to error: {e}") + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/build/lib/dimos/perception/segmentation/utils.py b/build/lib/dimos/perception/segmentation/utils.py new file mode 100644 index 0000000000..c96a7d4a64 --- /dev/null +++ b/build/lib/dimos/perception/segmentation/utils.py @@ -0,0 +1,315 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 +import torch + + +class SimpleTracker: + def __init__(self, history_size=100, min_count=10, count_window=20): + """ + Simple temporal tracker that counts appearances in a fixed window. + :param history_size: Number of past frames to remember + :param min_count: Minimum number of appearances required + :param count_window: Number of latest frames to consider for counting + """ + self.history = [] + self.history_size = history_size + self.min_count = min_count + self.count_window = count_window + self.total_counts = {} + + def update(self, track_ids): + # Add new frame's track IDs to history + self.history.append(track_ids) + if len(self.history) > self.history_size: + self.history.pop(0) + + # Consider only the latest `count_window` frames for counting + recent_history = self.history[-self.count_window :] + all_tracks = np.concatenate(recent_history) if recent_history else np.array([]) + + # Compute occurrences efficiently using numpy + unique_ids, counts = np.unique(all_tracks, return_counts=True) + id_counts = dict(zip(unique_ids, counts)) + + # Update total counts but ensure it only contains IDs within the history size + total_tracked_ids = np.concatenate(self.history) if self.history else np.array([]) + unique_total_ids, total_counts = np.unique(total_tracked_ids, return_counts=True) + self.total_counts = dict(zip(unique_total_ids, total_counts)) + + # Return IDs that appear often enough + return [track_id for track_id, count in id_counts.items() if count >= self.min_count] + + def get_total_counts(self): + """Returns the total count of each tracking ID seen over time, limited to history size.""" + return self.total_counts + + +def extract_masks_bboxes_probs_names(result, max_size=0.7): + """ + Extracts masks, bounding boxes, probabilities, and class names from one Ultralytics result object. + + Parameters: + result: Ultralytics result object + max_size: float, maximum allowed size of object relative to image (0-1) + + Returns: + tuple: (masks, bboxes, track_ids, probs, names, areas) + """ + masks = [] + bboxes = [] + track_ids = [] + probs = [] + names = [] + areas = [] + + if result.masks is None: + return masks, bboxes, track_ids, probs, names, areas + + total_area = result.masks.orig_shape[0] * result.masks.orig_shape[1] + + for box, mask_data in zip(result.boxes, result.masks.data): + mask_numpy = mask_data + + # Extract bounding box + x1, y1, x2, y2 = box.xyxy[0].tolist() + + # Extract track_id if available + track_id = -1 # default if no tracking + if hasattr(box, "id") and box.id is not None: + track_id = int(box.id[0].item()) + + # Extract probability and class index + conf = float(box.conf[0]) + cls_idx = int(box.cls[0]) + area = (x2 - x1) * (y2 - y1) + + if area / total_area > max_size: + continue + + masks.append(mask_numpy) + bboxes.append([x1, y1, x2, y2]) + track_ids.append(track_id) + probs.append(conf) + names.append(result.names[cls_idx]) + areas.append(area) + + return masks, bboxes, track_ids, probs, names, areas + + +def compute_texture_map(frame, blur_size=3): + """ + Compute texture map using gradient statistics. + Returns high values for textured regions and low values for smooth regions. + + Parameters: + frame: BGR image + blur_size: Size of Gaussian blur kernel for pre-processing + + Returns: + numpy array: Texture map with values normalized to [0,1] + """ + # Convert to grayscale + if len(frame.shape) == 3: + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + else: + gray = frame + + # Pre-process with slight blur to reduce noise + if blur_size > 0: + gray = cv2.GaussianBlur(gray, (blur_size, blur_size), 0) + + # Compute gradients in x and y directions + grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) + + # Compute gradient magnitude and direction + magnitude = np.sqrt(grad_x**2 + grad_y**2) + + # Compute local standard deviation of gradient magnitude + texture_map = cv2.GaussianBlur(magnitude, (15, 15), 0) + + # Normalize to [0,1] + texture_map = (texture_map - texture_map.min()) / (texture_map.max() - texture_map.min() + 1e-8) + + return texture_map + + +def filter_segmentation_results( + frame, masks, bboxes, track_ids, probs, names, areas, texture_threshold=0.07, size_filter=800 +): + """ + Filters segmentation results using both overlap and saliency detection. + Uses mask_sum tensor for efficient overlap detection. + + Parameters: + masks: list of torch.Tensor containing mask data + bboxes: list of bounding boxes [x1, y1, x2, y2] + track_ids: list of tracking IDs + probs: list of confidence scores + names: list of class names + areas: list of object areas + frame: BGR image for computing saliency + texture_threshold: Average texture value required for mask to be kept + size_filter: Minimum size of the object to be kept + + Returns: + tuple: (filtered_masks, filtered_bboxes, filtered_track_ids, filtered_probs, filtered_names, filtered_texture_values, texture_map) + """ + if len(masks) <= 1: + return masks, bboxes, track_ids, probs, names, [] + + # Compute texture map once and convert to tensor + texture_map = compute_texture_map(frame) + + # Sort by area (smallest to largest) + sorted_indices = torch.tensor(areas).argsort(descending=False) + + device = masks[0].device # Get the device of the first mask + + # Create mask_sum tensor where each pixel stores the index of the mask that claims it + mask_sum = torch.zeros_like(masks[0], dtype=torch.int32) + + texture_map = torch.from_numpy(texture_map).to( + device + ) # Convert texture_map to tensor and move to device + + filtered_texture_values = [] # List to store texture values of filtered masks + + for i, idx in enumerate(sorted_indices): + mask = masks[idx] + # Compute average texture value within mask + texture_value = torch.mean(texture_map[mask > 0]) if torch.any(mask > 0) else 0 + + # Only claim pixels if mask passes texture threshold + if texture_value >= texture_threshold: + mask_sum[mask > 0] = i + filtered_texture_values.append( + texture_value.item() + ) # Store the texture value as a Python float + + # Get indices that appear in mask_sum (these are the masks we want to keep) + keep_indices, counts = torch.unique(mask_sum[mask_sum > 0], return_counts=True) + size_indices = counts > size_filter + keep_indices = keep_indices[size_indices] + + sorted_indices = sorted_indices.cpu() + keep_indices = keep_indices.cpu() + + # Map back to original indices and filter + final_indices = sorted_indices[keep_indices].tolist() + + filtered_masks = [masks[i] for i in final_indices] + filtered_bboxes = [bboxes[i] for i in final_indices] + filtered_track_ids = [track_ids[i] for i in final_indices] + filtered_probs = [probs[i] for i in final_indices] + filtered_names = [names[i] for i in final_indices] + + return ( + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + filtered_texture_values, + ) + + +def plot_results(image, masks, bboxes, track_ids, probs, names, alpha=0.5): + """ + Draws bounding boxes, masks, and labels on the given image with enhanced visualization. + Includes object names in the overlay and improved text visibility. + """ + h, w = image.shape[:2] + overlay = image.copy() + + for mask, bbox, track_id, prob, name in zip(masks, bboxes, track_ids, probs, names): + # Convert mask tensor to numpy if needed + if isinstance(mask, torch.Tensor): + mask = mask.cpu().numpy() + + mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR) + + # Generate consistent color based on track_id + if track_id != -1: + np.random.seed(track_id) + color = np.random.randint(0, 255, (3,), dtype=np.uint8) + np.random.seed(None) + else: + color = np.random.randint(0, 255, (3,), dtype=np.uint8) + + # Apply mask color + overlay[mask_resized > 0.5] = color + + # Draw bounding box + x1, y1, x2, y2 = map(int, bbox) + cv2.rectangle(overlay, (x1, y1), (x2, y2), color.tolist(), 2) + + # Prepare label text + label = f"ID:{track_id} {prob:.2f}" + if name: # Add object name if available + label += f" {name}" + + # Calculate text size for background rectangle + (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + # Draw background rectangle for text + cv2.rectangle(overlay, (x1, y1 - text_h - 8), (x1 + text_w + 4, y1), color.tolist(), -1) + + # Draw text with white color for better visibility + cv2.putText( + overlay, + label, + (x1 + 2, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), # White text + 1, + ) + + # Blend overlay with original image + result = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) + return result + + +def crop_images_from_bboxes(image, bboxes, buffer=0): + """ + Crops regions from an image based on bounding boxes with an optional buffer. + + Parameters: + image (numpy array): Input image. + bboxes (list of lists): List of bounding boxes [x1, y1, x2, y2]. + buffer (int): Number of pixels to expand each bounding box. + + Returns: + list of numpy arrays: Cropped image regions. + """ + height, width, _ = image.shape + cropped_images = [] + + for bbox in bboxes: + x1, y1, x2, y2 = bbox + + # Apply buffer + x1 = max(0, x1 - buffer) + y1 = max(0, y1 - buffer) + x2 = min(width, x2 + buffer) + y2 = min(height, y2 + buffer) + + cropped_image = image[int(y1) : int(y2), int(x1) : int(x2)] + cropped_images.append(cropped_image) + + return cropped_images diff --git a/build/lib/dimos/perception/semantic_seg.py b/build/lib/dimos/perception/semantic_seg.py new file mode 100644 index 0000000000..a07e69c279 --- /dev/null +++ b/build/lib/dimos/perception/semantic_seg.py @@ -0,0 +1,245 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.perception.segmentation import Sam2DSegmenter +from dimos.models.depth.metric3d import Metric3D +from dimos.hardware.camera import Camera +from reactivex import Observable +from reactivex import operators as ops +from dimos.types.segmentation import SegmentationType +import numpy as np +import cv2 + + +class SemanticSegmentationStream: + def __init__( + self, + device: str = "cuda", + enable_mono_depth: bool = True, + enable_rich_labeling: bool = True, + camera_params: dict = None, + gt_depth_scale=256.0, + ): + """ + Initialize a semantic segmentation stream using Sam2DSegmenter. + + Args: + device: Computation device ("cuda" or "cpu") + enable_mono_depth: Whether to enable monocular depth processing + enable_rich_labeling: Whether to enable rich labeling + camera_params: Dictionary containing either: + - Direct intrinsics: [fx, fy, cx, cy] + - Physical parameters: resolution, focal_length, sensor_size + """ + self.segmenter = Sam2DSegmenter( + device=device, + min_analysis_interval=5.0, + use_tracker=True, + use_analyzer=True, + use_rich_labeling=enable_rich_labeling, + ) + + self.enable_mono_depth = enable_mono_depth + if enable_mono_depth: + self.depth_model = Metric3D(gt_depth_scale) + + if camera_params: + # Check if direct intrinsics are provided + if "intrinsics" in camera_params: + intrinsics = camera_params["intrinsics"] + if len(intrinsics) != 4: + raise ValueError("Intrinsics must be a list of 4 values: [fx, fy, cx, cy]") + self.depth_model.update_intrinsic(intrinsics) + else: + # Create camera object and calculate intrinsics from physical parameters + self.camera = Camera( + resolution=camera_params.get("resolution"), + focal_length=camera_params.get("focal_length"), + sensor_size=camera_params.get("sensor_size"), + ) + intrinsics = self.camera.calculate_intrinsics() + self.depth_model.update_intrinsic( + [ + intrinsics["focal_length_x"], + intrinsics["focal_length_y"], + intrinsics["principal_point_x"], + intrinsics["principal_point_y"], + ] + ) + else: + raise ValueError("Camera parameters are required for monocular depth processing.") + + def create_stream(self, video_stream: Observable) -> Observable[SegmentationType]: + """ + Create an Observable stream of segmentation results from a video stream. + + Args: + video_stream: Observable that emits video frames + + Returns: + Observable that emits SegmentationType objects containing masks and metadata + """ + + def process_frame(frame): + # Process image and get results + masks, bboxes, target_ids, probs, names = self.segmenter.process_image(frame) + + # Run analysis if enabled + if self.segmenter.use_analyzer: + self.segmenter.run_analysis(frame, bboxes, target_ids) + names = self.segmenter.get_object_names(target_ids, names) + + viz_frame = self.segmenter.visualize_results( + frame, masks, bboxes, target_ids, probs, names + ) + + # Process depth if enabled + depth_viz = None + object_depths = [] + if self.enable_mono_depth: + # Get depth map + depth_map = self.depth_model.infer_depth(frame) + depth_map = np.array(depth_map) + + # Calculate average depth for each object + object_depths = [] + for mask in masks: + # Convert mask to numpy if needed + mask_np = mask.cpu().numpy() if hasattr(mask, "cpu") else mask + # Get depth values where mask is True + object_depth = depth_map[mask_np > 0.5] + # Calculate average depth (in meters) + avg_depth = np.mean(object_depth) if len(object_depth) > 0 else 0 + object_depths.append(avg_depth / 1000) + + # Create colorized depth visualization + depth_viz = self._create_depth_visualization(depth_map) + + # Overlay depth values on the visualization frame + for bbox, depth in zip(bboxes, object_depths): + x1, y1, x2, y2 = map(int, bbox) + # Draw depth text at bottom left of bounding box + depth_text = f"{depth:.2f}mm" + # Add black background for better visibility + text_size = cv2.getTextSize(depth_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] + cv2.rectangle( + viz_frame, + (x1, y2 - text_size[1] - 5), + (x1 + text_size[0], y2), + (0, 0, 0), + -1, + ) + # Draw text in white + cv2.putText( + viz_frame, + depth_text, + (x1, y2 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 2, + ) + + # Create metadata in the new requested format + objects = [] + for i in range(len(bboxes)): + obj_data = { + "object_id": target_ids[i] if i < len(target_ids) else None, + "bbox": bboxes[i], + "prob": probs[i] if i < len(probs) else None, + "label": names[i] if i < len(names) else None, + } + + # Add depth if available + if self.enable_mono_depth and i < len(object_depths): + obj_data["depth"] = object_depths[i] + + objects.append(obj_data) + + # Create the new metadata dictionary + metadata = {"frame": frame, "viz_frame": viz_frame, "objects": objects} + + # Add depth visualization if available + if depth_viz is not None: + metadata["depth_viz"] = depth_viz + + # Convert masks to numpy arrays if they aren't already + numpy_masks = [mask.cpu().numpy() if hasattr(mask, "cpu") else mask for mask in masks] + + return SegmentationType(masks=numpy_masks, metadata=metadata) + + return video_stream.pipe(ops.map(process_frame)) + + def _create_depth_visualization(self, depth_map): + """ + Create a colorized visualization of the depth map. + + Args: + depth_map: Raw depth map in meters + + Returns: + Colorized depth map visualization + """ + # Normalize depth map to 0-255 range for visualization + depth_min = np.min(depth_map) + depth_max = np.max(depth_map) + depth_normalized = ((depth_map - depth_min) / (depth_max - depth_min) * 255).astype( + np.uint8 + ) + + # Apply colormap (using JET colormap for better depth perception) + depth_colored = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET) + + # Add depth scale bar + scale_height = 30 + scale_width = depth_map.shape[1] # Match width with depth map + scale_bar = np.zeros((scale_height, scale_width, 3), dtype=np.uint8) + + # Create gradient for scale bar + for i in range(scale_width): + color = cv2.applyColorMap( + np.array([[i * 255 // scale_width]], dtype=np.uint8), cv2.COLORMAP_JET + ) + scale_bar[:, i] = color[0, 0] + + # Add depth values to scale bar + cv2.putText( + scale_bar, + f"{depth_min:.1f}mm", + (5, 20), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 1, + ) + cv2.putText( + scale_bar, + f"{depth_max:.1f}mm", + (scale_width - 60, 20), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 1, + ) + + # Combine depth map and scale bar + combined_viz = np.vstack((depth_colored, scale_bar)) + + return combined_viz + + def cleanup(self): + """Clean up resources.""" + self.segmenter.cleanup() + if self.enable_mono_depth: + del self.depth_model diff --git a/build/lib/dimos/perception/spatial_perception.py b/build/lib/dimos/perception/spatial_perception.py new file mode 100644 index 0000000000..b994b52bc4 --- /dev/null +++ b/build/lib/dimos/perception/spatial_perception.py @@ -0,0 +1,438 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Spatial Memory module for creating a semantic map of the environment. +""" + +import uuid +import time +import os +from typing import Dict, List, Optional, Any + +import numpy as np +from reactivex import Observable, disposable +from reactivex import operators as ops +from datetime import datetime + +from dimos.utils.logging_config import setup_logger +from dimos.agents.memory.spatial_vector_db import SpatialVectorDB +from dimos.agents.memory.image_embedding import ImageEmbeddingProvider +from dimos.agents.memory.visual_memory import VisualMemory +from dimos.types.vector import Vector +from dimos.types.robot_location import RobotLocation + +logger = setup_logger("dimos.perception.spatial_memory") + + +class SpatialMemory: + """ + A class for building and querying Robot spatial memory. + + This class processes video frames from ROSControl, associates them with + XY locations, and stores them in a vector database for later retrieval. + It also maintains a list of named robot locations that can be queried by name. + """ + + def __init__( + self, + collection_name: str = "spatial_memory", + embedding_model: str = "clip", + embedding_dimensions: int = 512, + min_distance_threshold: float = 0.01, # Min distance in meters to store a new frame + min_time_threshold: float = 1.0, # Min time in seconds to record a new frame + db_path: Optional[str] = None, # Path for ChromaDB persistence + visual_memory_path: Optional[str] = None, # Path for saving/loading visual memory + new_memory: bool = False, # Whether to create a new memory from scratch + output_dir: Optional[str] = None, # Directory for storing visual memory data + chroma_client: Any = None, # Optional ChromaDB client for persistence + visual_memory: Optional[ + "VisualMemory" + ] = None, # Optional VisualMemory instance for storing images + video_stream: Optional[Observable] = None, # Video stream to process + get_pose: Optional[callable] = None, # Function that returns position and rotation + ): + """ + Initialize the spatial perception system. + + Args: + collection_name: Name of the vector database collection + embedding_model: Model to use for image embeddings ("clip", "resnet", etc.) + embedding_dimensions: Dimensions of the embedding vectors + min_distance_threshold: Minimum distance in meters to record a new frame + min_time_threshold: Minimum time in seconds to record a new frame + chroma_client: Optional ChromaDB client for persistent storage + visual_memory: Optional VisualMemory instance for storing images + output_dir: Directory for storing visual memory data if visual_memory is not provided + """ + self.collection_name = collection_name + self.embedding_model = embedding_model + self.embedding_dimensions = embedding_dimensions + self.min_distance_threshold = min_distance_threshold + self.min_time_threshold = min_time_threshold + + # Set up paths for persistence + self.db_path = db_path + self.visual_memory_path = visual_memory_path + self.output_dir = output_dir + + # Setup ChromaDB client if not provided + self._chroma_client = chroma_client + if chroma_client is None and db_path is not None: + # Create db directory if needed + os.makedirs(db_path, exist_ok=True) + + # Clean up existing DB if creating new memory + if new_memory and os.path.exists(db_path): + try: + logger.info("Creating new ChromaDB database (new_memory=True)") + # Try to delete any existing database files + import shutil + + for item in os.listdir(db_path): + item_path = os.path.join(db_path, item) + if os.path.isfile(item_path): + os.unlink(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + logger.info(f"Removed existing ChromaDB files from {db_path}") + except Exception as e: + logger.error(f"Error clearing ChromaDB directory: {e}") + + from chromadb.config import Settings + import chromadb + + self._chroma_client = chromadb.PersistentClient( + path=db_path, settings=Settings(anonymized_telemetry=False) + ) + + # Initialize or load visual memory + self._visual_memory = visual_memory + if visual_memory is None: + if new_memory or not os.path.exists(visual_memory_path or ""): + logger.info("Creating new visual memory") + self._visual_memory = VisualMemory(output_dir=output_dir) + else: + try: + logger.info(f"Loading existing visual memory from {visual_memory_path}...") + self._visual_memory = VisualMemory.load( + visual_memory_path, output_dir=output_dir + ) + logger.info(f"Loaded {self._visual_memory.count()} images from previous runs") + except Exception as e: + logger.error(f"Error loading visual memory: {e}") + self._visual_memory = VisualMemory(output_dir=output_dir) + + # Initialize vector database + self.vector_db: SpatialVectorDB = SpatialVectorDB( + collection_name=collection_name, + chroma_client=self._chroma_client, + visual_memory=self._visual_memory, + ) + + self.embedding_provider: ImageEmbeddingProvider = ImageEmbeddingProvider( + model_name=embedding_model, dimensions=embedding_dimensions + ) + + self.last_position: Optional[Vector] = None + self.last_record_time: Optional[float] = None + + self.frame_count: int = 0 + self.stored_frame_count: int = 0 + + # For tracking stream subscription + self._subscription = None + + # List to store robot locations + self.robot_locations: List[RobotLocation] = [] + + logger.info(f"SpatialMemory initialized with model {embedding_model}") + + # Start processing video stream if provided + if video_stream is not None and get_pose is not None: + self.start_continuous_processing(video_stream, get_pose) + + def query_by_location( + self, x: float, y: float, radius: float = 2.0, limit: int = 5 + ) -> List[Dict]: + """ + Query the vector database for images near the specified location. + + Args: + x: X coordinate + y: Y coordinate + radius: Search radius in meters + limit: Maximum number of results to return + + Returns: + List of results, each containing the image and its metadata + """ + return self.vector_db.query_by_location(x, y, radius, limit) + + def start_continuous_processing( + self, video_stream: Observable, get_pose: callable + ) -> disposable.Disposable: + """ + Start continuous processing of video frames from an Observable stream. + + Args: + video_stream: Observable of video frames + get_pose: Callable that returns position and rotation for each frame + + Returns: + Disposable subscription that can be used to stop processing + """ + # Stop any existing subscription + self.stop_continuous_processing() + + # Map each video frame to include transform data + combined_stream = video_stream.pipe( + ops.map(lambda video_frame: {"frame": video_frame, **get_pose()}), + # Filter out bad transforms + ops.filter( + lambda data: data.get("position") is not None and data.get("rotation") is not None + ), + ) + + # Process with spatial memory + result_stream = self.process_stream(combined_stream) + + # Subscribe to the result stream + self._subscription = result_stream.subscribe( + on_next=self._on_frame_processed, + on_error=lambda e: logger.error(f"Error in spatial memory stream: {e}"), + on_completed=lambda: logger.info("Spatial memory stream completed"), + ) + + logger.info("Continuous spatial memory processing started") + return self._subscription + + def stop_continuous_processing(self) -> None: + """ + Stop continuous processing of video frames. + """ + if self._subscription is not None: + try: + self._subscription.dispose() + self._subscription = None + logger.info("Stopped continuous spatial memory processing") + except Exception as e: + logger.error(f"Error stopping spatial memory processing: {e}") + + def _on_frame_processed(self, result: Dict[str, Any]) -> None: + """ + Handle updates from the spatial memory processing stream. + """ + # Log successful frame storage (if stored) + position = result.get("position") + if position is not None: + logger.debug( + f"Spatial memory updated with frame at ({position[0]:.2f}, {position[1]:.2f}, {position[2]:.2f})" + ) + + # Periodically save visual memory to disk (e.g., every 100 frames) + if self._visual_memory is not None and self.visual_memory_path is not None: + if self.stored_frame_count % 100 == 0: + self.save() + + def save(self) -> bool: + """ + Save the visual memory component to disk. + + Returns: + True if memory was saved successfully, False otherwise + """ + if self._visual_memory is not None and self.visual_memory_path is not None: + try: + saved_path = self._visual_memory.save(self.visual_memory_path) + logger.info(f"Saved {self._visual_memory.count()} images to {saved_path}") + return True + except Exception as e: + logger.error(f"Failed to save visual memory: {e}") + return False + + def process_stream(self, combined_stream: Observable) -> Observable: + """ + Process a combined stream of video frames and positions. + + This method handles a stream where each item already contains both the frame and position, + such as the stream created by combining video and transform streams with the + with_latest_from operator. + + Args: + combined_stream: Observable stream of dictionaries containing 'frame' and 'position' + + Returns: + Observable of processing results, including the stored frame and its metadata + """ + self.last_position = None + self.last_record_time = None + + def process_combined_data(data): + self.frame_count += 1 + + frame = data.get("frame") + position_vec = data.get("position") # Use .get() for consistency + rotation_vec = data.get("rotation") # Get rotation data if available + + if not position_vec or not rotation_vec: + logger.info("No position or rotation data available, skipping frame") + return None + + if ( + self.last_position is not None + and (self.last_position - position_vec).length() < self.min_distance_threshold + ): + logger.debug("Position has not moved, skipping frame") + return None + + if ( + self.last_record_time is not None + and (time.time() - self.last_record_time) < self.min_time_threshold + ): + logger.debug("Time since last record too short, skipping frame") + return None + + current_time = time.time() + + frame_embedding = self.embedding_provider.get_embedding(frame) + + frame_id = f"frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + + # Create metadata dictionary with primitive types only + metadata = { + "pos_x": float(position_vec.x), + "pos_y": float(position_vec.y), + "pos_z": float(position_vec.z), + "rot_x": float(rotation_vec.x), + "rot_y": float(rotation_vec.y), + "rot_z": float(rotation_vec.z), + "timestamp": current_time, + "frame_id": frame_id, + } + + self.vector_db.add_image_vector( + vector_id=frame_id, image=frame, embedding=frame_embedding, metadata=metadata + ) + + self.last_position = position_vec + self.last_record_time = current_time + self.stored_frame_count += 1 + + logger.info( + f"Stored frame at position {position_vec}, rotation {rotation_vec})" + f" stored {self.stored_frame_count}/{self.frame_count} frames" + ) + + # Create return dictionary with primitive-compatible values + return { + "frame": frame, + "position": (position_vec.x, position_vec.y, position_vec.z), + "rotation": (rotation_vec.x, rotation_vec.y, rotation_vec.z), + "frame_id": frame_id, + "timestamp": current_time, + } + + return combined_stream.pipe( + ops.map(process_combined_data), ops.filter(lambda result: result is not None) + ) + + def query_by_image(self, image: np.ndarray, limit: int = 5) -> List[Dict]: + """ + Query the vector database for images similar to the provided image. + + Args: + image: Query image + limit: Maximum number of results to return + + Returns: + List of results, each containing the image and its metadata + """ + embedding = self.embedding_provider.get_embedding(image) + return self.vector_db.query_by_embedding(embedding, limit) + + def query_by_text(self, text: str, limit: int = 5) -> List[Dict]: + """ + Query the vector database for images matching the provided text description. + + This method uses CLIP's text-to-image matching capability to find images + that semantically match the text query (e.g., "where is the kitchen"). + + Args: + text: Text query to search for + limit: Maximum number of results to return + + Returns: + List of results, each containing the image, its metadata, and similarity score + """ + logger.info(f"Querying spatial memory with text: '{text}'") + return self.vector_db.query_by_text(text, limit) + + def add_robot_location(self, location: RobotLocation) -> bool: + """ + Add a named robot location to spatial memory. + + Args: + location: The RobotLocation object to add + + Returns: + True if successfully added, False otherwise + """ + try: + # Add to our list of robot locations + self.robot_locations.append(location) + logger.info(f"Added robot location '{location.name}' at position {location.position}") + return True + + except Exception as e: + logger.error(f"Error adding robot location: {e}") + return False + + def get_robot_locations(self) -> List[RobotLocation]: + """ + Get all stored robot locations. + + Returns: + List of RobotLocation objects + """ + return self.robot_locations + + def find_robot_location(self, name: str) -> Optional[RobotLocation]: + """ + Find a robot location by name. + + Args: + name: Name of the location to find + + Returns: + RobotLocation object if found, None otherwise + """ + # Simple search through our list of locations + for location in self.robot_locations: + if location.name.lower() == name.lower(): + return location + + return None + + def cleanup(self): + """Clean up resources.""" + # Stop any ongoing processing + self.stop_continuous_processing() + + # Save data if possible + self.save() + + # Log cleanup + if self.vector_db: + logger.info(f"Cleaning up SpatialMemory, stored {self.stored_frame_count} frames") diff --git a/build/lib/dimos/perception/test_spatial_memory.py b/build/lib/dimos/perception/test_spatial_memory.py new file mode 100644 index 0000000000..9a519fe59c --- /dev/null +++ b/build/lib/dimos/perception/test_spatial_memory.py @@ -0,0 +1,214 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import time + +import cv2 +import numpy as np +import pytest +import reactivex as rx +from reactivex import Observable +from reactivex import operators as ops +from reactivex.subject import Subject + +from dimos.perception.spatial_perception import SpatialMemory +from dimos.stream.video_provider import VideoProvider +from dimos.types.pose import Pose +from dimos.types.vector import Vector + + +@pytest.mark.heavy +class TestSpatialMemory: + @pytest.fixture(scope="function") + def temp_dir(self): + # Create a temporary directory for storing spatial memory data + temp_dir = tempfile.mkdtemp() + yield temp_dir + # Clean up + shutil.rmtree(temp_dir) + + def test_spatial_memory_initialization(self): + """Test SpatialMemory initializes correctly with CLIP model.""" + try: + # Initialize spatial memory with default CLIP model + memory = SpatialMemory( + collection_name="test_collection", embedding_model="clip", new_memory=True + ) + assert memory is not None + assert memory.embedding_model == "clip" + assert memory.embedding_provider is not None + except Exception as e: + # If the model doesn't initialize, skip the test + pytest.fail(f"Failed to initialize model: {e}") + + def test_image_embedding(self): + """Test generating image embeddings using CLIP.""" + try: + # Initialize spatial memory with CLIP model + memory = SpatialMemory( + collection_name="test_collection", embedding_model="clip", new_memory=True + ) + + # Create a test image - use a simple colored square + test_image = np.zeros((224, 224, 3), dtype=np.uint8) + test_image[50:150, 50:150] = [0, 0, 255] # Blue square + + # Generate embedding + embedding = memory.embedding_provider.get_embedding(test_image) + + # Check embedding shape and characteristics + assert embedding is not None + assert isinstance(embedding, np.ndarray) + assert embedding.shape[0] == memory.embedding_dimensions + + # Check that embedding is normalized (unit vector) + assert np.isclose(np.linalg.norm(embedding), 1.0, atol=1e-5) + + # Test text embedding + text_embedding = memory.embedding_provider.get_text_embedding("a blue square") + assert text_embedding is not None + assert isinstance(text_embedding, np.ndarray) + assert text_embedding.shape[0] == memory.embedding_dimensions + assert np.isclose(np.linalg.norm(text_embedding), 1.0, atol=1e-5) + except Exception as e: + pytest.fail(f"Error in test: {e}") + + def test_spatial_memory_processing(self, temp_dir): + """Test processing video frames and building spatial memory with CLIP embeddings.""" + try: + # Initialize spatial memory with temporary storage + memory = SpatialMemory( + collection_name="test_collection", + embedding_model="clip", + new_memory=True, + db_path=os.path.join(temp_dir, "chroma_db"), + visual_memory_path=os.path.join(temp_dir, "visual_memory.pkl"), + output_dir=os.path.join(temp_dir, "images"), + min_distance_threshold=0.01, + min_time_threshold=0.01, + ) + + from dimos.utils.data import get_data + + video_path = get_data("assets") / "trimmed_video_office.mov" + assert os.path.exists(video_path), f"Test video not found: {video_path}" + video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) + + # Create a frame counter for position generation + frame_counter = 0 + + # Process each video frame directly + def process_frame(frame): + nonlocal frame_counter + + # Generate a unique position for this frame to ensure minimum distance threshold is met + pos = Pose(frame_counter * 0.5, frame_counter * 0.5, 0) + transform = {"position": pos, "timestamp": time.time()} + frame_counter += 1 + + # Create a dictionary with frame, position and rotation for SpatialMemory.process_stream + return { + "frame": frame, + "position": transform["position"], + "rotation": transform["position"], # Using position as rotation for testing + } + + # Create a stream that processes each frame + formatted_stream = video_stream.pipe(ops.map(process_frame)) + + # Process the stream using SpatialMemory's built-in processing + print("Creating spatial memory stream...") + spatial_stream = memory.process_stream(formatted_stream) + + # Stream is now created above using memory.process_stream() + + # Collect results from the stream + results = [] + + frames_processed = 0 + target_frames = 100 # Process more frames for thorough testing + + def on_next(result): + nonlocal results, frames_processed + if not result: # Skip None results + return + + results.append(result) + frames_processed += 1 + + # Stop processing after target frames + if frames_processed >= target_frames: + subscription.dispose() + + def on_error(error): + pytest.fail(f"Error in spatial stream: {error}") + + def on_completed(): + pass + + # Subscribe and wait for results + subscription = spatial_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + # Wait for frames to be processed + timeout = 30.0 # seconds + start_time = time.time() + while frames_processed < target_frames and time.time() - start_time < timeout: + time.sleep(0.5) + + subscription.dispose() + + assert len(results) > 0, "Failed to process any frames with spatial memory" + + relevant_queries = ["office", "room with furniture"] + irrelevant_query = "star wars" + + for query in relevant_queries: + results = memory.query_by_text(query, limit=2) + print(f"\nResults for query: '{query}'") + + assert len(results) > 0, f"No results found for relevant query: {query}" + + similarities = [1 - r.get("distance") for r in results] + print(f"Similarities: {similarities}") + + assert any(d > 0.24 for d in similarities), ( + f"Expected at least one result with similarity > 0.24 for query '{query}'" + ) + + results = memory.query_by_text(irrelevant_query, limit=2) + print(f"\nResults for query: '{irrelevant_query}'") + + if results: + similarities = [1 - r.get("distance") for r in results] + print(f"Similarities: {similarities}") + + assert all(d < 0.25 for d in similarities), ( + f"Expected all results to have similarity < 0.25 for irrelevant query '{irrelevant_query}'" + ) + + except Exception as e: + pytest.fail(f"Error in test: {e}") + finally: + memory.cleanup() + video_provider.dispose_all() + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/build/lib/dimos/perception/visual_servoing.py b/build/lib/dimos/perception/visual_servoing.py new file mode 100644 index 0000000000..40cee7c60c --- /dev/null +++ b/build/lib/dimos/perception/visual_servoing.py @@ -0,0 +1,500 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import threading +from typing import Dict, Optional, List, Tuple +import logging +import numpy as np + +from dimos.utils.simple_controller import VisualServoingController + +# Configure logging +logger = logging.getLogger(__name__) + + +def calculate_iou(box1, box2): + """Calculate Intersection over Union between two bounding boxes.""" + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + intersection = max(0, x2 - x1) * max(0, y2 - y1) + area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0 + + +class VisualServoing: + """ + A class that performs visual servoing to track and follow a human target. + + The class will use the provided tracking stream to detect people and estimate + their distance and angle, then use a VisualServoingController to generate + appropriate velocity commands to track the target. + """ + + def __init__( + self, + tracking_stream=None, + max_linear_speed=0.8, + max_angular_speed=1.5, + desired_distance=1.5, + max_lost_frames=10000, + iou_threshold=0.6, + ): + """Initialize the visual servoing. + + Args: + tracking_stream: Observable tracking stream (must be already set up) + max_linear_speed: Maximum linear speed in m/s + max_angular_speed: Maximum angular speed in rad/s + desired_distance: Desired distance to maintain from target in meters + max_lost_frames: Maximum number of frames target can be lost before stopping tracking + iou_threshold: Minimum IOU threshold to consider bounding boxes as matching + """ + self.tracking_stream = tracking_stream + self.max_linear_speed = max_linear_speed + self.max_angular_speed = max_angular_speed + self.desired_distance = desired_distance + self.max_lost_frames = max_lost_frames + self.iou_threshold = iou_threshold + + # Initialize the controller with PID parameters tuned for slow-moving robot + # Distance PID: (kp, ki, kd, output_limits, integral_limit, deadband, output_deadband) + distance_pid_params = ( + 1.0, # kp: Moderate proportional gain for smooth approach + 0.2, # ki: Small integral gain to eliminate steady-state error + 0.1, # kd: Some damping for smooth motion + (-self.max_linear_speed, self.max_linear_speed), # output_limits + 0.5, # integral_limit: Prevent windup + 0.1, # deadband: Small deadband for distance control + 0.05, # output_deadband: Minimum output to overcome friction + ) + + # Angle PID: (kp, ki, kd, output_limits, integral_limit, deadband, output_deadband) + angle_pid_params = ( + 1.4, # kp: Higher proportional gain for responsive turning + 0.1, # ki: Small integral gain + 0.05, # kd: Light damping to prevent oscillation + (-self.max_angular_speed, self.max_angular_speed), # output_limits + 0.3, # integral_limit: Prevent windup + 0.1, # deadband: Small deadband for angle control + 0.1, # output_deadband: Minimum output to overcome friction + True, # Invert output for angular control + ) + + # Initialize the visual servoing controller + self.controller = VisualServoingController( + distance_pid_params=distance_pid_params, angle_pid_params=angle_pid_params + ) + + # Initialize tracking state + self.last_control_time = time.time() + self.running = False + self.current_target = None # (target_id, bbox) + self.target_lost_frames = 0 + + # Add variables to track current distance and angle + self.current_distance = None + self.current_angle = None + + # Stream subscription management + self.subscription = None + self.latest_result = None + self.result_lock = threading.Lock() + self.stop_event = threading.Event() + + # Subscribe to the tracking stream + self._subscribe_to_tracking_stream() + + def start_tracking( + self, + desired_distance: int = None, + point: Tuple[int, int] = None, + timeout_wait_for_target: float = 20.0, + ) -> bool: + """ + Start tracking a human target using visual servoing. + + Args: + point: Optional tuple of (x, y) coordinates in image space. If provided, + will find the target whose bounding box contains this point. + If None, will track the closest person. + + Returns: + bool: True if tracking was successfully started, False otherwise + """ + if desired_distance is not None: + self.desired_distance = desired_distance + + if self.tracking_stream is None: + self.running = False + return False + + # Get the latest frame and targets from person tracker + try: + # Try getting the result multiple times with delays + for attempt in range(10): + result = self._get_current_tracking_result() + + if result is not None: + break + + logger.warning( + f"Attempt {attempt + 1}: No tracking result, retrying in 1 second..." + ) + time.sleep(3) # Wait 1 second between attempts + + if result is None: + logger.warning("Stream error, no targets found after multiple attempts") + return False + + targets = result.get("targets") + + # If bbox is provided, find matching target based on IOU + if point is not None and not self.running: + # Find the target with highest IOU to the provided bbox + best_target = self._find_target_by_point(point, targets) + # If no bbox is provided, find the closest person + elif not self.running: + if timeout_wait_for_target > 0.0 and len(targets) == 0: + # Wait for target to appear + start_time = time.time() + while time.time() - start_time < timeout_wait_for_target: + time.sleep(0.2) + result = self._get_current_tracking_result() + targets = result.get("targets") + if len(targets) > 0: + break + best_target = self._find_closest_target(targets) + else: + # Already tracking + return True + + if best_target: + # Set as current target and reset lost counter + target_id = best_target.get("target_id") + target_bbox = best_target.get("bbox") + self.current_target = (target_id, target_bbox) + self.target_lost_frames = 0 + self.running = True + logger.info(f"Started tracking target ID: {target_id}") + + # Get distance and angle and compute control (store as initial control values) + distance = best_target.get("distance") + angle = best_target.get("angle") + self._compute_control(distance, angle) + return True + else: + if point is not None: + logger.warning("No matching target found") + else: + logger.warning("No suitable target found for tracking") + self.running = False + return False + except Exception as e: + logger.error(f"Error starting tracking: {e}") + self.running = False + return False + + def _find_target_by_point(self, point, targets): + """Find the target whose bounding box contains the given point. + + Args: + point: Tuple of (x, y) coordinates in image space + targets: List of target dictionaries + + Returns: + dict: The target whose bbox contains the point, or None if no match + """ + x, y = point + for target in targets: + bbox = target.get("bbox") + if not bbox: + continue + + x1, y1, x2, y2 = bbox + if x1 <= x <= x2 and y1 <= y <= y2: + return target + return None + + def updateTracking(self) -> Dict[str, any]: + """ + Update tracking of current target. + + Returns: + Dict with linear_vel, angular_vel, and running state + """ + if not self.running or self.current_target is None: + self.running = False + self.current_distance = None + self.current_angle = None + return {"linear_vel": 0.0, "angular_vel": 0.0} + + # Get the latest tracking result + result = self._get_current_tracking_result() + + # Get targets from result + targets = result.get("targets") + + # Try to find current target by ID or IOU + current_target_id, current_bbox = self.current_target + target_found = False + + # First try to find by ID + for target in targets: + if target.get("target_id") == current_target_id: + # Found by ID, update bbox + self.current_target = (current_target_id, target.get("bbox")) + self.target_lost_frames = 0 + target_found = True + + # Store current distance and angle + self.current_distance = target.get("distance") + self.current_angle = target.get("angle") + + # Compute control + control = self._compute_control(self.current_distance, self.current_angle) + return control + + # If not found by ID, try to find by IOU + if not target_found and current_bbox is not None: + best_target = self._find_best_target_by_iou(current_bbox, targets) + if best_target: + # Update target + new_id = best_target.get("target_id") + new_bbox = best_target.get("bbox") + self.current_target = (new_id, new_bbox) + self.target_lost_frames = 0 + logger.info(f"Target ID updated: {current_target_id} -> {new_id}") + + # Store current distance and angle + self.current_distance = best_target.get("distance") + self.current_angle = best_target.get("angle") + + # Compute control + control = self._compute_control(self.current_distance, self.current_angle) + return control + + # Target not found, increment lost counter + if not target_found: + self.target_lost_frames += 1 + logger.warning(f"Target lost: frame {self.target_lost_frames}/{self.max_lost_frames}") + + # Check if target is lost for too many frames + if self.target_lost_frames >= self.max_lost_frames: + logger.info("Target lost for too many frames, stopping tracking") + self.stop_tracking() + return {"linear_vel": 0.0, "angular_vel": 0.0, "running": False} + + return {"linear_vel": 0.0, "angular_vel": 0.0} + + def _compute_control(self, distance: float, angle: float) -> Dict[str, float]: + """ + Compute control commands based on measured distance and angle. + + Args: + distance: Measured distance to target in meters + angle: Measured angle to target in radians + + Returns: + Dict with linear_vel and angular_vel keys + """ + current_time = time.time() + dt = current_time - self.last_control_time + self.last_control_time = current_time + + # Compute control with visual servoing controller + linear_vel, angular_vel = self.controller.compute_control( + measured_distance=distance, + measured_angle=angle, + desired_distance=self.desired_distance, + desired_angle=0.0, # Keep target centered + dt=dt, + ) + + # Log control values for debugging + logger.debug(f"Distance: {distance:.2f}m, Angle: {np.rad2deg(angle):.1f}°") + logger.debug(f"Control: linear={linear_vel:.2f}m/s, angular={angular_vel:.2f}rad/s") + + return {"linear_vel": linear_vel, "angular_vel": angular_vel} + + def _find_best_target_by_iou(self, bbox: List[float], targets: List[Dict]) -> Optional[Dict]: + """ + Find the target with highest IOU to the given bbox. + + Args: + bbox: Bounding box to match [x1, y1, x2, y2] + targets: List of target dictionaries + + Returns: + Best matching target or None if no match found + """ + if not targets: + return None + + best_iou = self.iou_threshold + best_target = None + + for target in targets: + target_bbox = target.get("bbox") + if target_bbox is None: + continue + + iou = calculate_iou(bbox, target_bbox) + if iou > best_iou: + best_iou = iou + best_target = target + + return best_target + + def _find_closest_target(self, targets: List[Dict]) -> Optional[Dict]: + """ + Find the target with shortest distance to the camera. + + Args: + targets: List of target dictionaries + + Returns: + The closest target or None if no targets available + """ + if not targets: + return None + + closest_target = None + min_distance = float("inf") + + for target in targets: + distance = target.get("distance") + if distance is not None and distance < min_distance: + min_distance = distance + closest_target = target + + return closest_target + + def _subscribe_to_tracking_stream(self): + """ + Subscribe to the already set up tracking stream. + """ + if self.tracking_stream is None: + logger.warning("No tracking stream provided to subscribe to") + return + + try: + # Set up subscription to process frames + self.subscription = self.tracking_stream.subscribe( + on_next=self._on_tracking_result, + on_error=self._on_tracking_error, + on_completed=self._on_tracking_completed, + ) + + logger.info("Subscribed to tracking stream successfully") + except Exception as e: + logger.error(f"Error subscribing to tracking stream: {e}") + + def _on_tracking_result(self, result): + """ + Callback for tracking stream results. + + This updates the latest result for use by _get_current_tracking_result. + + Args: + result: The result from the tracking stream + """ + if self.stop_event.is_set(): + return + + # Update the latest result + with self.result_lock: + self.latest_result = result + + def _on_tracking_error(self, error): + """ + Callback for tracking stream errors. + + Args: + error: The error from the tracking stream + """ + logger.error(f"Tracking stream error: {error}") + self.stop_event.set() + + def _on_tracking_completed(self): + """Callback for tracking stream completion.""" + logger.info("Tracking stream completed") + self.stop_event.set() + + def _get_current_tracking_result(self) -> Optional[Dict]: + """ + Get the current tracking result. + + Returns the latest result cached from the tracking stream subscription. + + Returns: + Dict with 'frame' and 'targets' or None if not available + """ + # Return the latest cached result + with self.result_lock: + return self.latest_result + + def stop_tracking(self): + """Stop tracking and reset controller state.""" + self.running = False + self.current_target = None + self.target_lost_frames = 0 + self.current_distance = None + self.current_angle = None + return {"linear_vel": 0.0, "angular_vel": 0.0, "running": False} + + def is_goal_reached(self, distance_threshold=0.2, angle_threshold=0.1) -> bool: + """ + Check if the robot has reached the tracking goal (desired distance and angle). + + Args: + distance_threshold: Maximum allowed difference between current and desired distance (meters) + angle_threshold: Maximum allowed difference between current and desired angle (radians) + + Returns: + bool: True if both distance and angle are within threshold of desired values + """ + if not self.running or self.current_target is None: + return False + + # Use the stored distance and angle values + if self.current_distance is None or self.current_angle is None: + return False + + # Check if within thresholds + distance_error = abs(self.current_distance - self.desired_distance) + angle_error = abs(self.current_angle) # Desired angle is always 0 (centered) + + logger.debug( + f"Goal check - Distance error: {distance_error:.2f}m, Angle error: {angle_error:.2f}rad" + ) + + return (distance_error <= distance_threshold) and (angle_error <= angle_threshold) + + def cleanup(self): + """Clean up all resources used by the visual servoing.""" + self.stop_event.set() + if self.subscription: + self.subscription.dispose() + self.subscription = None + + def __del__(self): + """Destructor to ensure cleanup on object deletion.""" + self.cleanup() diff --git a/build/lib/dimos/robot/__init__.py b/build/lib/dimos/robot/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/robot/connection_interface.py b/build/lib/dimos/robot/connection_interface.py new file mode 100644 index 0000000000..1f327a7939 --- /dev/null +++ b/build/lib/dimos/robot/connection_interface.py @@ -0,0 +1,70 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional +from reactivex.observable import Observable +from dimos.types.vector import Vector + +__all__ = ["ConnectionInterface"] + + +class ConnectionInterface(ABC): + """Abstract base class for robot connection interfaces. + + This class defines the minimal interface that all connection types (ROS, WebRTC, etc.) + must implement to provide robot control and data streaming capabilities. + """ + + @abstractmethod + def move(self, velocity: Vector, duration: float = 0.0) -> bool: + """Send movement command to the robot using velocity commands. + + Args: + velocity: Velocity vector [x, y, yaw] where: + x: Forward/backward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds). If 0, command is continuous + + Returns: + bool: True if command was sent successfully + """ + pass + + @abstractmethod + def get_video_stream(self, fps: int = 30) -> Optional[Observable]: + """Get the video stream from the robot's camera. + + Args: + fps: Frames per second for the video stream + + Returns: + Observable: An observable stream of video frames or None if not available + """ + pass + + @abstractmethod + def stop(self) -> bool: + """Stop the robot's movement. + + Returns: + bool: True if stop command was sent successfully + """ + pass + + @abstractmethod + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + pass diff --git a/build/lib/dimos/robot/foxglove_bridge.py b/build/lib/dimos/robot/foxglove_bridge.py new file mode 100644 index 0000000000..a0374fc251 --- /dev/null +++ b/build/lib/dimos/robot/foxglove_bridge.py @@ -0,0 +1,49 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading + +# this is missing, I'm just trying to import lcm_foxglove_bridge.py from dimos_lcm +import dimos_lcm.lcm_foxglove_bridge as bridge + +from dimos.core import Module, rpc + + +class FoxgloveBridge(Module): + _thread: threading.Thread + _loop: asyncio.AbstractEventLoop + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.start() + + @rpc + def start(self): + def run_bridge(): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + try: + self._loop.run_until_complete(bridge.main()) + except Exception as e: + print(f"Foxglove bridge error: {e}") + + self._thread = threading.Thread(target=run_bridge, daemon=True) + self._thread.start() + + @rpc + def stop(self): + if self._loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join(timeout=2) diff --git a/build/lib/dimos/robot/frontier_exploration/__init__.py b/build/lib/dimos/robot/frontier_exploration/__init__.py new file mode 100644 index 0000000000..2b69011a9f --- /dev/null +++ b/build/lib/dimos/robot/frontier_exploration/__init__.py @@ -0,0 +1 @@ +from utils import * diff --git a/build/lib/dimos/robot/frontier_exploration/qwen_frontier_predictor.py b/build/lib/dimos/robot/frontier_exploration/qwen_frontier_predictor.py new file mode 100644 index 0000000000..10a1d8a265 --- /dev/null +++ b/build/lib/dimos/robot/frontier_exploration/qwen_frontier_predictor.py @@ -0,0 +1,368 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Qwen-based frontier exploration goal predictor using vision language model. + +This module provides a frontier goal detector that uses Qwen's vision capabilities +to analyze costmap images and predict optimal exploration goals. +""" + +import os +import glob +import json +import re +from typing import Optional, List, Tuple + +import numpy as np +from PIL import Image, ImageDraw + +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector +from dimos.models.qwen.video_query import query_single_frame +from dimos.robot.frontier_exploration.utils import ( + costmap_to_pil_image, + smooth_costmap_for_frontiers, +) + + +class QwenFrontierPredictor: + """ + Qwen-based frontier exploration goal predictor. + + Uses Qwen's vision capabilities to analyze costmap images and predict + optimal exploration goals based on visual understanding of the map structure. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "qwen2.5-vl-72b-instruct", + use_smoothed_costmap: bool = True, + image_scale_factor: int = 4, + ): + """ + Initialize the Qwen frontier predictor. + + Args: + api_key: Alibaba API key for Qwen access + model_name: Qwen model to use for predictions + image_scale_factor: Scale factor for image processing + """ + self.api_key = api_key or os.getenv("ALIBABA_API_KEY") + if not self.api_key: + raise ValueError( + "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" + ) + + self.model_name = model_name + self.image_scale_factor = image_scale_factor + self.use_smoothed_costmap = use_smoothed_costmap + + # Storage for previously explored goals + self.explored_goals: List[Vector] = [] + + def _world_to_image_coords(self, world_pos: Vector, costmap: Costmap) -> Tuple[int, int]: + """Convert world coordinates to image pixel coordinates.""" + grid_pos = costmap.world_to_grid(world_pos) + img_x = int(grid_pos.x * self.image_scale_factor) + img_y = int((costmap.height - grid_pos.y) * self.image_scale_factor) # Flip Y + return img_x, img_y + + def _image_to_world_coords(self, img_x: int, img_y: int, costmap: Costmap) -> Vector: + """Convert image pixel coordinates to world coordinates.""" + # Unscale and flip Y coordinate + grid_x = img_x / self.image_scale_factor + grid_y = costmap.height - (img_y / self.image_scale_factor) + + # Convert grid to world coordinates + world_pos = costmap.grid_to_world(Vector([grid_x, grid_y])) + return world_pos + + def _draw_goals_on_image( + self, + image: Image.Image, + robot_pose: Vector, + costmap: Costmap, + latest_goal: Optional[Vector] = None, + ) -> Image.Image: + """ + Draw explored goals and robot position on the costmap image. + + Args: + image: PIL Image to draw on + robot_pose: Current robot position + costmap: Costmap for coordinate conversion + latest_goal: Latest predicted goal to highlight in red + + Returns: + PIL Image with goals drawn + """ + img_copy = image.copy() + draw = ImageDraw.Draw(img_copy) + + # Draw previously explored goals as green dots + for explored_goal in self.explored_goals: + x, y = self._world_to_image_coords(explored_goal, costmap) + radius = 8 + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(0, 255, 0), + outline=(0, 128, 0), + width=2, + ) + + # Draw robot position as blue dot + robot_x, robot_y = self._world_to_image_coords(robot_pose, costmap) + robot_radius = 10 + draw.ellipse( + [ + robot_x - robot_radius, + robot_y - robot_radius, + robot_x + robot_radius, + robot_y + robot_radius, + ], + fill=(0, 0, 255), + outline=(0, 0, 128), + width=3, + ) + + # Draw latest predicted goal as red dot + if latest_goal: + goal_x, goal_y = self._world_to_image_coords(latest_goal, costmap) + goal_radius = 12 + draw.ellipse( + [ + goal_x - goal_radius, + goal_y - goal_radius, + goal_x + goal_radius, + goal_y + goal_radius, + ], + fill=(255, 0, 0), + outline=(128, 0, 0), + width=3, + ) + + return img_copy + + def _create_vision_prompt(self) -> str: + """Create the vision prompt for Qwen model.""" + prompt = """You are an expert robot navigation system analyzing a costmap for frontier exploration. + +COSTMAP LEGEND: +- Light gray pixels (205,205,205): FREE SPACE - areas the robot can navigate +- Dark gray pixels (128,128,128): UNKNOWN SPACE - unexplored areas that need exploration +- Black pixels (0,0,0): OBSTACLES - walls, furniture, blocked areas +- Blue dot: CURRENT ROBOT POSITION +- Green dots: PREVIOUSLY EXPLORED GOALS - avoid these areas + +TASK: Find the best frontier exploration goal by identifying the optimal point where: +1. It's at the boundary between FREE SPACE (light gray) and UNKNOWN SPACE (dark gray) (HIGHEST Priority) +2. It's reasonably far from the robot position (blue dot) (MEDIUM Priority) +3. It's reasonably far from previously explored goals (green dots) (MEDIUM Priority) +4. It leads to a large area of unknown space to explore (HIGH Priority) +5. It's accessible from the robot's current position through free space (MEDIUM Priority) +6. It's not near or on obstacles (HIGHEST Priority) + +RESPONSE FORMAT: Return ONLY the pixel coordinates as a JSON object: +{"x": pixel_x_coordinate, "y": pixel_y_coordinate, "reasoning": "brief explanation"} + +Example: {"x": 245, "y": 187, "reasoning": "Large unknown area to the north, good distance from robot and previous goals"} + +Analyze the image and identify the single best frontier exploration goal.""" + + return prompt + + def _parse_prediction_response(self, response: str) -> Optional[Tuple[int, int, str]]: + """ + Parse the model's response to extract coordinates and reasoning. + + Args: + response: Raw response from Qwen model + + Returns: + Tuple of (x, y, reasoning) or None if parsing failed + """ + try: + # Try to find JSON object in response + json_match = re.search(r"\{[^}]*\}", response) + if json_match: + json_str = json_match.group() + data = json.loads(json_str) + + if "x" in data and "y" in data: + x = int(data["x"]) + y = int(data["y"]) + reasoning = data.get("reasoning", "No reasoning provided") + return (x, y, reasoning) + + # Fallback: try to extract coordinates with regex + coord_match = re.search(r"[^\d]*(\d+)[^\d]+(\d+)", response) + if coord_match: + x = int(coord_match.group(1)) + y = int(coord_match.group(2)) + return (x, y, "Coordinates extracted from response") + + except (json.JSONDecodeError, ValueError, KeyError) as e: + print(f"DEBUG: Failed to parse prediction response: {e}") + + return None + + def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional[Vector]: + """ + Get the best exploration goal using Qwen vision analysis. + + Args: + robot_pose: Current robot position in world coordinates + costmap: Current costmap for analysis + + Returns: + Single best frontier goal in world coordinates, or None if no suitable goal found + """ + print( + f"DEBUG: Qwen frontier prediction starting with {len(self.explored_goals)} explored goals" + ) + + # Create costmap image + if self.use_smoothed_costmap: + costmap = smooth_costmap_for_frontiers(costmap, alpha=4.0) + + base_image = costmap_to_pil_image(costmap, self.image_scale_factor) + + # Draw goals on image (without latest goal initially) + annotated_image = self._draw_goals_on_image(base_image, robot_pose, costmap) + + # Query Qwen model for frontier prediction + try: + prompt = self._create_vision_prompt() + response = query_single_frame( + annotated_image, prompt, api_key=self.api_key, model_name=self.model_name + ) + + print(f"DEBUG: Qwen response: {response}") + + # Parse response to get coordinates + parsed_result = self._parse_prediction_response(response) + if not parsed_result: + print("DEBUG: Failed to parse Qwen response") + return None + + img_x, img_y, reasoning = parsed_result + print(f"DEBUG: Parsed coordinates: ({img_x}, {img_y}), Reasoning: {reasoning}") + + # Convert image coordinates to world coordinates + predicted_goal = self._image_to_world_coords(img_x, img_y, costmap) + print( + f"DEBUG: Predicted goal in world coordinates: ({predicted_goal.x:.2f}, {predicted_goal.y:.2f})" + ) + + # Store the goal in explored_goals for future reference + self.explored_goals.append(predicted_goal) + + print(f"DEBUG: Successfully predicted frontier goal: {predicted_goal}") + return predicted_goal + + except Exception as e: + print(f"DEBUG: Error during Qwen prediction: {e}") + return None + + +def test_qwen_frontier_detection(): + """ + Visual test for Qwen frontier detection using saved costmaps. + Shows frontier detection results with Qwen predictions. + """ + + # Path to saved costmaps + saved_maps_dir = os.path.join(os.getcwd(), "assets", "saved_maps") + + if not os.path.exists(saved_maps_dir): + print(f"Error: Saved maps directory not found: {saved_maps_dir}") + return + + # Get all pickle files + pickle_files = sorted(glob.glob(os.path.join(saved_maps_dir, "*.pickle"))) + + if not pickle_files: + print(f"No pickle files found in {saved_maps_dir}") + return + + print(f"Found {len(pickle_files)} costmap files for Qwen testing") + + # Initialize Qwen frontier predictor + predictor = QwenFrontierPredictor(image_scale_factor=4, use_smoothed_costmap=False) + + # Track the robot pose across iterations + robot_pose = None + + # Process each costmap + for i, pickle_file in enumerate(pickle_files): + print( + f"\n--- Processing costmap {i + 1}/{len(pickle_files)}: {os.path.basename(pickle_file)} ---" + ) + + try: + # Load the costmap + costmap = Costmap.from_pickle(pickle_file) + print( + f"Loaded costmap: {costmap.width}x{costmap.height}, resolution: {costmap.resolution}" + ) + + # Set robot pose: first iteration uses center, subsequent use last predicted goal + if robot_pose is None: + # First iteration: use center of costmap as robot position + center_world = costmap.grid_to_world( + Vector([costmap.width / 2, costmap.height / 2]) + ) + robot_pose = Vector([center_world.x, center_world.y]) + # else: robot_pose remains the last predicted goal + + print(f"Using robot position: {robot_pose}") + + # Get frontier prediction from Qwen + print("Getting Qwen frontier prediction...") + predicted_goal = predictor.get_exploration_goal(robot_pose, costmap) + + if predicted_goal: + distance = np.sqrt( + (predicted_goal.x - robot_pose.x) ** 2 + (predicted_goal.y - robot_pose.y) ** 2 + ) + print(f"Predicted goal: {predicted_goal}, Distance: {distance:.2f}m") + + # Show the final visualization + base_image = costmap_to_pil_image(costmap, predictor.image_scale_factor) + final_image = predictor._draw_goals_on_image( + base_image, robot_pose, costmap, predicted_goal + ) + + # Display image + title = f"Qwen Frontier Prediction {i + 1:04d}" + final_image.show(title=title) + + # Update robot pose for next iteration + robot_pose = predicted_goal + + else: + print("No suitable frontier goal predicted by Qwen") + + except Exception as e: + print(f"Error processing {pickle_file}: {e}") + continue + + print(f"\n=== Qwen Frontier Detection Test Complete ===") + print(f"Final explored goals count: {len(predictor.explored_goals)}") + + +if __name__ == "__main__": + test_qwen_frontier_detection() diff --git a/build/lib/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py b/build/lib/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py new file mode 100644 index 0000000000..c9b75b28d8 --- /dev/null +++ b/build/lib/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -0,0 +1,297 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import numpy as np +import pytest +from PIL import Image, ImageDraw +from reactivex import operators as ops + +from dimos.robot.frontier_exploration.utils import costmap_to_pil_image +from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( + WavefrontFrontierExplorer, +) +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.types.vector import Vector +from dimos.utils.testing import SensorReplay + + +def get_office_lidar_costmap(take_frames: int = 1, voxel_size: float = 0.5) -> tuple: + """ + Get a costmap from office_lidar data using SensorReplay. + + Args: + take_frames: Number of lidar frames to take (default 1) + voxel_size: Voxel size for map construction + + Returns: + Tuple of (costmap, first_lidar_message) for testing + """ + # Load office lidar data using SensorReplay as documented + lidar_stream = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + # Create map with specified voxel size + map_obj = Map(voxel_size=voxel_size) + + # Take only the specified number of frames and build map + limited_stream = lidar_stream.stream().pipe(ops.take(take_frames)) + + # Store the first lidar message for reference + first_lidar = None + + def capture_first_and_add(lidar_msg): + nonlocal first_lidar + if first_lidar is None: + first_lidar = lidar_msg + return map_obj.add_frame(lidar_msg) + + # Process the stream + limited_stream.pipe(ops.map(capture_first_and_add)).run() + + # Get the resulting costmap + costmap = map_obj.costmap() + + return costmap, first_lidar + + +def test_frontier_detection_with_office_lidar(): + """Test frontier detection using a single frame from office_lidar data.""" + # Get costmap from office lidar data + costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + + # Verify we have a valid costmap + assert costmap is not None, "Costmap should not be None" + assert costmap.width > 0 and costmap.height > 0, "Costmap should have valid dimensions" + + print(f"Costmap dimensions: {costmap.width}x{costmap.height}") + print(f"Costmap resolution: {costmap.resolution}") + print(f"Unknown percent: {costmap.unknown_percent:.1f}%") + print(f"Free percent: {costmap.free_percent:.1f}%") + print(f"Occupied percent: {costmap.occupied_percent:.1f}%") + + # Initialize frontier explorer with default parameters + explorer = WavefrontFrontierExplorer() + + # Set robot pose near the center of free space in the costmap + # We'll use the lidar origin as a reasonable robot position + robot_pose = first_lidar.origin + print(f"Robot pose: {robot_pose}") + + # Detect frontiers + frontiers = explorer.detect_frontiers(robot_pose, costmap) + + # Verify frontier detection results + assert isinstance(frontiers, list), "Frontiers should be returned as a list" + print(f"Detected {len(frontiers)} frontiers") + + # Test that we get some frontiers (office environment should have unexplored areas) + if len(frontiers) > 0: + print("Frontier detection successful - found unexplored areas") + + # Verify frontiers are Vector objects with valid coordinates + for i, frontier in enumerate(frontiers[:5]): # Check first 5 + assert isinstance(frontier, Vector), f"Frontier {i} should be a Vector" + assert hasattr(frontier, "x") and hasattr(frontier, "y"), ( + f"Frontier {i} should have x,y coordinates" + ) + print(f" Frontier {i}: ({frontier.x:.2f}, {frontier.y:.2f})") + else: + print("No frontiers detected - map may be fully explored or parameters too restrictive") + + +def test_exploration_goal_selection(): + """Test the complete exploration goal selection pipeline.""" + # Get costmap from office lidar data + costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + + # Initialize frontier explorer with default parameters + explorer = WavefrontFrontierExplorer() + + # Use lidar origin as robot position + robot_pose = first_lidar.origin + + # Get exploration goal + goal = explorer.get_exploration_goal(robot_pose, costmap) + + if goal is not None: + assert isinstance(goal, Vector), "Goal should be a Vector" + print(f"Selected exploration goal: ({goal.x:.2f}, {goal.y:.2f})") + + # Verify goal is at reasonable distance from robot + distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) + print(f"Goal distance from robot: {distance:.2f}m") + assert distance >= explorer.min_distance_from_robot, ( + "Goal should respect minimum distance from robot" + ) + + # Test that goal gets marked as explored + assert len(explorer.explored_goals) == 1, "Goal should be marked as explored" + assert explorer.explored_goals[0] == goal, "Explored goal should match selected goal" + + else: + print("No exploration goal selected - map may be fully explored") + + +def test_exploration_session_reset(): + """Test exploration session reset functionality.""" + # Get costmap + costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) + + # Initialize explorer and select a goal + explorer = WavefrontFrontierExplorer() + robot_pose = first_lidar.origin + + # Select a goal to populate exploration state + goal = explorer.get_exploration_goal(robot_pose, costmap) + + # Verify state is populated + initial_explored_count = len(explorer.explored_goals) + initial_direction = explorer.exploration_direction + + # Reset exploration session + explorer.reset_exploration_session() + + # Verify state is cleared + assert len(explorer.explored_goals) == 0, "Explored goals should be cleared after reset" + assert explorer.exploration_direction.x == 0.0 and explorer.exploration_direction.y == 0.0, ( + "Exploration direction should be reset" + ) + assert explorer.last_costmap is None, "Last costmap should be cleared" + assert explorer.num_no_gain_attempts == 0, "No-gain attempts should be reset" + + print("Exploration session reset successfully") + + +@pytest.mark.vis +def test_frontier_detection_visualization(): + """Test frontier detection with visualization (marked with @pytest.mark.vis).""" + # Get costmap from office lidar data + costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.2) + + # Initialize frontier explorer with default parameters + explorer = WavefrontFrontierExplorer() + + # Use lidar origin as robot position + robot_pose = first_lidar.origin + + # Detect all frontiers for visualization + all_frontiers = explorer.detect_frontiers(robot_pose, costmap) + + # Get selected goal + selected_goal = explorer.get_exploration_goal(robot_pose, costmap) + + print(f"Visualizing {len(all_frontiers)} frontier candidates") + if selected_goal: + print(f"Selected goal: ({selected_goal.x:.2f}, {selected_goal.y:.2f})") + + # Create visualization + image_scale_factor = 4 + base_image = costmap_to_pil_image(costmap, image_scale_factor) + + # Helper function to convert world coordinates to image coordinates + def world_to_image_coords(world_pos: Vector) -> tuple[int, int]: + grid_pos = costmap.world_to_grid(world_pos) + img_x = int(grid_pos.x * image_scale_factor) + img_y = int((costmap.height - grid_pos.y) * image_scale_factor) # Flip Y + return img_x, img_y + + # Draw visualization + draw = ImageDraw.Draw(base_image) + + # Draw frontier candidates as gray dots + for frontier in all_frontiers[:20]: # Limit to top 20 + x, y = world_to_image_coords(frontier) + radius = 6 + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(128, 128, 128), # Gray + outline=(64, 64, 64), + width=1, + ) + + # Draw robot position as blue dot + robot_x, robot_y = world_to_image_coords(robot_pose) + robot_radius = 10 + draw.ellipse( + [ + robot_x - robot_radius, + robot_y - robot_radius, + robot_x + robot_radius, + robot_y + robot_radius, + ], + fill=(0, 0, 255), # Blue + outline=(0, 0, 128), + width=3, + ) + + # Draw selected goal as red dot + if selected_goal: + goal_x, goal_y = world_to_image_coords(selected_goal) + goal_radius = 12 + draw.ellipse( + [ + goal_x - goal_radius, + goal_y - goal_radius, + goal_x + goal_radius, + goal_y + goal_radius, + ], + fill=(255, 0, 0), # Red + outline=(128, 0, 0), + width=3, + ) + + # Display the image + base_image.show(title="Frontier Detection - Office Lidar") + + print("Visualization displayed. Close the image window to continue.") + + +def test_multi_frame_exploration(): + """Tool test for multi-frame exploration analysis.""" + print("=== Multi-Frame Exploration Analysis ===") + + # Test with different numbers of frames + frame_counts = [1, 3, 5] + + for frame_count in frame_counts: + print(f"\n--- Testing with {frame_count} lidar frame(s) ---") + + # Get costmap with multiple frames + costmap, first_lidar = get_office_lidar_costmap(take_frames=frame_count, voxel_size=0.3) + + print( + f"Costmap: {costmap.width}x{costmap.height}, " + f"unknown: {costmap.unknown_percent:.1f}%, " + f"free: {costmap.free_percent:.1f}%, " + f"occupied: {costmap.occupied_percent:.1f}%" + ) + + # Initialize explorer with default parameters + explorer = WavefrontFrontierExplorer() + + # Detect frontiers + robot_pose = first_lidar.origin + frontiers = explorer.detect_frontiers(robot_pose, costmap) + + print(f"Detected {len(frontiers)} frontiers") + + # Get exploration goal + goal = explorer.get_exploration_goal(robot_pose, costmap) + if goal: + distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) + print(f"Selected goal at distance {distance:.2f}m") + else: + print("No exploration goal selected") diff --git a/build/lib/dimos/robot/frontier_exploration/utils.py b/build/lib/dimos/robot/frontier_exploration/utils.py new file mode 100644 index 0000000000..746f72e2f5 --- /dev/null +++ b/build/lib/dimos/robot/frontier_exploration/utils.py @@ -0,0 +1,188 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for frontier exploration visualization and testing. +""" + +import numpy as np +from PIL import Image, ImageDraw +from typing import List, Tuple +from dimos.types.costmap import Costmap, CostValues +from dimos.types.vector import Vector +import os +import pickle +import cv2 + + +def costmap_to_pil_image(costmap: Costmap, scale_factor: int = 2) -> Image.Image: + """ + Convert costmap to PIL Image with ROS-style coloring and optional scaling. + + Args: + costmap: Costmap to convert + scale_factor: Factor to scale up the image for better visibility + + Returns: + PIL Image with ROS-style colors + """ + # Create image array (height, width, 3 for RGB) + img_array = np.zeros((costmap.height, costmap.width, 3), dtype=np.uint8) + + # Apply ROS-style coloring based on costmap values + for i in range(costmap.height): + for j in range(costmap.width): + value = costmap.grid[i, j] + if value == CostValues.FREE: # Free space = light grey + img_array[i, j] = [205, 205, 205] + elif value == CostValues.UNKNOWN: # Unknown = dark gray + img_array[i, j] = [128, 128, 128] + elif value >= CostValues.OCCUPIED: # Occupied/obstacles = black + img_array[i, j] = [0, 0, 0] + else: # Any other values (low cost) = light grey + img_array[i, j] = [205, 205, 205] + + # Flip vertically to match ROS convention (origin at bottom-left) + img_array = np.flipud(img_array) + + # Create PIL image + img = Image.fromarray(img_array, "RGB") + + # Scale up if requested + if scale_factor > 1: + new_size = (img.width * scale_factor, img.height * scale_factor) + img = img.resize(new_size, Image.NEAREST) # Use NEAREST to keep sharp pixels + + return img + + +def draw_frontiers_on_image( + image: Image.Image, + costmap: Costmap, + frontiers: List[Vector], + scale_factor: int = 2, + unfiltered_frontiers: List[Vector] = None, +) -> Image.Image: + """ + Draw frontier points on the costmap image. + + Args: + image: PIL Image to draw on + costmap: Original costmap for coordinate conversion + frontiers: List of frontier centroids (top 5) + scale_factor: Scaling factor used for the image + unfiltered_frontiers: All unfiltered frontier results (light green) + + Returns: + PIL Image with frontiers drawn + """ + img_copy = image.copy() + draw = ImageDraw.Draw(img_copy) + + def world_to_image_coords(world_pos: Vector) -> Tuple[int, int]: + """Convert world coordinates to image pixel coordinates.""" + grid_pos = costmap.world_to_grid(world_pos) + # Flip Y coordinate and apply scaling + img_x = int(grid_pos.x * scale_factor) + img_y = int((costmap.height - grid_pos.y) * scale_factor) # Flip Y + return img_x, img_y + + # Draw all unfiltered frontiers as light green circles + if unfiltered_frontiers: + for frontier in unfiltered_frontiers: + x, y = world_to_image_coords(frontier) + radius = 3 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(144, 238, 144), + outline=(144, 238, 144), + ) # Light green + + # Draw top 5 frontiers as green circles + for i, frontier in enumerate(frontiers[1:]): # Skip the best one for now + x, y = world_to_image_coords(frontier) + radius = 4 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(0, 255, 0), + outline=(0, 128, 0), + width=2, + ) # Green + + # Add number label + draw.text((x + radius + 2, y - radius), str(i + 2), fill=(0, 255, 0)) + + # Draw best frontier as red circle + if frontiers: + best_frontier = frontiers[0] + x, y = world_to_image_coords(best_frontier) + radius = 6 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(255, 0, 0), + outline=(128, 0, 0), + width=3, + ) # Red + + # Add "BEST" label + draw.text((x + radius + 2, y - radius), "BEST", fill=(255, 0, 0)) + + return img_copy + + +def smooth_costmap_for_frontiers( + costmap: Costmap, +) -> Costmap: + """ + Smooth a costmap using morphological operations for frontier exploration. + + This function applies OpenCV morphological operations to smooth free space + areas and improve connectivity for better frontier detection. It's designed + specifically for frontier exploration. + + Args: + costmap: Input Costmap object + + Returns: + Smoothed Costmap object with enhanced free space connectivity + """ + # Extract grid data and metadata from costmap + grid = costmap.grid + resolution = costmap.resolution + + # Work with a copy to avoid modifying input + filtered_grid = grid.copy() + + # 1. Create binary mask for free space + free_mask = (grid == CostValues.FREE).astype(np.uint8) * 255 + + # 2. Apply morphological operations for smoothing + kernel_size = 7 + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + + # Dilate free space to connect nearby areas + dilated = cv2.dilate(free_mask, kernel, iterations=1) + + # Morphological closing to fill small gaps + closed = cv2.morphologyEx(dilated, cv2.MORPH_CLOSE, kernel, iterations=1) + + eroded = cv2.erode(closed, kernel, iterations=1) + + # Apply the smoothed free space back to costmap + # Only change unknown areas to free, don't override obstacles + smoothed_free = eroded == 255 + unknown_mask = grid == CostValues.UNKNOWN + filtered_grid[smoothed_free & unknown_mask] = CostValues.FREE + + return Costmap(grid=filtered_grid, origin=costmap.origin, resolution=resolution) diff --git a/build/lib/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py b/build/lib/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py new file mode 100644 index 0000000000..76f2ddbb0a --- /dev/null +++ b/build/lib/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py @@ -0,0 +1,665 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple wavefront frontier exploration algorithm implementation using dimos types. + +This module provides frontier detection and exploration goal selection +for autonomous navigation using the dimos Costmap and Vector types. +""" + +from typing import List, Tuple, Optional, Callable +from collections import deque +import numpy as np +from dataclasses import dataclass +from enum import IntFlag +import threading +from dimos.utils.logging_config import setup_logger + +from dimos.types.costmap import Costmap, CostValues +from dimos.types.vector import Vector +from dimos.robot.frontier_exploration.utils import smooth_costmap_for_frontiers + +logger = setup_logger("dimos.robot.unitree.frontier_exploration") + + +class PointClassification(IntFlag): + """Point classification flags for frontier detection algorithm.""" + + NoInformation = 0 + MapOpen = 1 + MapClosed = 2 + FrontierOpen = 4 + FrontierClosed = 8 + + +@dataclass +class GridPoint: + """Represents a point in the grid map with classification.""" + + x: int + y: int + classification: int = PointClassification.NoInformation + + +class FrontierCache: + """Cache for grid points to avoid duplicate point creation.""" + + def __init__(self): + self.points = {} + + def get_point(self, x: int, y: int) -> GridPoint: + """Get or create a grid point at the given coordinates.""" + key = (x, y) + if key not in self.points: + self.points[key] = GridPoint(x, y) + return self.points[key] + + def clear(self): + """Clear the point cache.""" + self.points.clear() + + +class WavefrontFrontierExplorer: + """ + Wavefront frontier exploration algorithm implementation. + + This class encapsulates the frontier detection and exploration goal selection + functionality using the wavefront algorithm with BFS exploration. + """ + + def __init__( + self, + min_frontier_size: int = 10, + occupancy_threshold: int = 65, + subsample_resolution: int = 2, + min_distance_from_robot: float = 0.5, + explored_area_buffer: float = 0.5, + min_distance_from_obstacles: float = 0.6, + info_gain_threshold: float = 0.03, + num_no_gain_attempts: int = 4, + set_goal: Optional[Callable] = None, + get_costmap: Optional[Callable] = None, + get_robot_pos: Optional[Callable] = None, + ): + """ + Initialize the frontier explorer. + + Args: + min_frontier_size: Minimum number of points to consider a valid frontier + occupancy_threshold: Cost threshold above which a cell is considered occupied (0-255) + subsample_resolution: Factor by which to subsample the costmap for faster processing (1=no subsampling, 2=half resolution, 4=quarter resolution) + min_distance_from_robot: Minimum distance frontier must be from robot (meters) + explored_area_buffer: Buffer distance around free areas to consider as explored (meters) + min_distance_from_obstacles: Minimum distance frontier must be from obstacles (meters) + info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) + num_no_gain_attempts: Maximum number of consecutive attempts with no information gain + set_goal: Callable to set navigation goal, signature: (goal: Vector, stop_event: Optional[threading.Event]) -> bool + get_costmap: Callable to get current costmap, signature: () -> Costmap + get_robot_pos: Callable to get current robot position, signature: () -> Vector + """ + self.min_frontier_size = min_frontier_size + self.occupancy_threshold = occupancy_threshold + self.subsample_resolution = subsample_resolution + self.min_distance_from_robot = min_distance_from_robot + self.explored_area_buffer = explored_area_buffer + self.min_distance_from_obstacles = min_distance_from_obstacles + self.info_gain_threshold = info_gain_threshold + self.num_no_gain_attempts = num_no_gain_attempts + self.set_goal = set_goal + self.get_costmap = get_costmap + self.get_robot_pos = get_robot_pos + self._cache = FrontierCache() + self.explored_goals = [] # list of explored goals + self.exploration_direction = Vector([0.0, 0.0]) # current exploration direction + self.last_costmap = None # store last costmap for information comparison + + def _count_costmap_information(self, costmap: Costmap) -> int: + """ + Count the amount of information in a costmap (free space + obstacles). + + Args: + costmap: Costmap to analyze + + Returns: + Number of cells that are free space or obstacles (not unknown) + """ + free_count = np.sum(costmap.grid == CostValues.FREE) + obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) + return int(free_count + obstacle_count) + + def _get_neighbors(self, point: GridPoint, costmap: Costmap) -> List[GridPoint]: + """Get valid neighboring points for a given grid point.""" + neighbors = [] + + # 8-connected neighbors + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx == 0 and dy == 0: + continue + + nx, ny = point.x + dx, point.y + dy + + # Check bounds + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + neighbors.append(self._cache.get_point(nx, ny)) + + return neighbors + + def _is_frontier_point(self, point: GridPoint, costmap: Costmap) -> bool: + """ + Check if a point is a frontier point. + A frontier point is an unknown cell adjacent to at least one free cell + and not adjacent to any occupied cells. + """ + # Point must be unknown + world_pos = costmap.grid_to_world(Vector([float(point.x), float(point.y)])) + cost = costmap.get_value(world_pos) + if cost != CostValues.UNKNOWN: + return False + + has_free = False + + for neighbor in self._get_neighbors(point, costmap): + neighbor_world = costmap.grid_to_world(Vector([float(neighbor.x), float(neighbor.y)])) + neighbor_cost = costmap.get_value(neighbor_world) + + # If adjacent to occupied space, not a frontier + if neighbor_cost and neighbor_cost > self.occupancy_threshold: + return False + + # Check if adjacent to free space + if neighbor_cost == CostValues.FREE: + has_free = True + + return has_free + + def _find_free_space(self, start_x: int, start_y: int, costmap: Costmap) -> Tuple[int, int]: + """ + Find the nearest free space point using BFS from the starting position. + """ + queue = deque([self._cache.get_point(start_x, start_y)]) + visited = set() + + while queue: + point = queue.popleft() + + if (point.x, point.y) in visited: + continue + visited.add((point.x, point.y)) + + # Check if this point is free space + world_pos = costmap.grid_to_world(Vector([float(point.x), float(point.y)])) + if costmap.get_value(world_pos) == CostValues.FREE: + return (point.x, point.y) + + # Add neighbors to search + for neighbor in self._get_neighbors(point, costmap): + if (neighbor.x, neighbor.y) not in visited: + queue.append(neighbor) + + # If no free space found, return original position + return (start_x, start_y) + + def _compute_centroid(self, frontier_points: List[Vector]) -> Vector: + """Compute the centroid of a list of frontier points.""" + if not frontier_points: + return Vector([0.0, 0.0]) + + # Vectorized approach using numpy + points_array = np.array([[point.x, point.y] for point in frontier_points]) + centroid = np.mean(points_array, axis=0) + + return Vector([centroid[0], centroid[1]]) + + def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector]: + """ + Main frontier detection algorithm using wavefront exploration. + + Args: + robot_pose: Current robot position in world coordinates (Vector with x, y) + costmap: Costmap for additional analysis + + Returns: + List of frontier centroids in world coordinates + """ + self._cache.clear() + + # Apply filtered costmap (now default) + working_costmap = smooth_costmap_for_frontiers(costmap) + + # Subsample the costmap for faster processing + if self.subsample_resolution > 1: + subsampled_costmap = working_costmap.subsample(self.subsample_resolution) + else: + subsampled_costmap = working_costmap + + # Convert robot pose to subsampled grid coordinates + subsampled_grid_pos = subsampled_costmap.world_to_grid(robot_pose) + grid_x, grid_y = int(subsampled_grid_pos.x), int(subsampled_grid_pos.y) + + # Find nearest free space to start exploration + free_x, free_y = self._find_free_space(grid_x, grid_y, subsampled_costmap) + start_point = self._cache.get_point(free_x, free_y) + start_point.classification = PointClassification.MapOpen + + # Main exploration queue - explore ALL reachable free space + map_queue = deque([start_point]) + frontiers = [] + frontier_sizes = [] + + points_checked = 0 + frontier_candidates = 0 + + while map_queue: + current_point = map_queue.popleft() + points_checked += 1 + + # Skip if already processed + if current_point.classification & PointClassification.MapClosed: + continue + + # Mark as processed + current_point.classification |= PointClassification.MapClosed + + # Check if this point starts a new frontier + if self._is_frontier_point(current_point, subsampled_costmap): + frontier_candidates += 1 + current_point.classification |= PointClassification.FrontierOpen + frontier_queue = deque([current_point]) + new_frontier = [] + + # Explore this frontier region using BFS + while frontier_queue: + frontier_point = frontier_queue.popleft() + + # Skip if already processed + if frontier_point.classification & PointClassification.FrontierClosed: + continue + + # If this is still a frontier point, add to current frontier + if self._is_frontier_point(frontier_point, subsampled_costmap): + new_frontier.append(frontier_point) + + # Add neighbors to frontier queue + for neighbor in self._get_neighbors(frontier_point, subsampled_costmap): + if not ( + neighbor.classification + & ( + PointClassification.FrontierOpen + | PointClassification.FrontierClosed + ) + ): + neighbor.classification |= PointClassification.FrontierOpen + frontier_queue.append(neighbor) + + frontier_point.classification |= PointClassification.FrontierClosed + + # Check if we found a large enough frontier + if len(new_frontier) >= self.min_frontier_size: + world_points = [] + for point in new_frontier: + world_pos = subsampled_costmap.grid_to_world( + Vector([float(point.x), float(point.y)]) + ) + world_points.append(world_pos) + + # Compute centroid in world coordinates (already correctly scaled) + centroid = self._compute_centroid(world_points) + frontiers.append(centroid) # Store centroid + frontier_sizes.append(len(new_frontier)) # Store frontier size + + # Add ALL neighbors to main exploration queue to explore entire free space + for neighbor in self._get_neighbors(current_point, subsampled_costmap): + if not ( + neighbor.classification + & (PointClassification.MapOpen | PointClassification.MapClosed) + ): + # Check if neighbor is free space or unknown (explorable) + neighbor_world = subsampled_costmap.grid_to_world( + Vector([float(neighbor.x), float(neighbor.y)]) + ) + neighbor_cost = subsampled_costmap.get_value(neighbor_world) + + # Add free space and unknown space to exploration queue + if neighbor_cost is not None and ( + neighbor_cost == CostValues.FREE or neighbor_cost == CostValues.UNKNOWN + ): + neighbor.classification |= PointClassification.MapOpen + map_queue.append(neighbor) + + # Extract just the centroids for ranking + frontier_centroids = frontiers + + if not frontier_centroids: + return [] + + # Rank frontiers using original costmap for proper filtering + ranked_frontiers = self._rank_frontiers( + frontier_centroids, frontier_sizes, robot_pose, costmap + ) + + return ranked_frontiers + + def _update_exploration_direction(self, robot_pose: Vector, goal_pose: Optional[Vector] = None): + """Update the current exploration direction based on robot movement or selected goal.""" + if goal_pose is not None: + # Calculate direction from robot to goal + direction = Vector([goal_pose.x - robot_pose.x, goal_pose.y - robot_pose.y]) + magnitude = np.sqrt(direction.x**2 + direction.y**2) + if magnitude > 0.1: # Avoid division by zero for very close goals + self.exploration_direction = Vector( + [direction.x / magnitude, direction.y / magnitude] + ) + + def _compute_direction_momentum_score(self, frontier: Vector, robot_pose: Vector) -> float: + """Compute direction momentum score for a frontier.""" + if self.exploration_direction.x == 0 and self.exploration_direction.y == 0: + return 0.0 # No momentum if no previous direction + + # Calculate direction from robot to frontier + frontier_direction = Vector([frontier.x - robot_pose.x, frontier.y - robot_pose.y]) + magnitude = np.sqrt(frontier_direction.x**2 + frontier_direction.y**2) + + if magnitude < 0.1: + return 0.0 # Too close to calculate meaningful direction + + # Normalize frontier direction + frontier_direction = Vector( + [frontier_direction.x / magnitude, frontier_direction.y / magnitude] + ) + + # Calculate dot product for directional alignment + dot_product = ( + self.exploration_direction.x * frontier_direction.x + + self.exploration_direction.y * frontier_direction.y + ) + + # Return momentum score (higher for same direction, lower for opposite) + return max(0.0, dot_product) # Only positive momentum, no penalty for different directions + + def _compute_distance_to_explored_goals(self, frontier: Vector) -> float: + """Compute distance from frontier to the nearest explored goal.""" + if not self.explored_goals: + return 5.0 # Default consistent value when no explored goals + # Calculate distance to nearest explored goal + min_distance = float("inf") + for goal in self.explored_goals: + distance = np.sqrt((frontier.x - goal.x) ** 2 + (frontier.y - goal.y) ** 2) + min_distance = min(min_distance, distance) + + return min_distance + + def _compute_distance_to_obstacles(self, frontier: Vector, costmap: Costmap) -> float: + """ + Compute the minimum distance from a frontier point to the nearest obstacle. + + Args: + frontier: Frontier point in world coordinates + costmap: Costmap to check for obstacles + + Returns: + Minimum distance to nearest obstacle in meters + """ + # Convert frontier to grid coordinates + grid_pos = costmap.world_to_grid(frontier) + grid_x, grid_y = int(grid_pos.x), int(grid_pos.y) + + # Check if frontier is within costmap bounds + if grid_x < 0 or grid_x >= costmap.width or grid_y < 0 or grid_y >= costmap.height: + return 0.0 # Consider out-of-bounds as obstacle + + min_distance = float("inf") + search_radius = ( + int(self.min_distance_from_obstacles / costmap.resolution) + 5 + ) # Search a bit beyond minimum + + # Search in a square around the frontier point + for dy in range(-search_radius, search_radius + 1): + for dx in range(-search_radius, search_radius + 1): + check_x = grid_x + dx + check_y = grid_y + dy + + # Skip if out of bounds + if ( + check_x < 0 + or check_x >= costmap.width + or check_y < 0 + or check_y >= costmap.height + ): + continue + + # Check if this cell is an obstacle + if costmap.grid[check_y, check_x] >= self.occupancy_threshold: + # Calculate distance in meters + distance = np.sqrt(dx**2 + dy**2) * costmap.resolution + min_distance = min(min_distance, distance) + + return min_distance if min_distance != float("inf") else float("inf") + + def _compute_comprehensive_frontier_score( + self, frontier: Vector, frontier_size: int, robot_pose: Vector, costmap: Costmap + ) -> float: + """Compute comprehensive score considering multiple criteria.""" + + # 1. Distance from robot (preference for moderate distances) + robot_distance = np.sqrt( + (frontier.x - robot_pose.x) ** 2 + (frontier.y - robot_pose.y) ** 2 + ) + + # Distance score: prefer moderate distances (not too close, not too far) + optimal_distance = 4.0 # meters + distance_score = 1.0 / (1.0 + abs(robot_distance - optimal_distance)) + + # 2. Information gain (frontier size) + info_gain_score = frontier_size + + # 3. Distance to explored goals (bonus for being far from explored areas) + explored_goals_distance = self._compute_distance_to_explored_goals(frontier) + explored_goals_score = explored_goals_distance + + # 4. Distance to obstacles (penalty for being too close) + obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) + obstacles_score = obstacles_distance + + # 5. Direction momentum (if we have a current direction) + momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) + + # Combine scores with consistent scaling (no arbitrary multipliers) + total_score = ( + 0.3 * info_gain_score # 30% information gain + + 0.3 * explored_goals_score # 30% distance from explored goals + + 0.2 * distance_score # 20% distance optimization + + 0.15 * obstacles_score # 15% distance from obstacles + + 0.05 * momentum_score # 5% direction momentum + ) + + return total_score + + def _rank_frontiers( + self, + frontier_centroids: List[Vector], + frontier_sizes: List[int], + robot_pose: Vector, + costmap: Costmap, + ) -> List[Vector]: + """ + Find the single best frontier using comprehensive scoring and filtering. + + Args: + frontier_centroids: List of frontier centroids + frontier_sizes: List of frontier sizes + robot_pose: Current robot position + costmap: Costmap for additional analysis + + Returns: + List containing single best frontier, or empty list if none suitable + """ + if not frontier_centroids: + return [] + + valid_frontiers = [] + + for i, frontier in enumerate(frontier_centroids): + robot_distance = np.sqrt( + (frontier.x - robot_pose.x) ** 2 + (frontier.y - robot_pose.y) ** 2 + ) + + # Filter 1: Skip frontiers too close to robot + if robot_distance < self.min_distance_from_robot: + continue + + # Filter 2: Skip frontiers too close to obstacles + obstacle_distance = self._compute_distance_to_obstacles(frontier, costmap) + if obstacle_distance < self.min_distance_from_obstacles: + continue + + # Compute comprehensive score + frontier_size = frontier_sizes[i] if i < len(frontier_sizes) else 1 + score = self._compute_comprehensive_frontier_score( + frontier, frontier_size, robot_pose, costmap + ) + + valid_frontiers.append((frontier, score)) + + logger.info(f"Valid frontiers: {len(valid_frontiers)}") + + if not valid_frontiers: + return [] + + # Sort by score and return all valid frontiers (highest scores first) + valid_frontiers.sort(key=lambda x: x[1], reverse=True) + + # Extract just the frontiers (remove scores) and return as list + return [frontier for frontier, _ in valid_frontiers] + + def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional[Vector]: + """ + Get the single best exploration goal using comprehensive frontier scoring. + + Args: + robot_pose: Current robot position in world coordinates (Vector with x, y) + costmap: Costmap for additional analysis + + Returns: + Single best frontier goal in world coordinates, or None if no suitable frontiers found + """ + # Check if we should compare costmaps for information gain + if len(self.explored_goals) > 5 and self.last_costmap is not None: + current_info = self._count_costmap_information(costmap) + last_info = self._count_costmap_information(self.last_costmap) + + # Check if information increase meets minimum percentage threshold + if last_info > 0: # Avoid division by zero + info_increase_percent = (current_info - last_info) / last_info + if info_increase_percent < self.info_gain_threshold: + logger.info( + f"Information increase ({info_increase_percent:.2f}) below threshold ({self.info_gain_threshold:.2f})" + ) + logger.info( + f"Current information: {current_info}, Last information: {last_info}" + ) + self.num_no_gain_attempts += 1 + if self.num_no_gain_attempts >= self.num_no_gain_attempts: + logger.info( + "No information gain for {} consecutive attempts, skipping frontier selection".format( + self.num_no_gain_attempts + ) + ) + self.reset_exploration_session() + return None + + # Always detect new frontiers to get most up-to-date information + # The new algorithm filters out explored areas and returns only the best frontier + frontiers = self.detect_frontiers(robot_pose, costmap) + + if not frontiers: + # Store current costmap before returning + self.last_costmap = costmap + self.reset_exploration_session() + return None + + # Update exploration direction based on best goal selection + if frontiers: + self._update_exploration_direction(robot_pose, frontiers[0]) + + # Store the selected goal as explored + selected_goal = frontiers[0] + self.mark_explored_goal(selected_goal) + + # Store current costmap for next comparison + self.last_costmap = costmap + + return selected_goal + + # Store current costmap before returning + self.last_costmap = costmap + return None + + def mark_explored_goal(self, goal: Vector): + """Mark a goal as explored.""" + self.explored_goals.append(goal) + + def reset_exploration_session(self): + """ + Reset all exploration state variables for a new exploration session. + + Call this method when starting a new exploration or when the robot + needs to forget its previous exploration history. + """ + self.explored_goals.clear() # Clear all previously explored goals + self.exploration_direction = Vector([0.0, 0.0]) # Reset exploration direction + self.last_costmap = None # Clear last costmap comparison + self.num_no_gain_attempts = 0 # Reset no-gain attempt counter + self._cache.clear() # Clear frontier point cache + + logger.info("Exploration session reset - all state variables cleared") + + def explore(self, stop_event: Optional[threading.Event] = None) -> bool: + """ + Perform autonomous frontier exploration by continuously finding and navigating to frontiers. + + Args: + stop_event: Optional threading.Event to signal when exploration should stop + + Returns: + bool: True if exploration completed successfully, False if stopped or failed + """ + + logger.info("Starting autonomous frontier exploration") + + while True: + # Check if stop event is set + if stop_event and stop_event.is_set(): + logger.info("Exploration stopped by stop event") + return False + + # Get fresh robot position and costmap data + robot_pose = self.get_robot_pos() + costmap = self.get_costmap() + + # Get the next frontier goal + next_goal = self.get_exploration_goal(robot_pose, costmap) + if not next_goal: + logger.info("No more frontiers found, exploration complete") + return True + + # Navigate to the frontier + logger.info(f"Navigating to frontier at {next_goal}") + navigation_successful = self.set_goal(next_goal, stop_event=stop_event) + + if not navigation_successful: + logger.warning("Failed to navigate to frontier, continuing exploration") + # Continue to try other frontiers instead of stopping + continue diff --git a/build/lib/dimos/robot/global_planner/__init__.py b/build/lib/dimos/robot/global_planner/__init__.py new file mode 100644 index 0000000000..f26a5e8f7c --- /dev/null +++ b/build/lib/dimos/robot/global_planner/__init__.py @@ -0,0 +1 @@ +from dimos.robot.global_planner.planner import AstarPlanner, Planner diff --git a/build/lib/dimos/robot/global_planner/algo.py b/build/lib/dimos/robot/global_planner/algo.py new file mode 100644 index 0000000000..236725ce05 --- /dev/null +++ b/build/lib/dimos/robot/global_planner/algo.py @@ -0,0 +1,273 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import heapq +from typing import Optional, Tuple +from collections import deque +from dimos.types.path import Path +from dimos.types.vector import VectorLike, Vector +from dimos.types.costmap import Costmap + + +def find_nearest_free_cell( + costmap: Costmap, position: VectorLike, cost_threshold: int = 90, max_search_radius: int = 20 +) -> Tuple[int, int]: + """ + Find the nearest unoccupied cell in the costmap using BFS. + + Args: + costmap: Costmap object containing the environment + position: Position to find nearest free cell from + cost_threshold: Cost threshold above which a cell is considered an obstacle + max_search_radius: Maximum search radius in cells + + Returns: + Tuple of (x, y) in grid coordinates of the nearest free cell, + or the original position if no free cell is found within max_search_radius + """ + # Convert world coordinates to grid coordinates + grid_pos = costmap.world_to_grid(position) + start_x, start_y = int(grid_pos.x), int(grid_pos.y) + + # If the cell is already free, return it + if 0 <= start_x < costmap.width and 0 <= start_y < costmap.height: + if costmap.grid[start_y, start_x] < cost_threshold: + return (start_x, start_y) + + # BFS to find nearest free cell + queue = deque([(start_x, start_y, 0)]) # (x, y, distance) + visited = set([(start_x, start_y)]) + + # Possible movements (8-connected grid) + directions = [ + (0, 1), + (1, 0), + (0, -1), + (-1, 0), # horizontal/vertical + (1, 1), + (1, -1), + (-1, 1), + (-1, -1), # diagonal + ] + + while queue: + x, y, dist = queue.popleft() + + # Check if we've reached the maximum search radius + if dist > max_search_radius: + print( + f"Could not find free cell within {max_search_radius} cells of ({start_x}, {start_y})" + ) + return (start_x, start_y) # Return original position if no free cell found + + # Check if this cell is valid and free + if 0 <= x < costmap.width and 0 <= y < costmap.height: + if costmap.grid[y, x] < cost_threshold: + print( + f"Found free cell at ({x}, {y}), {dist} cells away from ({start_x}, {start_y})" + ) + return (x, y) + + # Add neighbors to the queue + for dx, dy in directions: + nx, ny = x + dx, y + dy + if (nx, ny) not in visited: + visited.add((nx, ny)) + queue.append((nx, ny, dist + 1)) + + # If the queue is empty and no free cell is found, return the original position + return (start_x, start_y) + + +def astar( + costmap: Costmap, + goal: VectorLike, + start: VectorLike = (0.0, 0.0), + cost_threshold: int = 90, + allow_diagonal: bool = True, +) -> Optional[Path]: + """ + A* path planning algorithm from start to goal position. + + Args: + costmap: Costmap object containing the environment + goal: Goal position as any vector-like object + start: Start position as any vector-like object (default: origin [0,0]) + cost_threshold: Cost threshold above which a cell is considered an obstacle + allow_diagonal: Whether to allow diagonal movements + + Returns: + Path object containing waypoints, or None if no path found + """ + # Convert world coordinates to grid coordinates directly using vector-like inputs + start_vector = costmap.world_to_grid(start) + goal_vector = costmap.world_to_grid(goal) + + # Store original positions for reference + original_start = (int(start_vector.x), int(start_vector.y)) + original_goal = (int(goal_vector.x), int(goal_vector.y)) + + adjusted_start = original_start + adjusted_goal = original_goal + + # Check if start is out of bounds or in an obstacle + start_valid = 0 <= start_vector.x < costmap.width and 0 <= start_vector.y < costmap.height + + start_in_obstacle = False + if start_valid: + start_in_obstacle = costmap.grid[int(start_vector.y), int(start_vector.x)] >= cost_threshold + + if not start_valid or start_in_obstacle: + print("Start position is out of bounds or in an obstacle, finding nearest free cell") + adjusted_start = find_nearest_free_cell(costmap, start, cost_threshold) + # Update start_vector for later use + start_vector = Vector(adjusted_start[0], adjusted_start[1]) + + # Check if goal is out of bounds or in an obstacle + goal_valid = 0 <= goal_vector.x < costmap.width and 0 <= goal_vector.y < costmap.height + + goal_in_obstacle = False + if goal_valid: + goal_in_obstacle = costmap.grid[int(goal_vector.y), int(goal_vector.x)] >= cost_threshold + + if not goal_valid or goal_in_obstacle: + print("Goal position is out of bounds or in an obstacle, finding nearest free cell") + adjusted_goal = find_nearest_free_cell(costmap, goal, cost_threshold) + # Update goal_vector for later use + goal_vector = Vector(adjusted_goal[0], adjusted_goal[1]) + + # Define possible movements (8-connected grid) + if allow_diagonal: + # 8-connected grid: horizontal, vertical, and diagonal movements + directions = [ + (0, 1), + (1, 0), + (0, -1), + (-1, 0), + (1, 1), + (1, -1), + (-1, 1), + (-1, -1), + ] + else: + # 4-connected grid: only horizontal and vertical ts + directions = [(0, 1), (1, 0), (0, -1), (-1, 0)] + + # Cost for each movement (straight vs diagonal) + sc = 1.0 + dc = 1.42 + movement_costs = [sc, sc, sc, sc, dc, dc, dc, dc] if allow_diagonal else [sc, sc, sc, sc] + + # A* algorithm implementation + open_set = [] # Priority queue for nodes to explore + closed_set = set() # Set of explored nodes + + # Use adjusted positions as tuples for dictionary keys + start_tuple = adjusted_start + goal_tuple = adjusted_goal + + # Dictionary to store cost from start and parents for each node + g_score = {start_tuple: 0} + parents = {} + + # Heuristic function (Euclidean distance) + def heuristic(x1, y1, x2, y2): + return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + + # Start with the starting node + f_score = g_score[start_tuple] + heuristic( + start_tuple[0], start_tuple[1], goal_tuple[0], goal_tuple[1] + ) + heapq.heappush(open_set, (f_score, start_tuple)) + + while open_set: + # Get the node with the lowest f_score + _, current = heapq.heappop(open_set) + current_x, current_y = current + + # Check if we've reached the goal + if current == goal_tuple: + # Reconstruct the path + waypoints = [] + while current in parents: + world_point = costmap.grid_to_world(current) + waypoints.append(world_point) + current = parents[current] + + # Add the start position + start_world_point = costmap.grid_to_world(start_tuple) + waypoints.append(start_world_point) + + # Reverse the path (start to goal) + waypoints.reverse() + + # Add the goal position if it's not already included + goal_point = costmap.grid_to_world(goal_tuple) + + if not waypoints or waypoints[-1].distance(goal_point) > 1e-5: + waypoints.append(goal_point) + + # If we adjusted the goal, add the original goal as the final point + if adjusted_goal != original_goal and goal_valid: + original_goal_point = costmap.grid_to_world(original_goal) + waypoints.append(original_goal_point) + + return Path(waypoints) + + # Add current node to closed set + closed_set.add(current) + + # Explore neighbors + for i, (dx, dy) in enumerate(directions): + neighbor_x, neighbor_y = current_x + dx, current_y + dy + neighbor = (neighbor_x, neighbor_y) + + # Check if the neighbor is valid + if not (0 <= neighbor_x < costmap.width and 0 <= neighbor_y < costmap.height): + continue + + # Check if the neighbor is already explored + if neighbor in closed_set: + continue + + # Check if the neighbor is an obstacle + neighbor_val = costmap.grid[neighbor_y, neighbor_x] + if neighbor_val >= cost_threshold: # or neighbor_val < 0: + continue + + obstacle_proximity_penalty = costmap.grid[neighbor_y, neighbor_x] / 25 + tentative_g_score = ( + g_score[current] + + movement_costs[i] + + (obstacle_proximity_penalty * movement_costs[i]) + ) + + # Get the current g_score for the neighbor or set to infinity if not yet explored + neighbor_g_score = g_score.get(neighbor, float("inf")) + + # If this path to the neighbor is better than any previous one + if tentative_g_score < neighbor_g_score: + # Update the neighbor's scores and parent + parents[neighbor] = current + g_score[neighbor] = tentative_g_score + f_score = tentative_g_score + heuristic( + neighbor_x, neighbor_y, goal_tuple[0], goal_tuple[1] + ) + + # Add the neighbor to the open set with its f_score + heapq.heappush(open_set, (f_score, neighbor)) + + # If we get here, no path was found + return None diff --git a/build/lib/dimos/robot/global_planner/planner.py b/build/lib/dimos/robot/global_planner/planner.py new file mode 100644 index 0000000000..55eea616a0 --- /dev/null +++ b/build/lib/dimos/robot/global_planner/planner.py @@ -0,0 +1,96 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from abc import abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.global_planner.algo import astar +from dimos.types.costmap import Costmap +from dimos.types.path import Path +from dimos.types.vector import Vector, VectorLike, to_vector +from dimos.utils.logging_config import setup_logger +from dimos.web.websocket_vis.helpers import Visualizable + +logger = setup_logger("dimos.robot.unitree.global_planner") + + +@dataclass +class Planner(Visualizable, Module): + target: In[Vector3] = None + path: Out[Path] = None + + def __init__(self): + Module.__init__(self) + Visualizable.__init__(self) + + # def set_goal( + # self, + # goal: VectorLike, + # goal_theta: Optional[float] = None, + # stop_event: Optional[threading.Event] = None, + # ): + # path = self.plan(goal) + # if not path: + # logger.warning("No path found to the goal.") + # return False + + # print("pathing success", path) + # return self.set_local_nav(path, stop_event=stop_event, goal_theta=goal_theta) + + +class AstarPlanner(Planner): + target: In[Vector3] = None + path: Out[Path] = None + + get_costmap: Callable[[], Costmap] + get_robot_pos: Callable[[], Vector3] + + conservativism: int = 8 + + def __init__( + self, + get_costmap: Callable[[], Costmap], + get_robot_pos: Callable[[], Vector3], + ): + super().__init__() + self.get_costmap = get_costmap + self.get_robot_pos = get_robot_pos + + @rpc + def start(self): + self.target.subscribe(self.plan) + + def plan(self, goal: VectorLike) -> Path: + print("planning path to goal", goal) + goal = to_vector(goal).to_2d() + pos = self.get_robot_pos() + print("current pos", pos) + costmap = self.get_costmap().smudge() + + print("current costmap", costmap) + self.vis("target", goal) + + print("ASTAR ", costmap, goal, pos) + path = astar(costmap, goal, pos) + + if path: + path = path.resample(0.1) + self.vis("a*", path) + self.path.publish(path) + return path + logger.warning("No path found to the goal.") diff --git a/build/lib/dimos/robot/local_planner/__init__.py b/build/lib/dimos/robot/local_planner/__init__.py new file mode 100644 index 0000000000..472b58dcd2 --- /dev/null +++ b/build/lib/dimos/robot/local_planner/__init__.py @@ -0,0 +1,7 @@ +from dimos.robot.local_planner.local_planner import ( + BaseLocalPlanner, + navigate_to_goal_local, + navigate_path_local, +) + +from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner diff --git a/build/lib/dimos/robot/local_planner/local_planner.py b/build/lib/dimos/robot/local_planner/local_planner.py new file mode 100644 index 0000000000..286ee94f2b --- /dev/null +++ b/build/lib/dimos/robot/local_planner/local_planner.py @@ -0,0 +1,1442 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numpy as np +from typing import Dict, Tuple, Optional, Callable, Any +from abc import ABC, abstractmethod +import cv2 +from reactivex import Observable +from reactivex.subject import Subject +import threading +import time +import logging +from collections import deque +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import normalize_angle, distance_angle_to_goal_xy + +from dimos.types.vector import VectorLike, Vector, to_tuple +from dimos.types.path import Path +from dimos.types.costmap import Costmap + +logger = setup_logger("dimos.robot.unitree.local_planner", level=logging.DEBUG) + + +class BaseLocalPlanner(ABC): + """ + Abstract base class for local planners that handle obstacle avoidance and path following. + + This class defines the common interface and shared functionality that all local planners + must implement, regardless of the specific algorithm used. + + Args: + get_costmap: Function to get the latest local costmap + get_robot_pose: Function to get the latest robot pose (returning odom object) + move: Function to send velocity commands + safety_threshold: Distance to maintain from obstacles (meters) + max_linear_vel: Maximum linear velocity (m/s) + max_angular_vel: Maximum angular velocity (rad/s) + lookahead_distance: Lookahead distance for path following (meters) + goal_tolerance: Distance at which the goal is considered reached (meters) + angle_tolerance: Angle at which the goal orientation is considered reached (radians) + robot_width: Width of the robot for visualization (meters) + robot_length: Length of the robot for visualization (meters) + visualization_size: Size of the visualization image in pixels + control_frequency: Frequency at which the planner is called (Hz) + safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) + max_recovery_attempts: Maximum number of recovery attempts before failing navigation. + If the robot gets stuck and cannot recover within this many attempts, navigation will fail. + global_planner_plan: Optional callable to plan a global path to the goal. + If provided, this will be used to generate a path to the goal before local planning. + """ + + def __init__( + self, + get_costmap: Callable[[], Optional[Costmap]], + get_robot_pose: Callable[[], Any], + move: Callable[[Vector], None], + safety_threshold: float = 0.5, + max_linear_vel: float = 0.8, + max_angular_vel: float = 1.0, + lookahead_distance: float = 1.0, + goal_tolerance: float = 0.75, + angle_tolerance: float = 0.5, + robot_width: float = 0.5, + robot_length: float = 0.7, + visualization_size: int = 400, + control_frequency: float = 10.0, + safe_goal_distance: float = 1.5, + max_recovery_attempts: int = 4, + global_planner_plan: Optional[Callable[[VectorLike], Optional[Any]]] = None, + ): # Control frequency in Hz + # Store callables for robot interactions + self.get_costmap = get_costmap + self.get_robot_pose = get_robot_pose + self.move = move + + # Store parameters + self.safety_threshold = safety_threshold + self.max_linear_vel = max_linear_vel + self.max_angular_vel = max_angular_vel + self.lookahead_distance = lookahead_distance + self.goal_tolerance = goal_tolerance + self.angle_tolerance = angle_tolerance + self.robot_width = robot_width + self.robot_length = robot_length + self.visualization_size = visualization_size + self.control_frequency = control_frequency + self.control_period = 1.0 / control_frequency # Period in seconds + self.safe_goal_distance = safe_goal_distance # Distance to ignore obstacles at goal + self.ignore_obstacles = False # Flag for derived classes to check + self.max_recovery_attempts = max_recovery_attempts # Maximum recovery attempts + self.recovery_attempts = 0 # Current number of recovery attempts + self.global_planner_plan = global_planner_plan # Global planner function for replanning + + # Goal and Waypoint Tracking + self.goal_xy: Optional[Tuple[float, float]] = None # Current target for planning + self.goal_theta: Optional[float] = None # Goal orientation (radians) + self.waypoints: Optional[Path] = None # List of waypoints to follow + self.waypoints_in_absolute: Optional[Path] = None # Full path in absolute frame + self.waypoint_is_relative: bool = False # Whether waypoints are in relative frame + self.current_waypoint_index: int = 0 # Index of the next waypoint to reach + self.final_goal_reached: bool = False # Flag indicating if the final waypoint is reached + self.position_reached: bool = False # Flag indicating if position goal is reached + + # Stuck detection + self.stuck_detection_window_seconds = 4.0 # Time window for stuck detection (seconds) + self.position_history_size = int(self.stuck_detection_window_seconds * control_frequency) + self.position_history = deque( + maxlen=self.position_history_size + ) # History of recent positions + self.stuck_distance_threshold = 0.15 # Distance threshold for stuck detection (meters) + self.unstuck_distance_threshold = ( + 0.5 # Distance threshold for unstuck detection (meters) - increased hysteresis + ) + self.stuck_time_threshold = 3.0 # Time threshold for stuck detection (seconds) - increased + self.is_recovery_active = False # Whether recovery behavior is active + self.recovery_start_time = 0.0 # When recovery behavior started + self.recovery_duration = ( + 10.0 # How long to run recovery before giving up (seconds) - increased + ) + self.last_update_time = time.time() # Last time position was updated + self.navigation_failed = False # Flag indicating if navigation should be terminated + + # Recovery improvements + self.recovery_cooldown_time = ( + 3.0 # Seconds to wait after recovery before checking stuck again + ) + self.last_recovery_end_time = 0.0 # When the last recovery ended + self.pre_recovery_position = ( + None # Position when recovery started (for better stuck detection) + ) + self.backup_duration = 4.0 # How long to backup when stuck (seconds) + + # Cached data updated periodically for consistent plan() execution time + self._robot_pose = None + self._costmap = None + self._update_frequency = 10.0 # Hz - how often to update cached data + self._update_timer = None + self._start_periodic_updates() + + def _start_periodic_updates(self): + self._update_timer = threading.Thread(target=self._periodic_update, daemon=True) + self._update_timer.start() + + def _periodic_update(self): + while True: + self._robot_pose = self.get_robot_pose() + self._costmap = self.get_costmap() + time.sleep(1.0 / self._update_frequency) + + def reset(self): + """ + Reset all navigation and state tracking variables. + Should be called whenever a new goal is set. + """ + # Reset stuck detection state + self.position_history.clear() + self.is_recovery_active = False + self.recovery_start_time = 0.0 + self.last_update_time = time.time() + + # Reset navigation state flags + self.navigation_failed = False + self.position_reached = False + self.final_goal_reached = False + self.ignore_obstacles = False + + # Reset recovery improvements + self.last_recovery_end_time = 0.0 + self.pre_recovery_position = None + + # Reset recovery attempts + self.recovery_attempts = 0 + + # Clear waypoint following state + self.waypoints = None + self.current_waypoint_index = 0 + self.goal_xy = None # Clear previous goal + self.goal_theta = None # Clear previous goal orientation + + logger.info("Local planner state has been reset") + + def _get_robot_pose(self) -> Tuple[Tuple[float, float], float]: + """ + Get the current robot position and orientation. + + Returns: + Tuple containing: + - position as (x, y) tuple + - orientation (theta) in radians + """ + if self._robot_pose is None: + return ((0.0, 0.0), 0.0) # Fallback if not yet initialized + pos, rot = self._robot_pose.pos, self._robot_pose.rot + return (pos.x, pos.y), rot.z + + def _get_costmap(self): + """Get cached costmap data.""" + return self._costmap + + def clear_cache(self): + """Clear all cached data to force fresh retrieval on next access.""" + self._robot_pose = None + self._costmap = None + + def set_goal( + self, goal_xy: VectorLike, is_relative: bool = False, goal_theta: Optional[float] = None + ): + """Set a single goal position, converting to absolute frame if necessary. + This clears any existing waypoints being followed. + + Args: + goal_xy: The goal position to set. + is_relative: Whether the goal is in the robot's relative frame. + goal_theta: Optional goal orientation in radians + """ + # Reset all state variables + self.reset() + + target_goal_xy: Optional[Tuple[float, float]] = None + + # Transform goal to absolute frame if it's relative + if is_relative: + # Get current robot pose + odom = self._robot_pose + if odom is None: + logger.warning("Robot pose not yet available, cannot set relative goal") + return + robot_pos, robot_rot = odom.pos, odom.rot + + # Extract current position and orientation + robot_x, robot_y = robot_pos.x, robot_pos.y + robot_theta = robot_rot.z # Assuming rotation is euler angles + + # Transform the relative goal into absolute coordinates + goal_x, goal_y = to_tuple(goal_xy) + # Rotate + abs_x = goal_x * math.cos(robot_theta) - goal_y * math.sin(robot_theta) + abs_y = goal_x * math.sin(robot_theta) + goal_y * math.cos(robot_theta) + # Translate + target_goal_xy = (robot_x + abs_x, robot_y + abs_y) + + logger.info( + f"Goal set in relative frame, converted to absolute: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" + ) + else: + target_goal_xy = to_tuple(goal_xy) + logger.info( + f"Goal set directly in absolute frame: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" + ) + + # Check if goal is valid (in bounds and not colliding) + if not self.is_goal_in_costmap_bounds(target_goal_xy) or self.check_goal_collision( + target_goal_xy + ): + logger.warning( + "Goal is in collision or out of bounds. Adjusting goal to valid position." + ) + self.goal_xy = self.adjust_goal_to_valid_position(target_goal_xy) + else: + self.goal_xy = target_goal_xy # Set the adjusted or original valid goal + + # Set goal orientation if provided + if goal_theta is not None: + if is_relative: + # Transform the orientation to absolute frame + odom = self._robot_pose + if odom is None: + logger.warning( + "Robot pose not yet available, cannot set relative goal orientation" + ) + return + robot_theta = odom.rot.z + self.goal_theta = normalize_angle(goal_theta + robot_theta) + else: + self.goal_theta = goal_theta + + def set_goal_waypoints(self, waypoints: Path, goal_theta: Optional[float] = None): + """Sets a path of waypoints for the robot to follow. + + Args: + waypoints: A list of waypoints to follow. Each waypoint is a tuple of (x, y) coordinates in absolute frame. + goal_theta: Optional final orientation in radians + """ + # Reset all state variables + self.reset() + + if not isinstance(waypoints, Path) or len(waypoints) == 0: + logger.warning("Invalid or empty path provided to set_goal_waypoints. Ignoring.") + self.waypoints = None + self.waypoint_is_relative = False + self.goal_xy = None + self.goal_theta = None + self.current_waypoint_index = 0 + return + + logger.info(f"Setting goal waypoints with {len(waypoints)} points.") + self.waypoints = waypoints + self.waypoint_is_relative = False + self.current_waypoint_index = 0 + + # Waypoints are always in absolute frame + self.waypoints_in_absolute = waypoints + + # Set the initial target to the first waypoint, adjusting if necessary + first_waypoint = self.waypoints_in_absolute[0] + if not self.is_goal_in_costmap_bounds(first_waypoint) or self.check_goal_collision( + first_waypoint + ): + logger.warning("First waypoint is invalid. Adjusting...") + self.goal_xy = self.adjust_goal_to_valid_position(first_waypoint) + else: + self.goal_xy = to_tuple(first_waypoint) # Initial target + + # Set goal orientation if provided + if goal_theta is not None: + self.goal_theta = goal_theta + + def _get_final_goal_position(self) -> Optional[Tuple[float, float]]: + """ + Get the final goal position (either last waypoint or direct goal). + + Returns: + Tuple (x, y) of the final goal, or None if no goal is set + """ + if self.waypoints_in_absolute is not None and len(self.waypoints_in_absolute) > 0: + return to_tuple(self.waypoints_in_absolute[-1]) + elif self.goal_xy is not None: + return self.goal_xy + return None + + def _distance_to_position(self, target_position: Tuple[float, float]) -> float: + """ + Calculate distance from the robot to a target position. + + Args: + target_position: Target (x, y) position + + Returns: + Distance in meters + """ + robot_pos, _ = self._get_robot_pose() + return np.linalg.norm( + [target_position[0] - robot_pos[0], target_position[1] - robot_pos[1]] + ) + + def plan(self) -> Dict[str, float]: + """ + Main planning method that computes velocity commands. + This includes common planning logic like waypoint following, + with algorithm-specific calculations delegated to subclasses. + + Returns: + Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys + """ + # If goal orientation is specified, rotate to match it + if ( + self.position_reached + and self.goal_theta is not None + and not self._is_goal_orientation_reached() + ): + return self._rotate_to_goal_orientation() + elif self.position_reached and self.goal_theta is None: + self.final_goal_reached = True + logger.info("Position goal reached. Stopping.") + return {"x_vel": 0.0, "angular_vel": 0.0} + + # Check if the robot is stuck and handle accordingly + if self.check_if_stuck() and not self.position_reached: + # Check if we're stuck but close to our goal + final_goal_pos = self._get_final_goal_position() + + # If we have a goal position, check distance to it + if final_goal_pos is not None: + distance_to_goal = self._distance_to_position(final_goal_pos) + + # If we're stuck but within 2x safe_goal_distance of the goal, consider it a success + if distance_to_goal < 2.0 * self.safe_goal_distance: + logger.info( + f"Robot is stuck but within {distance_to_goal:.2f}m of goal (< {2.0 * self.safe_goal_distance:.2f}m). Considering navigation successful." + ) + self.position_reached = True + return {"x_vel": 0.0, "angular_vel": 0.0} + + if self.navigation_failed: + return {"x_vel": 0.0, "angular_vel": 0.0} + + # Otherwise, execute normal recovery behavior + logger.warning("Robot is stuck - executing recovery behavior") + return self.execute_recovery_behavior() + + # Reset obstacle ignore flag + self.ignore_obstacles = False + + # --- Waypoint Following Mode --- + if self.waypoints is not None: + if self.final_goal_reached: + return {"x_vel": 0.0, "angular_vel": 0.0} + + # Get current robot pose + robot_pos, robot_theta = self._get_robot_pose() + robot_pos_np = np.array(robot_pos) + + # Check if close to final waypoint + if self.waypoints_in_absolute is not None and len(self.waypoints_in_absolute) > 0: + final_waypoint = self.waypoints_in_absolute[-1] + dist_to_final = np.linalg.norm(robot_pos_np - final_waypoint) + + # If we're close to the final waypoint, adjust it and ignore obstacles + if dist_to_final < self.safe_goal_distance: + final_wp_tuple = to_tuple(final_waypoint) + adjusted_goal = self.adjust_goal_to_valid_position(final_wp_tuple) + # Create a new Path with the adjusted final waypoint + new_waypoints = self.waypoints_in_absolute[:-1] # Get all but the last waypoint + new_waypoints.append(adjusted_goal) # Append the adjusted goal + self.waypoints_in_absolute = new_waypoints + self.ignore_obstacles = True + + # Update the target goal based on waypoint progression + just_reached_final = self._update_waypoint_target(robot_pos_np) + + # If the helper indicates the final goal was just reached, stop immediately + if just_reached_final: + return {"x_vel": 0.0, "angular_vel": 0.0} + + # --- Single Goal or Current Waypoint Target Set --- + if self.goal_xy is None: + # If no goal is set (e.g., empty path or rejected goal), stop. + return {"x_vel": 0.0, "angular_vel": 0.0} + + # Get necessary data for planning + costmap = self._get_costmap() + if costmap is None: + logger.warning("Local costmap is None. Cannot plan.") + return {"x_vel": 0.0, "angular_vel": 0.0} + + # Check if close to single goal mode goal + if self.waypoints is None: + # Get distance to goal + goal_distance = self._distance_to_position(self.goal_xy) + + # If within safe distance of goal, adjust it and ignore obstacles + if goal_distance < self.safe_goal_distance: + self.goal_xy = self.adjust_goal_to_valid_position(self.goal_xy) + self.ignore_obstacles = True + + # First check position + if goal_distance < self.goal_tolerance or self.position_reached: + self.position_reached = True + + else: + self.position_reached = False + + # Call the algorithm-specific planning implementation + return self._compute_velocity_commands() + + @abstractmethod + def _compute_velocity_commands(self) -> Dict[str, float]: + """ + Algorithm-specific method to compute velocity commands. + Must be implemented by derived classes. + + Returns: + Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys + """ + pass + + def _rotate_to_goal_orientation(self) -> Dict[str, float]: + """Compute velocity commands to rotate to the goal orientation. + + Returns: + Dict[str, float]: Velocity commands with zero linear velocity + """ + # Get current robot orientation + _, robot_theta = self._get_robot_pose() + + # Calculate the angle difference + angle_diff = normalize_angle(self.goal_theta - robot_theta) + + # Determine rotation direction and speed + if abs(angle_diff) < self.angle_tolerance: + # Already at correct orientation + return {"x_vel": 0.0, "angular_vel": 0.0} + + # Calculate rotation speed - proportional to the angle difference + # but capped at max_angular_vel + direction = 1.0 if angle_diff > 0 else -1.0 + angular_vel = direction * min(abs(angle_diff), self.max_angular_vel) + + return {"x_vel": 0.0, "angular_vel": angular_vel} + + def _is_goal_orientation_reached(self) -> bool: + """Check if the current robot orientation matches the goal orientation. + + Returns: + bool: True if orientation is reached or no orientation goal is set + """ + if self.goal_theta is None: + return True # No orientation goal set + + # Get current robot orientation + _, robot_theta = self._get_robot_pose() + + # Calculate the angle difference and normalize + angle_diff = abs(normalize_angle(self.goal_theta - robot_theta)) + + return angle_diff <= self.angle_tolerance + + def _update_waypoint_target(self, robot_pos_np: np.ndarray) -> bool: + """Helper function to manage waypoint progression and update the target goal. + + Args: + robot_pos_np: Current robot position as a numpy array [x, y]. + + Returns: + bool: True if the final waypoint has just been reached, False otherwise. + """ + if self.waypoints is None or len(self.waypoints) == 0: + return False # Not in waypoint mode or empty path + + # Waypoints are always in absolute frame + self.waypoints_in_absolute = self.waypoints + + # Check if final goal is reached + final_waypoint = self.waypoints_in_absolute[-1] + dist_to_final = np.linalg.norm(robot_pos_np - final_waypoint) + + if dist_to_final <= self.goal_tolerance: + # Final waypoint position reached + if self.goal_theta is not None: + # Check orientation if specified + if self._is_goal_orientation_reached(): + self.final_goal_reached = True + return True + # Continue rotating + self.position_reached = True + return False + else: + # No orientation goal, mark as reached + self.final_goal_reached = True + return True + + # Always find the lookahead point + lookahead_point = None + for i in range(self.current_waypoint_index, len(self.waypoints_in_absolute)): + wp = self.waypoints_in_absolute[i] + dist_to_wp = np.linalg.norm(robot_pos_np - wp) + if dist_to_wp >= self.lookahead_distance: + lookahead_point = wp + # Update current waypoint index to this point + self.current_waypoint_index = i + break + + # If no point is far enough, target the final waypoint + if lookahead_point is None: + lookahead_point = self.waypoints_in_absolute[-1] + self.current_waypoint_index = len(self.waypoints_in_absolute) - 1 + + # Set the lookahead point as the immediate target, adjusting if needed + if not self.is_goal_in_costmap_bounds(lookahead_point) or self.check_goal_collision( + lookahead_point + ): + adjusted_lookahead = self.adjust_goal_to_valid_position(lookahead_point) + # Only update if adjustment didn't fail completely + if adjusted_lookahead is not None: + self.goal_xy = adjusted_lookahead + else: + self.goal_xy = to_tuple(lookahead_point) + + return False # Final goal not reached in this update cycle + + @abstractmethod + def update_visualization(self) -> np.ndarray: + """ + Generate visualization of the planning state. + Must be implemented by derived classes. + + Returns: + np.ndarray: Visualization image as numpy array + """ + pass + + def create_stream(self, frequency_hz: float = None) -> Observable: + """ + Create an Observable stream that emits the visualization image at a fixed frequency. + + Args: + frequency_hz: Optional frequency override (defaults to 1/4 of control_frequency if None) + + Returns: + Observable: Stream of visualization frames + """ + # Default to 1/4 of control frequency if not specified (to reduce CPU usage) + if frequency_hz is None: + frequency_hz = self.control_frequency / 4.0 + + subject = Subject() + sleep_time = 1.0 / frequency_hz + + def frame_emitter(): + while True: + try: + # Generate the frame using the updated method + frame = self.update_visualization() + subject.on_next(frame) + except Exception as e: + logger.error(f"Error in frame emitter thread: {e}") + # Optionally, emit an error frame or simply skip + # subject.on_error(e) # This would terminate the stream + time.sleep(sleep_time) + + emitter_thread = threading.Thread(target=frame_emitter, daemon=True) + emitter_thread.start() + logger.info(f"Started visualization frame emitter thread at {frequency_hz:.1f} Hz") + return subject + + @abstractmethod + def check_collision(self, direction: float) -> bool: + """ + Check if there's a collision in the given direction. + Must be implemented by derived classes. + + Args: + direction: Direction to check for collision in radians + + Returns: + bool: True if collision detected, False otherwise + """ + pass + + def is_goal_reached(self) -> bool: + """Check if the final goal (single or last waypoint) is reached, including orientation.""" + if self.waypoints is not None: + # Waypoint mode: check if the final waypoint and orientation have been reached + return self.final_goal_reached and self._is_goal_orientation_reached() + else: + # Single goal mode: check distance to the single goal and orientation + if self.goal_xy is None: + return False # No goal set + + if self.goal_theta is None: + return self.position_reached + + return self.position_reached and self._is_goal_orientation_reached() + + def check_goal_collision(self, goal_xy: VectorLike) -> bool: + """Check if the current goal is in collision with obstacles in the costmap. + + Returns: + bool: True if goal is in collision, False if goal is safe or cannot be checked + """ + + costmap = self._get_costmap() + if costmap is None: + logger.warning("Cannot check collision: No costmap available") + return False + + # Check if the position is occupied + collision_threshold = 80 # Consider values above 80 as obstacles + + # Use Costmap's is_occupied method + return costmap.is_occupied(goal_xy, threshold=collision_threshold) + + def is_goal_in_costmap_bounds(self, goal_xy: VectorLike) -> bool: + """Check if the goal position is within the bounds of the costmap. + + Args: + goal_xy: Goal position (x, y) in odom frame + + Returns: + bool: True if the goal is within the costmap bounds, False otherwise + """ + costmap = self._get_costmap() + if costmap is None: + logger.warning("Cannot check bounds: No costmap available") + return False + + # Get goal position in grid coordinates + goal_point = costmap.world_to_grid(goal_xy) + goal_cell_x, goal_cell_y = goal_point.x, goal_point.y + + # Check if goal is within the costmap bounds + is_in_bounds = 0 <= goal_cell_x < costmap.width and 0 <= goal_cell_y < costmap.height + + if not is_in_bounds: + logger.warning(f"Goal ({goal_xy[0]:.2f}, {goal_xy[1]:.2f}) is outside costmap bounds") + + return is_in_bounds + + def adjust_goal_to_valid_position( + self, goal_xy: VectorLike, clearance: float = 0.5 + ) -> Tuple[float, float]: + """Find a valid (non-colliding) goal position by moving it towards the robot. + + Args: + goal_xy: Original goal position (x, y) in odom frame + clearance: Additional distance to move back from obstacles for better clearance (meters) + + Returns: + Tuple[float, float]: A valid goal position, or the original goal if already valid + """ + [pos, rot] = self._get_robot_pose() + + robot_x, robot_y = pos[0], pos[1] + + # Original goal + goal_x, goal_y = to_tuple(goal_xy) + + if not self.check_goal_collision((goal_x, goal_y)): + return (goal_x, goal_y) + + # Calculate vector from goal to robot + dx = robot_x - goal_x + dy = robot_y - goal_y + distance = np.sqrt(dx * dx + dy * dy) + + if distance < 0.001: # Goal is at robot position + return to_tuple(goal_xy) + + # Normalize direction vector + dx /= distance + dy /= distance + + # Step size + step_size = 0.25 # meters + + # Move goal towards robot step by step + current_x, current_y = goal_x, goal_y + steps = 0 + max_steps = 50 # Safety limit + + # Variables to store the first valid position found + valid_found = False + valid_x, valid_y = None, None + + while steps < max_steps: + # Move towards robot + current_x += dx * step_size + current_y += dy * step_size + steps += 1 + + # Check if we've reached or passed the robot + new_distance = np.sqrt((current_x - robot_x) ** 2 + (current_y - robot_y) ** 2) + if new_distance < step_size: + # We've reached the robot without finding a valid point + # Move back one step from robot to avoid self-collision + current_x = robot_x - dx * step_size + current_y = robot_y - dy * step_size + break + + # Check if this position is valid + if not self.check_goal_collision( + (current_x, current_y) + ) and self.is_goal_in_costmap_bounds((current_x, current_y)): + # Store the first valid position + if not valid_found: + valid_found = True + valid_x, valid_y = current_x, current_y + + # If clearance is requested, continue searching for a better position + if clearance > 0: + continue + + # Calculate position with additional clearance + if clearance > 0: + # Calculate clearance position + clearance_x = current_x + dx * clearance + clearance_y = current_y + dy * clearance + + # Check if the clearance position is also valid + if not self.check_goal_collision( + (clearance_x, clearance_y) + ) and self.is_goal_in_costmap_bounds((clearance_x, clearance_y)): + return (clearance_x, clearance_y) + + # Return the valid position without clearance + return (current_x, current_y) + + # If we found a valid position earlier but couldn't add clearance + if valid_found: + return (valid_x, valid_y) + + logger.warning( + f"Could not find valid goal after {steps} steps, using closest point to robot" + ) + return (current_x, current_y) + + def check_if_stuck(self) -> bool: + """ + Check if the robot is stuck by analyzing movement history. + Includes improvements to prevent oscillation between stuck and recovered states. + + Returns: + bool: True if the robot is determined to be stuck, False otherwise + """ + # Get current position and time + current_time = time.time() + + # Get current robot position + [pos, _] = self._get_robot_pose() + current_position = (pos[0], pos[1], current_time) + + # If we're already in recovery, don't add movements to history (they're intentional) + # Instead, check if we should continue or end recovery + if self.is_recovery_active: + # Check if we've moved far enough from our pre-recovery position to consider unstuck + if self.pre_recovery_position is not None: + pre_recovery_x, pre_recovery_y = self.pre_recovery_position[:2] + displacement_from_start = np.sqrt( + (pos[0] - pre_recovery_x) ** 2 + (pos[1] - pre_recovery_y) ** 2 + ) + + # If we've moved far enough, we're unstuck + if displacement_from_start > self.unstuck_distance_threshold: + logger.info( + f"Robot has escaped from stuck state (moved {displacement_from_start:.3f}m from start)" + ) + self.is_recovery_active = False + self.last_recovery_end_time = current_time + # Do not reset recovery attempts here - only reset during replanning or goal reaching + # Clear position history to start fresh tracking + self.position_history.clear() + return False + + # Check if we've been trying to recover for too long + recovery_time = current_time - self.recovery_start_time + if recovery_time > self.recovery_duration: + logger.error( + f"Recovery behavior has been active for {self.recovery_duration}s without success" + ) + self.navigation_failed = True + return True + + # Continue recovery + return True + + # Check cooldown period - don't immediately check for stuck after recovery + if current_time - self.last_recovery_end_time < self.recovery_cooldown_time: + # Add position to history but don't check for stuck yet + self.position_history.append(current_position) + return False + + # Add current position to history (newest is appended at the end) + self.position_history.append(current_position) + + # Need enough history to make a determination + min_history_size = int( + self.stuck_detection_window_seconds * self.control_frequency * 0.6 + ) # 60% of window + if len(self.position_history) < min_history_size: + return False + + # Find positions within our detection window + window_start_time = current_time - self.stuck_detection_window_seconds + window_positions = [] + + # Collect positions within the window (newest entries will be at the end) + for pos_x, pos_y, timestamp in self.position_history: + if timestamp >= window_start_time: + window_positions.append((pos_x, pos_y, timestamp)) + + # Need at least a few positions in the window + if len(window_positions) < 3: + return False + + # Ensure correct order: oldest to newest + window_positions.sort(key=lambda p: p[2]) + + # Get the oldest and newest positions in the window + oldest_x, oldest_y, oldest_time = window_positions[0] + newest_x, newest_y, newest_time = window_positions[-1] + + # Calculate time range in the window + time_range = newest_time - oldest_time + + # Calculate displacement from oldest to newest position + displacement = np.sqrt((newest_x - oldest_x) ** 2 + (newest_y - oldest_y) ** 2) + + # Also check average displacement over multiple sub-windows to avoid false positives + sub_window_size = max(3, len(window_positions) // 3) + avg_displacement = 0.0 + displacement_count = 0 + + for i in range(0, len(window_positions) - sub_window_size, sub_window_size // 2): + start_pos = window_positions[i] + end_pos = window_positions[min(i + sub_window_size, len(window_positions) - 1)] + sub_displacement = np.sqrt( + (end_pos[0] - start_pos[0]) ** 2 + (end_pos[1] - start_pos[1]) ** 2 + ) + avg_displacement += sub_displacement + displacement_count += 1 + + if displacement_count > 0: + avg_displacement /= displacement_count + + # Check if we're stuck - moved less than threshold over minimum time + is_currently_stuck = ( + time_range >= self.stuck_time_threshold + and time_range <= self.stuck_detection_window_seconds + and displacement < self.stuck_distance_threshold + and avg_displacement < self.stuck_distance_threshold * 1.5 + ) + + if is_currently_stuck: + logger.warning( + f"Robot appears to be stuck! Total displacement: {displacement:.3f}m, " + f"avg displacement: {avg_displacement:.3f}m over {time_range:.1f}s" + ) + + # Start recovery behavior + self.is_recovery_active = True + self.recovery_start_time = current_time + self.pre_recovery_position = current_position + + # Clear position history to avoid contamination during recovery + self.position_history.clear() + + # Increment recovery attempts + self.recovery_attempts += 1 + logger.warning( + f"Starting recovery attempt {self.recovery_attempts}/{self.max_recovery_attempts}" + ) + + # Check if maximum recovery attempts have been exceeded + if self.recovery_attempts > self.max_recovery_attempts: + logger.error( + f"Maximum recovery attempts ({self.max_recovery_attempts}) exceeded. Navigation failed." + ) + self.navigation_failed = True + + return True + + return False + + def execute_recovery_behavior(self) -> Dict[str, float]: + """ + Execute enhanced recovery behavior when the robot is stuck. + - First attempt: Backup for a set duration + - Second+ attempts: Replan to the original goal using global planner + + Returns: + Dict[str, float]: Velocity commands for the recovery behavior + """ + current_time = time.time() + recovery_time = current_time - self.recovery_start_time + + # First recovery attempt: Simple backup behavior + if self.recovery_attempts % 2 == 0: + if recovery_time < self.backup_duration: + logger.warning(f"Recovery attempt 1: backup for {recovery_time:.1f}s") + return {"x_vel": -0.5, "angular_vel": 0.0} # Backup at moderate speed + else: + logger.info("Recovery attempt 1: backup completed") + self.recovery_attempts += 1 + return {"x_vel": 0.0, "angular_vel": 0.0} + + final_goal = self.waypoints_in_absolute[-1] + logger.info( + f"Recovery attempt {self.recovery_attempts}: replanning to final waypoint {final_goal}" + ) + + new_path = self.global_planner_plan(Vector([final_goal[0], final_goal[1]])) + + if new_path is not None: + logger.info("Replanning successful. Setting new waypoints.") + attempts = self.recovery_attempts + self.set_goal_waypoints(new_path, self.goal_theta) + self.recovery_attempts = attempts + self.is_recovery_active = False + self.last_recovery_end_time = current_time + else: + logger.error("Global planner could not find a path to the goal. Recovery failed.") + self.navigation_failed = True + + return {"x_vel": 0.0, "angular_vel": 0.0} + + +def navigate_to_goal_local( + robot, + goal_xy_robot: Tuple[float, float], + goal_theta: Optional[float] = None, + distance: float = 0.0, + timeout: float = 60.0, + stop_event: Optional[threading.Event] = None, +) -> bool: + """ + Navigates the robot to a goal specified in the robot's local frame + using the local planner. + + Args: + robot: Robot instance to control + goal_xy_robot: Tuple (x, y) representing the goal position relative + to the robot's current position and orientation. + distance: Desired distance to maintain from the goal in meters. + If non-zero, the robot will stop this far away from the goal. + timeout: Maximum time (in seconds) allowed to reach the goal. + stop_event: Optional threading.Event to signal when navigation should stop + + Returns: + bool: True if the goal was reached within the timeout, False otherwise. + """ + logger.info( + f"Starting navigation to local goal {goal_xy_robot} with distance {distance}m and timeout {timeout}s." + ) + + robot.local_planner.reset() + + goal_x, goal_y = goal_xy_robot + + # Calculate goal orientation to face the target + if goal_theta is None: + goal_theta = np.arctan2(goal_y, goal_x) + + # If distance is non-zero, adjust the goal to stop at the desired distance + if distance > 0: + # Calculate magnitude of the goal vector + goal_distance = np.sqrt(goal_x**2 + goal_y**2) + + # Only adjust if goal is further than the desired distance + if goal_distance > distance: + goal_x, goal_y = distance_angle_to_goal_xy(goal_distance - distance, goal_theta) + + # Set the goal in the robot's frame with orientation to face the original target + robot.local_planner.set_goal((goal_x, goal_y), is_relative=True, goal_theta=goal_theta) + + # Get control period from robot's local planner for consistent timing + control_period = 1.0 / robot.local_planner.control_frequency + + start_time = time.time() + goal_reached = False + + try: + while time.time() - start_time < timeout and not (stop_event and stop_event.is_set()): + # Check if goal has been reached + if robot.local_planner.is_goal_reached(): + logger.info("Goal reached successfully.") + goal_reached = True + break + + # Check if navigation failed flag is set + if robot.local_planner.navigation_failed: + logger.error("Navigation aborted due to repeated recovery failures.") + goal_reached = False + break + + # Get planned velocity towards the goal + vel_command = robot.local_planner.plan() + x_vel = vel_command.get("x_vel", 0.0) + angular_vel = vel_command.get("angular_vel", 0.0) + + # Send velocity command + robot.local_planner.move(Vector(x_vel, 0, angular_vel)) + + # Control loop frequency - use robot's control frequency + time.sleep(control_period) + + if not goal_reached: + logger.warning(f"Navigation timed out after {timeout} seconds before reaching goal.") + + except KeyboardInterrupt: + logger.info("Navigation to local goal interrupted by user.") + goal_reached = False # Consider interruption as failure + except Exception as e: + logger.error(f"Error during navigation to local goal: {e}") + goal_reached = False # Consider error as failure + finally: + logger.info("Stopping robot after navigation attempt.") + robot.local_planner.move(Vector(0, 0, 0)) # Stop the robot + + return goal_reached + + +def navigate_path_local( + robot, + path: Path, + timeout: float = 120.0, + goal_theta: Optional[float] = None, + stop_event: Optional[threading.Event] = None, +) -> bool: + """ + Navigates the robot along a path of waypoints using the waypoint following capability + of the local planner. + + Args: + robot: Robot instance to control + path: Path object containing waypoints in absolute frame + timeout: Maximum time (in seconds) allowed to follow the complete path + goal_theta: Optional final orientation in radians + stop_event: Optional threading.Event to signal when navigation should stop + + Returns: + bool: True if the entire path was successfully followed, False otherwise + """ + logger.info( + f"Starting navigation along path with {len(path)} waypoints and timeout {timeout}s." + ) + + robot.local_planner.reset() + + # Set the path in the local planner + robot.local_planner.set_goal_waypoints(path, goal_theta=goal_theta) + + # Get control period from robot's local planner for consistent timing + control_period = 1.0 / robot.local_planner.control_frequency + + start_time = time.time() + path_completed = False + + try: + while time.time() - start_time < timeout and not (stop_event and stop_event.is_set()): + # Check if the entire path has been traversed + if robot.local_planner.is_goal_reached(): + logger.info("Path traversed successfully.") + path_completed = True + break + + # Check if navigation failed flag is set + if robot.local_planner.navigation_failed: + logger.error("Navigation aborted due to repeated recovery failures.") + path_completed = False + break + + # Get planned velocity towards the current waypoint target + vel_command = robot.local_planner.plan() + x_vel = vel_command.get("x_vel", 0.0) + angular_vel = vel_command.get("angular_vel", 0.0) + + # Send velocity command + robot.local_planner.move(Vector(x_vel, 0, angular_vel)) + + # Control loop frequency - use robot's control frequency + time.sleep(control_period) + + if not path_completed: + logger.warning( + f"Path following timed out after {timeout} seconds before completing the path." + ) + + except KeyboardInterrupt: + logger.info("Path navigation interrupted by user.") + path_completed = False + except Exception as e: + logger.error(f"Error during path navigation: {e}") + path_completed = False + finally: + logger.info("Stopping robot after path navigation attempt.") + robot.local_planner.move(Vector(0, 0, 0)) # Stop the robot + + return path_completed + + +def visualize_local_planner_state( + occupancy_grid: np.ndarray, + grid_resolution: float, + grid_origin: Tuple[float, float], + robot_pose: Tuple[float, float, float], + visualization_size: int = 400, + robot_width: float = 0.5, + robot_length: float = 0.7, + map_size_meters: float = 10.0, + goal_xy: Optional[Tuple[float, float]] = None, + goal_theta: Optional[float] = None, + histogram: Optional[np.ndarray] = None, + selected_direction: Optional[float] = None, + waypoints: Optional["Path"] = None, + current_waypoint_index: Optional[int] = None, +) -> np.ndarray: + """Generate a bird's eye view visualization of the local costmap. + Optionally includes VFH histogram, selected direction, and waypoints path. + + Args: + occupancy_grid: 2D numpy array of the occupancy grid + grid_resolution: Resolution of the grid in meters/cell + grid_origin: Tuple (x, y) of the grid origin in the odom frame + robot_pose: Tuple (x, y, theta) of the robot pose in the odom frame + visualization_size: Size of the visualization image in pixels + robot_width: Width of the robot in meters + robot_length: Length of the robot in meters + map_size_meters: Size of the map to visualize in meters + goal_xy: Optional tuple (x, y) of the goal position in the odom frame + goal_theta: Optional goal orientation in radians (in odom frame) + histogram: Optional numpy array of the VFH histogram + selected_direction: Optional selected direction angle in radians + waypoints: Optional Path object containing waypoints to visualize + current_waypoint_index: Optional index of the current target waypoint + """ + + robot_x, robot_y, robot_theta = robot_pose + grid_origin_x, grid_origin_y = grid_origin + vis_size = visualization_size + scale = vis_size / map_size_meters + + vis_img = np.ones((vis_size, vis_size, 3), dtype=np.uint8) * 255 + center_x = vis_size // 2 + center_y = vis_size // 2 + + grid_height, grid_width = occupancy_grid.shape + + # Calculate robot position relative to grid origin + robot_rel_x = robot_x - grid_origin_x + robot_rel_y = robot_y - grid_origin_y + robot_cell_x = int(robot_rel_x / grid_resolution) + robot_cell_y = int(robot_rel_y / grid_resolution) + + half_size_cells = int(map_size_meters / grid_resolution / 2) + + # Draw grid cells (using standard occupancy coloring) + for y in range( + max(0, robot_cell_y - half_size_cells), min(grid_height, robot_cell_y + half_size_cells) + ): + for x in range( + max(0, robot_cell_x - half_size_cells), min(grid_width, robot_cell_x + half_size_cells) + ): + cell_rel_x_meters = (x - robot_cell_x) * grid_resolution + cell_rel_y_meters = (y - robot_cell_y) * grid_resolution + + img_x = int(center_x + cell_rel_x_meters * scale) + img_y = int(center_y - cell_rel_y_meters * scale) # Flip y-axis + + if 0 <= img_x < vis_size and 0 <= img_y < vis_size: + cell_value = occupancy_grid[y, x] + if cell_value == -1: + color = (200, 200, 200) # Unknown (Light gray) + elif cell_value == 0: + color = (255, 255, 255) # Free (White) + else: # Occupied + # Scale darkness based on occupancy value (0-100) + darkness = 255 - int(155 * (cell_value / 100)) - 100 + color = (darkness, darkness, darkness) # Shades of gray/black + + cell_size_px = max(1, int(grid_resolution * scale)) + cv2.rectangle( + vis_img, + (img_x - cell_size_px // 2, img_y - cell_size_px // 2), + (img_x + cell_size_px // 2, img_y + cell_size_px // 2), + color, + -1, + ) + + # Draw waypoints path if provided + if waypoints is not None and len(waypoints) > 0: + try: + path_points = [] + for i, waypoint in enumerate(waypoints): + # Convert waypoint from odom frame to visualization frame + wp_x, wp_y = waypoint[0], waypoint[1] + wp_rel_x = wp_x - robot_x + wp_rel_y = wp_y - robot_y + + wp_img_x = int(center_x + wp_rel_x * scale) + wp_img_y = int(center_y - wp_rel_y * scale) # Flip y-axis + + if 0 <= wp_img_x < vis_size and 0 <= wp_img_y < vis_size: + path_points.append((wp_img_x, wp_img_y)) + + # Draw each waypoint as a small circle + cv2.circle(vis_img, (wp_img_x, wp_img_y), 3, (0, 128, 0), -1) # Dark green dots + + # Highlight current target waypoint + if current_waypoint_index is not None and i == current_waypoint_index: + cv2.circle(vis_img, (wp_img_x, wp_img_y), 6, (0, 0, 255), 2) # Red circle + + # Connect waypoints with lines to show the path + if len(path_points) > 1: + for i in range(len(path_points) - 1): + cv2.line( + vis_img, path_points[i], path_points[i + 1], (0, 200, 0), 1 + ) # Green line + except Exception as e: + logger.error(f"Error drawing waypoints: {e}") + + # Draw histogram + if histogram is not None: + num_bins = len(histogram) + # Find absolute maximum value (ignoring any negative debug values) + abs_histogram = np.abs(histogram) + max_hist_value = np.max(abs_histogram) if np.max(abs_histogram) > 0 else 1.0 + hist_scale = (vis_size / 2) * 0.8 # Scale histogram lines to 80% of half the viz size + + for i in range(num_bins): + # Angle relative to robot's forward direction + angle_relative_to_robot = (i / num_bins) * 2 * math.pi - math.pi + # Angle in the visualization frame (relative to image +X axis) + vis_angle = angle_relative_to_robot + robot_theta + + # Get the value and check if it's a special debug value (negative) + hist_val = histogram[i] + is_debug_value = hist_val < 0 + + # Use absolute value for line length + normalized_val = min(1.0, abs(hist_val) / max_hist_value) + line_length = normalized_val * hist_scale + + # Calculate endpoint using the visualization angle + end_x = int(center_x + line_length * math.cos(vis_angle)) + end_y = int(center_y - line_length * math.sin(vis_angle)) # Flipped Y + + # Color based on value and whether it's a debug value + if is_debug_value: + # Use green for debug values (minimum cost bin) + color = (0, 255, 0) # Green + line_width = 2 # Thicker line for emphasis + else: + # Regular coloring for normal values (blue to red gradient based on obstacle density) + blue = max(0, 255 - int(normalized_val * 255)) + red = min(255, int(normalized_val * 255)) + color = (blue, 0, red) # BGR format: obstacles are redder, clear areas are bluer + line_width = 1 + + cv2.line(vis_img, (center_x, center_y), (end_x, end_y), color, line_width) + + # Draw robot + robot_length_px = int(robot_length * scale) + robot_width_px = int(robot_width * scale) + robot_pts = np.array( + [ + [-robot_length_px / 2, -robot_width_px / 2], + [robot_length_px / 2, -robot_width_px / 2], + [robot_length_px / 2, robot_width_px / 2], + [-robot_length_px / 2, robot_width_px / 2], + ], + dtype=np.float32, + ) + rotation_matrix = np.array( + [ + [math.cos(robot_theta), -math.sin(robot_theta)], + [math.sin(robot_theta), math.cos(robot_theta)], + ] + ) + robot_pts = np.dot(robot_pts, rotation_matrix.T) + robot_pts[:, 0] += center_x + robot_pts[:, 1] = center_y - robot_pts[:, 1] # Flip y-axis + cv2.fillPoly( + vis_img, [robot_pts.reshape((-1, 1, 2)).astype(np.int32)], (0, 0, 255) + ) # Red robot + + # Draw robot direction line + front_x = int(center_x + (robot_length_px / 2) * math.cos(robot_theta)) + front_y = int(center_y - (robot_length_px / 2) * math.sin(robot_theta)) + cv2.line(vis_img, (center_x, center_y), (front_x, front_y), (255, 0, 0), 2) # Blue line + + # Draw selected direction + if selected_direction is not None: + # selected_direction is relative to robot frame + # Angle in the visualization frame (relative to image +X axis) + vis_angle_selected = selected_direction + robot_theta + + # Make slightly longer than max histogram line + sel_dir_line_length = (vis_size / 2) * 0.9 + + sel_end_x = int(center_x + sel_dir_line_length * math.cos(vis_angle_selected)) + sel_end_y = int(center_y - sel_dir_line_length * math.sin(vis_angle_selected)) # Flipped Y + + cv2.line( + vis_img, (center_x, center_y), (sel_end_x, sel_end_y), (0, 165, 255), 2 + ) # BGR for Orange + + # Draw goal + if goal_xy is not None: + goal_x, goal_y = goal_xy + goal_rel_x_map = goal_x - robot_x + goal_rel_y_map = goal_y - robot_y + goal_img_x = int(center_x + goal_rel_x_map * scale) + goal_img_y = int(center_y - goal_rel_y_map * scale) # Flip y-axis + if 0 <= goal_img_x < vis_size and 0 <= goal_img_y < vis_size: + cv2.circle(vis_img, (goal_img_x, goal_img_y), 5, (0, 255, 0), -1) # Green circle + cv2.circle(vis_img, (goal_img_x, goal_img_y), 8, (0, 0, 0), 1) # Black outline + + # Draw goal orientation + if goal_theta is not None and goal_xy is not None: + # For waypoint mode, only draw orientation at the final waypoint + if waypoints is not None and len(waypoints) > 0: + # Use the final waypoint position + final_waypoint = waypoints[-1] + goal_x, goal_y = final_waypoint[0], final_waypoint[1] + else: + # Use the current goal position + goal_x, goal_y = goal_xy + + goal_rel_x_map = goal_x - robot_x + goal_rel_y_map = goal_y - robot_y + goal_img_x = int(center_x + goal_rel_x_map * scale) + goal_img_y = int(center_y - goal_rel_y_map * scale) # Flip y-axis + + # Calculate goal orientation vector direction in visualization frame + # goal_theta is already in odom frame, need to adjust for visualization orientation + goal_dir_length = 30 # Length of direction indicator in pixels + goal_dir_end_x = int(goal_img_x + goal_dir_length * math.cos(goal_theta)) + goal_dir_end_y = int(goal_img_y - goal_dir_length * math.sin(goal_theta)) # Flip y-axis + + # Draw goal orientation arrow + if 0 <= goal_img_x < vis_size and 0 <= goal_img_y < vis_size: + cv2.arrowedLine( + vis_img, + (goal_img_x, goal_img_y), + (goal_dir_end_x, goal_dir_end_y), + (255, 0, 255), + 4, + ) # Magenta arrow + + # Add scale bar + scale_bar_length_px = int(1.0 * scale) + scale_bar_x = vis_size - scale_bar_length_px - 10 + scale_bar_y = vis_size - 20 + cv2.line( + vis_img, + (scale_bar_x, scale_bar_y), + (scale_bar_x + scale_bar_length_px, scale_bar_y), + (0, 0, 0), + 2, + ) + cv2.putText( + vis_img, "1m", (scale_bar_x, scale_bar_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1 + ) + + # Add status info + status_text = [] + if waypoints is not None: + if current_waypoint_index is not None: + status_text.append(f"WP: {current_waypoint_index}/{len(waypoints)}") + else: + status_text.append(f"WPs: {len(waypoints)}") + + y_pos = 20 + for text in status_text: + cv2.putText(vis_img, text, (10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) + y_pos += 20 + + return vis_img diff --git a/build/lib/dimos/robot/local_planner/simple.py b/build/lib/dimos/robot/local_planner/simple.py new file mode 100644 index 0000000000..8eaf20ba6c --- /dev/null +++ b/build/lib/dimos/robot/local_planner/simple.py @@ -0,0 +1,265 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import reactivex as rx +from plum import dispatch +from reactivex import operators as ops + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.robot.unitree_webrtc.type.odometry import Odometry + +# from dimos.robot.local_planner.local_planner import LocalPlanner +from dimos.types.costmap import Costmap +from dimos.types.path import Path +from dimos.types.pose import Pose +from dimos.types.vector import Vector, VectorLike, to_vector +from dimos.utils.logging_config import setup_logger +from dimos.utils.threadpool import get_scheduler + +logger = setup_logger("dimos.robot.unitree.global_planner") + + +def transform_to_robot_frame(global_vector: Vector, robot_position: Pose) -> Vector: + """Transform a global coordinate vector to robot-relative coordinates. + + Args: + global_vector: Vector in global coordinates + robot_position: Robot's position and orientation + + Returns: + Vector in robot coordinates where X is forward/backward, Y is left/right + """ + # Get the robot's yaw angle (rotation around Z-axis) + robot_yaw = robot_position.rot.z + + # Create rotation matrix to transform from global to robot frame + # We need to rotate the coordinate system by -robot_yaw to get robot-relative coordinates + cos_yaw = math.cos(-robot_yaw) + sin_yaw = math.sin(-robot_yaw) + + # Apply 2D rotation transformation + # This transforms a global direction vector into the robot's coordinate frame + # In robot frame: X=forward/backward, Y=left/right + # In global frame: X=east/west, Y=north/south + robot_x = global_vector.x * cos_yaw - global_vector.y * sin_yaw # Forward/backward + robot_y = global_vector.x * sin_yaw + global_vector.y * cos_yaw # Left/right + + return Vector(-robot_x, robot_y, 0) + + +class SimplePlanner(Module): + path: In[Path] = None + odom: In[PoseStamped] = None + movecmd: Out[Vector3] = None + + get_costmap: Callable[[], Costmap] + + latest_odom: PoseStamped = None + + goal: Optional[Vector] = None + speed: float = 0.3 + + def __init__( + self, + get_costmap: Callable[[], Costmap], + ): + Module.__init__(self) + self.get_costmap = get_costmap + + def get_move_stream(self, frequency: float = 40.0) -> rx.Observable: + return rx.interval(1.0 / frequency, scheduler=get_scheduler()).pipe( + # do we have a goal? + ops.filter(lambda _: self.goal is not None), + # For testing: make robot move left/right instead of rotating + ops.map(lambda _: self._test_translational_movement()), + self.frequency_spy("movement_test"), + ) + + @rpc + def start(self): + self.path.subscribe(self.set_goal) + + def setodom(odom: Odometry): + self.latest_odom = odom + + self.odom.subscribe(setodom) + self.get_move_stream(frequency=20.0).subscribe(self.movecmd.publish) + + @dispatch + def set_goal(self, goal: Path, stop_event=None, goal_theta=None) -> bool: + self.goal = goal.last().to_2d() + logger.info(f"Setting goal: {self.goal}") + return True + + @dispatch + def set_goal(self, goal: VectorLike, stop_event=None, goal_theta=None) -> bool: + self.goal = to_vector(goal).to_2d() + logger.info(f"Setting goal: {self.goal}") + return True + + def calc_move(self, direction: Vector) -> Vector: + """Calculate the movement vector based on the direction to the goal. + + Args: + direction: Direction vector towards the goal + + Returns: + Movement vector scaled by speed + """ + try: + # Normalize the direction vector and scale by speed + normalized_direction = direction.normalize() + move_vector = normalized_direction * self.speed + print("CALC MOVE", direction, normalized_direction, move_vector) + return move_vector + except Exception as e: + print("Error calculating move vector:", e) + + def spy(self, name: str): + def spyfun(x): + print(f"SPY {name}:", x) + return x + + return ops.map(spyfun) + + def frequency_spy(self, name: str, window_size: int = 10): + """Create a frequency spy that logs message rate over a sliding window. + + Args: + name: Name for the spy output + window_size: Number of messages to average frequency over + """ + timestamps = [] + + def freq_spy_fun(x): + current_time = time.time() + timestamps.append(current_time) + print(x) + # Keep only the last window_size timestamps + if len(timestamps) > window_size: + timestamps.pop(0) + + # Calculate frequency if we have enough samples + if len(timestamps) >= 2: + time_span = timestamps[-1] - timestamps[0] + if time_span > 0: + frequency = (len(timestamps) - 1) / time_span + print(f"FREQ SPY {name}: {frequency:.2f} Hz ({len(timestamps)} samples)") + else: + print(f"FREQ SPY {name}: calculating... ({len(timestamps)} samples)") + else: + print(f"FREQ SPY {name}: warming up... ({len(timestamps)} samples)") + + return x + + return ops.map(freq_spy_fun) + + def _test_translational_movement(self) -> Vector: + """Test translational movement by alternating left and right movement. + + Returns: + Vector with (x=0, y=left/right, z=0) for testing left-right movement + """ + # Use time to alternate between left and right movement every 3 seconds + current_time = time.time() + cycle_time = 6.0 # 6 second cycle (3 seconds each direction) + phase = (current_time % cycle_time) / cycle_time + + if phase < 0.5: + # First half: move LEFT (positive X according to our documentation) + movement = Vector3(0.2, 0, 0) # Move left at 0.2 m/s + direction = "LEFT (positive X)" + else: + # Second half: move RIGHT (negative X according to our documentation) + movement = Vector3(-0.2, 0, 0) # Move right at 0.2 m/s + direction = "RIGHT (negative X)" + + print("=== LEFT-RIGHT MOVEMENT TEST ===") + print(f"Phase: {phase:.2f}, Direction: {direction}") + print(f"Sending movement command: {movement}") + print(f"Expected: Robot should move {direction.split()[0]} relative to its body") + print("===================================") + return movement + + def _calculate_rotation_to_target(self, direction_to_goal: Vector) -> Vector: + """Calculate the rotation needed for the robot to face the target. + + Args: + direction_to_goal: Vector pointing from robot position to goal in global coordinates + + Returns: + Vector with (x=0, y=0, z=angular_velocity) for rotation only + """ + # Calculate the desired yaw angle to face the target + desired_yaw = math.atan2(direction_to_goal.y, direction_to_goal.x) + + # Get current robot yaw + current_yaw = self.latest_odom.orientation.z + + # Calculate the yaw error using a more robust method to avoid oscillation + yaw_error = math.atan2( + math.sin(desired_yaw - current_yaw), math.cos(desired_yaw - current_yaw) + ) + + print( + f"DEBUG: direction_to_goal={direction_to_goal}, desired_yaw={math.degrees(desired_yaw):.1f}°, current_yaw={math.degrees(current_yaw):.1f}°" + ) + print( + f"DEBUG: yaw_error={math.degrees(yaw_error):.1f}°, abs_error={abs(yaw_error):.3f}, tolerance=0.1" + ) + + # Calculate angular velocity (proportional control) + max_angular_speed = 0.15 # rad/s + raw_angular_velocity = yaw_error * 2.0 + angular_velocity = max(-max_angular_speed, min(max_angular_speed, raw_angular_velocity)) + + print( + f"DEBUG: raw_ang_vel={raw_angular_velocity:.3f}, clamped_ang_vel={angular_velocity:.3f}" + ) + + # Stop rotating if we're close enough to the target angle + if abs(yaw_error) < 0.1: # ~5.7 degrees tolerance + print("DEBUG: Within tolerance - stopping rotation") + angular_velocity = 0.0 + else: + print("DEBUG: Outside tolerance - continuing rotation") + + print( + f"Rotation control: current_yaw={math.degrees(current_yaw):.1f}°, desired_yaw={math.degrees(desired_yaw):.1f}°, error={math.degrees(yaw_error):.1f}°, ang_vel={angular_velocity:.3f}" + ) + + # Return movement command: no translation (x=0, y=0), only rotation (z=angular_velocity) + # Try flipping the sign in case the rotation convention is opposite + return Vector(0, 0, -angular_velocity) + + def _debug_direction(self, name: str, direction: Vector) -> Vector: + """Debug helper to log direction information""" + robot_pos = self.latest_odom + print( + f"DEBUG {name}: direction={direction}, robot_pos={robot_pos.position.to_2d()}, robot_yaw={math.degrees(robot_pos.rot.z):.1f}°, goal={self.goal}" + ) + return direction + + def _debug_robot_command(self, robot_cmd: Vector) -> Vector: + """Debug helper to log robot command information""" + print( + f"DEBUG robot_command: x={robot_cmd.x:.3f}, y={robot_cmd.y:.3f} (forward/backward, left/right)" + ) + return robot_cmd diff --git a/build/lib/dimos/robot/local_planner/vfh_local_planner.py b/build/lib/dimos/robot/local_planner/vfh_local_planner.py new file mode 100644 index 0000000000..f97701e5a5 --- /dev/null +++ b/build/lib/dimos/robot/local_planner/vfh_local_planner.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import Dict, Tuple, Optional, Callable, Any +import cv2 +import logging + +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import normalize_angle + +from dimos.robot.local_planner.local_planner import BaseLocalPlanner, visualize_local_planner_state +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector, VectorLike + +logger = setup_logger("dimos.robot.unitree.vfh_local_planner", level=logging.DEBUG) + + +class VFHPurePursuitPlanner(BaseLocalPlanner): + """ + A local planner that combines Vector Field Histogram (VFH) for obstacle avoidance + with Pure Pursuit for goal tracking. + """ + + def __init__( + self, + get_costmap: Callable[[], Optional[Costmap]], + get_robot_pose: Callable[[], Any], + move: Callable[[Vector], None], + safety_threshold: float = 0.8, + histogram_bins: int = 144, + max_linear_vel: float = 0.8, + max_angular_vel: float = 1.0, + lookahead_distance: float = 1.0, + goal_tolerance: float = 0.4, + angle_tolerance: float = 0.1, # ~5.7 degrees + robot_width: float = 0.5, + robot_length: float = 0.7, + visualization_size: int = 400, + control_frequency: float = 10.0, + safe_goal_distance: float = 1.0, + max_recovery_attempts: int = 3, + global_planner_plan: Optional[Callable[[VectorLike], Optional[Any]]] = None, + ): + """ + Initialize the VFH + Pure Pursuit planner. + + Args: + get_costmap: Function to get the latest local costmap + get_robot_pose: Function to get the latest robot pose (returning odom object) + move: Function to send velocity commands + safety_threshold: Distance to maintain from obstacles (meters) + histogram_bins: Number of directional bins in the polar histogram + max_linear_vel: Maximum linear velocity (m/s) + max_angular_vel: Maximum angular velocity (rad/s) + lookahead_distance: Lookahead distance for pure pursuit (meters) + goal_tolerance: Distance at which the goal is considered reached (meters) + angle_tolerance: Angle at which the goal orientation is considered reached (radians) + robot_width: Width of the robot for visualization (meters) + robot_length: Length of the robot for visualization (meters) + visualization_size: Size of the visualization image in pixels + control_frequency: Frequency at which the planner is called (Hz) + safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) + max_recovery_attempts: Maximum number of recovery attempts + global_planner_plan: Optional function to get the global plan + """ + # Initialize base class + super().__init__( + get_costmap=get_costmap, + get_robot_pose=get_robot_pose, + move=move, + safety_threshold=safety_threshold, + max_linear_vel=max_linear_vel, + max_angular_vel=max_angular_vel, + lookahead_distance=lookahead_distance, + goal_tolerance=goal_tolerance, + angle_tolerance=angle_tolerance, + robot_width=robot_width, + robot_length=robot_length, + visualization_size=visualization_size, + control_frequency=control_frequency, + safe_goal_distance=safe_goal_distance, + max_recovery_attempts=max_recovery_attempts, + global_planner_plan=global_planner_plan, + ) + + # VFH specific parameters + self.histogram_bins = histogram_bins + self.histogram = None + self.selected_direction = None + + # VFH tuning parameters + self.alpha = 0.25 # Histogram smoothing factor + self.obstacle_weight = 5.0 + self.goal_weight = 2.0 + self.prev_direction_weight = 1.0 + self.prev_selected_angle = 0.0 + self.prev_linear_vel = 0.0 + self.linear_vel_filter_factor = 0.4 + self.low_speed_nudge = 0.1 + + # Add after other initialization + self.angle_mapping = np.linspace(-np.pi, np.pi, self.histogram_bins, endpoint=False) + self.smoothing_kernel = np.array([self.alpha, (1 - 2 * self.alpha), self.alpha]) + + def _compute_velocity_commands(self) -> Dict[str, float]: + """ + VFH + Pure Pursuit specific implementation of velocity command computation. + + Returns: + Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys + """ + # Get necessary data for planning + costmap = self._get_costmap() + if costmap is None: + logger.warning("No costmap available for planning") + return {"x_vel": 0.0, "angular_vel": 0.0} + + robot_pos, robot_theta = self._get_robot_pose() + robot_x, robot_y = robot_pos + robot_pose = (robot_x, robot_y, robot_theta) + + # Calculate goal-related parameters + goal_x, goal_y = self.goal_xy + dx = goal_x - robot_x + dy = goal_y - robot_y + goal_distance = np.linalg.norm([dx, dy]) + goal_direction = np.arctan2(dy, dx) - robot_theta + goal_direction = normalize_angle(goal_direction) + + self.histogram = self.build_polar_histogram(costmap, robot_pose) + + # If we're ignoring obstacles near the goal, zero out the histogram + if self.ignore_obstacles: + self.histogram = np.zeros_like(self.histogram) + + self.selected_direction = self.select_direction( + self.goal_weight, + self.obstacle_weight, + self.prev_direction_weight, + self.histogram, + goal_direction, + ) + + # Calculate Pure Pursuit Velocities + linear_vel, angular_vel = self.compute_pure_pursuit(goal_distance, self.selected_direction) + + # Slow down when turning sharply + if abs(self.selected_direction) > 0.25: # ~15 degrees + # Scale from 1.0 (small turn) to 0.5 (sharp turn at 90 degrees or more) + turn_factor = max(0.25, 1.0 - (abs(self.selected_direction) / (np.pi / 2))) + linear_vel *= turn_factor + + # Apply Collision Avoidance Stop - skip if ignoring obstacles + if not self.ignore_obstacles and self.check_collision( + self.selected_direction, safety_threshold=0.5 + ): + # Re-select direction prioritizing obstacle avoidance if colliding + self.selected_direction = self.select_direction( + self.goal_weight * 0.2, + self.obstacle_weight, + self.prev_direction_weight * 0.2, + self.histogram, + goal_direction, + ) + linear_vel, angular_vel = self.compute_pure_pursuit( + goal_distance, self.selected_direction + ) + + if self.check_collision(0.0, safety_threshold=self.safety_threshold): + linear_vel = 0.0 + + self.prev_linear_vel = linear_vel + filtered_linear_vel = self.prev_linear_vel * self.linear_vel_filter_factor + linear_vel * ( + 1 - self.linear_vel_filter_factor + ) + + return {"x_vel": filtered_linear_vel, "angular_vel": angular_vel} + + def _smooth_histogram(self, histogram: np.ndarray) -> np.ndarray: + """ + Apply advanced smoothing to the polar histogram to better identify valleys + and reduce noise. + + Args: + histogram: Raw histogram to smooth + + Returns: + np.ndarray: Smoothed histogram + """ + # Apply a windowed average with variable width based on obstacle density + smoothed = np.zeros_like(histogram) + bins = len(histogram) + + # First pass: basic smoothing with a 5-point kernel + # This uses a wider window than the original 3-point smoother + for i in range(bins): + # Compute indices with wrap-around + indices = [(i + j) % bins for j in range(-2, 3)] + # Apply weighted average (more weight to the center) + weights = [0.1, 0.2, 0.4, 0.2, 0.1] # Sum = 1.0 + smoothed[i] = sum(histogram[idx] * weight for idx, weight in zip(indices, weights)) + + # Second pass: peak and valley enhancement + enhanced = np.zeros_like(smoothed) + for i in range(bins): + # Check neighboring values + prev_idx = (i - 1) % bins + next_idx = (i + 1) % bins + + # Enhance valleys (low values) + if smoothed[i] < smoothed[prev_idx] and smoothed[i] < smoothed[next_idx]: + # It's a local minimum - make it even lower + enhanced[i] = smoothed[i] * 0.8 + # Enhance peaks (high values) + elif smoothed[i] > smoothed[prev_idx] and smoothed[i] > smoothed[next_idx]: + # It's a local maximum - make it even higher + enhanced[i] = min(1.0, smoothed[i] * 1.2) + else: + enhanced[i] = smoothed[i] + + return enhanced + + def build_polar_histogram(self, costmap: Costmap, robot_pose: Tuple[float, float, float]): + """ + Build a polar histogram of obstacle densities around the robot. + + Args: + costmap: Costmap object with grid and metadata + robot_pose: Tuple (x, y, theta) of the robot pose in the odom frame + + Returns: + np.ndarray: Polar histogram of obstacle densities + """ + + # Get grid and find all obstacle cells + occupancy_grid = costmap.grid + y_indices, x_indices = np.where(occupancy_grid > 0) + if len(y_indices) == 0: # No obstacles + return np.zeros(self.histogram_bins) + + # Get robot position in grid coordinates + robot_x, robot_y, robot_theta = robot_pose + robot_point = costmap.world_to_grid((robot_x, robot_y)) + robot_cell_x, robot_cell_y = robot_point.x, robot_point.y + + # Vectorized distance and angle calculation + dx_cells = x_indices - robot_cell_x + dy_cells = y_indices - robot_cell_y + distances = np.sqrt(dx_cells**2 + dy_cells**2) * costmap.resolution + angles_grid = np.arctan2(dy_cells, dx_cells) + angles_robot = normalize_angle(angles_grid - robot_theta) + + # Convert to bin indices + bin_indices = ((angles_robot + np.pi) / (2 * np.pi) * self.histogram_bins).astype( + int + ) % self.histogram_bins + + # Get obstacle values + obstacle_values = occupancy_grid[y_indices, x_indices] / 100.0 + + # Build histogram + histogram = np.zeros(self.histogram_bins) + mask = distances > 0 + # Weight obstacles by inverse square of distance and cell value + np.add.at(histogram, bin_indices[mask], obstacle_values[mask] / (distances[mask] ** 2)) + + # Apply the enhanced smoothing + return self._smooth_histogram(histogram) + + def select_direction( + self, goal_weight, obstacle_weight, prev_direction_weight, histogram, goal_direction + ): + """ + Select best direction based on a simple weighted cost function. + + Args: + goal_weight: Weight for the goal direction component + obstacle_weight: Weight for the obstacle avoidance component + prev_direction_weight: Weight for previous direction consistency + histogram: Polar histogram of obstacle density + goal_direction: Desired direction to goal + + Returns: + float: Selected direction in radians + """ + # Normalize histogram if needed + if np.max(histogram) > 0: + histogram = histogram / np.max(histogram) + + # Calculate costs for each possible direction + angle_diffs = np.abs(normalize_angle(self.angle_mapping - goal_direction)) + prev_diffs = np.abs(normalize_angle(self.angle_mapping - self.prev_selected_angle)) + + # Combine costs with weights + obstacle_costs = obstacle_weight * histogram + goal_costs = goal_weight * angle_diffs + prev_costs = prev_direction_weight * prev_diffs + + total_costs = obstacle_costs + goal_costs + prev_costs + + # Select direction with lowest cost + min_cost_idx = np.argmin(total_costs) + selected_angle = self.angle_mapping[min_cost_idx] + + # Update history for next iteration + self.prev_selected_angle = selected_angle + + return selected_angle + + def compute_pure_pursuit( + self, goal_distance: float, goal_direction: float + ) -> Tuple[float, float]: + """Compute pure pursuit velocities.""" + if goal_distance < self.goal_tolerance: + return 0.0, 0.0 + + lookahead = min(self.lookahead_distance, goal_distance) + linear_vel = min(self.max_linear_vel, goal_distance) + angular_vel = 2.0 * np.sin(goal_direction) / lookahead + angular_vel = max(-self.max_angular_vel, min(angular_vel, self.max_angular_vel)) + + return linear_vel, angular_vel + + def check_collision(self, selected_direction: float, safety_threshold: float = 1.0) -> bool: + """Check if there's an obstacle in the selected direction within safety threshold.""" + # Skip collision check if ignoring obstacles + if self.ignore_obstacles: + return False + + # Get the latest costmap and robot pose + costmap = self._get_costmap() + if costmap is None: + return False # No costmap available + + robot_pos, robot_theta = self._get_robot_pose() + robot_x, robot_y = robot_pos + + # Direction in world frame + direction_world = robot_theta + selected_direction + + # Safety distance in cells + safety_cells = int(safety_threshold / costmap.resolution) + + # Get robot position in grid coordinates + robot_point = costmap.world_to_grid((robot_x, robot_y)) + robot_cell_x, robot_cell_y = robot_point.x, robot_point.y + + # Check for obstacles along the selected direction + for dist in range(1, safety_cells + 1): + # Calculate cell position + cell_x = robot_cell_x + int(dist * np.cos(direction_world)) + cell_y = robot_cell_y + int(dist * np.sin(direction_world)) + + # Check if cell is within grid bounds + if not (0 <= cell_x < costmap.width and 0 <= cell_y < costmap.height): + continue + + # Check if cell contains an obstacle (threshold at 50) + if costmap.grid[int(cell_y), int(cell_x)] > 50: + return True + + return False # No collision detected + + def update_visualization(self) -> np.ndarray: + """Generate visualization of the planning state.""" + try: + costmap = self._get_costmap() + if costmap is None: + raise ValueError("Costmap is None") + + robot_pos, robot_theta = self._get_robot_pose() + robot_x, robot_y = robot_pos + robot_pose = (robot_x, robot_y, robot_theta) + + goal_xy = self.goal_xy # This could be a lookahead point or final goal + + # Get the latest histogram and selected direction, if available + histogram = getattr(self, "histogram", None) + selected_direction = getattr(self, "selected_direction", None) + + # Get waypoint data if in waypoint mode + waypoints_to_draw = self.waypoints_in_absolute + current_wp_index_to_draw = ( + self.current_waypoint_index if self.waypoints_in_absolute is not None else None + ) + # Ensure index is valid before passing + if waypoints_to_draw is not None and current_wp_index_to_draw is not None: + if not (0 <= current_wp_index_to_draw < len(waypoints_to_draw)): + current_wp_index_to_draw = None # Invalidate index if out of bounds + + return visualize_local_planner_state( + occupancy_grid=costmap.grid, + grid_resolution=costmap.resolution, + grid_origin=(costmap.origin.x, costmap.origin.y), + robot_pose=robot_pose, + goal_xy=goal_xy, # Current target (lookahead or final) + goal_theta=self.goal_theta, # Pass goal orientation if available + visualization_size=self.visualization_size, + robot_width=self.robot_width, + robot_length=self.robot_length, + histogram=histogram, + selected_direction=selected_direction, + waypoints=waypoints_to_draw, # Pass the full path + current_waypoint_index=current_wp_index_to_draw, # Pass the target index + ) + except Exception as e: + logger.error(f"Error during visualization update: {e}") + # Return a blank image with error text + blank = ( + np.ones((self.visualization_size, self.visualization_size, 3), dtype=np.uint8) * 255 + ) + cv2.putText( + blank, + "Viz Error", + (self.visualization_size // 4, self.visualization_size // 2), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 0), + 2, + ) + return blank diff --git a/build/lib/dimos/robot/position_stream.py b/build/lib/dimos/robot/position_stream.py new file mode 100644 index 0000000000..05d80b8bcf --- /dev/null +++ b/build/lib/dimos/robot/position_stream.py @@ -0,0 +1,162 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Position stream provider for ROS-based robots. + +This module creates a reactive stream of position updates from ROS odometry or pose topics. +""" + +import logging +from typing import Tuple, Optional +import time +from reactivex import Subject, Observable +from reactivex import operators as ops +from rclpy.node import Node +from geometry_msgs.msg import PoseStamped +from nav_msgs.msg import Odometry + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.position_stream", level=logging.INFO) + + +class PositionStreamProvider: + """ + A provider for streaming position updates from ROS. + + This class creates an Observable stream of position updates by subscribing + to ROS odometry or pose topics. + """ + + def __init__( + self, + ros_node: Node, + odometry_topic: str = "/odom", + pose_topic: Optional[str] = None, + use_odometry: bool = True, + ): + """ + Initialize the position stream provider. + + Args: + ros_node: ROS node to use for subscriptions + odometry_topic: Name of the odometry topic (if use_odometry is True) + pose_topic: Name of the pose topic (if use_odometry is False) + use_odometry: Whether to use odometry (True) or pose (False) for position + """ + self.ros_node = ros_node + self.odometry_topic = odometry_topic + self.pose_topic = pose_topic + self.use_odometry = use_odometry + + self._subject = Subject() + + self.last_position = None + self.last_update_time = None + + self._create_subscription() + + logger.info( + f"PositionStreamProvider initialized with " + f"{'odometry topic' if use_odometry else 'pose topic'}: " + f"{odometry_topic if use_odometry else pose_topic}" + ) + + def _create_subscription(self): + """Create the appropriate ROS subscription based on configuration.""" + if self.use_odometry: + self.subscription = self.ros_node.create_subscription( + Odometry, self.odometry_topic, self._odometry_callback, 10 + ) + logger.info(f"Subscribed to odometry topic: {self.odometry_topic}") + else: + if not self.pose_topic: + raise ValueError("Pose topic must be specified when use_odometry is False") + + self.subscription = self.ros_node.create_subscription( + PoseStamped, self.pose_topic, self._pose_callback, 10 + ) + logger.info(f"Subscribed to pose topic: {self.pose_topic}") + + def _odometry_callback(self, msg: Odometry): + """ + Process odometry messages and extract position. + + Args: + msg: Odometry message from ROS + """ + x = msg.pose.pose.position.x + y = msg.pose.pose.position.y + + self._update_position(x, y) + + def _pose_callback(self, msg: PoseStamped): + """ + Process pose messages and extract position. + + Args: + msg: PoseStamped message from ROS + """ + x = msg.pose.position.x + y = msg.pose.position.y + + self._update_position(x, y) + + def _update_position(self, x: float, y: float): + """ + Update the current position and emit to subscribers. + + Args: + x: X coordinate + y: Y coordinate + """ + current_time = time.time() + position = (x, y) + + if self.last_update_time: + update_rate = 1.0 / (current_time - self.last_update_time) + logger.debug(f"Position update rate: {update_rate:.1f} Hz") + + self.last_position = position + self.last_update_time = current_time + + self._subject.on_next(position) + logger.debug(f"Position updated: ({x:.2f}, {y:.2f})") + + def get_position_stream(self) -> Observable: + """ + Get an Observable stream of position updates. + + Returns: + Observable that emits (x, y) tuples + """ + return self._subject.pipe( + ops.share() # Share the stream among multiple subscribers + ) + + def get_current_position(self) -> Optional[Tuple[float, float]]: + """ + Get the most recent position. + + Returns: + Tuple of (x, y) coordinates, or None if no position has been received + """ + return self.last_position + + def cleanup(self): + """Clean up resources.""" + if hasattr(self, "subscription") and self.subscription: + self.ros_node.destroy_subscription(self.subscription) + logger.info("Position subscription destroyed") diff --git a/build/lib/dimos/robot/recorder.py b/build/lib/dimos/robot/recorder.py new file mode 100644 index 0000000000..56b6cea888 --- /dev/null +++ b/build/lib/dimos/robot/recorder.py @@ -0,0 +1,159 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# UNDER DEVELOPMENT 🚧🚧🚧, NEEDS TESTING + +import threading +import time +from queue import Queue +from typing import Callable, Literal + +# from dimos.data.recording import Recorder + + +class RobotRecorder: + """A class for recording robot observation and actions. + + Recording at a specified frequency on the observation and action of a robot. It leverages a queue and a worker + thread to handle the recording asynchronously, ensuring that the main operations of the + robot are not blocked. + + Robot class must pass in the `get_state`, `get_observation`, `prepare_action` methods.` + get_state() gets the current state/pose of the robot. + get_observation() captures the observation/image of the robot. + prepare_action() calculates the action between the new and old states. + """ + + def __init__( + self, + get_state: Callable, + get_observation: Callable, + prepare_action: Callable, + frequency_hz: int = 5, + recorder_kwargs: dict = None, + on_static: Literal["record", "omit"] = "omit", + ) -> None: + """Initializes the RobotRecorder. + + This constructor sets up the recording mechanism on the given robot, including the recorder instance, + recording frequency, and the asynchronous processing queue and worker thread. It also + initializes attributes to track the last recorded pose and the current instruction. + + Args: + get_state: A function that returns the current state of the robot. + get_observation: A function that captures the observation/image of the robot. + prepare_action: A function that calculates the action between the new and old states. + frequency_hz: Frequency at which to record pose and image data (in Hz). + recorder_kwargs: Keyword arguments to pass to the Recorder constructor. + on_static: Whether to record on static poses or not. If "record", it will record when the robot is not moving. + """ + if recorder_kwargs is None: + recorder_kwargs = {} + self.recorder = Recorder(**recorder_kwargs) + self.task = None + + self.last_recorded_state = None + self.last_image = None + + self.recording = False + self.frequency_hz = frequency_hz + self.record_on_static = on_static == "record" + self.recording_queue = Queue() + + self.get_state = get_state + self.get_observation = get_observation + self.prepare_action = prepare_action + + self._worker_thread = threading.Thread(target=self._process_queue, daemon=True) + self._worker_thread.start() + + def __enter__(self): + """Enter the context manager, starting the recording.""" + self.start_recording(self.task) + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """Exit the context manager, stopping the recording.""" + self.stop_recording() + + def record(self, task: str) -> "RobotRecorder": + """Set the task and return the context manager.""" + self.task = task + return self + + def reset_recorder(self) -> None: + """Reset the recorder.""" + while self.recording: + time.sleep(0.1) + self.recorder.reset() + + def record_from_robot(self) -> None: + """Records the current pose and captures an image at the specified frequency.""" + while self.recording: + start_time = time.perf_counter() + self.record_current_state() + elapsed_time = time.perf_counter() - start_time + # Sleep for the remaining time to maintain the desired frequency + sleep_time = max(0, (1.0 / self.frequency_hz) - elapsed_time) + time.sleep(sleep_time) + + def start_recording(self, task: str = "") -> None: + """Starts the recording of pose and image.""" + if not self.recording: + self.task = task + self.recording = True + self.recording_thread = threading.Thread(target=self.record_from_robot) + self.recording_thread.start() + + def stop_recording(self) -> None: + """Stops the recording of pose and image.""" + if self.recording: + self.recording = False + self.recording_thread.join() + + def _process_queue(self) -> None: + """Processes the recording queue asynchronously.""" + while True: + image, instruction, action, state = self.recording_queue.get() + self.recorder.record( + observation={"image": image, "instruction": instruction}, action=action, state=state + ) + self.recording_queue.task_done() + + def record_current_state(self) -> None: + """Records the current pose and image if the pose has changed.""" + state = self.get_state() + image = self.get_observation() + + # This is the beginning of the episode + if self.last_recorded_state is None: + self.last_recorded_state = state + self.last_image = image + return + + if state != self.last_recorded_state or self.record_on_static: + action = self.prepare_action(self.last_recorded_state, state) + self.recording_queue.put( + ( + self.last_image, + self.task, + action, + self.last_recorded_state, + ), + ) + self.last_image = image + self.last_recorded_state = state + + def record_last_state(self) -> None: + """Records the final pose and image after the movement completes.""" + self.record_current_state() diff --git a/build/lib/dimos/robot/robot.py b/build/lib/dimos/robot/robot.py new file mode 100644 index 0000000000..58526b5f0c --- /dev/null +++ b/build/lib/dimos/robot/robot.py @@ -0,0 +1,435 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base module for all DIMOS robots. + +This module provides the foundation for all DIMOS robots, including both physical +and simulated implementations, with common functionality for movement, control, +and video streaming. +""" + +from abc import ABC, abstractmethod +import os +from typing import Optional, List, Union, Dict, Any + +from dimos.hardware.interface import HardwareInterface +from dimos.perception.spatial_perception import SpatialMemory +from dimos.manipulation.manipulation_interface import ManipulationInterface +from dimos.types.robot_capabilities import RobotCapability +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger +from dimos.robot.connection_interface import ConnectionInterface + +from dimos.skills.skills import SkillLibrary +from reactivex import Observable, operators as ops +from reactivex.disposable import CompositeDisposable +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.utils.threadpool import get_scheduler +from dimos.utils.reactive import backpressure +from dimos.stream.video_provider import VideoProvider + +logger = setup_logger("dimos.robot.robot") + + +class Robot(ABC): + """Base class for all DIMOS robots. + + This abstract base class defines the common interface and functionality for all + DIMOS robots, whether physical or simulated. It provides methods for movement, + rotation, video streaming, and hardware configuration management. + + Attributes: + agent_config: Configuration for the robot's agent. + hardware_interface: Interface to the robot's hardware components. + ros_control: ROS-based control system for the robot. + output_dir: Directory for storing output files. + disposables: Collection of disposable resources for cleanup. + pool_scheduler: Thread pool scheduler for managing concurrent operations. + """ + + def __init__( + self, + hardware_interface: HardwareInterface = None, + connection_interface: ConnectionInterface = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "output"), + pool_scheduler: ThreadPoolScheduler = None, + skill_library: SkillLibrary = None, + spatial_memory_collection: str = "spatial_memory", + new_memory: bool = False, + capabilities: List[RobotCapability] = None, + video_stream: Optional[Observable] = None, + enable_perception: bool = True, + ): + """Initialize a Robot instance. + + Args: + hardware_interface: Interface to the robot's hardware. Defaults to None. + connection_interface: Connection interface for robot control and communication. + output_dir: Directory for storing output files. Defaults to "./assets/output". + pool_scheduler: Thread pool scheduler. If None, one will be created. + skill_library: Skill library instance. If None, one will be created. + spatial_memory_collection: Name of the collection in the ChromaDB database. + new_memory: If True, creates a new spatial memory from scratch. Defaults to False. + capabilities: List of robot capabilities. Defaults to None. + video_stream: Optional video stream. Defaults to None. + enable_perception: If True, enables perception streams and spatial memory. Defaults to True. + """ + self.hardware_interface = hardware_interface + self.connection_interface = connection_interface + self.output_dir = output_dir + self.disposables = CompositeDisposable() + self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() + self.skill_library = skill_library if skill_library else SkillLibrary() + self.enable_perception = enable_perception + + # Initialize robot capabilities + self.capabilities = capabilities or [] + + # Create output directory if it doesn't exist + os.makedirs(self.output_dir, exist_ok=True) + logger.info(f"Robot outputs will be saved to: {self.output_dir}") + + # Initialize memory properties + self.memory_dir = os.path.join(self.output_dir, "memory") + os.makedirs(self.memory_dir, exist_ok=True) + + # Initialize spatial memory properties + self.spatial_memory_dir = os.path.join(self.memory_dir, "spatial_memory") + self.spatial_memory_collection = spatial_memory_collection + self.db_path = os.path.join(self.spatial_memory_dir, "chromadb_data") + self.visual_memory_path = os.path.join(self.spatial_memory_dir, "visual_memory.pkl") + + # Create spatial memory directory + os.makedirs(self.spatial_memory_dir, exist_ok=True) + os.makedirs(self.db_path, exist_ok=True) + + # Initialize spatial memory properties + self._video_stream = video_stream + + # Only create video stream if connection interface is available + if self.connection_interface is not None: + # Get video stream - always create this, regardless of enable_perception + self._video_stream = self.get_video_stream(fps=10) # Lower FPS for processing + + # Create SpatialMemory instance only if perception is enabled + if self.enable_perception: + self._spatial_memory = SpatialMemory( + collection_name=self.spatial_memory_collection, + db_path=self.db_path, + visual_memory_path=self.visual_memory_path, + new_memory=new_memory, + output_dir=self.spatial_memory_dir, + video_stream=self._video_stream, + get_pose=self.get_pose, + ) + logger.info("Spatial memory initialized") + else: + self._spatial_memory = None + logger.info("Spatial memory disabled (enable_perception=False)") + + # Initialize manipulation interface if the robot has manipulation capability + self._manipulation_interface = None + if RobotCapability.MANIPULATION in self.capabilities: + # Initialize manipulation memory properties if the robot has manipulation capability + self.manipulation_memory_dir = os.path.join(self.memory_dir, "manipulation_memory") + + # Create manipulation memory directory + os.makedirs(self.manipulation_memory_dir, exist_ok=True) + + self._manipulation_interface = ManipulationInterface( + output_dir=self.output_dir, # Use the main output directory + new_memory=new_memory, + ) + logger.info("Manipulation interface initialized") + + def get_video_stream(self, fps: int = 30) -> Observable: + """Get the video stream with rate limiting and frame processing. + + Args: + fps: Frames per second for the video stream. Defaults to 30. + + Returns: + Observable: An observable stream of video frames. + + Raises: + RuntimeError: If no connection interface is available for video streaming. + """ + if self.connection_interface is None: + raise RuntimeError("No connection interface available for video streaming") + + stream = self.connection_interface.get_video_stream(fps) + if stream is None: + raise RuntimeError("No video stream available from connection interface") + + return stream.pipe( + ops.observe_on(self.pool_scheduler), + ) + + def move(self, velocity: Vector, duration: float = 0.0) -> bool: + """Move the robot using velocity commands. + + Args: + velocity: Velocity vector [x, y, yaw] where: + x: Linear velocity in x direction (m/s) + y: Linear velocity in y direction (m/s) + yaw: Angular velocity (rad/s) + duration: Duration to apply command (seconds). If 0, apply once. + + Returns: + bool: True if movement succeeded. + + Raises: + RuntimeError: If no connection interface is available. + """ + if self.connection_interface is None: + raise RuntimeError("No connection interface available for movement") + + return self.connection_interface.move(velocity, duration) + + def spin(self, degrees: float, speed: float = 45.0) -> bool: + """Rotate the robot by a specified angle. + + Args: + degrees: Angle to rotate in degrees (positive for counter-clockwise, + negative for clockwise). + speed: Angular speed in degrees/second. Defaults to 45.0. + + Returns: + bool: True if rotation succeeded. + + Raises: + RuntimeError: If no connection interface is available. + """ + if self.connection_interface is None: + raise RuntimeError("No connection interface available for rotation") + + # Convert degrees to radians + import math + + angular_velocity = math.radians(speed) + duration = abs(degrees) / speed if speed > 0 else 0 + + # Set direction based on sign of degrees + if degrees < 0: + angular_velocity = -angular_velocity + + velocity = Vector(0.0, 0.0, angular_velocity) + return self.connection_interface.move(velocity, duration) + + @abstractmethod + def get_pose(self) -> dict: + """ + Get the current pose (position and rotation) of the robot. + + Returns: + Dictionary containing: + - position: Tuple[float, float, float] (x, y, z) + - rotation: Tuple[float, float, float] (roll, pitch, yaw) in radians + """ + pass + + def webrtc_req( + self, + api_id: int, + topic: str = None, + parameter: str = "", + priority: int = 0, + request_id: str = None, + data=None, + timeout: float = 1000.0, + ): + """Send a WebRTC request command to the robot. + + Args: + api_id: The API ID for the command. + topic: The API topic to publish to. Defaults to ROSControl.webrtc_api_topic. + parameter: Additional parameter data. Defaults to "". + priority: Priority of the request. Defaults to 0. + request_id: Unique identifier for the request. If None, one will be generated. + data: Additional data to include with the request. Defaults to None. + timeout: Timeout for the request in milliseconds. Defaults to 1000.0. + + Returns: + The result of the WebRTC request. + + Raises: + RuntimeError: If no connection interface with WebRTC capability is available. + """ + if self.connection_interface is None: + raise RuntimeError("No connection interface available for WebRTC commands") + + # WebRTC requests are only available on ROS control interfaces + if hasattr(self.connection_interface, "queue_webrtc_req"): + return self.connection_interface.queue_webrtc_req( + api_id=api_id, + topic=topic, + parameter=parameter, + priority=priority, + request_id=request_id, + data=data, + timeout=timeout, + ) + else: + raise RuntimeError("WebRTC requests not supported by this connection interface") + + def pose_command(self, roll: float, pitch: float, yaw: float) -> bool: + """Send a pose command to the robot. + + Args: + roll: Roll angle in radians. + pitch: Pitch angle in radians. + yaw: Yaw angle in radians. + + Returns: + bool: True if command was sent successfully. + + Raises: + RuntimeError: If no connection interface with pose command capability is available. + """ + # Pose commands are only available on ROS control interfaces + if hasattr(self.connection_interface, "pose_command"): + return self.connection_interface.pose_command(roll, pitch, yaw) + else: + raise RuntimeError("Pose commands not supported by this connection interface") + + def update_hardware_interface(self, new_hardware_interface: HardwareInterface): + """Update the hardware interface with a new configuration. + + Args: + new_hardware_interface: New hardware interface to use for the robot. + """ + self.hardware_interface = new_hardware_interface + + def get_hardware_configuration(self): + """Retrieve the current hardware configuration. + + Returns: + The current hardware configuration from the hardware interface. + + Raises: + AttributeError: If hardware_interface is None. + """ + return self.hardware_interface.get_configuration() + + def set_hardware_configuration(self, configuration): + """Set a new hardware configuration. + + Args: + configuration: The new hardware configuration to set. + + Raises: + AttributeError: If hardware_interface is None. + """ + self.hardware_interface.set_configuration(configuration) + + @property + def spatial_memory(self) -> Optional[SpatialMemory]: + """Get the robot's spatial memory. + + Returns: + SpatialMemory: The robot's spatial memory system, or None if perception is disabled. + """ + return self._spatial_memory + + @property + def manipulation_interface(self) -> Optional[ManipulationInterface]: + """Get the robot's manipulation interface. + + Returns: + ManipulationInterface: The robot's manipulation interface or None if not available. + """ + return self._manipulation_interface + + def has_capability(self, capability: RobotCapability) -> bool: + """Check if the robot has a specific capability. + + Args: + capability: The capability to check for + + Returns: + bool: True if the robot has the capability, False otherwise + """ + return capability in self.capabilities + + def get_spatial_memory(self) -> Optional[SpatialMemory]: + """Simple getter for the spatial memory instance. + (For backwards compatibility) + + Returns: + The spatial memory instance or None if not set. + """ + return self._spatial_memory if self._spatial_memory else None + + @property + def video_stream(self) -> Optional[Observable]: + """Get the robot's video stream. + + Returns: + Observable: The robot's video stream or None if not available. + """ + return self._video_stream + + def get_skills(self): + """Get the robot's skill library. + + Returns: + The robot's skill library for adding/managing skills. + """ + return self.skill_library + + def cleanup(self): + """Clean up resources used by the robot. + + This method should be called when the robot is no longer needed to + ensure proper release of resources such as ROS connections and + subscriptions. + """ + # Dispose of resources + if self.disposables: + self.disposables.dispose() + + # Clean up connection interface + if self.connection_interface: + self.connection_interface.disconnect() + + self.disposables.dispose() + + +class MockRobot(Robot): + def __init__(self): + super().__init__() + self.ros_control = None + self.hardware_interface = None + self.skill_library = SkillLibrary() + + def my_print(self): + print("Hello, world!") + + +class MockManipulationRobot(Robot): + def __init__(self, skill_library: Optional[SkillLibrary] = None): + video_provider = VideoProvider("webcam", video_source=0) # Default camera + video_stream = backpressure( + video_provider.capture_video_as_observable(realtime=True, fps=30) + ) + + super().__init__( + capabilities=[RobotCapability.MANIPULATION], + video_stream=video_stream, + skill_library=skill_library, + ) + self.camera_intrinsics = [489.33, 367.0, 320.0, 240.0] + self.ros_control = None + self.hardware_interface = None diff --git a/build/lib/dimos/robot/ros_command_queue.py b/build/lib/dimos/robot/ros_command_queue.py new file mode 100644 index 0000000000..fc48ce5cde --- /dev/null +++ b/build/lib/dimos/robot/ros_command_queue.py @@ -0,0 +1,471 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Queue-based command management system for robot commands. + +This module provides a unified approach to queueing and processing all robot commands, +including WebRTC requests and action client commands. +Commands are processed sequentially and only when the robot is in IDLE state. +""" + +import threading +import time +import uuid +from enum import Enum, auto +from queue import PriorityQueue, Empty +from typing import Callable, Optional, NamedTuple, Dict, Any +from dimos.utils.logging_config import setup_logger + +# Initialize logger for the ros command queue module +logger = setup_logger("dimos.robot.ros_command_queue") + + +class CommandType(Enum): + """Types of commands that can be queued""" + + WEBRTC = auto() # WebRTC API requests + ACTION = auto() # Any action client or function call + + +class WebRTCRequest(NamedTuple): + """Class to represent a WebRTC request in the queue""" + + id: str # Unique ID for tracking + api_id: int # API ID for the command + topic: str # Topic to publish to + parameter: str # Optional parameter string + priority: int # Priority level + timeout: float # How long to wait for this request to complete + + +class ROSCommand(NamedTuple): + """Class to represent a command in the queue""" + + id: str # Unique ID for tracking + cmd_type: CommandType # Type of command + execute_func: Callable # Function to execute the command + params: Dict[str, Any] # Parameters for the command (for debugging/logging) + priority: int # Priority level (lower is higher priority) + timeout: float # How long to wait for this command to complete + + +class ROSCommandQueue: + """ + Manages a queue of commands for the robot. + + Commands are executed sequentially, with only one command being processed at a time. + Commands are only executed when the robot is in the IDLE state. + """ + + def __init__( + self, + webrtc_func: Callable, + is_ready_func: Callable[[], bool] = None, + is_busy_func: Optional[Callable[[], bool]] = None, + debug: bool = True, + ): + """ + Initialize the ROSCommandQueue. + + Args: + webrtc_func: Function to send WebRTC requests + is_ready_func: Function to check if the robot is ready for a command + is_busy_func: Function to check if the robot is busy + debug: Whether to enable debug logging + """ + self._webrtc_func = webrtc_func + self._is_ready_func = is_ready_func or (lambda: True) + self._is_busy_func = is_busy_func + self._debug = debug + + # Queue of commands to process + self._queue = PriorityQueue() + self._current_command = None + self._last_command_time = 0 + + # Last known robot state + self._last_ready_state = None + self._last_busy_state = None + self._stuck_in_busy_since = None + + # Command execution status + self._should_stop = False + self._queue_thread = None + + # Stats + self._command_count = 0 + self._success_count = 0 + self._failure_count = 0 + self._command_history = [] + + self._max_queue_wait_time = ( + 30.0 # Maximum time to wait for robot to be ready before forcing + ) + + logger.info("ROSCommandQueue initialized") + + def start(self): + """Start the queue processing thread""" + if self._queue_thread is not None and self._queue_thread.is_alive(): + logger.warning("Queue processing thread already running") + return + + self._should_stop = False + self._queue_thread = threading.Thread(target=self._process_queue, daemon=True) + self._queue_thread.start() + logger.info("Queue processing thread started") + + def stop(self, timeout=2.0): + """ + Stop the queue processing thread + + Args: + timeout: Maximum time to wait for the thread to stop + """ + if self._queue_thread is None or not self._queue_thread.is_alive(): + logger.warning("Queue processing thread not running") + return + + self._should_stop = True + try: + self._queue_thread.join(timeout=timeout) + if self._queue_thread.is_alive(): + logger.warning(f"Queue processing thread did not stop within {timeout}s") + else: + logger.info("Queue processing thread stopped") + except Exception as e: + logger.error(f"Error stopping queue processing thread: {e}") + + def queue_webrtc_request( + self, + api_id: int, + topic: str = None, + parameter: str = "", + request_id: str = None, + data: Dict[str, Any] = None, + priority: int = 0, + timeout: float = 30.0, + ) -> str: + """ + Queue a WebRTC request + + Args: + api_id: API ID for the command + topic: Topic to publish to + parameter: Optional parameter string + request_id: Unique ID for the request (will be generated if not provided) + data: Data to include in the request + priority: Priority level (lower is higher priority) + timeout: Maximum time to wait for the command to complete + + Returns: + str: Unique ID for the request + """ + request_id = request_id or str(uuid.uuid4()) + + # Create a function that will execute this WebRTC request + def execute_webrtc(): + try: + logger.info(f"Executing WebRTC request: {api_id} (ID: {request_id})") + if self._debug: + logger.debug(f"[WebRTC Queue] SENDING request: API ID {api_id}") + + result = self._webrtc_func( + api_id=api_id, + topic=topic, + parameter=parameter, + request_id=request_id, + data=data, + ) + if not result: + logger.warning(f"WebRTC request failed: {api_id} (ID: {request_id})") + if self._debug: + logger.debug(f"[WebRTC Queue] Request API ID {api_id} FAILED to send") + return False + + if self._debug: + logger.debug(f"[WebRTC Queue] Request API ID {api_id} sent SUCCESSFULLY") + + # Allow time for the robot to process the command + start_time = time.time() + stabilization_delay = 0.5 # Half-second delay for stabilization + time.sleep(stabilization_delay) + + # Wait for the robot to complete the command (timeout check) + while self._is_busy_func() and (time.time() - start_time) < timeout: + if ( + self._debug and (time.time() - start_time) % 5 < 0.1 + ): # Print every ~5 seconds + logger.debug( + f"[WebRTC Queue] Still waiting on API ID {api_id} - elapsed: {time.time() - start_time:.1f}s" + ) + time.sleep(0.1) + + # Check if we timed out + if self._is_busy_func() and (time.time() - start_time) >= timeout: + logger.warning(f"WebRTC request timed out: {api_id} (ID: {request_id})") + return False + + wait_time = time.time() - start_time + if self._debug: + logger.debug( + f"[WebRTC Queue] Request API ID {api_id} completed after {wait_time:.1f}s" + ) + + logger.info(f"WebRTC request completed: {api_id} (ID: {request_id})") + return True + except Exception as e: + logger.error(f"Error executing WebRTC request: {e}") + if self._debug: + logger.debug(f"[WebRTC Queue] ERROR processing request: {e}") + return False + + # Create the command and queue it + command = ROSCommand( + id=request_id, + cmd_type=CommandType.WEBRTC, + execute_func=execute_webrtc, + params={"api_id": api_id, "topic": topic, "request_id": request_id}, + priority=priority, + timeout=timeout, + ) + + # Queue the command + self._queue.put((priority, self._command_count, command)) + self._command_count += 1 + if self._debug: + logger.debug( + f"[WebRTC Queue] Added request ID {request_id} for API ID {api_id} - Queue size now: {self.queue_size}" + ) + logger.info(f"Queued WebRTC request: {api_id} (ID: {request_id}, Priority: {priority})") + + return request_id + + def queue_action_client_request( + self, + action_name: str, + execute_func: Callable, + priority: int = 0, + timeout: float = 30.0, + **kwargs, + ) -> str: + """ + Queue any action client request or function + + Args: + action_name: Name of the action for logging/tracking + execute_func: Function to execute the command + priority: Priority level (lower is higher priority) + timeout: Maximum time to wait for the command to complete + **kwargs: Additional parameters to pass to the execute function + + Returns: + str: Unique ID for the request + """ + request_id = str(uuid.uuid4()) + + # Create the command + command = ROSCommand( + id=request_id, + cmd_type=CommandType.ACTION, + execute_func=execute_func, + params={"action_name": action_name, **kwargs}, + priority=priority, + timeout=timeout, + ) + + # Queue the command + self._queue.put((priority, self._command_count, command)) + self._command_count += 1 + + action_params = ", ".join([f"{k}={v}" for k, v in kwargs.items()]) + logger.info( + f"Queued action request: {action_name} (ID: {request_id}, Priority: {priority}, Params: {action_params})" + ) + + return request_id + + def _process_queue(self): + """Process commands in the queue""" + logger.info("Starting queue processing") + logger.info("[WebRTC Queue] Processing thread started") + + while not self._should_stop: + # Print queue status + self._print_queue_status() + + # Check if we're ready to process a command + if not self._queue.empty() and self._current_command is None: + current_time = time.time() + is_ready = self._is_ready_func() + is_busy = self._is_busy_func() if self._is_busy_func else False + + if self._debug: + logger.debug( + f"[WebRTC Queue] Status: {self.queue_size} requests waiting | Robot ready: {is_ready} | Robot busy: {is_busy}" + ) + + # Track robot state changes + if is_ready != self._last_ready_state: + logger.debug( + f"Robot ready state changed: {self._last_ready_state} -> {is_ready}" + ) + self._last_ready_state = is_ready + + if is_busy != self._last_busy_state: + logger.debug(f"Robot busy state changed: {self._last_busy_state} -> {is_busy}") + self._last_busy_state = is_busy + + # If the robot has transitioned to busy, record the time + if is_busy: + self._stuck_in_busy_since = current_time + else: + self._stuck_in_busy_since = None + + # Check if we've been waiting too long for the robot to be ready + force_processing = False + if ( + not is_ready + and is_busy + and self._stuck_in_busy_since is not None + and current_time - self._stuck_in_busy_since > self._max_queue_wait_time + ): + logger.warning( + f"Robot has been busy for {current_time - self._stuck_in_busy_since:.1f}s, " + f"forcing queue to continue" + ) + force_processing = True + + # Process the next command if ready or forcing + if is_ready or force_processing: + if self._debug and is_ready: + logger.debug("[WebRTC Queue] Robot is READY for next command") + + try: + # Get the next command + _, _, command = self._queue.get(block=False) + self._current_command = command + self._last_command_time = current_time + + # Log the command + cmd_info = f"ID: {command.id}, Type: {command.cmd_type.name}" + if command.cmd_type == CommandType.WEBRTC: + api_id = command.params.get("api_id") + cmd_info += f", API: {api_id}" + if self._debug: + logger.debug(f"[WebRTC Queue] DEQUEUED request: API ID {api_id}") + elif command.cmd_type == CommandType.ACTION: + action_name = command.params.get("action_name") + cmd_info += f", Action: {action_name}" + if self._debug: + logger.debug(f"[WebRTC Queue] DEQUEUED action: {action_name}") + + forcing_str = " (FORCED)" if force_processing else "" + logger.info(f"Processing command{forcing_str}: {cmd_info}") + + # Execute the command + try: + # Where command execution occurs + success = command.execute_func() + + if success: + self._success_count += 1 + logger.info(f"Command succeeded: {cmd_info}") + if self._debug: + logger.debug( + f"[WebRTC Queue] Command {command.id} marked as COMPLETED" + ) + else: + self._failure_count += 1 + logger.warning(f"Command failed: {cmd_info}") + if self._debug: + logger.debug(f"[WebRTC Queue] Command {command.id} FAILED") + + # Record command history + self._command_history.append( + { + "id": command.id, + "type": command.cmd_type.name, + "params": command.params, + "success": success, + "time": time.time() - self._last_command_time, + } + ) + + except Exception as e: + self._failure_count += 1 + logger.error(f"Error executing command: {e}") + if self._debug: + logger.debug(f"[WebRTC Queue] ERROR executing command: {e}") + + # Mark the command as complete + self._current_command = None + if self._debug: + logger.debug( + "[WebRTC Queue] Adding 0.5s stabilization delay before next command" + ) + time.sleep(0.5) + + except Empty: + pass + + # Sleep to avoid busy-waiting + time.sleep(0.1) + + logger.info("Queue processing stopped") + + def _print_queue_status(self): + """Print the current queue status""" + current_time = time.time() + + # Only print once per second to avoid spamming the log + if current_time - self._last_command_time < 1.0 and self._current_command is None: + return + + is_ready = self._is_ready_func() + is_busy = self._is_busy_func() if self._is_busy_func else False + queue_size = self.queue_size + + # Get information about the current command + current_command_info = "None" + if self._current_command is not None: + current_command_info = f"{self._current_command.cmd_type.name}" + if self._current_command.cmd_type == CommandType.WEBRTC: + api_id = self._current_command.params.get("api_id") + current_command_info += f" (API: {api_id})" + elif self._current_command.cmd_type == CommandType.ACTION: + action_name = self._current_command.params.get("action_name") + current_command_info += f" (Action: {action_name})" + + # Print the status + status = ( + f"Queue: {queue_size} items | " + f"Robot: {'READY' if is_ready else 'BUSY'} | " + f"Current: {current_command_info} | " + f"Stats: {self._success_count} OK, {self._failure_count} FAIL" + ) + + logger.debug(status) + self._last_command_time = current_time + + @property + def queue_size(self) -> int: + """Get the number of commands in the queue""" + return self._queue.qsize() + + @property + def current_command(self) -> Optional[ROSCommand]: + """Get the current command being processed""" + return self._current_command diff --git a/build/lib/dimos/robot/ros_control.py b/build/lib/dimos/robot/ros_control.py new file mode 100644 index 0000000000..6aa51fc3a8 --- /dev/null +++ b/build/lib/dimos/robot/ros_control.py @@ -0,0 +1,867 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import rclpy +from rclpy.node import Node +from rclpy.executors import MultiThreadedExecutor +from rclpy.action import ActionClient +from geometry_msgs.msg import Twist +from nav2_msgs.action import Spin + +from sensor_msgs.msg import Image, CompressedImage +from cv_bridge import CvBridge +from enum import Enum, auto +import threading +import time +from typing import Optional, Dict, Any, Type +from abc import ABC, abstractmethod +from rclpy.qos import ( + QoSProfile, + QoSReliabilityPolicy, + QoSHistoryPolicy, + QoSDurabilityPolicy, +) +from dimos.stream.ros_video_provider import ROSVideoProvider +import math +from builtin_interfaces.msg import Duration +from geometry_msgs.msg import Point, Vector3 +from dimos.robot.ros_command_queue import ROSCommandQueue +from dimos.utils.logging_config import setup_logger + +from nav_msgs.msg import OccupancyGrid + +import tf2_ros +from dimos.robot.ros_transform import ROSTransformAbility +from dimos.robot.ros_observable_topic import ROSObservableTopicAbility +from dimos.robot.connection_interface import ConnectionInterface +from dimos.types.vector import Vector + +from nav_msgs.msg import Odometry + +logger = setup_logger("dimos.robot.ros_control") + +__all__ = ["ROSControl", "RobotMode"] + + +class RobotMode(Enum): + """Enum for robot modes""" + + UNKNOWN = auto() + INITIALIZING = auto() + IDLE = auto() + MOVING = auto() + ERROR = auto() + + +class ROSControl(ROSTransformAbility, ROSObservableTopicAbility, ConnectionInterface, ABC): + """Abstract base class for ROS-controlled robots""" + + def __init__( + self, + node_name: str, + camera_topics: Dict[str, str] = None, + max_linear_velocity: float = 1.0, + mock_connection: bool = False, + max_angular_velocity: float = 2.0, + state_topic: str = None, + imu_topic: str = None, + state_msg_type: Type = None, + imu_msg_type: Type = None, + webrtc_topic: str = None, + webrtc_api_topic: str = None, + webrtc_msg_type: Type = None, + move_vel_topic: str = None, + pose_topic: str = None, + odom_topic: str = "/odom", + global_costmap_topic: str = "map", + costmap_topic: str = "/local_costmap/costmap", + debug: bool = False, + ): + """ + Initialize base ROS control interface + Args: + node_name: Name for the ROS node + camera_topics: Dictionary of camera topics + max_linear_velocity: Maximum linear velocity (m/s) + max_angular_velocity: Maximum angular velocity (rad/s) + state_topic: Topic name for robot state (optional) + imu_topic: Topic name for IMU data (optional) + state_msg_type: The ROS message type for state data + imu_msg_type: The ROS message type for IMU data + webrtc_topic: Topic for WebRTC commands + webrtc_api_topic: Topic for WebRTC API commands + webrtc_msg_type: The ROS message type for webrtc data + move_vel_topic: Topic for direct movement commands + pose_topic: Topic for pose commands + odom_topic: Topic for odometry data + costmap_topic: Topic for local costmap data + """ + # Initialize rclpy and ROS node if not already running + if not rclpy.ok(): + rclpy.init() + + self._state_topic = state_topic + self._imu_topic = imu_topic + self._odom_topic = odom_topic + self._costmap_topic = costmap_topic + self._state_msg_type = state_msg_type + self._imu_msg_type = imu_msg_type + self._webrtc_msg_type = webrtc_msg_type + self._webrtc_topic = webrtc_topic + self._webrtc_api_topic = webrtc_api_topic + self._node = Node(node_name) + self._global_costmap_topic = global_costmap_topic + self._debug = debug + + # Prepare a multi-threaded executor + self._executor = MultiThreadedExecutor() + + # Movement constraints + self.MAX_LINEAR_VELOCITY = max_linear_velocity + self.MAX_ANGULAR_VELOCITY = max_angular_velocity + + self._subscriptions = [] + + # Track State variables + self._robot_state = None # Full state message + self._imu_state = None # Full IMU message + self._odom_data = None # Odometry data + self._costmap_data = None # Costmap data + self._mode = RobotMode.INITIALIZING + + # Create sensor data QoS profile + sensor_qos = QoSProfile( + reliability=QoSReliabilityPolicy.BEST_EFFORT, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=1, + ) + + command_qos = QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=10, # Higher depth for commands to ensure delivery + ) + + if self._global_costmap_topic: + self._global_costmap_data = None + self._global_costmap_sub = self._node.create_subscription( + OccupancyGrid, + self._global_costmap_topic, + self._global_costmap_callback, + sensor_qos, + ) + self._subscriptions.append(self._global_costmap_sub) + else: + logger.warning("No costmap topic provided - costmap data tracking will be unavailable") + + # Initialize data handling + self._video_provider = None + self._bridge = None + if camera_topics: + self._bridge = CvBridge() + self._video_provider = ROSVideoProvider(dev_name=f"{node_name}_video") + + # Create subscribers for each topic with sensor QoS + for camera_config in camera_topics.values(): + topic = camera_config["topic"] + msg_type = camera_config["type"] + + logger.info( + f"Subscribing to {topic} with BEST_EFFORT QoS using message type {msg_type.__name__}" + ) + _camera_subscription = self._node.create_subscription( + msg_type, topic, self._image_callback, sensor_qos + ) + self._subscriptions.append(_camera_subscription) + + # Subscribe to state topic if provided + if self._state_topic and self._state_msg_type: + logger.info(f"Subscribing to {state_topic} with BEST_EFFORT QoS") + self._state_sub = self._node.create_subscription( + self._state_msg_type, + self._state_topic, + self._state_callback, + qos_profile=sensor_qos, + ) + self._subscriptions.append(self._state_sub) + else: + logger.warning( + "No state topic andor message type provided - robot state tracking will be unavailable" + ) + + if self._imu_topic and self._imu_msg_type: + self._imu_sub = self._node.create_subscription( + self._imu_msg_type, self._imu_topic, self._imu_callback, sensor_qos + ) + self._subscriptions.append(self._imu_sub) + else: + logger.warning( + "No IMU topic and/or message type provided - IMU data tracking will be unavailable" + ) + + if self._odom_topic: + self._odom_sub = self._node.create_subscription( + Odometry, self._odom_topic, self._odom_callback, sensor_qos + ) + self._subscriptions.append(self._odom_sub) + else: + logger.warning( + "No odometry topic provided - odometry data tracking will be unavailable" + ) + + if self._costmap_topic: + self._costmap_sub = self._node.create_subscription( + OccupancyGrid, self._costmap_topic, self._costmap_callback, sensor_qos + ) + self._subscriptions.append(self._costmap_sub) + else: + logger.warning("No costmap topic provided - costmap data tracking will be unavailable") + + # Nav2 Action Clients + self._spin_client = ActionClient(self._node, Spin, "spin") + + # Wait for action servers + if not mock_connection: + self._spin_client.wait_for_server() + + # Publishers + self._move_vel_pub = self._node.create_publisher(Twist, move_vel_topic, command_qos) + self._pose_pub = self._node.create_publisher(Vector3, pose_topic, command_qos) + + if webrtc_msg_type: + self._webrtc_pub = self._node.create_publisher( + webrtc_msg_type, webrtc_topic, qos_profile=command_qos + ) + + # Initialize command queue + self._command_queue = ROSCommandQueue( + webrtc_func=self.webrtc_req, + is_ready_func=lambda: self._mode == RobotMode.IDLE, + is_busy_func=lambda: self._mode == RobotMode.MOVING, + ) + # Start the queue processing thread + self._command_queue.start() + else: + logger.warning("No WebRTC message type provided - WebRTC commands will be unavailable") + + # Initialize TF Buffer and Listener for transform abilities + self._tf_buffer = tf2_ros.Buffer() + self._tf_listener = tf2_ros.TransformListener(self._tf_buffer, self._node) + logger.info(f"TF Buffer and Listener initialized for {node_name}") + + # Start ROS spin in a background thread via the executor + self._spin_thread = threading.Thread(target=self._ros_spin, daemon=True) + self._spin_thread.start() + + logger.info(f"{node_name} initialized with multi-threaded executor") + print(f"{node_name} initialized with multi-threaded executor") + + def get_global_costmap(self) -> Optional[OccupancyGrid]: + """ + Get current global_costmap data + + Returns: + Optional[OccupancyGrid]: Current global_costmap data or None if not available + """ + if not self._global_costmap_topic: + logger.warning( + "No global_costmap topic provided - global_costmap data tracking will be unavailable" + ) + return None + + if self._global_costmap_data: + return self._global_costmap_data + else: + return None + + def _global_costmap_callback(self, msg): + """Callback for costmap data""" + self._global_costmap_data = msg + + def _imu_callback(self, msg): + """Callback for IMU data""" + self._imu_state = msg + # Log IMU state (very verbose) + # logger.debug(f"IMU state updated: {self._imu_state}") + + def _odom_callback(self, msg): + """Callback for odometry data""" + self._odom_data = msg + + def _costmap_callback(self, msg): + """Callback for costmap data""" + self._costmap_data = msg + + def _state_callback(self, msg): + """Callback for state messages to track mode and progress""" + + # Call the abstract method to update RobotMode enum based on the received state + self._robot_state = msg + self._update_mode(msg) + # Log state changes (very verbose) + # logger.debug(f"Robot state updated: {self._robot_state}") + + @property + def robot_state(self) -> Optional[Any]: + """Get the full robot state message""" + return self._robot_state + + def _ros_spin(self): + """Background thread for spinning the multi-threaded executor.""" + self._executor.add_node(self._node) + try: + self._executor.spin() + finally: + self._executor.shutdown() + + def _clamp_velocity(self, velocity: float, max_velocity: float) -> float: + """Clamp velocity within safe limits""" + return max(min(velocity, max_velocity), -max_velocity) + + @abstractmethod + def _update_mode(self, *args, **kwargs): + """Update robot mode based on state - to be implemented by child classes""" + pass + + def get_state(self) -> Optional[Any]: + """ + Get current robot state + + Base implementation provides common state fields. Child classes should + extend this method to include their specific state information. + + Returns: + ROS msg containing the robot state information + """ + if not self._state_topic: + logger.warning("No state topic provided - robot state tracking will be unavailable") + return None + + return self._robot_state + + def get_imu_state(self) -> Optional[Any]: + """ + Get current IMU state + + Base implementation provides common state fields. Child classes should + extend this method to include their specific state information. + + Returns: + ROS msg containing the IMU state information + """ + if not self._imu_topic: + logger.warning("No IMU topic provided - IMU data tracking will be unavailable") + return None + return self._imu_state + + def get_odometry(self) -> Optional[Odometry]: + """ + Get current odometry data + + Returns: + Optional[Odometry]: Current odometry data or None if not available + """ + if not self._odom_topic: + logger.warning( + "No odometry topic provided - odometry data tracking will be unavailable" + ) + return None + return self._odom_data + + def get_costmap(self) -> Optional[OccupancyGrid]: + """ + Get current costmap data + + Returns: + Optional[OccupancyGrid]: Current costmap data or None if not available + """ + if not self._costmap_topic: + logger.warning("No costmap topic provided - costmap data tracking will be unavailable") + return None + return self._costmap_data + + def _image_callback(self, msg): + """Convert ROS image to numpy array and push to data stream""" + if self._video_provider and self._bridge: + try: + if isinstance(msg, CompressedImage): + frame = self._bridge.compressed_imgmsg_to_cv2(msg) + elif isinstance(msg, Image): + frame = self._bridge.imgmsg_to_cv2(msg, "bgr8") + else: + logger.error(f"Unsupported image message type: {type(msg)}") + return + self._video_provider.push_data(frame) + except Exception as e: + logger.error(f"Error converting image: {e}") + print(f"Full conversion error: {str(e)}") + + @property + def video_provider(self) -> Optional[ROSVideoProvider]: + """Data provider property for streaming data""" + return self._video_provider + + def get_video_stream(self, fps: int = 30) -> Optional[Observable]: + """Get the video stream from the robot's camera. + + Args: + fps: Frames per second for the video stream + + Returns: + Observable: An observable stream of video frames or None if not available + """ + if not self.video_provider: + return None + + return self.video_provider.get_stream(fps=fps) + + def _send_action_client_goal(self, client, goal_msg, description=None, time_allowance=20.0): + """ + Generic function to send any action client goal and wait for completion. + + Args: + client: The action client to use + goal_msg: The goal message to send + description: Optional description for logging + time_allowance: Maximum time to wait for completion + + Returns: + bool: True if action succeeded, False otherwise + """ + if description: + logger.info(description) + + print(f"[ROSControl] Sending action client goal: {description}") + print(f"[ROSControl] Goal message: {goal_msg}") + + # Reset action result tracking + self._action_success = None + + # Send the goal + send_goal_future = client.send_goal_async(goal_msg, feedback_callback=lambda feedback: None) + send_goal_future.add_done_callback(self._goal_response_callback) + + # Wait for completion + start_time = time.time() + while self._action_success is None and time.time() - start_time < time_allowance: + time.sleep(0.1) + + elapsed = time.time() - start_time + print( + f"[ROSControl] Action completed in {elapsed:.2f}s with result: {self._action_success}" + ) + + # Check result + if self._action_success is None: + logger.error(f"Action timed out after {time_allowance}s") + return False + elif self._action_success: + logger.info("Action succeeded") + return True + else: + logger.error("Action failed") + return False + + def move(self, velocity: Vector, duration: float = 0.0) -> bool: + """Send velocity commands to the robot. + + Args: + velocity: Velocity vector [x, y, yaw] where: + x: Linear velocity in x direction (m/s) + y: Linear velocity in y direction (m/s) + yaw: Angular velocity around z axis (rad/s) + duration: Duration to apply command (seconds). If 0, apply once. + + Returns: + bool: True if command was sent successfully + """ + x, y, yaw = velocity.x, velocity.y, velocity.z + + # Clamp velocities to safe limits + x = self._clamp_velocity(x, self.MAX_LINEAR_VELOCITY) + y = self._clamp_velocity(y, self.MAX_LINEAR_VELOCITY) + yaw = self._clamp_velocity(yaw, self.MAX_ANGULAR_VELOCITY) + + # Create and send command + cmd = Twist() + cmd.linear.x = float(x) + cmd.linear.y = float(y) + cmd.angular.z = float(yaw) + + try: + if duration > 0: + start_time = time.time() + while time.time() - start_time < duration: + self._move_vel_pub.publish(cmd) + time.sleep(0.1) # 10Hz update rate + # Stop after duration + self.stop() + else: + self._move_vel_pub.publish(cmd) + return True + + except Exception as e: + self._logger.error(f"Failed to send movement command: {e}") + return False + + def reverse(self, distance: float, speed: float = 0.5, time_allowance: float = 120) -> bool: + """ + Move the robot backward by a specified distance + + Args: + distance: Distance to move backward in meters (must be positive) + speed: Speed to move at in m/s (default 0.5) + time_allowance: Maximum time to wait for the request to complete + + Returns: + bool: True if movement succeeded + """ + try: + if distance <= 0: + logger.error("Distance must be positive") + return False + + speed = min(abs(speed), self.MAX_LINEAR_VELOCITY) + + # Define function to execute the reverse + def execute_reverse(): + # Create BackUp goal + goal = BackUp.Goal() + goal.target = Point() + goal.target.x = -distance # Negative for backward motion + goal.target.y = 0.0 + goal.target.z = 0.0 + goal.speed = speed # BackUp expects positive speed + goal.time_allowance = Duration(sec=time_allowance) + + print( + f"[ROSControl] execute_reverse: Creating BackUp goal with distance={distance}m, speed={speed}m/s" + ) + print( + f"[ROSControl] execute_reverse: Goal details: x={goal.target.x}, y={goal.target.y}, z={goal.target.z}, speed={goal.speed}" + ) + + logger.info(f"Moving backward: distance={distance}m, speed={speed}m/s") + + result = self._send_action_client_goal( + self._backup_client, + goal, + f"Moving backward {distance}m at {speed}m/s", + time_allowance, + ) + + print(f"[ROSControl] execute_reverse: BackUp action result: {result}") + return result + + # Queue the action + cmd_id = self._command_queue.queue_action_client_request( + action_name="reverse", + execute_func=execute_reverse, + priority=0, + timeout=time_allowance, + distance=distance, + speed=speed, + ) + logger.info( + f"Queued reverse command: {cmd_id} - Distance: {distance}m, Speed: {speed}m/s" + ) + return True + + except Exception as e: + logger.error(f"Backward movement failed: {e}") + import traceback + + logger.error(traceback.format_exc()) + return False + + def spin(self, degrees: float, speed: float = 45.0, time_allowance: float = 120) -> bool: + """ + Rotate the robot by a specified angle + + Args: + degrees: Angle to rotate in degrees (positive for counter-clockwise, negative for clockwise) + speed: Angular speed in degrees/second (default 45.0) + time_allowance: Maximum time to wait for the request to complete + + Returns: + bool: True if movement succeeded + """ + try: + # Convert degrees to radians + angle = math.radians(degrees) + angular_speed = math.radians(abs(speed)) + + # Clamp angular speed + angular_speed = min(angular_speed, self.MAX_ANGULAR_VELOCITY) + time_allowance = max( + int(abs(angle) / angular_speed * 2), 20 + ) # At least 20 seconds or double the expected time + + # Define function to execute the spin + def execute_spin(): + # Create Spin goal + goal = Spin.Goal() + goal.target_yaw = angle # Nav2 Spin action expects radians + goal.time_allowance = Duration(sec=time_allowance) + + logger.info(f"Spinning: angle={degrees}deg ({angle:.2f}rad)") + + return self._send_action_client_goal( + self._spin_client, + goal, + f"Spinning {degrees} degrees at {speed} deg/s", + time_allowance, + ) + + # Queue the action + cmd_id = self._command_queue.queue_action_client_request( + action_name="spin", + execute_func=execute_spin, + priority=0, + timeout=time_allowance, + degrees=degrees, + speed=speed, + ) + logger.info(f"Queued spin command: {cmd_id} - Degrees: {degrees}, Speed: {speed}deg/s") + return True + + except Exception as e: + logger.error(f"Spin movement failed: {e}") + import traceback + + logger.error(traceback.format_exc()) + return False + + def stop(self) -> bool: + """Stop all robot movement""" + try: + # self.navigator.cancelTask() + self._current_velocity = {"x": 0.0, "y": 0.0, "z": 0.0} + self._is_moving = False + return True + except Exception as e: + logger.error(f"Failed to stop movement: {e}") + return False + + def cleanup(self): + """Cleanup the executor, ROS node, and stop robot.""" + self.stop() + + # Stop the WebRTC queue manager + if self._command_queue: + logger.info("Stopping WebRTC queue manager...") + self._command_queue.stop() + + # Shut down the executor to stop spin loop cleanly + self._executor.shutdown() + + # Destroy node and shutdown rclpy + self._node.destroy_node() + rclpy.shutdown() + + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + self.cleanup() + + def webrtc_req( + self, + api_id: int, + topic: str = None, + parameter: str = "", + priority: int = 0, + request_id: str = None, + data=None, + ) -> bool: + """ + Send a WebRTC request command to the robot + + Args: + api_id: The API ID for the command + topic: The API topic to publish to (defaults to self._webrtc_api_topic) + parameter: Optional parameter string + priority: Priority level (0 or 1) + request_id: Optional request ID for tracking (not used in ROS implementation) + data: Optional data dictionary (not used in ROS implementation) + params: Optional params dictionary (not used in ROS implementation) + + Returns: + bool: True if command was sent successfully + """ + try: + # Create and send command + cmd = self._webrtc_msg_type() + cmd.api_id = api_id + cmd.topic = topic if topic is not None else self._webrtc_api_topic + cmd.parameter = parameter + cmd.priority = priority + + self._webrtc_pub.publish(cmd) + logger.info(f"Sent WebRTC request: api_id={api_id}, topic={cmd.topic}") + return True + + except Exception as e: + logger.error(f"Failed to send WebRTC request: {e}") + return False + + def get_robot_mode(self) -> RobotMode: + """ + Get the current robot mode + + Returns: + RobotMode: The current robot mode enum value + """ + return self._mode + + def print_robot_mode(self): + """Print the current robot mode to the console""" + mode = self.get_robot_mode() + print(f"Current RobotMode: {mode.name}") + print(f"Mode enum: {mode}") + + def queue_webrtc_req( + self, + api_id: int, + topic: str = None, + parameter: str = "", + priority: int = 0, + timeout: float = 90.0, + request_id: str = None, + data=None, + ) -> str: + """ + Queue a WebRTC request to be sent when the robot is IDLE + + Args: + api_id: The API ID for the command + topic: The topic to publish to (defaults to self._webrtc_api_topic) + parameter: Optional parameter string + priority: Priority level (0 or 1) + timeout: Maximum time to wait for the request to complete + request_id: Optional request ID (if None, one will be generated) + data: Optional data dictionary (not used in ROS implementation) + + Returns: + str: Request ID that can be used to track the request + """ + return self._command_queue.queue_webrtc_request( + api_id=api_id, + topic=topic if topic is not None else self._webrtc_api_topic, + parameter=parameter, + priority=priority, + timeout=timeout, + request_id=request_id, + data=data, + ) + + def move_vel_control(self, x: float, y: float, yaw: float) -> bool: + """ + Send a single velocity command without duration handling. + + Args: + x: Forward/backward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + + Returns: + bool: True if command was sent successfully + """ + # Clamp velocities to safe limits + x = self._clamp_velocity(x, self.MAX_LINEAR_VELOCITY) + y = self._clamp_velocity(y, self.MAX_LINEAR_VELOCITY) + yaw = self._clamp_velocity(yaw, self.MAX_ANGULAR_VELOCITY) + + # Create and send command + cmd = Twist() + cmd.linear.x = float(x) + cmd.linear.y = float(y) + cmd.angular.z = float(yaw) + + try: + self._move_vel_pub.publish(cmd) + return True + except Exception as e: + logger.error(f"Failed to send velocity command: {e}") + return False + + def pose_command(self, roll: float, pitch: float, yaw: float) -> bool: + """ + Send a pose command to the robot to adjust its body orientation + + Args: + roll: Roll angle in radians + pitch: Pitch angle in radians + yaw: Yaw angle in radians + + Returns: + bool: True if command was sent successfully + """ + # Create the pose command message + cmd = Vector3() + cmd.x = float(roll) # Roll + cmd.y = float(pitch) # Pitch + cmd.z = float(yaw) # Yaw + + try: + self._pose_pub.publish(cmd) + logger.debug(f"Sent pose command: roll={roll}, pitch={pitch}, yaw={yaw}") + return True + except Exception as e: + logger.error(f"Failed to send pose command: {e}") + return False + + def get_position_stream(self): + """ + Get a stream of position updates from ROS. + + Returns: + Observable that emits (x, y) tuples representing the robot's position + """ + from dimos.robot.position_stream import PositionStreamProvider + + # Create a position stream provider + position_provider = PositionStreamProvider( + ros_node=self._node, + odometry_topic="/odom", # Default odometry topic + use_odometry=True, + ) + + return position_provider.get_position_stream() + + def _goal_response_callback(self, future): + """Handle the goal response.""" + goal_handle = future.result() + if not goal_handle.accepted: + logger.warn("Goal was rejected!") + print("[ROSControl] Goal was REJECTED by the action server") + self._action_success = False + return + + logger.info("Goal accepted") + print("[ROSControl] Goal was ACCEPTED by the action server") + result_future = goal_handle.get_result_async() + result_future.add_done_callback(self._goal_result_callback) + + def _goal_result_callback(self, future): + """Handle the goal result.""" + try: + result = future.result().result + logger.info("Goal completed") + print(f"[ROSControl] Goal COMPLETED with result: {result}") + self._action_success = True + except Exception as e: + logger.error(f"Goal failed with error: {e}") + print(f"[ROSControl] Goal FAILED with error: {e}") + self._action_success = False diff --git a/build/lib/dimos/robot/ros_observable_topic.py b/build/lib/dimos/robot/ros_observable_topic.py new file mode 100644 index 0000000000..697ddff398 --- /dev/null +++ b/build/lib/dimos/robot/ros_observable_topic.py @@ -0,0 +1,240 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import functools +import enum +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.scheduler import ThreadPoolScheduler +from rxpy_backpressure import BackPressure + +from nav_msgs import msg +from dimos.utils.logging_config import setup_logger +from dimos.utils.threadpool import get_scheduler +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector + +from typing import Union, Callable, Any + +from rclpy.qos import ( + QoSProfile, + QoSReliabilityPolicy, + QoSHistoryPolicy, + QoSDurabilityPolicy, +) + +__all__ = ["ROSObservableTopicAbility", "QOS"] + +ConversionType = Costmap +TopicType = Union[ConversionType, msg.OccupancyGrid, msg.Odometry] + + +class QOS(enum.Enum): + SENSOR = "sensor" + COMMAND = "command" + + def to_profile(self) -> QoSProfile: + if self == QOS.SENSOR: + return QoSProfile( + reliability=QoSReliabilityPolicy.BEST_EFFORT, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=1, + ) + if self == QOS.COMMAND: + return QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=10, # Higher depth for commands to ensure delivery + ) + + raise ValueError(f"Unknown QoS enum value: {self}") + + +logger = setup_logger("dimos.robot.ros_control.observable_topic") + + +class ROSObservableTopicAbility: + # Ensures that we can return multiple observables which have multiple subscribers + # consuming the same topic at different (blocking) rates while: + # + # - immediately returning latest value received to new subscribers + # - allowing slow subscribers to consume the topic without blocking fast ones + # - dealing with backpressure from slow subscribers (auto dropping unprocessed messages) + # + # (for more details see corresponding test file) + # + # ROS thread ─► ReplaySubject─► observe_on(pool) ─► backpressure.latest ─► sub1 (fast) + # ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) + # └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) + # + def _maybe_conversion(self, msg_type: TopicType, callback) -> Callable[[TopicType], Any]: + if msg_type == Costmap: + return lambda msg: callback(Costmap.from_msg(msg)) + # just for test, not sure if this Vector auto-instantiation is used irl + if msg_type == Vector: + return lambda msg: callback(Vector.from_msg(msg)) + return callback + + def _sub_msg_type(self, msg_type): + if msg_type == Costmap: + return msg.OccupancyGrid + + if msg_type == Vector: + return msg.Odometry + + return msg_type + + @functools.lru_cache(maxsize=None) + def topic( + self, + topic_name: str, + msg_type: TopicType, + qos=QOS.SENSOR, + scheduler: ThreadPoolScheduler | None = None, + drop_unprocessed: bool = True, + ) -> rx.Observable: + if scheduler is None: + scheduler = get_scheduler() + + # Convert QOS to QoSProfile + qos_profile = qos.to_profile() + + # upstream ROS callback + def _on_subscribe(obs, _): + ros_sub = self._node.create_subscription( + self._sub_msg_type(msg_type), + topic_name, + self._maybe_conversion(msg_type, obs.on_next), + qos_profile, + ) + return Disposable(lambda: self._node.destroy_subscription(ros_sub)) + + upstream = rx.create(_on_subscribe) + + # hot, latest-cached core + core = upstream.pipe( + ops.replay(buffer_size=1), + ops.ref_count(), # still synchronous! + ) + + # per-subscriber factory + def per_sub(): + # hop off the ROS thread into the pool + base = core.pipe(ops.observe_on(scheduler)) + + # optional back-pressure handling + if not drop_unprocessed: + return base + + def _subscribe(observer, sch=None): + return base.subscribe(BackPressure.LATEST(observer), scheduler=sch) + + return rx.create(_subscribe) + + # each `.subscribe()` call gets its own async backpressure chain + return rx.defer(lambda *_: per_sub()) + + # If you are not interested in processing streams, just want to fetch the latest stream + # value use this function. It runs a subscription in the background. + # caches latest value for you, always ready to return. + # + # odom = robot.topic_latest("/odom", msg.Odometry) + # the initial call to odom() will block until the first message is received + # + # any time you'd like you can call: + # + # print(f"Latest odom: {odom()}") + # odom.dispose() # clean up the subscription + # + # see test_ros_observable_topic.py test_topic_latest for more details + def topic_latest( + self, topic_name: str, msg_type: TopicType, timeout: float | None = 100.0, qos=QOS.SENSOR + ): + """ + Blocks the current thread until the first message is received, then + returns `reader()` (sync) and keeps one ROS subscription alive + in the background. + + latest_scan = robot.ros_control.topic_latest_blocking("scan", LaserScan) + do_something(latest_scan()) # instant + latest_scan.dispose() # clean up + """ + # one shared observable with a 1-element replay buffer + core = self.topic(topic_name, msg_type, qos=qos).pipe(ops.replay(buffer_size=1)) + conn = core.connect() # starts the ROS subscription immediately + + try: + first_val = core.pipe( + ops.first(), *([ops.timeout(timeout)] if timeout is not None else []) + ).run() + except Exception: + conn.dispose() + msg = f"{topic_name} message not received after {timeout} seconds. Is robot connected?" + logger.error(msg) + raise Exception(msg) + + cache = {"val": first_val} + sub = core.subscribe(lambda v: cache.__setitem__("val", v)) + + def reader(): + return cache["val"] + + reader.dispose = lambda: (sub.dispose(), conn.dispose()) + return reader + + # If you are not interested in processing streams, just want to fetch the latest stream + # value use this function. It runs a subscription in the background. + # caches latest value for you, always ready to return + # + # odom = await robot.topic_latest_async("/odom", msg.Odometry) + # + # async nature of this function allows you to do other stuff while you wait + # for a first message to arrive + # + # any time you'd like you can call: + # + # print(f"Latest odom: {odom()}") + # odom.dispose() # clean up the subscription + # + # see test_ros_observable_topic.py test_topic_latest for more details + async def topic_latest_async( + self, topic_name: str, msg_type: TopicType, qos=QOS.SENSOR, timeout: float = 30.0 + ): + loop = asyncio.get_running_loop() + first = loop.create_future() + cache = {"val": None} + + core = self.topic(topic_name, msg_type, qos=qos) # single ROS callback + + def _on_next(v): + cache["val"] = v + if not first.done(): + loop.call_soon_threadsafe(first.set_result, v) + + subscription = core.subscribe(_on_next) + + try: + await asyncio.wait_for(first, timeout) + except Exception: + subscription.dispose() + raise + + def reader(): + return cache["val"] + + reader.dispose = subscription.dispose + return reader diff --git a/build/lib/dimos/robot/ros_transform.py b/build/lib/dimos/robot/ros_transform.py new file mode 100644 index 0000000000..b0c46fd275 --- /dev/null +++ b/build/lib/dimos/robot/ros_transform.py @@ -0,0 +1,243 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import rclpy +from typing import Optional +from geometry_msgs.msg import TransformStamped +from tf2_ros import Buffer +import tf2_ros +from tf2_geometry_msgs import PointStamped +from dimos.utils.logging_config import setup_logger +from dimos.types.vector import Vector +from dimos.types.path import Path +from scipy.spatial.transform import Rotation as R + +logger = setup_logger("dimos.robot.ros_transform") + +__all__ = ["ROSTransformAbility"] + + +def to_euler_rot(msg: TransformStamped) -> [Vector, Vector]: + q = msg.transform.rotation + rotation = R.from_quat([q.x, q.y, q.z, q.w]) + return Vector(rotation.as_euler("xyz", degrees=False)) + + +def to_euler_pos(msg: TransformStamped) -> [Vector, Vector]: + return Vector(msg.transform.translation).to_2d() + + +def to_euler(msg: TransformStamped) -> [Vector, Vector]: + return [to_euler_pos(msg), to_euler_rot(msg)] + + +class ROSTransformAbility: + """Mixin class for handling ROS transforms between coordinate frames""" + + @property + def tf_buffer(self) -> Buffer: + if not hasattr(self, "_tf_buffer"): + self._tf_buffer = tf2_ros.Buffer() + self._tf_listener = tf2_ros.TransformListener(self._tf_buffer, self._node) + logger.info("Transform listener initialized") + + return self._tf_buffer + + def transform_euler_pos( + self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 + ): + return to_euler_pos(self.transform(source_frame, target_frame, timeout)) + + def transform_euler_rot( + self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 + ): + return to_euler_rot(self.transform(source_frame, target_frame, timeout)) + + def transform_euler(self, source_frame: str, target_frame: str = "map", timeout: float = 1.0): + res = self.transform(source_frame, target_frame, timeout) + return to_euler(res) + + def transform( + self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 + ) -> Optional[TransformStamped]: + try: + transform = self.tf_buffer.lookup_transform( + target_frame, + source_frame, + rclpy.time.Time(), + rclpy.duration.Duration(seconds=timeout), + ) + return transform + except ( + tf2_ros.LookupException, + tf2_ros.ConnectivityException, + tf2_ros.ExtrapolationException, + ) as e: + logger.error(f"Transform lookup failed: {e}") + return None + + def transform_point( + self, point: Vector, source_frame: str, target_frame: str = "map", timeout: float = 1.0 + ): + """Transform a point from source_frame to target_frame. + + Args: + point: The point to transform (x, y, z) + source_frame: The source frame of the point + target_frame: The target frame to transform to + timeout: Time to wait for the transform to become available (seconds) + + Returns: + The transformed point as a Vector, or None if the transform failed + """ + try: + # Wait for transform to become available + self.tf_buffer.can_transform( + target_frame, + source_frame, + rclpy.time.Time(), + rclpy.duration.Duration(seconds=timeout), + ) + + # Create a PointStamped message + ps = PointStamped() + ps.header.frame_id = source_frame + ps.header.stamp = rclpy.time.Time().to_msg() # Latest available transform + ps.point.x = point[0] + ps.point.y = point[1] + ps.point.z = point[2] if len(point) > 2 else 0.0 + + # Transform point + transformed_ps = self.tf_buffer.transform( + ps, target_frame, rclpy.duration.Duration(seconds=timeout) + ) + + # Return as Vector type + if len(point) > 2: + return Vector( + transformed_ps.point.x, transformed_ps.point.y, transformed_ps.point.z + ) + else: + return Vector(transformed_ps.point.x, transformed_ps.point.y) + except ( + tf2_ros.LookupException, + tf2_ros.ConnectivityException, + tf2_ros.ExtrapolationException, + ) as e: + logger.error(f"Transform from {source_frame} to {target_frame} failed: {e}") + return None + + def transform_path( + self, path: Path, source_frame: str, target_frame: str = "map", timeout: float = 1.0 + ): + """Transform a path from source_frame to target_frame. + + Args: + path: The path to transform + source_frame: The source frame of the path + target_frame: The target frame to transform to + timeout: Time to wait for the transform to become available (seconds) + + Returns: + The transformed path as a Path, or None if the transform failed + """ + transformed_path = Path() + for point in path: + transformed_point = self.transform_point(point, source_frame, target_frame, timeout) + if transformed_point is not None: + transformed_path.append(transformed_point) + return transformed_path + + def transform_rot( + self, rotation: Vector, source_frame: str, target_frame: str = "map", timeout: float = 1.0 + ): + """Transform a rotation from source_frame to target_frame. + + Args: + rotation: The rotation to transform as Euler angles (x, y, z) in radians + source_frame: The source frame of the rotation + target_frame: The target frame to transform to + timeout: Time to wait for the transform to become available (seconds) + + Returns: + The transformed rotation as a Vector of Euler angles (x, y, z), or None if the transform failed + """ + try: + # Wait for transform to become available + self.tf_buffer.can_transform( + target_frame, + source_frame, + rclpy.time.Time(), + rclpy.duration.Duration(seconds=timeout), + ) + + # Create a rotation matrix from the input Euler angles + input_rotation = R.from_euler("xyz", rotation, degrees=False) + + # Get the transform from source to target frame + transform = self.transform(source_frame, target_frame, timeout) + if transform is None: + return None + + # Extract the rotation from the transform + q = transform.transform.rotation + transform_rotation = R.from_quat([q.x, q.y, q.z, q.w]) + + # Compose the rotations + # The resulting rotation is the composition of the transform rotation and input rotation + result_rotation = transform_rotation * input_rotation + + # Convert back to Euler angles + euler_angles = result_rotation.as_euler("xyz", degrees=False) + + # Return as Vector type + return Vector(euler_angles) + + except ( + tf2_ros.LookupException, + tf2_ros.ConnectivityException, + tf2_ros.ExtrapolationException, + ) as e: + logger.error(f"Transform rotation from {source_frame} to {target_frame} failed: {e}") + return None + + def transform_pose( + self, + position: Vector, + rotation: Vector, + source_frame: str, + target_frame: str = "map", + timeout: float = 1.0, + ): + """Transform a pose from source_frame to target_frame. + + Args: + position: The position to transform + rotation: The rotation to transform + source_frame: The source frame of the pose + target_frame: The target frame to transform to + timeout: Time to wait for the transform to become available (seconds) + + Returns: + Tuple of (transformed_position, transformed_rotation) as Vectors, + or (None, None) if either transform failed + """ + # Transform position + transformed_position = self.transform_point(position, source_frame, target_frame, timeout) + + # Transform rotation + transformed_rotation = self.transform_rot(rotation, source_frame, target_frame, timeout) + + # Return results (both might be None if transforms failed) + return transformed_position, transformed_rotation diff --git a/build/lib/dimos/robot/test_ros_observable_topic.py b/build/lib/dimos/robot/test_ros_observable_topic.py new file mode 100644 index 0000000000..71a1484de3 --- /dev/null +++ b/build/lib/dimos/robot/test_ros_observable_topic.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +import pytest +from dimos.utils.logging_config import setup_logger +from dimos.types.vector import Vector +import asyncio + + +class MockROSNode: + def __init__(self): + self.logger = setup_logger("ROS") + + self.sub_id_cnt = 0 + self.subs = {} + + def _get_sub_id(self): + sub_id = self.sub_id_cnt + self.sub_id_cnt += 1 + return sub_id + + def create_subscription(self, msg_type, topic_name, callback, qos): + # Mock implementation of ROS subscription + + sub_id = self._get_sub_id() + stop_event = threading.Event() + self.subs[sub_id] = stop_event + self.logger.info(f"Subscribed {topic_name} subid {sub_id}") + + # Create message simulation thread + def simulate_messages(): + message_count = 0 + while not stop_event.is_set(): + message_count += 1 + time.sleep(0.1) # 20Hz default publication rate + if topic_name == "/vector": + callback([message_count, message_count]) + else: + callback(message_count) + # cleanup + self.subs.pop(sub_id) + + thread = threading.Thread(target=simulate_messages, daemon=True) + thread.start() + return sub_id + + def destroy_subscription(self, subscription): + if subscription in self.subs: + self.subs[subscription].set() + self.logger.info(f"Destroyed subscription: {subscription}") + else: + self.logger.info(f"Unknown subscription: {subscription}") + + +# we are doing this in order to avoid importing ROS dependencies if ros tests aren't runnin +@pytest.fixture +def robot(): + from dimos.robot.ros_observable_topic import ROSObservableTopicAbility + + class MockRobot(ROSObservableTopicAbility): + def __init__(self): + self.logger = setup_logger("ROBOT") + # Initialize the mock ROS node + self._node = MockROSNode() + + return MockRobot() + + +# This test verifies a bunch of basics: +# +# 1. that the system creates a single ROS sub for multiple reactivex subs +# 2. that the system creates a single ROS sub for multiple observers +# 3. that the system unsubscribes from ROS when observers are disposed +# 4. that the system replays the last message to new observers, +# before the new ROS sub starts producing +@pytest.mark.ros +def test_parallel_and_cleanup(robot): + from nav_msgs import msg + + received_messages = [] + + obs1 = robot.topic("/odom", msg.Odometry) + + print(f"Created subscription: {obs1}") + + subscription1 = obs1.subscribe(lambda x: received_messages.append(x + 2)) + + subscription2 = obs1.subscribe(lambda x: received_messages.append(x + 3)) + + obs2 = robot.topic("/odom", msg.Odometry) + subscription3 = obs2.subscribe(lambda x: received_messages.append(x + 5)) + + time.sleep(0.25) + + # We have 2 messages and 3 subscribers + assert len(received_messages) == 6, "Should have received exactly 6 messages" + + # [1, 1, 1, 2, 2, 2] + + # [2, 3, 5, 2, 3, 5] + # = + for i in [3, 4, 6, 4, 5, 7]: + assert i in received_messages, f"Expected {i} in received messages, got {received_messages}" + + # ensure that ROS end has only a single subscription + assert len(robot._node.subs) == 1, ( + f"Expected 1 subscription, got {len(robot._node.subs)}: {robot._node.subs}" + ) + + subscription1.dispose() + subscription2.dispose() + subscription3.dispose() + + # Make sure that ros end was unsubscribed, thread terminated + time.sleep(0.1) + assert not robot._node.subs, f"Expected empty subs dict, got: {robot._node.subs}" + + # Ensure we replay the last message + second_received = [] + second_sub = obs1.subscribe(lambda x: second_received.append(x)) + + time.sleep(0.075) + # we immediately receive the stored topic message + assert len(second_received) == 1 + + # now that sub is hot, we wait for a second one + time.sleep(0.2) + + # we expect 2, 1 since first message was preserved from a previous ros topic sub + # second one is the first message of the second ros topic sub + assert second_received == [2, 1, 2] + + print(f"Second subscription immediately received {len(second_received)} message(s)") + + second_sub.dispose() + + time.sleep(0.1) + assert not robot._node.subs, f"Expected empty subs dict, got: {robot._node.subs}" + + print("Test completed successfully") + + +# here we test parallel subs and slow observers hogging our topic +# we expect slow observers to skip messages by default +# +# ROS thread ─► ReplaySubject─► observe_on(pool) ─► backpressure.latest ─► sub1 (fast) +# ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) +# └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) +@pytest.mark.ros +def test_parallel_and_hog(robot): + from nav_msgs import msg + + obs1 = robot.topic("/odom", msg.Odometry) + obs2 = robot.topic("/odom", msg.Odometry) + + subscriber1_messages = [] + subscriber2_messages = [] + subscriber3_messages = [] + + subscription1 = obs1.subscribe(lambda x: subscriber1_messages.append(x)) + subscription2 = obs1.subscribe(lambda x: time.sleep(0.15) or subscriber2_messages.append(x)) + subscription3 = obs2.subscribe(lambda x: time.sleep(0.25) or subscriber3_messages.append(x)) + + assert len(robot._node.subs) == 1 + + time.sleep(2) + + subscription1.dispose() + subscription2.dispose() + subscription3.dispose() + + print("Subscriber 1 messages:", len(subscriber1_messages), subscriber1_messages) + print("Subscriber 2 messages:", len(subscriber2_messages), subscriber2_messages) + print("Subscriber 3 messages:", len(subscriber3_messages), subscriber3_messages) + + assert len(subscriber1_messages) == 19 + assert len(subscriber2_messages) == 12 + assert len(subscriber3_messages) == 7 + + assert subscriber2_messages[1] != [2] + assert subscriber3_messages[1] != [2] + + time.sleep(0.1) + + assert robot._node.subs == {} + + +@pytest.mark.asyncio +@pytest.mark.ros +async def test_topic_latest_async(robot): + from nav_msgs import msg + + odom = await robot.topic_latest_async("/odom", msg.Odometry) + assert odom() == 1 + await asyncio.sleep(0.45) + assert odom() == 5 + odom.dispose() + await asyncio.sleep(0.1) + assert robot._node.subs == {} + + +@pytest.mark.ros +def test_topic_auto_conversion(robot): + odom = robot.topic("/vector", Vector).subscribe(lambda x: print(x)) + time.sleep(0.5) + odom.dispose() + + +@pytest.mark.ros +def test_topic_latest_sync(robot): + from nav_msgs import msg + + odom = robot.topic_latest("/odom", msg.Odometry) + assert odom() == 1 + time.sleep(0.45) + assert odom() == 5 + odom.dispose() + time.sleep(0.1) + assert robot._node.subs == {} + + +@pytest.mark.ros +def test_topic_latest_sync_benchmark(robot): + from nav_msgs import msg + + odom = robot.topic_latest("/odom", msg.Odometry) + + start_time = time.time() + for i in range(100): + odom() + end_time = time.time() + elapsed = end_time - start_time + avg_time = elapsed / 100 + + print("avg time", avg_time) + + assert odom() == 1 + time.sleep(0.45) + assert odom() >= 5 + odom.dispose() + time.sleep(0.1) + assert robot._node.subs == {} diff --git a/build/lib/dimos/robot/unitree/__init__.py b/build/lib/dimos/robot/unitree/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/robot/unitree/unitree_go2.py b/build/lib/dimos/robot/unitree/unitree_go2.py new file mode 100644 index 0000000000..ca878e7134 --- /dev/null +++ b/build/lib/dimos/robot/unitree/unitree_go2.py @@ -0,0 +1,208 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +from typing import Optional, Union, List +import numpy as np +from dimos.robot.robot import Robot +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary +from reactivex.disposable import CompositeDisposable +import logging +import os +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from reactivex.scheduler import ThreadPoolScheduler +from dimos.utils.logging_config import setup_logger +from dimos.perception.person_tracker import PersonTrackingStream +from dimos.perception.object_tracker import ObjectTrackingStream +from dimos.robot.local_planner.local_planner import navigate_path_local +from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner +from dimos.robot.global_planner.planner import AstarPlanner +from dimos.types.costmap import Costmap +from dimos.types.robot_capabilities import RobotCapability +from dimos.types.vector import Vector + +# Set up logging +logger = setup_logger("dimos.robot.unitree.unitree_go2", level=logging.DEBUG) + +# UnitreeGo2 Print Colors (Magenta) +UNITREE_GO2_PRINT_COLOR = "\033[35m" +UNITREE_GO2_RESET_COLOR = "\033[0m" + + +class UnitreeGo2(Robot): + """Unitree Go2 robot implementation using ROS2 control interface. + + This class extends the base Robot class to provide specific functionality + for the Unitree Go2 quadruped robot using ROS2 for communication and control. + """ + + def __init__( + self, + video_provider=None, + output_dir: str = os.path.join(os.getcwd(), "assets", "output"), + skill_library: SkillLibrary = None, + robot_capabilities: List[RobotCapability] = None, + spatial_memory_collection: str = "spatial_memory", + new_memory: bool = False, + disable_video_stream: bool = False, + mock_connection: bool = False, + enable_perception: bool = True, + ): + """Initialize UnitreeGo2 robot with ROS control interface. + + Args: + video_provider: Provider for video streams + output_dir: Directory for output files + skill_library: Library of robot skills + robot_capabilities: List of robot capabilities + spatial_memory_collection: Collection name for spatial memory + new_memory: Whether to create new memory collection + disable_video_stream: Whether to disable video streaming + mock_connection: Whether to use mock connection for testing + enable_perception: Whether to enable perception streams and spatial memory + """ + # Create ROS control interface + ros_control = UnitreeROSControl( + node_name="unitree_go2", + video_provider=video_provider, + disable_video_stream=disable_video_stream, + mock_connection=mock_connection, + ) + + # Initialize skill library if not provided + if skill_library is None: + skill_library = MyUnitreeSkills() + + # Initialize base robot with connection interface + super().__init__( + connection_interface=ros_control, + output_dir=output_dir, + skill_library=skill_library, + capabilities=robot_capabilities + or [ + RobotCapability.LOCOMOTION, + RobotCapability.VISION, + RobotCapability.AUDIO, + ], + spatial_memory_collection=spatial_memory_collection, + new_memory=new_memory, + enable_perception=enable_perception, + ) + + if self.skill_library is not None: + for skill in self.skill_library: + if isinstance(skill, AbstractRobotSkill): + self.skill_library.create_instance(skill.__name__, robot=self) + if isinstance(self.skill_library, MyUnitreeSkills): + self.skill_library._robot = self + self.skill_library.init() + self.skill_library.initialize_skills() + + # Camera stuff + self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] + self.camera_pitch = np.deg2rad(0) # negative for downward pitch + self.camera_height = 0.44 # meters + + # Initialize UnitreeGo2-specific attributes + self.disposables = CompositeDisposable() + self.main_stream_obs = None + + # Initialize thread pool scheduler + self.optimal_thread_count = multiprocessing.cpu_count() + self.thread_pool_scheduler = ThreadPoolScheduler(self.optimal_thread_count // 2) + + # Initialize visual servoing if enabled + if not disable_video_stream: + self.video_stream_ros = self.get_video_stream(fps=8) + if enable_perception: + self.person_tracker = PersonTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + self.object_tracker = ObjectTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + person_tracking_stream = self.person_tracker.create_stream(self.video_stream_ros) + object_tracking_stream = self.object_tracker.create_stream(self.video_stream_ros) + + self.person_tracking_stream = person_tracking_stream + self.object_tracking_stream = object_tracking_stream + else: + # Video stream is available but perception tracking is disabled + self.person_tracker = None + self.object_tracker = None + self.person_tracking_stream = None + self.object_tracking_stream = None + else: + # Video stream is disabled + self.video_stream_ros = None + self.person_tracker = None + self.object_tracker = None + self.person_tracking_stream = None + self.object_tracking_stream = None + + # Initialize the local planner and create BEV visualization stream + # Note: These features require ROS-specific methods that may not be available on all connection interfaces + if hasattr(self.connection_interface, "topic_latest") and hasattr( + self.connection_interface, "transform_euler" + ): + self.local_planner = VFHPurePursuitPlanner( + get_costmap=self.connection_interface.topic_latest( + "/local_costmap/costmap", Costmap + ), + transform=self.connection_interface, + move_vel_control=self.connection_interface.move_vel_control, + robot_width=0.36, # Unitree Go2 width in meters + robot_length=0.6, # Unitree Go2 length in meters + max_linear_vel=0.5, + lookahead_distance=2.0, + visualization_size=500, # 500x500 pixel visualization + ) + + self.global_planner = AstarPlanner( + conservativism=20, # how close to obstacles robot is allowed to path plan + set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( + self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event + ), + get_costmap=self.connection_interface.topic_latest("map", Costmap), + get_robot_pos=lambda: self.connection_interface.transform_euler_pos("base_link"), + ) + + # Create the visualization stream at 5Hz + self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) + else: + self.local_planner = None + self.global_planner = None + self.local_planner_viz_stream = None + + def get_skills(self) -> Optional[SkillLibrary]: + return self.skill_library + + def get_pose(self) -> dict: + """ + Get the current pose (position and rotation) of the robot in the map frame. + + Returns: + Dictionary containing: + - position: Vector (x, y, z) + - rotation: Vector (roll, pitch, yaw) in radians + """ + position_tuple, orientation_tuple = self.connection_interface.get_pose_odom_transform() + position = Vector(position_tuple[0], position_tuple[1], position_tuple[2]) + rotation = Vector(orientation_tuple[0], orientation_tuple[1], orientation_tuple[2]) + return {"position": position, "rotation": rotation} diff --git a/build/lib/dimos/robot/unitree/unitree_ros_control.py b/build/lib/dimos/robot/unitree/unitree_ros_control.py new file mode 100644 index 0000000000..56e83cb30f --- /dev/null +++ b/build/lib/dimos/robot/unitree/unitree_ros_control.py @@ -0,0 +1,157 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from go2_interfaces.msg import Go2State, IMU +from unitree_go.msg import WebRtcReq +from typing import Type +from sensor_msgs.msg import Image, CompressedImage, CameraInfo +from dimos.robot.ros_control import ROSControl, RobotMode +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.unitree.unitree_ros_control") + + +class UnitreeROSControl(ROSControl): + """Hardware interface for Unitree Go2 robot using ROS2""" + + # ROS Camera Topics + CAMERA_TOPICS = { + "raw": {"topic": "camera/image_raw", "type": Image}, + "compressed": {"topic": "camera/compressed", "type": CompressedImage}, + "info": {"topic": "camera/camera_info", "type": CameraInfo}, + } + # Hard coded ROS Message types and Topic names for Unitree Go2 + DEFAULT_STATE_MSG_TYPE = Go2State + DEFAULT_IMU_MSG_TYPE = IMU + DEFAULT_WEBRTC_MSG_TYPE = WebRtcReq + DEFAULT_STATE_TOPIC = "go2_states" + DEFAULT_IMU_TOPIC = "imu" + DEFAULT_WEBRTC_TOPIC = "webrtc_req" + DEFAULT_CMD_VEL_TOPIC = "cmd_vel_out" + DEFAULT_POSE_TOPIC = "pose_cmd" + DEFAULT_ODOM_TOPIC = "odom" + DEFAULT_COSTMAP_TOPIC = "local_costmap/costmap" + DEFAULT_MAX_LINEAR_VELOCITY = 1.0 + DEFAULT_MAX_ANGULAR_VELOCITY = 2.0 + + # Hard coded WebRTC API parameters for Unitree Go2 + DEFAULT_WEBRTC_API_TOPIC = "rt/api/sport/request" + + def __init__( + self, + node_name: str = "unitree_hardware_interface", + state_topic: str = None, + imu_topic: str = None, + webrtc_topic: str = None, + webrtc_api_topic: str = None, + move_vel_topic: str = None, + pose_topic: str = None, + odom_topic: str = None, + costmap_topic: str = None, + state_msg_type: Type = None, + imu_msg_type: Type = None, + webrtc_msg_type: Type = None, + max_linear_velocity: float = None, + max_angular_velocity: float = None, + use_raw: bool = False, + debug: bool = False, + disable_video_stream: bool = False, + mock_connection: bool = False, + ): + """ + Initialize Unitree ROS control interface with default values for Unitree Go2 + + Args: + node_name: Name for the ROS node + state_topic: ROS Topic name for robot state (defaults to DEFAULT_STATE_TOPIC) + imu_topic: ROS Topic name for IMU data (defaults to DEFAULT_IMU_TOPIC) + webrtc_topic: ROS Topic for WebRTC commands (defaults to DEFAULT_WEBRTC_TOPIC) + cmd_vel_topic: ROS Topic for direct movement velocity commands (defaults to DEFAULT_CMD_VEL_TOPIC) + pose_topic: ROS Topic for pose commands (defaults to DEFAULT_POSE_TOPIC) + odom_topic: ROS Topic for odometry data (defaults to DEFAULT_ODOM_TOPIC) + costmap_topic: ROS Topic for local costmap data (defaults to DEFAULT_COSTMAP_TOPIC) + state_msg_type: ROS Message type for state data (defaults to DEFAULT_STATE_MSG_TYPE) + imu_msg_type: ROS message type for IMU data (defaults to DEFAULT_IMU_MSG_TYPE) + webrtc_msg_type: ROS message type for webrtc data (defaults to DEFAULT_WEBRTC_MSG_TYPE) + max_linear_velocity: Maximum linear velocity in m/s (defaults to DEFAULT_MAX_LINEAR_VELOCITY) + max_angular_velocity: Maximum angular velocity in rad/s (defaults to DEFAULT_MAX_ANGULAR_VELOCITY) + use_raw: Whether to use raw camera topics (defaults to False) + debug: Whether to enable debug logging + disable_video_stream: Whether to run without video stream for testing. + mock_connection: Whether to run without active ActionClient servers for testing. + """ + + logger.info("Initializing Unitree ROS control interface") + # Select which camera topics to use + active_camera_topics = None + if not disable_video_stream: + active_camera_topics = {"main": self.CAMERA_TOPICS["raw" if use_raw else "compressed"]} + + # Use default values if not provided + state_topic = state_topic or self.DEFAULT_STATE_TOPIC + imu_topic = imu_topic or self.DEFAULT_IMU_TOPIC + webrtc_topic = webrtc_topic or self.DEFAULT_WEBRTC_TOPIC + move_vel_topic = move_vel_topic or self.DEFAULT_CMD_VEL_TOPIC + pose_topic = pose_topic or self.DEFAULT_POSE_TOPIC + odom_topic = odom_topic or self.DEFAULT_ODOM_TOPIC + costmap_topic = costmap_topic or self.DEFAULT_COSTMAP_TOPIC + webrtc_api_topic = webrtc_api_topic or self.DEFAULT_WEBRTC_API_TOPIC + state_msg_type = state_msg_type or self.DEFAULT_STATE_MSG_TYPE + imu_msg_type = imu_msg_type or self.DEFAULT_IMU_MSG_TYPE + webrtc_msg_type = webrtc_msg_type or self.DEFAULT_WEBRTC_MSG_TYPE + max_linear_velocity = max_linear_velocity or self.DEFAULT_MAX_LINEAR_VELOCITY + max_angular_velocity = max_angular_velocity or self.DEFAULT_MAX_ANGULAR_VELOCITY + + super().__init__( + node_name=node_name, + camera_topics=active_camera_topics, + mock_connection=mock_connection, + state_topic=state_topic, + imu_topic=imu_topic, + state_msg_type=state_msg_type, + imu_msg_type=imu_msg_type, + webrtc_msg_type=webrtc_msg_type, + webrtc_topic=webrtc_topic, + webrtc_api_topic=webrtc_api_topic, + move_vel_topic=move_vel_topic, + pose_topic=pose_topic, + odom_topic=odom_topic, + costmap_topic=costmap_topic, + max_linear_velocity=max_linear_velocity, + max_angular_velocity=max_angular_velocity, + debug=debug, + ) + + # Unitree-specific RobotMode State update conditons + def _update_mode(self, msg: Go2State): + """ + Implementation of abstract method to update robot mode + + Logic: + - If progress is 0 and mode is 1, then state is IDLE + - If progress is 1 OR mode is NOT equal to 1, then state is MOVING + """ + # Direct access to protected instance variables from the parent class + mode = msg.mode + progress = msg.progress + + if progress == 0 and mode == 1: + self._mode = RobotMode.IDLE + logger.debug("Robot mode set to IDLE (progress=0, mode=1)") + elif progress == 1 or mode != 1: + self._mode = RobotMode.MOVING + logger.debug(f"Robot mode set to MOVING (progress={progress}, mode={mode})") + else: + self._mode = RobotMode.UNKNOWN + logger.debug(f"Robot mode set to UNKNOWN (progress={progress}, mode={mode})") diff --git a/build/lib/dimos/robot/unitree/unitree_skills.py b/build/lib/dimos/robot/unitree/unitree_skills.py new file mode 100644 index 0000000000..5029123ed1 --- /dev/null +++ b/build/lib/dimos/robot/unitree/unitree_skills.py @@ -0,0 +1,314 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union +import time +from pydantic import Field + +if TYPE_CHECKING: + from dimos.robot.robot import Robot, MockRobot +else: + Robot = "Robot" + MockRobot = "MockRobot" + +from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary +from dimos.types.constants import Colors +from dimos.types.vector import Vector + +# Module-level constant for Unitree ROS control definitions +UNITREE_ROS_CONTROLS: List[Tuple[str, int, str]] = [ + ("Damp", 1001, "Lowers the robot to the ground fully."), + ( + "BalanceStand", + 1002, + "Activates a mode that maintains the robot in a balanced standing position.", + ), + ( + "StandUp", + 1004, + "Commands the robot to transition from a sitting or prone position to a standing posture.", + ), + ( + "StandDown", + 1005, + "Instructs the robot to move from a standing position to a sitting or prone posture.", + ), + ( + "RecoveryStand", + 1006, + "Recovers the robot to a state from which it can take more commands. Useful to run after multiple dynamic commands like front flips.", + ), + # ( + # "Euler", + # 1007, + # "Adjusts the robot's orientation using Euler angles, providing precise control over its rotation.", + # ), + # ("Move", 1008, "Move the robot using velocity commands."), # Intentionally omitted + ("Sit", 1009, "Commands the robot to sit down from a standing or moving stance."), + # ( + # "RiseSit", + # 1010, + # "Commands the robot to rise back to a standing position from a sitting posture.", + # ), + # ( + # "SwitchGait", + # 1011, + # "Switches the robot's walking pattern or style dynamically, suitable for different terrains or speeds.", + # ), + # ("Trigger", 1012, "Triggers a specific action or custom routine programmed into the robot."), + # ( + # "BodyHeight", + # 1013, + # "Adjusts the height of the robot's body from the ground, useful for navigating various obstacles.", + # ), + # ( + # "FootRaiseHeight", + # 1014, + # "Controls how high the robot lifts its feet during movement, which can be adjusted for different surfaces.", + # ), + ( + "SpeedLevel", + 1015, + "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", + ), + ( + "ShakeHand", + 1016, + "Performs a greeting action, which could involve a wave or other friendly gesture.", + ), + ("Stretch", 1017, "Engages the robot in a stretching routine."), + # ( + # "TrajectoryFollow", + # 1018, + # "Directs the robot to follow a predefined trajectory, which could involve complex paths or maneuvers.", + # ), + # ( + # "ContinuousGait", + # 1019, + # "Enables a mode for continuous walking or running, ideal for long-distance travel.", + # ), + ("Content", 1020, "To display or trigger when the robot is happy."), + ("Wallow", 1021, "The robot falls onto its back and rolls around."), + ( + "Dance1", + 1022, + "Performs a predefined dance routine 1, programmed for entertainment or demonstration.", + ), + ("Dance2", 1023, "Performs another variant of a predefined dance routine 2."), + # ("GetBodyHeight", 1024, "Retrieves the current height of the robot's body from the ground."), + # ( + # "GetFootRaiseHeight", + # 1025, + # "Retrieves the current height at which the robot's feet are being raised during movement.", + # ), + # ("GetSpeedLevel", 1026, "Returns the current speed level at which the robot is operating."), + # ( + # "SwitchJoystick", + # 1027, + # "Toggles the control mode to joystick input, allowing for manual direction of the robot's movements.", + # ), + ( + "Pose", + 1028, + "Directs the robot to take a specific pose or stance, which could be used for tasks or performances.", + ), + ( + "Scrape", + 1029, + "Robot falls to its hind legs and makes scraping motions with its front legs.", + ), + ("FrontFlip", 1030, "Executes a front flip, a complex and dynamic maneuver."), + ("FrontJump", 1031, "Commands the robot to perform a forward jump."), + ( + "FrontPounce", + 1032, + "Initiates a pouncing movement forward, mimicking animal-like pouncing behavior.", + ), + # ("WiggleHips", 1033, "Causes the robot to wiggle its hips."), + # ( + # "GetState", + # 1034, + # "Retrieves the current operational state of the robot, including status reports or diagnostic information.", + # ), + # ( + # "EconomicGait", + # 1035, + # "Engages a more energy-efficient walking or running mode to conserve battery life.", + # ), + # ("FingerHeart", 1036, "Performs a finger heart gesture while on its hind legs."), + # ( + # "Handstand", + # 1301, + # "Commands the robot to perform a handstand, demonstrating balance and control.", + # ), + # ( + # "CrossStep", + # 1302, + # "Engages the robot in a cross-stepping routine, useful for complex locomotion or dance moves.", + # ), + # ( + # "OnesidedStep", + # 1303, + # "Commands the robot to perform a stepping motion that predominantly uses one side.", + # ), + # ( + # "Bound", + # 1304, + # "Initiates a bounding motion, similar to a light, repetitive hopping or leaping.", + # ), + # ( + # "LeadFollow", + # 1045, + # "Engages follow-the-leader behavior, where the robot follows a designated leader or follows a signal.", + # ), + # ("LeftFlip", 1042, "Executes a flip towards the left side."), + # ("RightFlip", 1043, "Performs a flip towards the right side."), + # ("Backflip", 1044, "Executes a backflip, a complex and dynamic maneuver."), +] + +# region MyUnitreeSkills + + +class MyUnitreeSkills(SkillLibrary): + """My Unitree Skills.""" + + _robot: Optional[Robot] = None + + @classmethod + def register_skills(cls, skill_classes: Union["AbstractSkill", list["AbstractSkill"]]): + """Add multiple skill classes as class attributes. + + Args: + skill_classes: List of skill classes to add + """ + if isinstance(skill_classes, list): + for skill_class in skill_classes: + setattr(cls, skill_class.__name__, skill_class) + else: + setattr(cls, skill_classes.__name__, skill_classes) + + def __init__(self, robot: Optional[Robot] = None): + super().__init__() + self._robot: Robot = None + + # Add dynamic skills to this class + self.register_skills(self.create_skills_live()) + + if robot is not None: + self._robot = robot + self.initialize_skills() + + def initialize_skills(self): + # Create the skills and add them to the list of skills + self.register_skills(self.create_skills_live()) + + # Provide the robot instance to each skill + for skill_class in self: + print( + f"{Colors.GREEN_PRINT_COLOR}Creating instance for skill: {skill_class}{Colors.RESET_COLOR}" + ) + self.create_instance(skill_class.__name__, robot=self._robot) + + # Refresh the class skills + self.refresh_class_skills() + + def create_skills_live(self) -> List[AbstractRobotSkill]: + # ================================================ + # Procedurally created skills + # ================================================ + class BaseUnitreeSkill(AbstractRobotSkill): + """Base skill for dynamic skill creation.""" + + def __call__(self): + string = f"{Colors.GREEN_PRINT_COLOR}This is a base skill, created for the specific skill: {self._app_id}{Colors.RESET_COLOR}" + print(string) + super().__call__() + if self._app_id is None: + raise RuntimeError( + f"{Colors.RED_PRINT_COLOR}" + f"No App ID provided to {self.__class__.__name__} Skill" + f"{Colors.RESET_COLOR}" + ) + else: + self._robot.webrtc_req(api_id=self._app_id) + string = f"{Colors.GREEN_PRINT_COLOR}{self.__class__.__name__} was successful: id={self._app_id}{Colors.RESET_COLOR}" + print(string) + return string + + skills_classes = [] + for name, app_id, description in UNITREE_ROS_CONTROLS: + skill_class = type( + name, # Name of the class + (BaseUnitreeSkill,), # Base classes + {"__doc__": description, "_app_id": app_id}, + ) + skills_classes.append(skill_class) + + return skills_classes + + # region Class-based Skills + + class Move(AbstractRobotSkill): + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" + + x: float = Field(..., description="Forward velocity (m/s).") + y: float = Field(default=0.0, description="Left/right velocity (m/s)") + yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") + duration: float = Field(default=0.0, description="How long to move (seconds).") + + def __call__(self): + super().__call__() + return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) + + class Reverse(AbstractRobotSkill): + """Reverse the robot using direct velocity commands. Determine duration required based on user distance instructions.""" + + x: float = Field(..., description="Backward velocity (m/s). Positive values move backward.") + y: float = Field(default=0.0, description="Left/right velocity (m/s)") + yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") + duration: float = Field(default=0.0, description="How long to move (seconds).") + + def __call__(self): + super().__call__() + # Use move with negative x for backward movement + return self._robot.move(Vector(-self.x, self.y, self.yaw), duration=self.duration) + + class SpinLeft(AbstractRobotSkill): + """Spin the robot left using degree commands.""" + + degrees: float = Field(..., description="Distance to spin left in degrees") + + def __call__(self): + super().__call__() + return self._robot.spin(degrees=self.degrees) # Spinning left is positive degrees + + class SpinRight(AbstractRobotSkill): + """Spin the robot right using degree commands.""" + + degrees: float = Field(..., description="Distance to spin right in degrees") + + def __call__(self): + super().__call__() + return self._robot.spin(degrees=-self.degrees) # Spinning right is negative degrees + + class Wait(AbstractSkill): + """Wait for a specified amount of time.""" + + seconds: float = Field(..., description="Seconds to wait") + + def __call__(self): + time.sleep(self.seconds) + return f"Wait completed with length={self.seconds}s" diff --git a/build/lib/dimos/robot/unitree_webrtc/__init__.py b/build/lib/dimos/robot/unitree_webrtc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/robot/unitree_webrtc/connection.py b/build/lib/dimos/robot/unitree_webrtc/connection.py new file mode 100644 index 0000000000..86fe5f6a85 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/connection.py @@ -0,0 +1,309 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import functools +import threading +import time +from typing import Literal, TypeAlias + +import numpy as np +from aiortc import MediaStreamTrack +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR +from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] + Go2WebRTCConnection, + WebRTCConnectionMethod, +) +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.sensor_msgs import Image +from dimos.robot.connection_interface import ConnectionInterface +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.pose import Pose +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, callback_to_observable + +VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] + + +class WebRTCRobot(ConnectionInterface): + def __init__(self, ip: str, mode: str = "ai"): + self.ip = ip + self.mode = mode + self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) + self.connect() + + def connect(self): + self.loop = asyncio.new_event_loop() + self.task = None + self.connected_event = asyncio.Event() + self.connection_ready = threading.Event() + + async def async_connect(): + await self.conn.connect() + await self.conn.datachannel.disableTrafficSaving(True) + + self.conn.datachannel.set_decoder(decoder_type="native") + + await self.conn.datachannel.pub_sub.publish_request_new( + RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} + ) + + self.connected_event.set() + self.connection_ready.set() + + while True: + await asyncio.sleep(1) + + def start_background_loop(): + asyncio.set_event_loop(self.loop) + self.task = self.loop.create_task(async_connect()) + self.loop.run_forever() + + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=start_background_loop, daemon=True) + self.thread.start() + self.connection_ready.wait() + + def move(self, velocity: Vector, duration: float = 0.0) -> bool: + """Send movement command to the robot using velocity commands. + + Args: + velocity: Velocity vector [x, y, yaw] where: + x: Forward/backward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds). If 0, command is continuous + + Returns: + bool: True if command was sent successfully + """ + x, y, yaw = velocity.x, velocity.y, velocity.z + + # WebRTC coordinate mapping: + # x - Positive right, negative left + # y - positive forward, negative backwards + # yaw - Positive rotate right, negative rotate left + async def async_move(): + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": y, "ly": x, "rx": -yaw, "ry": 0}, + ) + + async def async_move_duration(): + """Send movement commands continuously for the specified duration.""" + start_time = time.time() + sleep_time = 0.01 + + while time.time() - start_time < duration: + await async_move() + await asyncio.sleep(sleep_time) + + try: + if duration > 0: + # Send continuous move commands for the duration + future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) + future.result() + # Stop after duration + self.stop() + else: + # Single command for continuous movement + future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) + future.result() + return True + except Exception as e: + print(f"Failed to send movement command: {e}") + return False + + # Generic conversion of unitree subscription to Subject (used for all subs) + def unitree_sub_stream(self, topic_name: str): + def subscribe_in_thread(cb): + # Run the subscription in the background thread that has the event loop + def run_subscription(): + self.conn.datachannel.pub_sub.subscribe(topic_name, cb) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_subscription) + + def unsubscribe_in_thread(cb): + # Run the unsubscription in the background thread that has the event loop + def run_unsubscription(): + self.conn.datachannel.pub_sub.unsubscribe(topic_name) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_unsubscription) + + return callback_to_observable( + start=subscribe_in_thread, + stop=unsubscribe_in_thread, + ) + + # Generic sync API call (we jump into the client thread) + def publish_request(self, topic: str, data: dict): + future = asyncio.run_coroutine_threadsafe( + self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop + ) + return future.result() + + @functools.cache + def raw_lidar_stream(self) -> Subject[LidarMessage]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) + + @functools.cache + def raw_odom_stream(self) -> Subject[Pose]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) + + @functools.cache + def lidar_stream(self) -> Subject[LidarMessage]: + return backpressure( + self.raw_lidar_stream().pipe( + ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame)) + ) + ) + + @functools.cache + def odom_stream(self) -> Subject[Pose]: + return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) + + @functools.cache + def lowstate_stream(self) -> Subject[LowStateMsg]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) + + def standup_ai(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) + + def standup_normal(self): + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + time.sleep(0.5) + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) + return True + + @rpc + def standup(self): + if self.mode == "ai": + return self.standup_ai() + else: + return self.standup_normal() + + @rpc + def liedown(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + + async def handstand(self): + return self.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, + ) + + @rpc + def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: + return self.publish_request( + RTC_TOPIC["VUI"], + { + "api_id": 1001, + "parameter": { + "color": color, + "time": colortime, + }, + }, + ) + + @functools.lru_cache(maxsize=None) + def video_stream(self) -> Observable[VideoMessage]: + subject: Subject[VideoMessage] = Subject() + stop_event = threading.Event() + + async def accept_track(track: MediaStreamTrack) -> VideoMessage: + while True: + if stop_event.is_set(): + return + frame = await track.recv() + subject.on_next(Image.from_numpy(frame.to_ndarray(format="bgr24"))) + + self.conn.video.add_track_callback(accept_track) + + # Run the video channel switching in the background thread + def switch_video_channel(): + self.conn.video.switchVideoChannel(True) + + self.loop.call_soon_threadsafe(switch_video_channel) + + def stop(cb): + stop_event.set() # Signal the loop to stop + self.conn.video.track_callbacks.remove(accept_track) + + # Run the video channel switching off in the background thread + def switch_video_channel_off(): + self.conn.video.switchVideoChannel(False) + + self.loop.call_soon_threadsafe(switch_video_channel_off) + + return subject.pipe(ops.finally_action(stop)) + + def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: + """Get the video stream from the robot's camera. + + Implements the AbstractRobot interface method. + + Args: + fps: Frames per second. This parameter is included for API compatibility, + but doesn't affect the actual frame rate which is determined by the camera. + + Returns: + Observable: An observable stream of video frames or None if video is not available. + """ + try: + print("Starting WebRTC video stream...") + stream = self.video_stream() + if stream is None: + print("Warning: Video stream is not available") + return stream + + except Exception as e: + print(f"Error getting video stream: {e}") + return None + + def stop(self) -> bool: + """Stop the robot's movement. + + Returns: + bool: True if stop command was sent successfully + """ + return self.move(Vector(0.0, 0.0, 0.0)) + + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + if hasattr(self, "task") and self.task: + self.task.cancel() + if hasattr(self, "conn"): + + async def async_disconnect(): + try: + await self.conn.disconnect() + except: + pass + + if hasattr(self, "loop") and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + if hasattr(self, "loop") and self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + + if hasattr(self, "thread") and self.thread.is_alive(): + self.thread.join(timeout=2.0) diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/__init__.py b/build/lib/dimos/robot/unitree_webrtc/testing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/helpers.py b/build/lib/dimos/robot/unitree_webrtc/testing/helpers.py new file mode 100644 index 0000000000..8d01cb76cc --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/testing/helpers.py @@ -0,0 +1,168 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import open3d as o3d +from typing import Callable, Union, Any, Protocol, Iterable +from reactivex.observable import Observable + +color1 = [1, 0.706, 0] +color2 = [0, 0.651, 0.929] +color3 = [0.8, 0.196, 0.6] +color4 = [0.235, 0.702, 0.443] +color = [color1, color2, color3, color4] + + +# benchmarking function can return int, which will be applied to the time. +# +# (in case there is some preparation within the fuction and this time needs to be subtracted +# from the benchmark target) +def benchmark(calls: int, targetf: Callable[[], Union[int, None]]) -> float: + start = time.time() + timemod = 0 + for _ in range(calls): + res = targetf() + if res is not None: + timemod += res + end = time.time() + return (end - start + timemod) * 1000 / calls + + +O3dDrawable = ( + o3d.geometry.Geometry + | o3d.geometry.LineSet + | o3d.geometry.TriangleMesh + | o3d.geometry.PointCloud +) + + +class ReturnsDrawable(Protocol): + def o3d_geometry(self) -> O3dDrawable: ... + + +Drawable = O3dDrawable | ReturnsDrawable + + +def show3d(*components: Iterable[Drawable], title: str = "open3d") -> o3d.visualization.Visualizer: + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=title) + for component in components: + # our custom drawable components should return an open3d geometry + if hasattr(component, "o3d_geometry"): + vis.add_geometry(component.o3d_geometry) + else: + vis.add_geometry(component) + + opt = vis.get_render_option() + opt.background_color = [0, 0, 0] + opt.point_size = 10 + vis.poll_events() + vis.update_renderer() + return vis + + +def multivis(*vis: o3d.visualization.Visualizer) -> None: + while True: + for v in vis: + v.poll_events() + v.update_renderer() + + +def show3d_stream( + geometry_observable: Observable[Any], + clearframe: bool = False, + title: str = "open3d", +) -> o3d.visualization.Visualizer: + """ + Visualize a stream of geometries using Open3D. The first geometry initializes the visualizer. + Subsequent geometries update the visualizer. If no new geometry, just poll events. + geometry_observable: Observable of objects with .o3d_geometry or Open3D geometry + """ + import threading + import queue + import time + from typing import Any + + q: queue.Queue[Any] = queue.Queue() + stop_flag = threading.Event() + + def on_next(geometry: O3dDrawable) -> None: + q.put(geometry) + + def on_error(e: Exception) -> None: + print(f"Visualization error: {e}") + stop_flag.set() + + def on_completed() -> None: + print("Geometry stream completed") + stop_flag.set() + + subscription = geometry_observable.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + ) + + def geom(geometry: Drawable) -> O3dDrawable: + """Extracts the Open3D geometry from the given object.""" + return geometry.o3d_geometry if hasattr(geometry, "o3d_geometry") else geometry + + # Wait for the first geometry + first_geometry = None + while first_geometry is None and not stop_flag.is_set(): + try: + first_geometry = q.get(timeout=100) + except queue.Empty: + print("No geometry received to visualize.") + return + + scene_geometries = [] + first_geom_obj = geom(first_geometry) + + scene_geometries.append(first_geom_obj) + + vis = show3d(first_geom_obj, title=title) + + try: + while not stop_flag.is_set(): + try: + geometry = q.get_nowait() + geom_obj = geom(geometry) + if clearframe: + scene_geometries = [] + vis.clear_geometries() + + vis.add_geometry(geom_obj) + scene_geometries.append(geom_obj) + else: + if geom_obj in scene_geometries: + print("updating existing geometry") + vis.update_geometry(geom_obj) + else: + print("new geometry") + vis.add_geometry(geom_obj) + scene_geometries.append(geom_obj) + except queue.Empty: + pass + vis.poll_events() + vis.update_renderer() + time.sleep(0.1) + + except KeyboardInterrupt: + print("closing visualizer...") + stop_flag.set() + vis.destroy_window() + subscription.dispose() + + return vis diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/mock.py b/build/lib/dimos/robot/unitree_webrtc/testing/mock.py new file mode 100644 index 0000000000..f929d33c5c --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/testing/mock.py @@ -0,0 +1,91 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import glob +from typing import Union, Iterator, cast, overload +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage, RawLidarMsg + +from reactivex import operators as ops +from reactivex import interval, from_iterable +from reactivex.observable import Observable + + +class Mock: + def __init__(self, root="office", autocast: bool = True): + current_dir = os.path.dirname(os.path.abspath(__file__)) + self.root = os.path.join(current_dir, f"mockdata/{root}") + self.autocast = autocast + self.cnt = 0 + + @overload + def load(self, name: Union[int, str], /) -> LidarMessage: ... + @overload + def load(self, *names: Union[int, str]) -> list[LidarMessage]: ... + + def load(self, *names: Union[int, str]) -> Union[LidarMessage, list[LidarMessage]]: + if len(names) == 1: + return self.load_one(names[0]) + return list(map(lambda name: self.load_one(name), names)) + + def load_one(self, name: Union[int, str]) -> LidarMessage: + if isinstance(name, int): + file_name = f"/lidar_data_{name:03d}.pickle" + else: + file_name = f"/{name}.pickle" + + full_path = self.root + file_name + with open(full_path, "rb") as f: + return LidarMessage.from_msg(cast(RawLidarMsg, pickle.load(f))) + + def iterate(self) -> Iterator[LidarMessage]: + pattern = os.path.join(self.root, "lidar_data_*.pickle") + print("loading data", pattern) + for file_path in sorted(glob.glob(pattern)): + basename = os.path.basename(file_path) + filename = os.path.splitext(basename)[0] + yield self.load_one(filename) + + def stream(self, rate_hz=10.0): + sleep_time = 1.0 / rate_hz + + return from_iterable(self.iterate()).pipe( + ops.zip(interval(sleep_time)), + ops.map(lambda x: x[0] if isinstance(x, tuple) else x), + ) + + def save_stream(self, observable: Observable[LidarMessage]): + return observable.pipe(ops.map(lambda frame: self.save_one(frame))) + + def save(self, *frames): + [self.save_one(frame) for frame in frames] + return self.cnt + + def save_one(self, frame): + file_name = f"/lidar_data_{self.cnt:03d}.pickle" + full_path = self.root + file_name + + self.cnt += 1 + + if os.path.isfile(full_path): + raise Exception(f"file {full_path} exists") + + if frame.__class__ == LidarMessage: + frame = frame.raw_msg + + with open(full_path, "wb") as f: + pickle.dump(frame, f) + + return self.cnt diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/multimock.py b/build/lib/dimos/robot/unitree_webrtc/testing/multimock.py new file mode 100644 index 0000000000..cfc2688129 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/testing/multimock.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multimock – lightweight persistence & replay helper built on RxPy. + +A directory of pickle files acts as a tiny append-only log of (timestamp, data) +pairs. You can: + • save() / consume(): append new frames + • iterate(): read them back lazily + • interval_stream(): emit at a fixed cadence + • stream(): replay with original timing (optionally scaled) + +The implementation keeps memory usage constant by relying on reactive +operators instead of pre-materialising lists. Timing is reproduced via +`rx.timer`, and drift is avoided with `concat_map`. +""" + +from __future__ import annotations + +import glob +import os +import pickle +import time +from typing import Any, Generic, Iterator, List, Tuple, TypeVar, Union, Optional +from reactivex.scheduler import ThreadPoolScheduler + +from reactivex import from_iterable, interval, operators as ops +from reactivex.observable import Observable +from dimos.utils.threadpool import get_scheduler +from dimos.robot.unitree_webrtc.type.timeseries import TEvent, Timeseries + +T = TypeVar("T") + + +class Multimock(Generic[T], Timeseries[TEvent[T]]): + """Persist frames as pickle files and replay them with RxPy.""" + + def __init__(self, root: str = "office", file_prefix: str = "msg") -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + self.root = os.path.join(current_dir, f"multimockdata/{root}") + self.file_prefix = file_prefix + + os.makedirs(self.root, exist_ok=True) + self.cnt: int = 0 + + def save(self, *frames: Any) -> int: + """Persist one or more frames; returns the new counter value.""" + for frame in frames: + self.save_one(frame) + return self.cnt + + def save_one(self, frame: Any) -> int: + """Persist a single frame and return the running count.""" + file_name = f"/{self.file_prefix}_{self.cnt:03d}.pickle" + full_path = os.path.join(self.root, file_name.lstrip("/")) + self.cnt += 1 + + if os.path.isfile(full_path): + raise FileExistsError(f"file {full_path} exists") + + # Optional convinience magic to extract raw messages from advanced types + # trying to deprecate for now + # if hasattr(frame, "raw_msg"): + # frame = frame.raw_msg # type: ignore[attr-defined] + + with open(full_path, "wb") as f: + pickle.dump([time.time(), frame], f) + + return self.cnt + + def load(self, *names: Union[int, str]) -> List[Tuple[float, T]]: + """Load multiple items by name or index.""" + return list(map(self.load_one, names)) + + def load_one(self, name: Union[int, str]) -> TEvent[T]: + """Load a single item by name or index.""" + if isinstance(name, int): + file_name = f"/{self.file_prefix}_{name:03d}.pickle" + else: + file_name = f"/{name}.pickle" + + full_path = os.path.join(self.root, file_name.lstrip("/")) + + with open(full_path, "rb") as f: + timestamp, data = pickle.load(f) + + return TEvent(timestamp, data) + + def iterate(self) -> Iterator[TEvent[T]]: + """Yield all persisted TEvent(timestamp, data) pairs lazily in order.""" + pattern = os.path.join(self.root, f"{self.file_prefix}_*.pickle") + for file_path in sorted(glob.glob(pattern)): + with open(file_path, "rb") as f: + timestamp, data = pickle.load(f) + yield TEvent(timestamp, data) + + def list(self) -> List[TEvent[T]]: + return list(self.iterate()) + + def interval_stream(self, rate_hz: float = 10.0) -> Observable[T]: + """Emit frames at a fixed rate, ignoring recorded timing.""" + sleep_time = 1.0 / rate_hz + return from_iterable(self.iterate()).pipe( + ops.zip(interval(sleep_time)), + ops.map(lambda pair: pair[1]), # keep only the frame + ) + + def stream( + self, + replay_speed: float = 1.0, + scheduler: Optional[ThreadPoolScheduler] = None, + ) -> Observable[T]: + def _generator(): + prev_ts: float | None = None + for event in self.iterate(): + if prev_ts is not None: + delay = (event.ts - prev_ts).total_seconds() / replay_speed + time.sleep(delay) + prev_ts = event.ts + yield event.data + + return from_iterable(_generator(), scheduler=scheduler or get_scheduler()) + + def consume(self, observable: Observable[Any]) -> Observable[int]: + """Side-effect: save every frame that passes through.""" + return observable.pipe(ops.map(self.save_one)) + + def __iter__(self) -> Iterator[TEvent[T]]: + """Allow iteration over the Multimock instance to yield TEvent(timestamp, data) pairs.""" + return self.iterate() diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/test_mock.py b/build/lib/dimos/robot/unitree_webrtc/testing/test_mock.py new file mode 100644 index 0000000000..4852392943 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/testing/test_mock.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import pytest +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.testing.mock import Mock + + +@pytest.mark.needsdata +def test_mock_load_cast(): + mock = Mock("test") + + # Load a frame with type casting + frame = mock.load("a") + + # Verify it's a LidarMessage object + assert frame.__class__.__name__ == "LidarMessage" + assert hasattr(frame, "timestamp") + assert hasattr(frame, "origin") + assert hasattr(frame, "resolution") + assert hasattr(frame, "pointcloud") + + # Verify pointcloud has points + assert frame.pointcloud.has_points() + assert len(frame.pointcloud.points) > 0 + + +@pytest.mark.needsdata +def test_mock_iterate(): + """Test the iterate method of the Mock class.""" + mock = Mock("office") + + # Test iterate method + frames = list(mock.iterate()) + assert len(frames) > 0 + for frame in frames: + assert isinstance(frame, LidarMessage) + assert frame.pointcloud.has_points() + + +@pytest.mark.needsdata +def test_mock_stream(): + frames = [] + sub1 = Mock("office").stream(rate_hz=30.0).subscribe(on_next=frames.append) + time.sleep(0.1) + sub1.dispose() + + assert len(frames) >= 2 + assert isinstance(frames[0], LidarMessage) diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/test_multimock.py b/build/lib/dimos/robot/unitree_webrtc/testing/test_multimock.py new file mode 100644 index 0000000000..1d64cbd3a0 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/testing/test_multimock.py @@ -0,0 +1,111 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import pytest + +from reactivex import operators as ops + +from dimos.utils.reactive import backpressure +from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream +from dimos.web.websocket_vis.server import WebsocketVis +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.robot.unitree_webrtc.type.timeseries import to_datetime +from dimos.robot.unitree_webrtc.testing.multimock import Multimock + + +@pytest.mark.needsdata +@pytest.mark.vis +def test_multimock_stream(): + backpressure(Multimock("athens_odom").stream().pipe(ops.map(Odometry.from_msg))).subscribe( + lambda x: print(x) + ) + map = Map() + + def lidarmsg(msg): + frame = LidarMessage.from_msg(msg) + map.add_frame(frame) + return [map, map.costmap.smudge()] + + mapstream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) + show3d_stream(mapstream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() + time.sleep(5) + + +@pytest.mark.needsdata +def test_clock_mismatch(): + for odometry_raw in Multimock("athens_odom").iterate(): + print( + odometry_raw.ts - to_datetime(odometry_raw.data["data"]["header"]["stamp"]), + odometry_raw.data["data"]["header"]["stamp"], + ) + + +@pytest.mark.needsdata +def test_odom_stream(): + for odometry_raw in Multimock("athens_odom").iterate(): + print(Odometry.from_msg(odometry_raw.data)) + + +@pytest.mark.needsdata +def test_lidar_stream(): + for lidar_raw in Multimock("athens_lidar").iterate(): + lidarmsg = LidarMessage.from_msg(lidar_raw.data) + print(lidarmsg) + print(lidar_raw) + + +@pytest.mark.needsdata +def test_multimock_timeseries(): + odom = Odometry.from_msg(Multimock("athens_odom").load_one(1).data) + lidar_raw = Multimock("athens_lidar").load_one(1).data + lidar = LidarMessage.from_msg(lidar_raw) + map = Map() + map.add_frame(lidar) + print(odom) + print(lidar) + print(lidar_raw) + print(map.costmap) + + +@pytest.mark.needsdata +def test_origin_changes(): + for lidar_raw in Multimock("athens_lidar").iterate(): + print(LidarMessage.from_msg(lidar_raw.data).origin) + + +@pytest.mark.needsdata +@pytest.mark.vis +def test_webui_multistream(): + websocket_vis = WebsocketVis() + websocket_vis.start() + + odom_stream = Multimock("athens_odom").stream().pipe(ops.map(Odometry.from_msg)) + lidar_stream = backpressure( + Multimock("athens_lidar").stream().pipe(ops.map(LidarMessage.from_msg)) + ) + + map = Map() + map_stream = map.consume(lidar_stream) + + costmap_stream = map_stream.pipe( + ops.map(lambda x: ["costmap", map.costmap.smudge(preserve_unknown=False)]) + ) + + websocket_vis.connect(costmap_stream) + websocket_vis.connect(odom_stream.pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) + + show3d_stream(lidar_stream, clearframe=True).run() diff --git a/build/lib/dimos/robot/unitree_webrtc/type/__init__.py b/build/lib/dimos/robot/unitree_webrtc/type/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/robot/unitree_webrtc/type/lidar.py b/build/lib/dimos/robot/unitree_webrtc/type/lidar.py new file mode 100644 index 0000000000..f45cb8dfe7 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/lidar.py @@ -0,0 +1,138 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import copy +from typing import List, Optional, TypedDict + +import numpy as np +import open3d as o3d + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.robot.unitree_webrtc.type.timeseries import to_human_readable +from dimos.types.costmap import Costmap, pointcloud_to_costmap +from dimos.types.vector import Vector + + +class RawLidarPoints(TypedDict): + points: np.ndarray # Shape (N, 3) array of 3D points [x, y, z] + + +class RawLidarData(TypedDict): + """Data portion of the LIDAR message""" + + frame_id: str + origin: List[float] + resolution: float + src_size: int + stamp: float + width: List[int] + data: RawLidarPoints + + +class RawLidarMsg(TypedDict): + """Static type definition for raw LIDAR message""" + + type: str + topic: str + data: RawLidarData + + +class LidarMessage(PointCloud2): + resolution: float # we lose resolution when encoding PointCloud2 + origin: Vector3 + raw_msg: Optional[RawLidarMsg] + _costmap: Optional[Costmap] = None + + def __init__(self, **kwargs): + super().__init__( + pointcloud=kwargs.get("pointcloud"), + ts=kwargs.get("ts"), + frame_id="lidar", + ) + + self.origin = kwargs.get("origin") + self.resolution = kwargs.get("resolution") + + @classmethod + def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg) -> "LidarMessage": + data = raw_message["data"] + points = data["data"]["points"] + pointcloud = o3d.geometry.PointCloud() + pointcloud.points = o3d.utility.Vector3dVector(points) + + origin = Vector3(data["origin"]) + # webrtc decoding via native decompression doesn't require us + # to shift the pointcloud by it's origin + # + # pointcloud.translate((origin / 2).to_tuple()) + + return cls( + origin=origin, + resolution=data["resolution"], + pointcloud=pointcloud, + ts=data["stamp"], + raw_msg=raw_message, + ) + + def to_pointcloud2(self) -> PointCloud2: + """Convert to PointCloud2 message format.""" + return PointCloud2( + pointcloud=self.pointcloud, + frame_id=self.frame_id, + ts=self.ts, + ) + + def __repr__(self): + return f"LidarMessage(ts={to_human_readable(self.ts)}, origin={self.origin}, resolution={self.resolution}, {self.pointcloud})" + + def __iadd__(self, other: "LidarMessage") -> "LidarMessage": + self.pointcloud += other.pointcloud + return self + + def __add__(self, other: "LidarMessage") -> "LidarMessage": + # Determine which message is more recent + if self.ts >= other.ts: + ts = self.ts + origin = self.origin + resolution = self.resolution + else: + ts = other.ts + origin = other.origin + resolution = other.resolution + + # Return a new LidarMessage with combined data + return LidarMessage( + ts=ts, + origin=origin, + resolution=resolution, + pointcloud=self.pointcloud + other.pointcloud, + ).estimate_normals() + + @property + def o3d_geometry(self): + return self.pointcloud + + def costmap(self, voxel_size: float = 0.2) -> Costmap: + if not self._costmap: + down_sampled_pointcloud = self.pointcloud.voxel_down_sample(voxel_size=voxel_size) + inflate_radius_m = 1.0 * voxel_size if voxel_size > self.resolution else 0.0 + grid, origin_xy = pointcloud_to_costmap( + down_sampled_pointcloud, + resolution=self.resolution, + inflate_radius_m=inflate_radius_m, + ) + self._costmap = Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.resolution) + + return self._costmap diff --git a/build/lib/dimos/robot/unitree_webrtc/type/lowstate.py b/build/lib/dimos/robot/unitree_webrtc/type/lowstate.py new file mode 100644 index 0000000000..9c4d8edee5 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/lowstate.py @@ -0,0 +1,93 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypedDict, List, Literal + +raw_odom_msg_sample = { + "type": "msg", + "topic": "rt/lf/lowstate", + "data": { + "imu_state": {"rpy": [0.008086, -0.007515, 2.981771]}, + "motor_state": [ + {"q": 0.098092, "temperature": 40, "lost": 0, "reserve": [0, 674]}, + {"q": 0.757921, "temperature": 32, "lost": 0, "reserve": [0, 674]}, + {"q": -1.490911, "temperature": 38, "lost": 6, "reserve": [0, 674]}, + {"q": -0.072477, "temperature": 42, "lost": 0, "reserve": [0, 674]}, + {"q": 1.020276, "temperature": 32, "lost": 5, "reserve": [0, 674]}, + {"q": -2.007172, "temperature": 38, "lost": 5, "reserve": [0, 674]}, + {"q": 0.071382, "temperature": 50, "lost": 5, "reserve": [0, 674]}, + {"q": 0.963379, "temperature": 36, "lost": 6, "reserve": [0, 674]}, + {"q": -1.978311, "temperature": 40, "lost": 5, "reserve": [0, 674]}, + {"q": -0.051066, "temperature": 48, "lost": 0, "reserve": [0, 674]}, + {"q": 0.73103, "temperature": 34, "lost": 10, "reserve": [0, 674]}, + {"q": -1.466473, "temperature": 38, "lost": 6, "reserve": [0, 674]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + ], + "bms_state": { + "version_high": 1, + "version_low": 18, + "soc": 55, + "current": -2481, + "cycle": 56, + "bq_ntc": [30, 29], + "mcu_ntc": [33, 32], + }, + "foot_force": [97, 84, 81, 81], + "temperature_ntc1": 48, + "power_v": 28.331045, + }, +} + + +class MotorState(TypedDict): + q: float + temperature: int + lost: int + reserve: List[int] + + +class ImuState(TypedDict): + rpy: List[float] + + +class BmsState(TypedDict): + version_high: int + version_low: int + soc: int + current: int + cycle: int + bq_ntc: List[int] + mcu_ntc: List[int] + + +class LowStateData(TypedDict): + imu_state: ImuState + motor_state: List[MotorState] + bms_state: BmsState + foot_force: List[int] + temperature_ntc1: int + power_v: float + + +class LowStateMsg(TypedDict): + type: Literal["msg"] + topic: str + data: LowStateData diff --git a/build/lib/dimos/robot/unitree_webrtc/type/map.py b/build/lib/dimos/robot/unitree_webrtc/type/map.py new file mode 100644 index 0000000000..898bd473b5 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/map.py @@ -0,0 +1,150 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple +import time +import numpy as np +import open3d as o3d +import reactivex.operators as ops +from reactivex import interval +from reactivex.observable import Observable + +from dimos.core import In, Module, Out, rpc +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.types.costmap import Costmap, pointcloud_to_costmap + + +class Map(Module): + lidar: In[LidarMessage] = None + global_map: Out[LidarMessage] = None + pointcloud: o3d.geometry.PointCloud = o3d.geometry.PointCloud() + + def __init__( + self, + voxel_size: float = 0.05, + cost_resolution: float = 0.05, + global_publish_interval: Optional[float] = None, + **kwargs, + ): + self.voxel_size = voxel_size + self.cost_resolution = cost_resolution + self.global_publish_interval = global_publish_interval + super().__init__(**kwargs) + + @rpc + def start(self): + self.lidar.subscribe(self.add_frame) + + if self.global_publish_interval is not None: + interval(self.global_publish_interval).subscribe( + lambda _: self.global_map.publish(self.to_lidar_message()) + ) + + def to_lidar_message(self) -> LidarMessage: + return LidarMessage( + pointcloud=self.pointcloud, + origin=[0.0, 0.0, 0.0], + resolution=self.voxel_size, + ts=time.time(), + ) + + @rpc + def add_frame(self, frame: LidarMessage) -> "Map": + """Voxelise *frame* and splice it into the running map.""" + new_pct = frame.pointcloud.voxel_down_sample(voxel_size=self.voxel_size) + self.pointcloud = splice_cylinder(self.pointcloud, new_pct, shrink=0.5) + + def consume(self, observable: Observable[LidarMessage]) -> Observable["Map"]: + """Reactive operator that folds a stream of `LidarMessage` into the map.""" + return observable.pipe(ops.map(self.add_frame)) + + @property + def o3d_geometry(self) -> o3d.geometry.PointCloud: + return self.pointcloud + + @rpc + def costmap(self) -> Costmap: + """Return a fully inflated cost-map in a `Costmap` wrapper.""" + inflate_radius_m = 0.5 * self.voxel_size if self.voxel_size > self.cost_resolution else 0.0 + grid, origin_xy = pointcloud_to_costmap( + self.pointcloud, + resolution=self.cost_resolution, + inflate_radius_m=inflate_radius_m, + ) + + return Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.cost_resolution) + + +def splice_sphere( + map_pcd: o3d.geometry.PointCloud, + patch_pcd: o3d.geometry.PointCloud, + shrink: float = 0.95, +) -> o3d.geometry.PointCloud: + center = patch_pcd.get_center() + radius = np.linalg.norm(np.asarray(patch_pcd.points) - center, axis=1).max() * shrink + dists = np.linalg.norm(np.asarray(map_pcd.points) - center, axis=1) + victims = np.nonzero(dists < radius)[0] + survivors = map_pcd.select_by_index(victims, invert=True) + return survivors + patch_pcd + + +def splice_cylinder( + map_pcd: o3d.geometry.PointCloud, + patch_pcd: o3d.geometry.PointCloud, + axis: int = 2, + shrink: float = 0.95, +) -> o3d.geometry.PointCloud: + center = patch_pcd.get_center() + patch_pts = np.asarray(patch_pcd.points) + + # Axes perpendicular to cylinder + axes = [0, 1, 2] + axes.remove(axis) + + planar_dists = np.linalg.norm(patch_pts[:, axes] - center[axes], axis=1) + radius = planar_dists.max() * shrink + + axis_min = (patch_pts[:, axis].min() - center[axis]) * shrink + center[axis] + axis_max = (patch_pts[:, axis].max() - center[axis]) * shrink + center[axis] + + map_pts = np.asarray(map_pcd.points) + planar_dists_map = np.linalg.norm(map_pts[:, axes] - center[axes], axis=1) + + victims = np.nonzero( + (planar_dists_map < radius) + & (map_pts[:, axis] >= axis_min) + & (map_pts[:, axis] <= axis_max) + )[0] + + survivors = map_pcd.select_by_index(victims, invert=True) + return survivors + patch_pcd + + +def _inflate_lethal(costmap: np.ndarray, radius: int, lethal_val: int = 100) -> np.ndarray: + """Return *costmap* with lethal cells dilated by *radius* grid steps (circular).""" + if radius <= 0 or not np.any(costmap == lethal_val): + return costmap + + mask = costmap == lethal_val + dilated = mask.copy() + for dy in range(-radius, radius + 1): + for dx in range(-radius, radius + 1): + if dx * dx + dy * dy > radius * radius or (dx == 0 and dy == 0): + continue + dilated |= np.roll(mask, shift=(dy, dx), axis=(0, 1)) + + out = costmap.copy() + out[dilated] = lethal_val + return out diff --git a/build/lib/dimos/robot/unitree_webrtc/type/odometry.py b/build/lib/dimos/robot/unitree_webrtc/type/odometry.py new file mode 100644 index 0000000000..76def232e4 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/odometry.py @@ -0,0 +1,108 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from datetime import datetime +from io import BytesIO +from typing import BinaryIO, Literal, TypeAlias, TypedDict + +from scipy.spatial.transform import Rotation as R + +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.robot.unitree_webrtc.type.timeseries import ( + EpochLike, + Timestamped, + to_datetime, + to_human_readable, +) +from dimos.types.timestamped import to_timestamp +from dimos.types.vector import Vector, VectorLike + +raw_odometry_msg_sample = { + "type": "msg", + "topic": "rt/utlidar/robot_pose", + "data": { + "header": {"stamp": {"sec": 1746565669, "nanosec": 448350564}, "frame_id": "odom"}, + "pose": { + "position": {"x": 5.961965, "y": -2.916958, "z": 0.319509}, + "orientation": {"x": 0.002787, "y": -0.000902, "z": -0.970244, "w": -0.242112}, + }, + }, +} + + +class TimeStamp(TypedDict): + sec: int + nanosec: int + + +class Header(TypedDict): + stamp: TimeStamp + frame_id: str + + +class RawPosition(TypedDict): + x: float + y: float + z: float + + +class Orientation(TypedDict): + x: float + y: float + z: float + w: float + + +class PoseData(TypedDict): + position: RawPosition + orientation: Orientation + + +class OdometryData(TypedDict): + header: Header + pose: PoseData + + +class RawOdometryMessage(TypedDict): + type: Literal["msg"] + topic: str + data: OdometryData + + +class Odometry(PoseStamped, Timestamped): + name = "geometry_msgs.PoseStamped" + + @classmethod + def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": + pose = msg["data"]["pose"] + + # Extract position + pos = Vector3( + pose["position"].get("x"), + pose["position"].get("y"), + pose["position"].get("z"), + ) + + rot = Quaternion( + pose["orientation"].get("x"), + pose["orientation"].get("y"), + pose["orientation"].get("z"), + pose["orientation"].get("w"), + ) + + ts = to_timestamp(msg["data"]["header"]["stamp"]) + return Odometry(position=pos, orientation=rot, ts=ts, frame_id="lidar") + + def __repr__(self) -> str: + return f"Odom pos({self.position}), rot({self.orientation})" diff --git a/build/lib/dimos/robot/unitree_webrtc/type/test_lidar.py b/build/lib/dimos/robot/unitree_webrtc/type/test_lidar.py new file mode 100644 index 0000000000..912740a71a --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/test_lidar.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import time + +import pytest + +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils.testing import SensorReplay + + +def test_init(): + lidar = SensorReplay("office_lidar") + + for raw_frame in itertools.islice(lidar.iterate(), 5): + assert isinstance(raw_frame, dict) + frame = LidarMessage.from_msg(raw_frame) + assert isinstance(frame, LidarMessage) + data = frame.to_pointcloud2().lcm_encode() + assert len(data) > 0 + assert isinstance(data, bytes) + + +@pytest.mark.tool +def test_publish(): + lcm = LCM() + lcm.start() + + topic = Topic(topic="/lidar", lcm_type=PointCloud2) + lidar = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + while True: + for frame in lidar.iterate(): + print(frame) + lcm.publish(topic, frame.to_pointcloud2()) + time.sleep(0.1) diff --git a/build/lib/dimos/robot/unitree_webrtc/type/test_map.py b/build/lib/dimos/robot/unitree_webrtc/type/test_map.py new file mode 100644 index 0000000000..d705bb965b --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/test_map.py @@ -0,0 +1,80 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.robot.unitree_webrtc.testing.helpers import show3d, show3d_stream +from dimos.robot.unitree_webrtc.testing.mock import Mock +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map, splice_sphere +from dimos.utils.reactive import backpressure +from dimos.utils.testing import SensorReplay + + +@pytest.mark.vis +def test_costmap_vis(): + map = Map() + for frame in Mock("office").iterate(): + print(frame) + map.add_frame(frame) + costmap = map.costmap + print(costmap) + show3d(costmap.smudge().pointcloud, title="Costmap").run() + + +@pytest.mark.vis +def test_reconstruction_with_realtime_vis(): + show3d_stream(Map().consume(Mock("office").stream(rate_hz=60.0)), clearframe=True).run() + + +@pytest.mark.vis +def test_splice_vis(): + mock = Mock("test") + target = mock.load("a") + insert = mock.load("b") + show3d(splice_sphere(target.pointcloud, insert.pointcloud, shrink=0.7)).run() + + +@pytest.mark.vis +def test_robot_vis(): + show3d_stream( + Map().consume(backpressure(Mock("office").stream())), + clearframe=True, + title="gloal dynamic map test", + ) + + +def test_robot_mapping(): + lidar_stream = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + map = Map(voxel_size=0.5) + + # this will block until map has consumed the whole stream + map.consume(lidar_stream.stream()).run() + + # we investigate built map + costmap = map.costmap() + + assert costmap.grid.shape == (404, 276) + + assert 70 <= costmap.unknown_percent <= 80, ( + f"Unknown percent {costmap.unknown_percent} is not within the range 70-80" + ) + + assert 5 < costmap.free_percent < 10, ( + f"Free percent {costmap.free_percent} is not within the range 5-10" + ) + + assert 8 < costmap.occupied_percent < 15, ( + f"Occupied percent {costmap.occupied_percent} is not within the range 8-15" + ) diff --git a/build/lib/dimos/robot/unitree_webrtc/type/test_odometry.py b/build/lib/dimos/robot/unitree_webrtc/type/test_odometry.py new file mode 100644 index 0000000000..0bd76f1900 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/test_odometry.py @@ -0,0 +1,109 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import threading +from operator import add, sub +from typing import Optional + +import pytest +import reactivex.operators as ops +from dotenv import load_dotenv + +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.testing import SensorReplay, SensorStorage + +_EXPECTED_TOTAL_RAD = -4.05212 + + +def test_dataset_size() -> None: + """Ensure the replay contains the expected number of messages.""" + assert sum(1 for _ in SensorReplay(name="raw_odometry_rotate_walk").iterate()) == 179 + + +def test_odometry_conversion_and_count() -> None: + """Each replay entry converts to :class:`Odometry` and count is correct.""" + for raw in SensorReplay(name="raw_odometry_rotate_walk").iterate(): + odom = Odometry.from_msg(raw) + assert isinstance(raw, dict) + assert isinstance(odom, Odometry) + + +def test_last_yaw_value() -> None: + """Verify yaw of the final message (regression guard).""" + last_msg = SensorReplay(name="raw_odometry_rotate_walk").stream().pipe(ops.last()).run() + + assert last_msg is not None, "Replay is empty" + assert last_msg["data"]["pose"]["orientation"] == { + "x": 0.01077, + "y": 0.008505, + "z": 0.499171, + "w": -0.866395, + } + + +def test_total_rotation_travel_iterate() -> None: + total_rad = 0.0 + prev_yaw: Optional[float] = None + + for odom in SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg).iterate(): + yaw = odom.orientation.radians.z + if prev_yaw is not None: + diff = yaw - prev_yaw + total_rad += diff + prev_yaw = yaw + + assert total_rad == pytest.approx(_EXPECTED_TOTAL_RAD, abs=0.001) + + +def test_total_rotation_travel_rxpy() -> None: + total_rad = ( + SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg) + .stream() + .pipe( + ops.map(lambda odom: odom.orientation.radians.z), + ops.pairwise(), # [1,2,3,4] -> [[1,2], [2,3], [3,4]] + ops.starmap(sub), # [sub(1,2), sub(2,3), sub(3,4)] + ops.reduce(add), + ) + .run() + ) + + assert total_rad == pytest.approx(4.05, abs=0.01) + + +# data collection tool +@pytest.mark.tool +def test_store_odometry_stream() -> None: + from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 + + load_dotenv() + + robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") + robot.standup() + + storage = SensorStorage("raw_odometry_rotate_walk") + storage.save_stream(robot.raw_odom_stream()) + + shutdown = threading.Event() + + try: + while not shutdown.wait(0.1): + pass + except KeyboardInterrupt: + shutdown.set() + finally: + robot.liedown() diff --git a/build/lib/dimos/robot/unitree_webrtc/type/test_timeseries.py b/build/lib/dimos/robot/unitree_webrtc/type/test_timeseries.py new file mode 100644 index 0000000000..fe96d75eaf --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/test_timeseries.py @@ -0,0 +1,44 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta, datetime +from dimos.robot.unitree_webrtc.type.timeseries import TEvent, TList + + +fixed_date = datetime(2025, 5, 13, 15, 2, 5).astimezone() +start_event = TEvent(fixed_date, 1) +end_event = TEvent(fixed_date + timedelta(seconds=10), 9) + +sample_list = TList([start_event, TEvent(fixed_date + timedelta(seconds=2), 5), end_event]) + + +def test_repr(): + assert ( + str(sample_list) + == "Timeseries(date=2025-05-13, start=15:02:05, end=15:02:15, duration=0:00:10, events=3, freq=0.30Hz)" + ) + + +def test_equals(): + assert start_event == TEvent(start_event.ts, 1) + assert start_event != TEvent(start_event.ts, 2) + assert start_event != TEvent(start_event.ts + timedelta(seconds=1), 1) + + +def test_range(): + assert sample_list.time_range() == (start_event.ts, end_event.ts) + + +def test_duration(): + assert sample_list.duration() == timedelta(seconds=10) diff --git a/build/lib/dimos/robot/unitree_webrtc/type/timeseries.py b/build/lib/dimos/robot/unitree_webrtc/type/timeseries.py new file mode 100644 index 0000000000..48dfddcac5 --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/timeseries.py @@ -0,0 +1,146 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone +from typing import Generic, Iterable, Tuple, TypedDict, TypeVar, Union + +PAYLOAD = TypeVar("PAYLOAD") + + +class RosStamp(TypedDict): + sec: int + nanosec: int + + +EpochLike = Union[int, float, datetime, RosStamp] + + +def from_ros_stamp(stamp: dict[str, int], tz: timezone = None) -> datetime: + """Convert ROS-style timestamp {'sec': int, 'nanosec': int} to datetime.""" + return datetime.fromtimestamp(stamp["sec"] + stamp["nanosec"] / 1e9, tz=tz) + + +def to_human_readable(ts: EpochLike) -> str: + dt = to_datetime(ts) + return dt.strftime("%Y-%m-%d %H:%M:%S") + + +def to_datetime(ts: EpochLike, tz: timezone = None) -> datetime: + if isinstance(ts, datetime): + # if ts.tzinfo is None: + # ts = ts.astimezone(tz) + return ts + if isinstance(ts, (int, float)): + return datetime.fromtimestamp(ts, tz=tz) + if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: + return datetime.fromtimestamp(ts["sec"] + ts["nanosec"] / 1e9, tz=tz) + raise TypeError("unsupported timestamp type") + + +class Timestamped(ABC): + """Abstract class for an event with a timestamp.""" + + ts: datetime + + def __init__(self, ts: EpochLike): + self.ts = to_datetime(ts) + + +class TEvent(Timestamped, Generic[PAYLOAD]): + """Concrete class for an event with a timestamp and data.""" + + def __init__(self, timestamp: EpochLike, data: PAYLOAD): + super().__init__(timestamp) + self.data = data + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TEvent): + return NotImplemented + return self.ts == other.ts and self.data == other.data + + def __repr__(self) -> str: + return f"TEvent(ts={self.ts}, data={self.data})" + + +EVENT = TypeVar("EVENT", bound=Timestamped) # any object that is a subclass of Timestamped + + +class Timeseries(ABC, Generic[EVENT]): + """Abstract class for an iterable of events with timestamps.""" + + @abstractmethod + def __iter__(self) -> Iterable[EVENT]: ... + + @property + def start_time(self) -> datetime: + """Return the timestamp of the earliest event, assuming the data is sorted.""" + return next(iter(self)).ts + + @property + def end_time(self) -> datetime: + """Return the timestamp of the latest event, assuming the data is sorted.""" + return next(reversed(list(self))).ts + + @property + def frequency(self) -> float: + """Calculate the frequency of events in Hz.""" + return len(list(self)) / (self.duration().total_seconds() or 1) + + def time_range(self) -> Tuple[datetime, datetime]: + """Return (earliest_ts, latest_ts). Empty input ⇒ ValueError.""" + return self.start_time, self.end_time + + def duration(self) -> timedelta: + """Total time spanned by the iterable (Δ = last - first).""" + return self.end_time - self.start_time + + def closest_to(self, timestamp: EpochLike) -> EVENT: + """Return the event closest to the given timestamp. Assumes timeseries is sorted.""" + print("closest to", timestamp) + target = to_datetime(timestamp) + print("converted to", target) + target_ts = target.timestamp() + + closest = None + min_dist = float("inf") + + for event in self: + dist = abs(event.ts - target_ts) + if dist > min_dist: + break + + min_dist = dist + closest = event + + print(f"closest: {closest}") + return closest + + def __repr__(self) -> str: + """Return a string representation of the Timeseries.""" + return f"Timeseries(date={self.start_time.strftime('%Y-%m-%d')}, start={self.start_time.strftime('%H:%M:%S')}, end={self.end_time.strftime('%H:%M:%S')}, duration={self.duration()}, events={len(list(self))}, freq={self.frequency:.2f}Hz)" + + def __str__(self) -> str: + """Return a string representation of the Timeseries.""" + return self.__repr__() + + +class TList(list[EVENT], Timeseries[EVENT]): + """A test class that inherits from both list and Timeseries.""" + + def __repr__(self) -> str: + """Return a string representation of the TList using Timeseries repr method.""" + return Timeseries.__repr__(self) diff --git a/build/lib/dimos/robot/unitree_webrtc/type/vector.py b/build/lib/dimos/robot/unitree_webrtc/type/vector.py new file mode 100644 index 0000000000..22b00a753d --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/type/vector.py @@ -0,0 +1,448 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import ( + Tuple, + List, + TypeVar, + Protocol, + runtime_checkable, + Any, + Iterable, + Union, +) +from numpy.typing import NDArray + +T = TypeVar("T", bound="Vector") + + +class Vector: + """A wrapper around numpy arrays for vector operations with intuitive syntax.""" + + def __init__(self, *args: Any) -> None: + """Initialize a vector from components or another iterable. + + Examples: + Vector(1, 2) # 2D vector + Vector(1, 2, 3) # 3D vector + Vector([1, 2, 3]) # From list + Vector(np.array([1, 2, 3])) # From numpy array + """ + if len(args) == 1 and hasattr(args[0], "__iter__"): + self._data = np.array(args[0], dtype=float) + elif len(args) == 1: + self._data = np.array([args[0].x, args[0].y, args[0].z], dtype=float) + + else: + self._data = np.array(args, dtype=float) + + @property + def yaw(self) -> float: + return self.x + + @property + def tuple(self) -> Tuple[float, ...]: + """Tuple representation of the vector.""" + return tuple(self._data) + + @property + def x(self) -> float: + """X component of the vector.""" + return self._data[0] if len(self._data) > 0 else 0.0 + + @property + def y(self) -> float: + """Y component of the vector.""" + return self._data[1] if len(self._data) > 1 else 0.0 + + @property + def z(self) -> float: + """Z component of the vector.""" + return self._data[2] if len(self._data) > 2 else 0.0 + + @property + def dim(self) -> int: + """Dimensionality of the vector.""" + return len(self._data) + + @property + def data(self) -> NDArray[np.float64]: + """Get the underlying numpy array.""" + return self._data + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, idx: int) -> float: + return float(self._data[idx]) + + def __iter__(self) -> Iterable[float]: + return iter(self._data) + + def __repr__(self) -> str: + components = ",".join(f"{x:.6g}" for x in self._data) + return f"({components})" + + def __str__(self) -> str: + if self.dim < 2: + return self.__repr__() + + def getArrow() -> str: + repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] + + if self.y == 0 and self.x == 0: + return "·" + + # Calculate angle in radians and convert to directional index + angle = np.arctan2(self.y, self.x) + # Map angle to 0-7 index (8 directions) with proper orientation + dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) + # Get directional arrow symbol + return repr[dir_index] + + return f"{getArrow()} Vector {self.__repr__()}" + + def serialize(self) -> dict: + """Serialize the vector to a dictionary.""" + return {"type": "vector", "c": self._data.tolist()} + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Vector): + return np.array_equal(self._data, other._data) + return np.array_equal(self._data, np.array(other, dtype=float)) + + def __add__(self: T, other: Union["Vector", Iterable[float]]) -> T: + if isinstance(other, Vector): + return self.__class__(self._data + other._data) + return self.__class__(self._data + np.array(other, dtype=float)) + + def __sub__(self: T, other: Union["Vector", Iterable[float]]) -> T: + if isinstance(other, Vector): + return self.__class__(self._data - other._data) + return self.__class__(self._data - np.array(other, dtype=float)) + + def __mul__(self: T, scalar: float) -> T: + return self.__class__(self._data * scalar) + + def __rmul__(self: T, scalar: float) -> T: + return self.__mul__(scalar) + + def __truediv__(self: T, scalar: float) -> T: + return self.__class__(self._data / scalar) + + def __neg__(self: T) -> T: + return self.__class__(-self._data) + + def dot(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute dot product.""" + if isinstance(other, Vector): + return float(np.dot(self._data, other._data)) + return float(np.dot(self._data, np.array(other, dtype=float))) + + def cross(self: T, other: Union["Vector", Iterable[float]]) -> T: + """Compute cross product (3D vectors only).""" + if self.dim != 3: + raise ValueError("Cross product is only defined for 3D vectors") + + if isinstance(other, Vector): + other_data = other._data + else: + other_data = np.array(other, dtype=float) + + if len(other_data) != 3: + raise ValueError("Cross product requires two 3D vectors") + + return self.__class__(np.cross(self._data, other_data)) + + def length(self) -> float: + """Compute the Euclidean length (magnitude) of the vector.""" + return float(np.linalg.norm(self._data)) + + def length_squared(self) -> float: + """Compute the squared length of the vector (faster than length()).""" + return float(np.sum(self._data * self._data)) + + def normalize(self: T) -> T: + """Return a normalized unit vector in the same direction.""" + length = self.length() + if length < 1e-10: # Avoid division by near-zero + return self.__class__(np.zeros_like(self._data)) + return self.__class__(self._data / length) + + def to_2d(self: T) -> T: + """Convert a vector to a 2D vector by taking only the x and y components.""" + return self.__class__(self._data[:2]) + + def distance(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute Euclidean distance to another vector.""" + if isinstance(other, Vector): + return float(np.linalg.norm(self._data - other._data)) + return float(np.linalg.norm(self._data - np.array(other, dtype=float))) + + def distance_squared(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute squared Euclidean distance to another vector (faster than distance()).""" + if isinstance(other, Vector): + diff = self._data - other._data + else: + diff = self._data - np.array(other, dtype=float) + return float(np.sum(diff * diff)) + + def angle(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute the angle (in radians) between this vector and another.""" + if self.length() < 1e-10 or (isinstance(other, Vector) and other.length() < 1e-10): + return 0.0 + + if isinstance(other, Vector): + other_data = other._data + else: + other_data = np.array(other, dtype=float) + + cos_angle = np.clip( + np.dot(self._data, other_data) + / (np.linalg.norm(self._data) * np.linalg.norm(other_data)), + -1.0, + 1.0, + ) + return float(np.arccos(cos_angle)) + + def project(self: T, onto: Union["Vector", Iterable[float]]) -> T: + """Project this vector onto another vector.""" + if isinstance(onto, Vector): + onto_data = onto._data + else: + onto_data = np.array(onto, dtype=float) + + onto_length_sq = np.sum(onto_data * onto_data) + if onto_length_sq < 1e-10: + return self.__class__(np.zeros_like(self._data)) + + scalar_projection = np.dot(self._data, onto_data) / onto_length_sq + return self.__class__(scalar_projection * onto_data) + + # this is here to test ros_observable_topic + # doesn't happen irl afaik that we want a vector from ros message + @classmethod + def from_msg(cls: type[T], msg: Any) -> T: + return cls(*msg) + + @classmethod + def zeros(cls: type[T], dim: int) -> T: + """Create a zero vector of given dimension.""" + return cls(np.zeros(dim)) + + @classmethod + def ones(cls: type[T], dim: int) -> T: + """Create a vector of ones with given dimension.""" + return cls(np.ones(dim)) + + @classmethod + def unit_x(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the x direction.""" + v = np.zeros(dim) + v[0] = 1.0 + return cls(v) + + @classmethod + def unit_y(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the y direction.""" + v = np.zeros(dim) + v[1] = 1.0 + return cls(v) + + @classmethod + def unit_z(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the z direction.""" + v = np.zeros(dim) + if dim > 2: + v[2] = 1.0 + return cls(v) + + def to_list(self) -> List[float]: + """Convert the vector to a list.""" + return [float(x) for x in self._data] + + def to_tuple(self) -> Tuple[float, ...]: + """Convert the vector to a tuple.""" + return tuple(self._data) + + def to_numpy(self) -> NDArray[np.float64]: + """Convert the vector to a numpy array.""" + return self._data + + +# Protocol approach for static type checking +@runtime_checkable +class VectorLike(Protocol): + """Protocol for types that can be treated as vectors.""" + + def __getitem__(self, key: int) -> float: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterable[float]: ... + + +def to_numpy(value: VectorLike) -> NDArray[np.float64]: + """Convert a vector-compatible value to a numpy array. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Numpy array representation + """ + if isinstance(value, Vector): + return value.data + elif isinstance(value, np.ndarray): + return value + else: + return np.array(value, dtype=float) + + +def to_vector(value: VectorLike) -> Vector: + """Convert a vector-compatible value to a Vector object. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Vector object + """ + if isinstance(value, Vector): + return value + else: + return Vector(value) + + +def to_tuple(value: VectorLike) -> Tuple[float, ...]: + """Convert a vector-compatible value to a tuple. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Tuple of floats + """ + if isinstance(value, Vector): + return tuple(float(x) for x in value.data) + elif isinstance(value, np.ndarray): + return tuple(float(x) for x in value) + elif isinstance(value, tuple): + return tuple(float(x) for x in value) + else: + # Convert to list first to ensure we have an indexable sequence + data = [value[i] for i in range(len(value))] + return tuple(float(x) for x in data) + + +def to_list(value: VectorLike) -> List[float]: + """Convert a vector-compatible value to a list. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + List of floats + """ + if isinstance(value, Vector): + return [float(x) for x in value.data] + elif isinstance(value, np.ndarray): + return [float(x) for x in value] + elif isinstance(value, list): + return [float(x) for x in value] + else: + # Convert to list using indexing + return [float(value[i]) for i in range(len(value))] + + +# Helper functions to check dimensionality +def is_2d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 2D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 2D + """ + if isinstance(value, Vector): + return len(value) == 2 + elif isinstance(value, np.ndarray): + return value.shape[-1] == 2 or value.size == 2 + else: + return len(value) == 2 + + +def is_3d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 3D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 3D + """ + if isinstance(value, Vector): + return len(value) == 3 + elif isinstance(value, np.ndarray): + return value.shape[-1] == 3 or value.size == 3 + else: + return len(value) == 3 + + +# Extraction functions for XYZ components +def x(value: VectorLike) -> float: + """Get the X component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + X component as a float + """ + if isinstance(value, Vector): + return value.x + else: + return float(to_numpy(value)[0]) + + +def y(value: VectorLike) -> float: + """Get the Y component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Y component as a float + """ + if isinstance(value, Vector): + return value.y + else: + arr = to_numpy(value) + return float(arr[1]) if len(arr) > 1 else 0.0 + + +def z(value: VectorLike) -> float: + """Get the Z component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Z component as a float + """ + if isinstance(value, Vector): + return value.z + else: + arr = to_numpy(value) + return float(arr[2]) if len(arr) > 2 else 0.0 diff --git a/build/lib/dimos/robot/unitree_webrtc/unitree_go2.py b/build/lib/dimos/robot/unitree_webrtc/unitree_go2.py new file mode 100644 index 0000000000..94676bfffc --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/unitree_go2.py @@ -0,0 +1,224 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Optional, List +import time +import numpy as np +import os +from dimos.robot.robot import Robot +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.connection import WebRTCRobot +from dimos.robot.global_planner.planner import AstarPlanner +from dimos.utils.reactive import getter_streaming +from dimos.skills.skills import AbstractRobotSkill, SkillLibrary +from go2_webrtc_driver.constants import VUI_COLOR +from go2_webrtc_driver.webrtc_driver import WebRTCConnectionMethod +from dimos.perception.person_tracker import PersonTrackingStream +from dimos.perception.object_tracker import ObjectTrackingStream +from dimos.robot.local_planner.local_planner import navigate_path_local +from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner +from dimos.types.robot_capabilities import RobotCapability +from dimos.types.vector import Vector +from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills +from dimos.robot.frontier_exploration.qwen_frontier_predictor import QwenFrontierPredictor +from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( + WavefrontFrontierExplorer, +) +import threading + + +class Color(VUI_COLOR): ... + + +class UnitreeGo2(Robot): + def __init__( + self, + ip: str, + mode: str = "ai", + output_dir: str = os.path.join(os.getcwd(), "assets", "output"), + skill_library: SkillLibrary = None, + robot_capabilities: List[RobotCapability] = None, + spatial_memory_collection: str = "spatial_memory", + new_memory: bool = True, + enable_perception: bool = True, + ): + """Initialize Unitree Go2 robot with WebRTC control interface. + + Args: + ip: IP address of the robot + mode: Robot mode (ai, etc.) + output_dir: Directory for output files + skill_library: Skill library instance + robot_capabilities: List of robot capabilities + spatial_memory_collection: Collection name for spatial memory + new_memory: Whether to create new spatial memory + enable_perception: Whether to enable perception streams and spatial memory + """ + # Create WebRTC connection interface + self.webrtc_connection = WebRTCRobot( + ip=ip, + mode=mode, + ) + + print("standing up") + self.webrtc_connection.standup() + + # Initialize WebRTC-specific features + self.lidar_stream = self.webrtc_connection.lidar_stream() + self.odom = getter_streaming(self.webrtc_connection.odom_stream()) + self.map = Map(voxel_size=0.2) + self.map_stream = self.map.consume(self.lidar_stream) + self.lidar_message = getter_streaming(self.lidar_stream) + + if skill_library is None: + skill_library = MyUnitreeSkills() + + # Initialize base robot with connection interface + super().__init__( + connection_interface=self.webrtc_connection, + output_dir=output_dir, + skill_library=skill_library, + capabilities=robot_capabilities + or [ + RobotCapability.LOCOMOTION, + RobotCapability.VISION, + RobotCapability.AUDIO, + ], + spatial_memory_collection=spatial_memory_collection, + new_memory=new_memory, + enable_perception=enable_perception, + ) + + if self.skill_library is not None: + for skill in self.skill_library: + if isinstance(skill, AbstractRobotSkill): + self.skill_library.create_instance(skill.__name__, robot=self) + if isinstance(self.skill_library, MyUnitreeSkills): + self.skill_library._robot = self + self.skill_library.init() + self.skill_library.initialize_skills() + + # Camera configuration + self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] + self.camera_pitch = np.deg2rad(0) # negative for downward pitch + self.camera_height = 0.44 # meters + + # Initialize visual servoing using connection interface + video_stream = self.get_video_stream() + if video_stream is not None and enable_perception: + self.person_tracker = PersonTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + self.object_tracker = ObjectTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + person_tracking_stream = self.person_tracker.create_stream(video_stream) + object_tracking_stream = self.object_tracker.create_stream(video_stream) + + self.person_tracking_stream = person_tracking_stream + self.object_tracking_stream = object_tracking_stream + else: + # Video stream not available or perception disabled + self.person_tracker = None + self.object_tracker = None + self.person_tracking_stream = None + self.object_tracking_stream = None + + self.global_planner = AstarPlanner( + set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( + self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event + ), + get_costmap=lambda: self.map.costmap, + get_robot_pos=lambda: self.odom().pos, + ) + + # Initialize the local planner using WebRTC-specific methods + self.local_planner = VFHPurePursuitPlanner( + get_costmap=lambda: self.lidar_message().costmap(), + get_robot_pose=lambda: self.odom(), + move=self.move, # Use the robot's move method directly + robot_width=0.36, # Unitree Go2 width in meters + robot_length=0.6, # Unitree Go2 length in meters + max_linear_vel=0.7, + max_angular_vel=0.65, + lookahead_distance=1.5, + visualization_size=500, # 500x500 pixel visualization + global_planner_plan=self.global_planner.plan, + ) + + # Initialize frontier exploration + self.frontier_explorer = WavefrontFrontierExplorer( + set_goal=self.global_planner.set_goal, + get_costmap=lambda: self.map.costmap, + get_robot_pos=lambda: self.odom().pos, + ) + + # Create the visualization stream at 5Hz + self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) + + def get_pose(self) -> dict: + """ + Get the current pose (position and rotation) of the robot in the map frame. + + Returns: + Dictionary containing: + - position: Vector (x, y, z) + - rotation: Vector (roll, pitch, yaw) in radians + """ + position = Vector(self.odom().pos.x, self.odom().pos.y, self.odom().pos.z) + orientation = Vector(self.odom().rot.x, self.odom().rot.y, self.odom().rot.z) + return {"position": position, "rotation": orientation} + + def explore(self, stop_event: Optional[threading.Event] = None) -> bool: + """ + Start autonomous frontier exploration. + + Args: + stop_event: Optional threading.Event to signal when exploration should stop + + Returns: + bool: True if exploration completed successfully, False if stopped or failed + """ + return self.frontier_explorer.explore(stop_event=stop_event) + + def odom_stream(self): + """Get the odometry stream from the robot. + + Returns: + Observable stream of robot odometry data containing position and orientation. + """ + return self.webrtc_connection.odom_stream() + + def standup(self): + """Make the robot stand up. + + Uses AI mode standup if robot is in AI mode, otherwise uses normal standup. + """ + return self.webrtc_connection.standup() + + def liedown(self): + """Make the robot lie down. + + Commands the robot to lie down on the ground. + """ + return self.webrtc_connection.liedown() + + @property + def costmap(self): + """Access to the costmap for navigation.""" + return self.map.costmap diff --git a/build/lib/dimos/robot/unitree_webrtc/unitree_skills.py b/build/lib/dimos/robot/unitree_webrtc/unitree_skills.py new file mode 100644 index 0000000000..f9dfc1eede --- /dev/null +++ b/build/lib/dimos/robot/unitree_webrtc/unitree_skills.py @@ -0,0 +1,279 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union +import time +from pydantic import Field + +if TYPE_CHECKING: + from dimos.robot.robot import Robot, MockRobot +else: + Robot = "Robot" + MockRobot = "MockRobot" + +from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary +from dimos.types.constants import Colors +from dimos.types.vector import Vector +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD + +# Module-level constant for Unitree WebRTC control definitions +UNITREE_WEBRTC_CONTROLS: List[Tuple[str, int, str]] = [ + ("Damp", 1001, "Lowers the robot to the ground fully."), + ( + "BalanceStand", + 1002, + "Activates a mode that maintains the robot in a balanced standing position.", + ), + ( + "StandUp", + 1004, + "Commands the robot to transition from a sitting or prone position to a standing posture.", + ), + ( + "StandDown", + 1005, + "Instructs the robot to move from a standing position to a sitting or prone posture.", + ), + ( + "RecoveryStand", + 1006, + "Recovers the robot to a state from which it can take more commands. Useful to run after multiple dynamic commands like front flips, Must run after skills like sit and jump and standup.", + ), + ("Sit", 1009, "Commands the robot to sit down from a standing or moving stance."), + ( + "RiseSit", + 1010, + "Commands the robot to rise back to a standing position from a sitting posture.", + ), + ( + "SwitchGait", + 1011, + "Switches the robot's walking pattern or style dynamically, suitable for different terrains or speeds.", + ), + ("Trigger", 1012, "Triggers a specific action or custom routine programmed into the robot."), + ( + "BodyHeight", + 1013, + "Adjusts the height of the robot's body from the ground, useful for navigating various obstacles.", + ), + ( + "FootRaiseHeight", + 1014, + "Controls how high the robot lifts its feet during movement, which can be adjusted for different surfaces.", + ), + ( + "SpeedLevel", + 1015, + "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", + ), + ( + "Hello", + 1016, + "Performs a greeting action, which could involve a wave or other friendly gesture.", + ), + ("Stretch", 1017, "Engages the robot in a stretching routine."), + ( + "TrajectoryFollow", + 1018, + "Directs the robot to follow a predefined trajectory, which could involve complex paths or maneuvers.", + ), + ( + "ContinuousGait", + 1019, + "Enables a mode for continuous walking or running, ideal for long-distance travel.", + ), + ("Content", 1020, "To display or trigger when the robot is happy."), + ("Wallow", 1021, "The robot falls onto its back and rolls around."), + ( + "Dance1", + 1022, + "Performs a predefined dance routine 1, programmed for entertainment or demonstration.", + ), + ("Dance2", 1023, "Performs another variant of a predefined dance routine 2."), + ("GetBodyHeight", 1024, "Retrieves the current height of the robot's body from the ground."), + ( + "GetFootRaiseHeight", + 1025, + "Retrieves the current height at which the robot's feet are being raised during movement.", + ), + ( + "GetSpeedLevel", + 1026, + "Retrieves the current speed level setting of the robot.", + ), + ( + "SwitchJoystick", + 1027, + "Switches the robot's control mode to respond to joystick input for manual operation.", + ), + ( + "Pose", + 1028, + "Commands the robot to assume a specific pose or posture as predefined in its programming.", + ), + ("Scrape", 1029, "The robot performs a scraping motion."), + ( + "FrontFlip", + 1030, + "Commands the robot to perform a front flip, showcasing its agility and dynamic movement capabilities.", + ), + ( + "FrontJump", + 1031, + "Instructs the robot to jump forward, demonstrating its explosive movement capabilities.", + ), + ( + "FrontPounce", + 1032, + "Commands the robot to perform a pouncing motion forward.", + ), + ( + "WiggleHips", + 1033, + "The robot performs a hip wiggling motion, often used for entertainment or demonstration purposes.", + ), + ( + "GetState", + 1034, + "Retrieves the current operational state of the robot, including its mode, position, and status.", + ), + ( + "EconomicGait", + 1035, + "Engages a more energy-efficient walking or running mode to conserve battery life.", + ), + ("FingerHeart", 1036, "Performs a finger heart gesture while on its hind legs."), + ( + "Handstand", + 1301, + "Commands the robot to perform a handstand, demonstrating balance and control.", + ), + ( + "CrossStep", + 1302, + "Commands the robot to perform cross-step movements.", + ), + ( + "OnesidedStep", + 1303, + "Commands the robot to perform one-sided step movements.", + ), + ("Bound", 1304, "Commands the robot to perform bounding movements."), + ("MoonWalk", 1305, "Commands the robot to perform a moonwalk motion."), + ("LeftFlip", 1042, "Executes a flip towards the left side."), + ("RightFlip", 1043, "Performs a flip towards the right side."), + ("Backflip", 1044, "Executes a backflip, a complex and dynamic maneuver."), +] + +# region MyUnitreeSkills + + +class MyUnitreeSkills(SkillLibrary): + """My Unitree Skills for WebRTC interface.""" + + def __init__(self, robot: Optional[Robot] = None): + super().__init__() + self._robot: Robot = None + + # Add dynamic skills to this class + dynamic_skills = self.create_skills_live() + self.register_skills(dynamic_skills) + + @classmethod + def register_skills(cls, skill_classes: Union["AbstractSkill", list["AbstractSkill"]]): + """Add multiple skill classes as class attributes. + + Args: + skill_classes: List of skill classes to add + """ + if not isinstance(skill_classes, list): + skill_classes = [skill_classes] + + for skill_class in skill_classes: + # Add to the class as a skill + setattr(cls, skill_class.__name__, skill_class) + + def initialize_skills(self): + for skill_class in self.get_class_skills(): + self.create_instance(skill_class.__name__, robot=self._robot) + + # Refresh the class skills + self.refresh_class_skills() + + def create_skills_live(self) -> List[AbstractRobotSkill]: + # ================================================ + # Procedurally created skills + # ================================================ + class BaseUnitreeSkill(AbstractRobotSkill): + """Base skill for dynamic skill creation.""" + + def __call__(self): + string = f"{Colors.GREEN_PRINT_COLOR}This is a base skill, created for the specific skill: {self._app_id}{Colors.RESET_COLOR}" + print(string) + super().__call__() + if self._app_id is None: + raise RuntimeError( + f"{Colors.RED_PRINT_COLOR}" + f"No App ID provided to {self.__class__.__name__} Skill" + f"{Colors.RESET_COLOR}" + ) + else: + # Use WebRTC publish_request interface through the robot's webrtc_connection + result = self._robot.webrtc_connection.publish_request( + RTC_TOPIC["SPORT_MOD"], {"api_id": self._app_id} + ) + string = f"{Colors.GREEN_PRINT_COLOR}{self.__class__.__name__} was successful: id={self._app_id}{Colors.RESET_COLOR}" + print(string) + return string + + skills_classes = [] + for name, app_id, description in UNITREE_WEBRTC_CONTROLS: + if name not in ["Reverse", "Spin"]: # Exclude reverse and spin skills + skill_class = type( + name, # Name of the class + (BaseUnitreeSkill,), # Base classes + {"__doc__": description, "_app_id": app_id}, + ) + skills_classes.append(skill_class) + + return skills_classes + + # region Class-based Skills + + class Move(AbstractRobotSkill): + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" + + x: float = Field(..., description="Forward velocity (m/s).") + y: float = Field(default=0.0, description="Left/right velocity (m/s)") + yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") + duration: float = Field(default=0.0, description="How long to move (seconds).") + + def __call__(self): + return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) + + class Wait(AbstractSkill): + """Wait for a specified amount of time.""" + + seconds: float = Field(..., description="Seconds to wait") + + def __call__(self): + time.sleep(self.seconds) + return f"Wait completed with length={self.seconds}s" + + # endregion + + +# endregion diff --git a/build/lib/dimos/simulation/__init__.py b/build/lib/dimos/simulation/__init__.py new file mode 100644 index 0000000000..3d25363b30 --- /dev/null +++ b/build/lib/dimos/simulation/__init__.py @@ -0,0 +1,15 @@ +# Try to import Isaac Sim components +try: + from .isaac import IsaacSimulator, IsaacStream +except ImportError: + IsaacSimulator = None # type: ignore + IsaacStream = None # type: ignore + +# Try to import Genesis components +try: + from .genesis import GenesisSimulator, GenesisStream +except ImportError: + GenesisSimulator = None # type: ignore + GenesisStream = None # type: ignore + +__all__ = ["IsaacSimulator", "IsaacStream", "GenesisSimulator", "GenesisStream"] diff --git a/build/lib/dimos/simulation/base/__init__.py b/build/lib/dimos/simulation/base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/simulation/base/simulator_base.py b/build/lib/dimos/simulation/base/simulator_base.py new file mode 100644 index 0000000000..91633bb53a --- /dev/null +++ b/build/lib/dimos/simulation/base/simulator_base.py @@ -0,0 +1,48 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union, List, Dict +from abc import ABC, abstractmethod + + +class SimulatorBase(ABC): + """Base class for simulators.""" + + @abstractmethod + def __init__( + self, + headless: bool = True, + open_usd: Optional[str] = None, # Keep for Isaac compatibility + entities: Optional[List[Dict[str, Union[str, dict]]]] = None, # Add for Genesis + ): + """Initialize the simulator. + + Args: + headless: Whether to run without visualization + open_usd: Path to USD file (for Isaac) + entities: List of entity configurations (for Genesis) + """ + self.headless = headless + self.open_usd = open_usd + self.stage = None + + @abstractmethod + def get_stage(self): + """Get the current stage/scene.""" + pass + + @abstractmethod + def close(self): + """Close the simulation.""" + pass diff --git a/build/lib/dimos/simulation/base/stream_base.py b/build/lib/dimos/simulation/base/stream_base.py new file mode 100644 index 0000000000..d20af296e2 --- /dev/null +++ b/build/lib/dimos/simulation/base/stream_base.py @@ -0,0 +1,116 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Literal, Optional, Union +from pathlib import Path +import subprocess + +AnnotatorType = Literal["rgb", "normals", "bounding_box_3d", "motion_vectors"] +TransportType = Literal["tcp", "udp"] + + +class StreamBase(ABC): + """Base class for simulation streaming.""" + + @abstractmethod + def __init__( + self, + simulator, + width: int = 1920, + height: int = 1080, + fps: int = 60, + camera_path: str = "/World/camera", + annotator_type: AnnotatorType = "rgb", + transport: TransportType = "tcp", + rtsp_url: str = "rtsp://mediamtx:8554/stream", + usd_path: Optional[Union[str, Path]] = None, + ): + """Initialize the stream. + + Args: + simulator: Simulator instance + width: Stream width in pixels + height: Stream height in pixels + fps: Frames per second + camera_path: Camera path in scene + annotator: Type of annotator to use + transport: Transport protocol + rtsp_url: RTSP stream URL + usd_path: Optional USD file path to load + """ + self.simulator = simulator + self.width = width + self.height = height + self.fps = fps + self.camera_path = camera_path + self.annotator_type = annotator_type + self.transport = transport + self.rtsp_url = rtsp_url + self.proc = None + + @abstractmethod + def _load_stage(self, usd_path: Union[str, Path]): + """Load stage from file.""" + pass + + @abstractmethod + def _setup_camera(self): + """Setup and validate camera.""" + pass + + def _setup_ffmpeg(self): + """Setup FFmpeg process for streaming.""" + command = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-pix_fmt", + "bgr24", + "-s", + f"{self.width}x{self.height}", + "-r", + str(self.fps), + "-i", + "-", + "-an", + "-c:v", + "h264_nvenc", + "-preset", + "fast", + "-f", + "rtsp", + "-rtsp_transport", + self.transport, + self.rtsp_url, + ] + self.proc = subprocess.Popen(command, stdin=subprocess.PIPE) + + @abstractmethod + def _setup_annotator(self): + """Setup annotator.""" + pass + + @abstractmethod + def stream(self): + """Start streaming.""" + pass + + @abstractmethod + def cleanup(self): + """Cleanup resources.""" + pass diff --git a/build/lib/dimos/simulation/genesis/__init__.py b/build/lib/dimos/simulation/genesis/__init__.py new file mode 100644 index 0000000000..5657d9167b --- /dev/null +++ b/build/lib/dimos/simulation/genesis/__init__.py @@ -0,0 +1,4 @@ +from .simulator import GenesisSimulator +from .stream import GenesisStream + +__all__ = ["GenesisSimulator", "GenesisStream"] diff --git a/build/lib/dimos/simulation/genesis/simulator.py b/build/lib/dimos/simulation/genesis/simulator.py new file mode 100644 index 0000000000..e531e6b422 --- /dev/null +++ b/build/lib/dimos/simulation/genesis/simulator.py @@ -0,0 +1,158 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union, List, Dict +import genesis as gs # type: ignore +from ..base.simulator_base import SimulatorBase + + +class GenesisSimulator(SimulatorBase): + """Genesis simulator implementation.""" + + def __init__( + self, + headless: bool = True, + open_usd: Optional[str] = None, # Keep for compatibility + entities: Optional[List[Dict[str, Union[str, dict]]]] = None, + ): + """Initialize the Genesis simulation. + + Args: + headless: Whether to run without visualization + open_usd: Path to USD file (for Isaac) + entities: List of entity configurations to load. Each entity is a dict with: + - type: str ('mesh', 'urdf', 'mjcf', 'primitive') + - path: str (file path for mesh/urdf/mjcf) + - params: dict (parameters for primitives or loading options) + """ + super().__init__(headless, open_usd, entities) + + # Initialize Genesis + gs.init() + + # Create scene with viewer options + self.scene = gs.Scene( + show_viewer=not headless, + viewer_options=gs.options.ViewerOptions( + res=(1280, 960), + camera_pos=(3.5, 0.0, 2.5), + camera_lookat=(0.0, 0.0, 0.5), + camera_fov=40, + max_FPS=60, + ), + vis_options=gs.options.VisOptions( + show_world_frame=True, + world_frame_size=1.0, + show_link_frame=False, + show_cameras=False, + plane_reflection=True, + ambient_light=(0.1, 0.1, 0.1), + ), + renderer=gs.renderers.Rasterizer(), + ) + + # Handle USD parameter for compatibility + if open_usd: + print(f"[Warning] USD files not supported in Genesis. Ignoring: {open_usd}") + + # Load entities if provided + if entities: + self._load_entities(entities) + + # Don't build scene yet - let stream add camera first + self.is_built = False + + def _load_entities(self, entities: List[Dict[str, Union[str, dict]]]): + """Load multiple entities into the scene.""" + for entity in entities: + entity_type = entity.get("type", "").lower() + path = entity.get("path", "") + params = entity.get("params", {}) + + try: + if entity_type == "mesh": + mesh = gs.morphs.Mesh( + file=path, # Explicit file argument + **params, + ) + self.scene.add_entity(mesh) + print(f"[Genesis] Added mesh from {path}") + + elif entity_type == "urdf": + robot = gs.morphs.URDF( + file=path, # Explicit file argument + **params, + ) + self.scene.add_entity(robot) + print(f"[Genesis] Added URDF robot from {path}") + + elif entity_type == "mjcf": + mujoco = gs.morphs.MJCF( + file=path, # Explicit file argument + **params, + ) + self.scene.add_entity(mujoco) + print(f"[Genesis] Added MJCF model from {path}") + + elif entity_type == "primitive": + shape_type = params.pop("shape", "plane") + if shape_type == "plane": + morph = gs.morphs.Plane(**params) + elif shape_type == "box": + morph = gs.morphs.Box(**params) + elif shape_type == "sphere": + morph = gs.morphs.Sphere(**params) + else: + raise ValueError(f"Unsupported primitive shape: {shape_type}") + + # Add position if not specified + if "pos" not in params: + if shape_type == "plane": + morph.pos = [0, 0, 0] + else: + morph.pos = [0, 0, 1] # Lift objects above ground + + self.scene.add_entity(morph) + print(f"[Genesis] Added {shape_type} at position {morph.pos}") + + else: + raise ValueError(f"Unsupported entity type: {entity_type}") + + except Exception as e: + print(f"[Warning] Failed to load entity {entity}: {str(e)}") + + def add_entity(self, entity_type: str, path: str = "", **params): + """Add a single entity to the scene. + + Args: + entity_type: Type of entity ('mesh', 'urdf', 'mjcf', 'primitive') + path: File path for mesh/urdf/mjcf entities + **params: Additional parameters for entity creation + """ + self._load_entities([{"type": entity_type, "path": path, "params": params}]) + + def get_stage(self): + """Get the current stage/scene.""" + return self.scene + + def build(self): + """Build the scene if not already built.""" + if not self.is_built: + self.scene.build() + self.is_built = True + + def close(self): + """Close the simulation.""" + # Genesis handles cleanup automatically + pass diff --git a/build/lib/dimos/simulation/genesis/stream.py b/build/lib/dimos/simulation/genesis/stream.py new file mode 100644 index 0000000000..fbb70fea13 --- /dev/null +++ b/build/lib/dimos/simulation/genesis/stream.py @@ -0,0 +1,143 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import time +from typing import Optional, Union +from pathlib import Path +from ..base.stream_base import StreamBase, AnnotatorType, TransportType + + +class GenesisStream(StreamBase): + """Genesis stream implementation.""" + + def __init__( + self, + simulator, + width: int = 1920, + height: int = 1080, + fps: int = 60, + camera_path: str = "/camera", + annotator_type: AnnotatorType = "rgb", + transport: TransportType = "tcp", + rtsp_url: str = "rtsp://mediamtx:8554/stream", + usd_path: Optional[Union[str, Path]] = None, + ): + """Initialize the Genesis stream.""" + super().__init__( + simulator=simulator, + width=width, + height=height, + fps=fps, + camera_path=camera_path, + annotator_type=annotator_type, + transport=transport, + rtsp_url=rtsp_url, + usd_path=usd_path, + ) + + self.scene = simulator.get_stage() + + # Initialize components + if usd_path: + self._load_stage(usd_path) + self._setup_camera() + self._setup_ffmpeg() + self._setup_annotator() + + # Build scene after camera is set up + simulator.build() + + def _load_stage(self, usd_path: Union[str, Path]): + """Load stage from file.""" + # Genesis handles stage loading through simulator + pass + + def _setup_camera(self): + """Setup and validate camera.""" + self.camera = self.scene.add_camera( + res=(self.width, self.height), + pos=(3.5, 0.0, 2.5), + lookat=(0, 0, 0.5), + fov=30, + GUI=False, + ) + + def _setup_annotator(self): + """Setup the specified annotator.""" + # Genesis handles different render types through camera.render() + pass + + def stream(self): + """Start the streaming loop.""" + try: + print("[Stream] Starting Genesis camera stream...") + frame_count = 0 + start_time = time.time() + + while True: + frame_start = time.time() + + # Step simulation and get frame + step_start = time.time() + self.scene.step() + step_time = time.time() - step_start + print(f"[Stream] Simulation step took {step_time * 1000:.2f}ms") + + # Get frame based on annotator type + if self.annotator_type == "rgb": + frame, _, _, _ = self.camera.render(rgb=True) + elif self.annotator_type == "normals": + _, _, _, frame = self.camera.render(normal=True) + else: + frame, _, _, _ = self.camera.render(rgb=True) # Default to RGB + + # Convert frame format if needed + if isinstance(frame, np.ndarray): + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + # Write to FFmpeg + self.proc.stdin.write(frame.tobytes()) + self.proc.stdin.flush() + + # Log metrics + frame_time = time.time() - frame_start + print(f"[Stream] Total frame processing took {frame_time * 1000:.2f}ms") + frame_count += 1 + + if frame_count % 100 == 0: + elapsed_time = time.time() - start_time + current_fps = frame_count / elapsed_time + print( + f"[Stream] Processed {frame_count} frames | Current FPS: {current_fps:.2f}" + ) + + except KeyboardInterrupt: + print("\n[Stream] Received keyboard interrupt, stopping stream...") + finally: + self.cleanup() + + def cleanup(self): + """Cleanup resources.""" + print("[Cleanup] Stopping FFmpeg process...") + if hasattr(self, "proc"): + self.proc.stdin.close() + self.proc.wait() + print("[Cleanup] Closing simulation...") + try: + self.simulator.close() + except AttributeError: + print("[Cleanup] Warning: Could not close simulator properly") + print("[Cleanup] Successfully cleaned up resources") diff --git a/build/lib/dimos/simulation/isaac/__init__.py b/build/lib/dimos/simulation/isaac/__init__.py new file mode 100644 index 0000000000..2b9bdc082d --- /dev/null +++ b/build/lib/dimos/simulation/isaac/__init__.py @@ -0,0 +1,4 @@ +from .simulator import IsaacSimulator +from .stream import IsaacStream + +__all__ = ["IsaacSimulator", "IsaacStream"] diff --git a/build/lib/dimos/simulation/isaac/simulator.py b/build/lib/dimos/simulation/isaac/simulator.py new file mode 100644 index 0000000000..ba6fe319b4 --- /dev/null +++ b/build/lib/dimos/simulation/isaac/simulator.py @@ -0,0 +1,43 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Dict, Union +from isaacsim import SimulationApp +from ..base.simulator_base import SimulatorBase + + +class IsaacSimulator(SimulatorBase): + """Isaac Sim simulator implementation.""" + + def __init__( + self, + headless: bool = True, + open_usd: Optional[str] = None, + entities: Optional[List[Dict[str, Union[str, dict]]]] = None, # Add but ignore + ): + """Initialize the Isaac Sim simulation.""" + super().__init__(headless, open_usd) + self.app = SimulationApp({"headless": headless, "open_usd": open_usd}) + + def get_stage(self): + """Get the current USD stage.""" + import omni.usd + + self.stage = omni.usd.get_context().get_stage() + return self.stage + + def close(self): + """Close the simulation.""" + if hasattr(self, "app"): + self.app.close() diff --git a/build/lib/dimos/simulation/isaac/stream.py b/build/lib/dimos/simulation/isaac/stream.py new file mode 100644 index 0000000000..44560783bd --- /dev/null +++ b/build/lib/dimos/simulation/isaac/stream.py @@ -0,0 +1,136 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import time +from typing import Optional, Union +from pathlib import Path +from ..base.stream_base import StreamBase, AnnotatorType, TransportType + + +class IsaacStream(StreamBase): + """Isaac Sim stream implementation.""" + + def __init__( + self, + simulator, + width: int = 1920, + height: int = 1080, + fps: int = 60, + camera_path: str = "/World/alfred_parent_prim/alfred_base_descr/chest_cam_rgb_camera_frame/chest_cam", + annotator_type: AnnotatorType = "rgb", + transport: TransportType = "tcp", + rtsp_url: str = "rtsp://mediamtx:8554/stream", + usd_path: Optional[Union[str, Path]] = None, + ): + """Initialize the Isaac Sim stream.""" + super().__init__( + simulator=simulator, + width=width, + height=height, + fps=fps, + camera_path=camera_path, + annotator_type=annotator_type, + transport=transport, + rtsp_url=rtsp_url, + usd_path=usd_path, + ) + + # Import omni.replicator after SimulationApp initialization + import omni.replicator.core as rep + + self.rep = rep + + # Initialize components + if usd_path: + self._load_stage(usd_path) + self._setup_camera() + self._setup_ffmpeg() + self._setup_annotator() + + def _load_stage(self, usd_path: Union[str, Path]): + """Load USD stage from file.""" + import omni.usd + + abs_path = str(Path(usd_path).resolve()) + omni.usd.get_context().open_stage(abs_path) + self.stage = self.simulator.get_stage() + if not self.stage: + raise RuntimeError(f"Failed to load stage: {abs_path}") + + def _setup_camera(self): + """Setup and validate camera.""" + self.stage = self.simulator.get_stage() + camera_prim = self.stage.GetPrimAtPath(self.camera_path) + if not camera_prim: + raise RuntimeError(f"Failed to find camera at path: {self.camera_path}") + + self.render_product = self.rep.create.render_product( + self.camera_path, resolution=(self.width, self.height) + ) + + def _setup_annotator(self): + """Setup the specified annotator.""" + self.annotator = self.rep.AnnotatorRegistry.get_annotator(self.annotator_type) + self.annotator.attach(self.render_product) + + def stream(self): + """Start the streaming loop.""" + try: + print("[Stream] Starting camera stream loop...") + frame_count = 0 + start_time = time.time() + + while True: + frame_start = time.time() + + # Step simulation and get frame + step_start = time.time() + self.rep.orchestrator.step() + step_time = time.time() - step_start + print(f"[Stream] Simulation step took {step_time * 1000:.2f}ms") + + frame = self.annotator.get_data() + frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR) + + # Write to FFmpeg + self.proc.stdin.write(frame.tobytes()) + self.proc.stdin.flush() + + # Log metrics + frame_time = time.time() - frame_start + print(f"[Stream] Total frame processing took {frame_time * 1000:.2f}ms") + frame_count += 1 + + if frame_count % 100 == 0: + elapsed_time = time.time() - start_time + current_fps = frame_count / elapsed_time + print( + f"[Stream] Processed {frame_count} frames | Current FPS: {current_fps:.2f}" + ) + + except KeyboardInterrupt: + print("\n[Stream] Received keyboard interrupt, stopping stream...") + finally: + self.cleanup() + + def cleanup(self): + """Cleanup resources.""" + print("[Cleanup] Stopping FFmpeg process...") + if hasattr(self, "proc"): + self.proc.stdin.close() + self.proc.wait() + print("[Cleanup] Closing simulation...") + self.simulator.close() + print("[Cleanup] Successfully cleaned up resources") diff --git a/build/lib/dimos/skills/__init__.py b/build/lib/dimos/skills/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/skills/kill_skill.py b/build/lib/dimos/skills/kill_skill.py new file mode 100644 index 0000000000..f7eb63e807 --- /dev/null +++ b/build/lib/dimos/skills/kill_skill.py @@ -0,0 +1,62 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Kill skill for terminating running skills. + +This module provides a skill that can terminate other running skills, +particularly those running in separate threads like the monitor skill. +""" + +from typing import Optional +from pydantic import Field + +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.skills.kill_skill") + + +class KillSkill(AbstractSkill): + """ + A skill that terminates other running skills. + + This skill can be used to stop long-running or background skills + like the monitor skill. It uses the centralized process management + in the SkillLibrary to track and terminate skills. + """ + + skill_name: str = Field(..., description="Name of the skill to terminate") + + def __init__(self, skill_library: Optional[SkillLibrary] = None, **data): + """ + Initialize the kill skill. + + Args: + skill_library: The skill library instance + **data: Additional data for configuration + """ + super().__init__(**data) + self._skill_library = skill_library + + def __call__(self): + """ + Terminate the specified skill. + + Returns: + A message indicating whether the skill was successfully terminated + """ + print("running skills", self._skill_library.get_running_skills()) + # Terminate the skill using the skill library + return self._skill_library.terminate_skill(self.skill_name) diff --git a/build/lib/dimos/skills/navigation.py b/build/lib/dimos/skills/navigation.py new file mode 100644 index 0000000000..6d67ae04f2 --- /dev/null +++ b/build/lib/dimos/skills/navigation.py @@ -0,0 +1,699 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Semantic map skills for building and navigating spatial memory maps. + +This module provides two skills: +1. BuildSemanticMap - Builds a semantic map by recording video frames at different locations +2. Navigate - Queries an existing semantic map using natural language +""" + +import os +import time +import threading +from typing import Optional, Tuple +from dimos.utils.threadpool import get_scheduler + +from reactivex import operators as ops +from pydantic import Field + +from dimos.skills.skills import AbstractRobotSkill +from dimos.types.robot_location import RobotLocation +from dimos.utils.logging_config import setup_logger +from dimos.models.qwen.video_query import get_bbox_from_qwen_frame +from dimos.utils.transform_utils import distance_angle_to_goal_xy +from dimos.robot.local_planner.local_planner import navigate_to_goal_local + +logger = setup_logger("dimos.skills.semantic_map_skills") + + +def get_dimos_base_path(): + """ + Get the DiMOS base path from DIMOS_PATH environment variable or default to user's home directory. + + Returns: + Base path to use for DiMOS assets + """ + dimos_path = os.environ.get("DIMOS_PATH") + if dimos_path: + return dimos_path + # Get the current user's username + user = os.environ.get("USER", os.path.basename(os.path.expanduser("~"))) + return f"/home/{user}/dimos" + + +class NavigateWithText(AbstractRobotSkill): + """ + A skill that queries an existing semantic map using natural language or tries to navigate to an object in view. + + This skill first attempts to locate an object in the robot's camera view using vision. + If the object is found, it navigates to it. If not, it falls back to querying the + semantic map for a location matching the description. For example, "Find the Teddy Bear" + will first look for a Teddy Bear in view, then check the semantic map coordinates where + a Teddy Bear was previously observed. + + CALL THIS SKILL FOR ONE SUBJECT AT A TIME. For example: "Go to the person wearing a blue shirt in the living room", + you should call this skill twice, once for the person wearing a blue shirt and once for the living room. + + If skip_visual_search is True, this skill will skip the visual search for the object in view. + This is useful if you want to navigate to a general location such as a kitchen or office. + For example, "Go to the kitchen" will not look for a kitchen in view, but will check the semantic map coordinates where + a kitchen was previously observed. + """ + + query: str = Field("", description="Text query to search for in the semantic map") + + limit: int = Field(1, description="Maximum number of results to return") + distance: float = Field(1.0, description="Desired distance to maintain from object in meters") + skip_visual_search: bool = Field(False, description="Skip visual search for object in view") + timeout: float = Field(40.0, description="Maximum time to spend navigating in seconds") + + def __init__(self, robot=None, **data): + """ + Initialize the Navigate skill. + + Args: + robot: The robot instance + **data: Additional data for configuration + """ + super().__init__(robot=robot, **data) + self._stop_event = threading.Event() + self._spatial_memory = None + self._scheduler = get_scheduler() # Use the shared DiMOS thread pool + self._navigation_disposable = None # Disposable returned by scheduler.schedule() + self._tracking_subscriber = None # For object tracking + self._similarity_threshold = 0.25 + + def _navigate_to_object(self): + """ + Helper method that attempts to navigate to an object visible in the camera view. + + Returns: + dict: Result dictionary with success status and details + """ + # Stop any existing operation + self._stop_event.clear() + + try: + logger.warning( + f"Attempting to navigate to visible object: {self.query} with desired distance {self.distance}m, timeout {self.timeout} seconds..." + ) + + # Try to get a bounding box from Qwen - only try once + bbox = None + try: + # Use the robot's existing video stream instead of creating a new one + frame = self._robot.get_video_stream().pipe(ops.take(1)).run() + # Use the frame-based function + bbox, object_size = get_bbox_from_qwen_frame(frame, object_name=self.query) + except Exception as e: + logger.error(f"Error querying Qwen: {e}") + return { + "success": False, + "failure_reason": "Perception", + "error": f"Could not detect {self.query} in view: {e}", + } + + if bbox is None or self._stop_event.is_set(): + logger.error(f"Failed to get bounding box for {self.query}") + return { + "success": False, + "failure_reason": "Perception", + "error": f"Could not find {self.query} in view", + } + + logger.info(f"Found {self.query} at {bbox} with size {object_size}") + + # Start the object tracker with the detected bbox + self._robot.object_tracker.track(bbox, frame=frame) + + # Get the first tracking data with valid distance and angle + start_time = time.time() + target_acquired = False + goal_x_robot = 0 + goal_y_robot = 0 + goal_angle = 0 + + while ( + time.time() - start_time < 10.0 + and not self._stop_event.is_set() + and not target_acquired + ): + # Get the latest tracking data + tracking_data = self._robot.object_tracking_stream.pipe(ops.take(1)).run() + + if tracking_data and tracking_data.get("targets") and tracking_data["targets"]: + target = tracking_data["targets"][0] + + if "distance" in target and "angle" in target: + # Convert target distance and angle to xy coordinates in robot frame + goal_distance = ( + target["distance"] - self.distance + ) # Subtract desired distance to stop short + goal_angle = -target["angle"] + logger.info(f"Target distance: {goal_distance}, Target angle: {goal_angle}") + + goal_x_robot, goal_y_robot = distance_angle_to_goal_xy( + goal_distance, goal_angle + ) + target_acquired = True + break + + else: + logger.warning("No valid target tracking data found.") + + else: + logger.warning("No valid target tracking data found.") + + time.sleep(0.1) + + if not target_acquired: + logger.error("Failed to acquire valid target tracking data") + return { + "success": False, + "failure_reason": "Perception", + "error": "Failed to track object", + } + + logger.info( + f"Navigating to target at local coordinates: ({goal_x_robot:.2f}, {goal_y_robot:.2f}), angle: {goal_angle:.2f}" + ) + + # Use navigate_to_goal_local instead of directly controlling the local planner + success = navigate_to_goal_local( + robot=self._robot, + goal_xy_robot=(goal_x_robot, goal_y_robot), + goal_theta=goal_angle, + distance=0.0, # We already accounted for desired distance + timeout=self.timeout, + stop_event=self._stop_event, + ) + + if success: + logger.info(f"Successfully navigated to {self.query}") + return { + "success": True, + "failure_reason": None, + "query": self.query, + "message": f"Successfully navigated to {self.query} in view", + } + else: + logger.warning( + f"Failed to reach {self.query} within timeout or operation was stopped" + ) + return { + "success": False, + "failure_reason": "Navigation", + "error": f"Failed to reach {self.query} within timeout", + } + + except Exception as e: + logger.error(f"Error in navigate to object: {e}") + return {"success": False, "failure_reason": "Code Error", "error": f"Error: {e}"} + finally: + # Clean up + self._robot.object_tracker.cleanup() + + def _navigate_using_semantic_map(self): + """ + Helper method that attempts to navigate using the semantic map query. + + Returns: + dict: Result dictionary with success status and details + """ + logger.info(f"Querying semantic map for: '{self.query}'") + + try: + self._spatial_memory = self._robot.get_spatial_memory() + + # Run the query + results = self._spatial_memory.query_by_text(self.query, limit=self.limit) + + if not results: + logger.warning(f"No results found for query: '{self.query}'") + return { + "success": False, + "query": self.query, + "error": "No matching location found in semantic map", + } + + # Get the best match + best_match = results[0] + metadata = best_match.get("metadata", {}) + + if isinstance(metadata, list) and metadata: + metadata = metadata[0] + + # Extract coordinates from metadata + if ( + isinstance(metadata, dict) + and "pos_x" in metadata + and "pos_y" in metadata + and "rot_z" in metadata + ): + pos_x = metadata.get("pos_x", 0) + pos_y = metadata.get("pos_y", 0) + theta = metadata.get("rot_z", 0) + + # Calculate similarity score (distance is inverse of similarity) + similarity = 1.0 - ( + best_match.get("distance", 0) if best_match.get("distance") is not None else 0 + ) + + logger.info( + f"Found match for '{self.query}' at ({pos_x:.2f}, {pos_y:.2f}, rotation {theta:.2f}) with similarity: {similarity:.4f}" + ) + + # Check if similarity is below the threshold + if similarity < self._similarity_threshold: + logger.warning( + f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})" + ) + return { + "success": False, + "query": self.query, + "position": (pos_x, pos_y), + "rotation": theta, + "similarity": similarity, + "error": f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})", + } + + # Reset the stop event before starting navigation + self._stop_event.clear() + + # The scheduler approach isn't working, switch to direct threading + # Define a navigation function that will run on a separate thread + def run_navigation(): + skill_library = self._robot.get_skills() + self.register_as_running("Navigate", skill_library) + + try: + logger.info( + f"Starting navigation to ({pos_x:.2f}, {pos_y:.2f}) with rotation {theta:.2f}" + ) + # Pass our stop_event to allow cancellation + result = False + try: + result = self._robot.global_planner.set_goal( + (pos_x, pos_y), goal_theta=theta, stop_event=self._stop_event + ) + except Exception as e: + logger.error(f"Error calling global_planner.set_goal: {e}") + + if result: + logger.info("Navigation completed successfully") + else: + logger.error("Navigation did not complete successfully") + return result + except Exception as e: + logger.error(f"Unexpected error in navigation thread: {e}") + return False + finally: + self.stop() + + # Cancel any existing navigation before starting a new one + # Signal stop to any running navigation + self._stop_event.set() + # Clear stop event for new navigation + self._stop_event.clear() + + # Run the navigation in the main thread + run_navigation() + + return { + "success": True, + "query": self.query, + "position": (pos_x, pos_y), + "rotation": theta, + "similarity": similarity, + "metadata": metadata, + } + else: + logger.warning(f"No valid position data found for query: '{self.query}'") + return { + "success": False, + "query": self.query, + "error": "No valid position data found in semantic map", + } + + except Exception as e: + logger.error(f"Error in semantic map navigation: {e}") + return {"success": False, "error": f"Semantic map error: {e}"} + + def __call__(self): + """ + First attempts to navigate to an object in view, then falls back to querying the semantic map. + + Returns: + A dictionary with the result of the navigation attempt + """ + super().__call__() + + if not self.query: + error_msg = "No query provided to Navigate skill" + logger.error(error_msg) + return {"success": False, "error": error_msg} + + # First, try to find and navigate to the object in camera view + logger.info(f"First attempting to find and navigate to visible object: '{self.query}'") + + if not self.skip_visual_search: + object_result = self._navigate_to_object() + + if object_result and object_result["success"]: + logger.info(f"Successfully navigated to {self.query} in view") + return object_result + + elif object_result and object_result["failure_reason"] == "Navigation": + logger.info( + f"Failed to navigate to {self.query} in view: {object_result.get('error', 'Unknown error')}" + ) + return object_result + + # If object navigation failed, fall back to semantic map + logger.info( + f"Object not found in view. Falling back to semantic map query for: '{self.query}'" + ) + + return self._navigate_using_semantic_map() + + def stop(self): + """ + Stop the navigation skill and clean up resources. + + Returns: + A message indicating whether the navigation was stopped successfully + """ + logger.info("Stopping Navigate skill") + + # Signal any running processes to stop via the shared event + self._stop_event.set() + + skill_library = self._robot.get_skills() + self.unregister_as_running("Navigate", skill_library) + + # Dispose of any existing navigation task + if hasattr(self, "_navigation_disposable") and self._navigation_disposable: + logger.info("Disposing navigation task") + try: + self._navigation_disposable.dispose() + except Exception as e: + logger.error(f"Error disposing navigation task: {e}") + self._navigation_disposable = None + + return "Navigate skill stopped successfully." + + +class GetPose(AbstractRobotSkill): + """ + A skill that returns the current position and orientation of the robot. + + This skill is useful for getting the current pose of the robot in the map frame. You call this skill + if you want to remember a location, for example, "remember this is where my favorite chair is" and then + call this skill to get the position and rotation of approximately where the chair is. You can then use + the position to navigate to the chair. + + When location_name is provided, this skill will also remember the current location with that name, + allowing you to navigate back to it later using the Navigate skill. + """ + + location_name: str = Field( + "", description="Optional name to assign to this location (e.g., 'kitchen', 'office')" + ) + + def __init__(self, robot=None, **data): + """ + Initialize the GetPose skill. + + Args: + robot: The robot instance + **data: Additional data for configuration + """ + super().__init__(robot=robot, **data) + + def __call__(self): + """ + Get the current pose of the robot. + + Returns: + A dictionary containing the position and rotation of the robot + """ + super().__call__() + + if self._robot is None: + error_msg = "No robot instance provided to GetPose skill" + logger.error(error_msg) + return {"success": False, "error": error_msg} + + try: + # Get the current pose using the robot's get_pose method + pose_data = self._robot.get_pose() + + # Extract position and rotation from the new dictionary format + position = pose_data["position"] + rotation = pose_data["rotation"] + + # Format the response + result = { + "success": True, + "position": { + "x": position.x, + "y": position.y, + "z": position.z, + }, + "rotation": {"roll": rotation.x, "pitch": rotation.y, "yaw": rotation.z}, + } + + # If location_name is provided, remember this location + if self.location_name: + # Get the spatial memory instance + spatial_memory = self._robot.get_spatial_memory() + + # Create a RobotLocation object + location = RobotLocation( + name=self.location_name, + position=(position.x, position.y, position.z), + rotation=(rotation.x, rotation.y, rotation.z), + ) + + # Add to spatial memory + if spatial_memory.add_robot_location(location): + result["location_saved"] = True + result["location_name"] = self.location_name + logger.info(f"Location '{self.location_name}' saved at {position}") + else: + result["location_saved"] = False + logger.error(f"Failed to save location '{self.location_name}'") + + return result + except Exception as e: + error_msg = f"Error getting robot pose: {e}" + logger.error(error_msg) + return {"success": False, "error": error_msg} + + +class NavigateToGoal(AbstractRobotSkill): + """ + A skill that navigates the robot to a specified position and orientation. + + This skill uses the global planner to generate a path to the target position + and then uses navigate_path_local to follow that path, achieving the desired + orientation at the goal position. + """ + + position: Tuple[float, float] = Field( + (0.0, 0.0), description="Target position (x, y) in map frame" + ) + rotation: Optional[float] = Field(None, description="Target orientation (yaw) in radians") + frame: str = Field("map", description="Reference frame for the position and rotation") + timeout: float = Field(120.0, description="Maximum time (in seconds) allowed for navigation") + + def __init__(self, robot=None, **data): + """ + Initialize the NavigateToGoal skill. + + Args: + robot: The robot instance + **data: Additional data for configuration + """ + super().__init__(robot=robot, **data) + self._stop_event = threading.Event() + + def __call__(self): + """ + Navigate to the specified goal position and orientation. + + Returns: + A dictionary containing the result of the navigation attempt + """ + super().__call__() + + if self._robot is None: + error_msg = "No robot instance provided to NavigateToGoal skill" + logger.error(error_msg) + return {"success": False, "error": error_msg} + + # Reset stop event to make sure we don't immediately abort + self._stop_event.clear() + + skill_library = self._robot.get_skills() + self.register_as_running("NavigateToGoal", skill_library) + + logger.info( + f"Starting navigation to position=({self.position[0]:.2f}, {self.position[1]:.2f}) " + f"with rotation={self.rotation if self.rotation is not None else 'None'} " + f"in frame={self.frame}" + ) + + try: + # Use the global planner to set the goal and generate a path + result = self._robot.global_planner.set_goal( + self.position, goal_theta=self.rotation, stop_event=self._stop_event + ) + + if result: + logger.info("Navigation completed successfully") + return { + "success": True, + "position": self.position, + "rotation": self.rotation, + "message": "Goal reached successfully", + } + else: + logger.warning("Navigation did not complete successfully") + return { + "success": False, + "position": self.position, + "rotation": self.rotation, + "message": "Goal could not be reached", + } + + except Exception as e: + error_msg = f"Error during navigation: {e}" + logger.error(error_msg) + return { + "success": False, + "position": self.position, + "rotation": self.rotation, + "error": error_msg, + } + finally: + self.stop() + + def stop(self): + """ + Stop the navigation. + + Returns: + A message indicating that the navigation was stopped + """ + logger.info("Stopping NavigateToGoal") + skill_library = self._robot.get_skills() + self.unregister_as_running("NavigateToGoal", skill_library) + self._stop_event.set() + return "Navigation stopped" + + +class Explore(AbstractRobotSkill): + """ + A skill that performs autonomous frontier exploration. + + This skill continuously finds and navigates to unknown frontiers in the environment + until no more frontiers are found or the exploration is stopped. + + Don't save GetPose locations when frontier exploring. Don't call any other skills except stop skill when needed. + """ + + timeout: float = Field(60.0, description="Maximum time (in seconds) allowed for exploration") + + def __init__(self, robot=None, **data): + """ + Initialize the Explore skill. + + Args: + robot: The robot instance + **data: Additional data for configuration + """ + super().__init__(robot=robot, **data) + self._stop_event = threading.Event() + + def __call__(self): + """ + Start autonomous frontier exploration. + + Returns: + A dictionary containing the result of the exploration + """ + super().__call__() + + if self._robot is None: + error_msg = "No robot instance provided to Explore skill" + logger.error(error_msg) + return {"success": False, "error": error_msg} + + # Reset stop event to make sure we don't immediately abort + self._stop_event.clear() + + skill_library = self._robot.get_skills() + self.register_as_running("Explore", skill_library) + + logger.info("Starting autonomous frontier exploration") + + try: + # Start exploration using the robot's explore method + result = self._robot.explore(stop_event=self._stop_event) + + if result: + logger.info("Exploration completed successfully - no more frontiers found") + return { + "success": True, + "message": "Exploration completed - all accessible areas explored", + } + else: + if self._stop_event.is_set(): + logger.info("Exploration stopped by user") + return { + "success": False, + "message": "Exploration stopped by user", + } + else: + logger.warning("Exploration did not complete successfully") + return { + "success": False, + "message": "Exploration failed or was interrupted", + } + + except Exception as e: + error_msg = f"Error during exploration: {e}" + logger.error(error_msg) + return { + "success": False, + "error": error_msg, + } + finally: + self.stop() + + def stop(self): + """ + Stop the exploration. + + Returns: + A message indicating that the exploration was stopped + """ + logger.info("Stopping Explore") + skill_library = self._robot.get_skills() + self.unregister_as_running("Explore", skill_library) + self._stop_event.set() + return "Exploration stopped" diff --git a/build/lib/dimos/skills/observe.py b/build/lib/dimos/skills/observe.py new file mode 100644 index 0000000000..067307353a --- /dev/null +++ b/build/lib/dimos/skills/observe.py @@ -0,0 +1,192 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Observer skill for an agent. + +This module provides a skill that sends a single image from any +Robot Data Stream to the Qwen VLM for inference and adds the response +to the agent's conversation history. +""" + +import time +from typing import Optional +import base64 +import cv2 +import numpy as np +import reactivex as rx +from reactivex import operators as ops +from pydantic import Field + +from dimos.skills.skills import AbstractRobotSkill +from dimos.agents.agent import LLMAgent +from dimos.models.qwen.video_query import query_single_frame +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.skills.observe") + + +class Observe(AbstractRobotSkill): + """ + A skill that captures a single frame from a Robot Video Stream, sends it to a VLM, + and adds the response to the agent's conversation history. + + This skill is used for visual reasoning, spatial understanding, or any queries involving visual information that require critical thinking. + """ + + query_text: str = Field( + "What do you see in this image? Describe the environment in detail.", + description="Query text to send to the VLM model with the image", + ) + + def __init__(self, robot=None, agent: Optional[LLMAgent] = None, **data): + """ + Initialize the Observe skill. + + Args: + robot: The robot instance + agent: The agent to store results in + **data: Additional data for configuration + """ + super().__init__(robot=robot, **data) + self._agent = agent + self._model_name = "qwen2.5-vl-72b-instruct" + + # Get the video stream from the robot + self._video_stream = self._robot.video_stream + if self._video_stream is None: + logger.error("Failed to get video stream from robot") + + def __call__(self): + """ + Capture a single frame, process it with Qwen, and add the result to conversation history. + + Returns: + A message indicating the observation result + """ + super().__call__() + + if self._agent is None: + error_msg = "No agent provided to Observe skill" + logger.error(error_msg) + return error_msg + + if self._robot is None: + error_msg = "No robot instance provided to Observe skill" + logger.error(error_msg) + return error_msg + + if self._video_stream is None: + error_msg = "No video stream available" + logger.error(error_msg) + return error_msg + + try: + logger.info("Capturing frame for Qwen observation") + + # Get a single frame from the video stream + frame = self._get_frame_from_stream() + + if frame is None: + error_msg = "Failed to capture frame from video stream" + logger.error(error_msg) + return error_msg + + # Process the frame with Qwen + response = self._process_frame_with_qwen(frame) + + logger.info(f"Added Qwen observation to conversation history") + return f"Observation complete: {response}" + + except Exception as e: + error_msg = f"Error in Observe skill: {e}" + logger.error(error_msg) + return error_msg + + def _get_frame_from_stream(self): + """ + Get a single frame from the video stream. + + Returns: + A single frame from the video stream, or None if no frame is available + """ + if self._video_stream is None: + logger.error("Video stream is None") + return None + + frame = None + frame_subject = rx.subject.Subject() + + subscription = self._video_stream.pipe( + ops.take(1) # Take just one frame + ).subscribe( + on_next=lambda x: frame_subject.on_next(x), + on_error=lambda e: logger.error(f"Error getting frame: {e}"), + ) + + # Wait up to 5 seconds for a frame + timeout = 5.0 + start_time = time.time() + + def on_frame(f): + nonlocal frame + frame = f + + frame_subject.subscribe(on_frame) + + while frame is None and time.time() - start_time < timeout: + time.sleep(0.1) + + subscription.dispose() + return frame + + def _process_frame_with_qwen(self, frame): + """ + Process a frame with the Qwen model using query_single_frame. + + Args: + frame: The video frame to process (numpy array) + + Returns: + The response from Qwen + """ + logger.info(f"Processing frame with Qwen model: {self._model_name}") + + try: + # Convert numpy array to PIL Image if needed + from PIL import Image + + if isinstance(frame, np.ndarray): + # OpenCV uses BGR, PIL uses RGB + if frame.shape[-1] == 3: # Check if it has color channels + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + else: + pil_image = Image.fromarray(frame) + else: + pil_image = frame + + # Query Qwen with the frame (direct function call) + response = query_single_frame( + pil_image, + self.query_text, + model_name=self._model_name, + ) + + logger.info(f"Qwen response received: {response[:100]}...") + return response + + except Exception as e: + logger.error(f"Error processing frame with Qwen: {e}") + raise diff --git a/build/lib/dimos/skills/observe_stream.py b/build/lib/dimos/skills/observe_stream.py new file mode 100644 index 0000000000..7b4e08874e --- /dev/null +++ b/build/lib/dimos/skills/observe_stream.py @@ -0,0 +1,245 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Observer skill for an agent. + +This module provides a skill that periodically sends images from any +Robot Data Stream to an agent for inference. +""" + +import time +import threading +from typing import Optional +import base64 +import cv2 +import numpy as np +import reactivex as rx +from reactivex import operators as ops +from pydantic import Field +from PIL import Image + +from dimos.skills.skills import AbstractRobotSkill +from dimos.agents.agent import LLMAgent +from dimos.models.qwen.video_query import query_single_frame +from dimos.utils.threadpool import get_scheduler +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.skills.observe_stream") + + +class ObserveStream(AbstractRobotSkill): + """ + A skill that periodically Observes a Robot Video Stream and sends images to current instance of an agent for context. + + This skill runs in a non-halting manner, allowing other skills to run concurrently. + It can be used for continuous perception and passive monitoring, such as waiting for a person to enter a room + or to monitor changes in the environment. + """ + + timestep: float = Field( + 60.0, description="Time interval in seconds between observation queries" + ) + query_text: str = Field( + "What do you see in this image? Alert me if you see any people or important changes.", + description="Query text to send to agent with each image", + ) + max_duration: float = Field( + 0.0, description="Maximum duration to run the observer in seconds (0 for indefinite)" + ) + + def __init__(self, robot=None, agent: Optional[LLMAgent] = None, video_stream=None, **data): + """ + Initialize the ObserveStream skill. + + Args: + robot: The robot instance + agent: The agent to send queries to + **data: Additional data for configuration + """ + super().__init__(robot=robot, **data) + self._agent = agent + self._stop_event = threading.Event() + self._monitor_thread = None + self._scheduler = get_scheduler() + self._subscription = None + + # Get the video stream + # TODO: Use the video stream provided in the constructor for dynamic video_stream selection by the agent + self._video_stream = self._robot.video_stream + if self._video_stream is None: + logger.error("Failed to get video stream from robot") + return + + def __call__(self): + """ + Start the observing process in a separate thread using the threadpool. + + Returns: + A message indicating the observer has started + """ + super().__call__() + + if self._agent is None: + error_msg = "No agent provided to ObserveStream" + logger.error(error_msg) + return error_msg + + if self._robot is None: + error_msg = "No robot instance provided to ObserveStream" + logger.error(error_msg) + return error_msg + + self.stop() + + self._stop_event.clear() + + # Initialize start time for duration tracking + self._start_time = time.time() + + interval_observable = rx.interval(self.timestep, scheduler=self._scheduler).pipe( + ops.take_while(lambda _: not self._stop_event.is_set()) + ) + + # Subscribe to the interval observable + self._subscription = interval_observable.subscribe( + on_next=self._monitor_iteration, + on_error=lambda e: logger.error(f"Error in monitor observable: {e}"), + on_completed=lambda: logger.info("Monitor observable completed"), + ) + + skill_library = self._robot.get_skills() + self.register_as_running("ObserveStream", skill_library, self._subscription) + + logger.info(f"Observer started with timestep={self.timestep}s, query='{self.query_text}'") + return f"Observer started with timestep={self.timestep}s, query='{self.query_text}'" + + def _monitor_iteration(self, iteration): + """ + Execute a single observer iteration. + + Args: + iteration: The iteration number (provided by rx.interval) + """ + try: + if self.max_duration > 0: + elapsed_time = time.time() - self._start_time + if elapsed_time > self.max_duration: + logger.info(f"Observer reached maximum duration of {self.max_duration}s") + self.stop() + return + + logger.info(f"Observer iteration {iteration} executing") + + # Get a frame from the video stream + frame = self._get_frame_from_stream() + + if frame is not None: + self._process_frame(frame) + else: + logger.warning("Failed to get frame from video stream") + + except Exception as e: + logger.error(f"Error in monitor iteration {iteration}: {e}") + + def _get_frame_from_stream(self): + """ + Get a single frame from the video stream. + + Args: + video_stream: The ROS video stream observable + + Returns: + A single frame from the video stream, or None if no frame is available + """ + frame = None + + frame_subject = rx.subject.Subject() + + subscription = self._video_stream.pipe( + ops.take(1) # Take just one frame + ).subscribe( + on_next=lambda x: frame_subject.on_next(x), + on_error=lambda e: logger.error(f"Error getting frame: {e}"), + ) + + timeout = 5.0 # 5 seconds timeout + start_time = time.time() + + def on_frame(f): + nonlocal frame + frame = f + + frame_subject.subscribe(on_frame) + + while frame is None and time.time() - start_time < timeout: + time.sleep(0.1) + + subscription.dispose() + + return frame + + def _process_frame(self, frame): + """ + Process a frame with the Qwen VLM and add the response to conversation history. + + Args: + frame: The video frame to process + """ + logger.info("Processing frame with Qwen VLM") + + try: + # Convert frame to PIL Image format + if isinstance(frame, np.ndarray): + # OpenCV uses BGR, PIL uses RGB + if frame.shape[-1] == 3: # Check if it has color channels + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + else: + pil_image = Image.fromarray(frame) + else: + pil_image = frame + + # Use Qwen to process the frame + model_name = "qwen2.5-vl-72b-instruct" # Using the most capable model + response = query_single_frame(pil_image, self.query_text, model_name=model_name) + + logger.info(f"Qwen response received: {response[:100]}...") + + # Add the response to the conversation history + # self._agent.append_to_history( + # f"Observation: {response}", + # ) + response = self._agent.run_observable_query(f"Observation: {response}") + + logger.info("Added Qwen observation to conversation history") + + except Exception as e: + logger.error(f"Error processing frame with Qwen VLM: {e}") + + def stop(self): + """ + Stop the ObserveStream monitoring process. + + Returns: + A message indicating the observer has stopped + """ + if self._subscription is not None and not self._subscription.is_disposed: + logger.info("Stopping ObserveStream") + self._stop_event.set() + self._subscription.dispose() + self._subscription = None + + return "Observer stopped" + return "Observer was not running" diff --git a/build/lib/dimos/skills/rest/__init__.py b/build/lib/dimos/skills/rest/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/skills/rest/rest.py b/build/lib/dimos/skills/rest/rest.py new file mode 100644 index 0000000000..3e7c7426cc --- /dev/null +++ b/build/lib/dimos/skills/rest/rest.py @@ -0,0 +1,99 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +from dimos.skills.skills import AbstractSkill +from pydantic import Field +import logging + +logger = logging.getLogger(__name__) + + +class GenericRestSkill(AbstractSkill): + """Performs a configurable REST API call. + + This skill executes an HTTP request based on the provided parameters. It + supports various HTTP methods and allows specifying URL, timeout. + + Attributes: + url: The target URL for the API call. + method: The HTTP method (e.g., 'GET', 'POST'). Case-insensitive. + timeout: Request timeout in seconds. + """ + + # TODO: Add query parameters, request body data (form-encoded or JSON), and headers. + # , query + # parameters, request body data (form-encoded or JSON), and headers. + # params: Optional dictionary of URL query parameters. + # data: Optional dictionary for form-encoded request body data. + # json_payload: Optional dictionary for JSON request body data. Use the + # alias 'json' when initializing. + # headers: Optional dictionary of HTTP headers. + url: str = Field(..., description="The target URL for the API call.") + method: str = Field(..., description="HTTP method (e.g., 'GET', 'POST').") + timeout: int = Field(..., description="Request timeout in seconds.") + # params: Optional[Dict[str, Any]] = Field(default=None, description="URL query parameters.") + # data: Optional[Dict[str, Any]] = Field(default=None, description="Form-encoded request body.") + # json_payload: Optional[Dict[str, Any]] = Field(default=None, alias="json", description="JSON request body.") + # headers: Optional[Dict[str, str]] = Field(default=None, description="HTTP headers.") + + def __call__(self) -> str: + """Executes the configured REST API call. + + Returns: + The text content of the response on success (HTTP 2xx). + + Raises: + requests.exceptions.RequestException: If a connection error, timeout, + or other request-related issue occurs. + requests.exceptions.HTTPError: If the server returns an HTTP 4xx or + 5xx status code. + Exception: For any other unexpected errors during execution. + + Returns: + A string representing the success or failure outcome. If successful, + returns the response body text. If an error occurs, returns a + descriptive error message. + """ + try: + logger.debug( + f"Executing {self.method.upper()} request to {self.url} " + f"with timeout={self.timeout}" # , params={self.params}, " + # f"data={self.data}, json={self.json_payload}, headers={self.headers}" + ) + response = requests.request( + method=self.method.upper(), # Normalize method to uppercase + url=self.url, + # params=self.params, + # data=self.data, + # json=self.json_payload, # Use the attribute name defined in Pydantic + # headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx) + logger.debug( + f"Request successful. Status: {response.status_code}, Response: {response.text[:100]}..." + ) + return response.text # Return text content directly + except requests.exceptions.HTTPError as http_err: + logger.error( + f"HTTP error occurred: {http_err} - Status Code: {http_err.response.status_code}" + ) + return f"HTTP error making {self.method.upper()} request to {self.url}: {http_err.response.status_code} {http_err.response.reason}" + except requests.exceptions.RequestException as req_err: + logger.error(f"Request exception occurred: {req_err}") + return f"Error making {self.method.upper()} request to {self.url}: {req_err}" + except Exception as e: + logger.exception(f"An unexpected error occurred: {e}") # Log the full traceback + return f"An unexpected error occurred: {type(e).__name__}: {e}" diff --git a/build/lib/dimos/skills/skills.py b/build/lib/dimos/skills/skills.py new file mode 100644 index 0000000000..f6c7456d24 --- /dev/null +++ b/build/lib/dimos/skills/skills.py @@ -0,0 +1,340 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Optional +from pydantic import BaseModel +from openai import pydantic_function_tool + +from dimos.types.constants import Colors + +# Configure logging for the module +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# region SkillLibrary + + +class SkillLibrary: + # ==== Flat Skill Library ==== + + def __init__(self): + self.registered_skills: list["AbstractSkill"] = [] + self.class_skills: list["AbstractSkill"] = [] + self._running_skills = {} # {skill_name: (instance, subscription)} + + self.init() + + def init(self): + # Collect all skills from the parent class and update self.skills + self.refresh_class_skills() + + # Temporary + self.registered_skills = self.class_skills.copy() + + def get_class_skills(self) -> list["AbstractSkill"]: + """Extract all AbstractSkill subclasses from a class. + + Returns: + List of skill classes found within the class + """ + skills = [] + + # Loop through all attributes of the class + for attr_name in dir(self.__class__): + # Skip special/dunder attributes + if attr_name.startswith("__"): + continue + + try: + attr = getattr(self.__class__, attr_name) + + # Check if it's a class and inherits from AbstractSkill + if ( + isinstance(attr, type) + and issubclass(attr, AbstractSkill) + and attr is not AbstractSkill + ): + skills.append(attr) + except (AttributeError, TypeError): + # Skip attributes that can't be accessed or aren't classes + continue + + return skills + + def refresh_class_skills(self): + self.class_skills = self.get_class_skills() + + def add(self, skill: "AbstractSkill") -> None: + if skill not in self.registered_skills: + self.registered_skills.append(skill) + + def get(self) -> list["AbstractSkill"]: + return self.registered_skills.copy() + + def remove(self, skill: "AbstractSkill") -> None: + try: + self.registered_skills.remove(skill) + except ValueError: + logger.warning(f"Attempted to remove non-existent skill: {skill}") + + def clear(self) -> None: + self.registered_skills.clear() + + def __iter__(self): + return iter(self.registered_skills) + + def __len__(self) -> int: + return len(self.registered_skills) + + def __contains__(self, skill: "AbstractSkill") -> bool: + return skill in self.registered_skills + + def __getitem__(self, index): + return self.registered_skills[index] + + # ==== Calling a Function ==== + + _instances: dict[str, dict] = {} + + def create_instance(self, name, **kwargs): + # Key based only on the name + key = name + + print(f"Preparing to create instance with name: {name} and args: {kwargs}") + + if key not in self._instances: + # Instead of creating an instance, store the args for later use + self._instances[key] = kwargs + print(f"Stored args for later instance creation: {name} with args: {kwargs}") + + def call(self, name, **args): + try: + # Get the stored args if available; otherwise, use an empty dict + stored_args = self._instances.get(name, {}) + + # Merge the arguments with priority given to stored arguments + complete_args = {**args, **stored_args} + + # Dynamically get the class from the module or current script + skill_class = getattr(self, name, None) + if skill_class is None: + for skill in self.get(): + if name == skill.__name__: + skill_class = skill + break + if skill_class is None: + error_msg = f"Skill '{name}' is not available. Please check if it's properly registered." + logger.error(f"Skill class not found: {name}") + return error_msg + + # Initialize the instance with the merged arguments + instance = skill_class(**complete_args) + print(f"Instance created and function called for: {name} with args: {complete_args}") + + # Call the instance directly + return instance() + except Exception as e: + error_msg = f"Error executing skill '{name}': {str(e)}" + logger.error(error_msg) + return error_msg + + # ==== Tools ==== + + def get_tools(self) -> Any: + tools_json = self.get_list_of_skills_as_json(list_of_skills=self.registered_skills) + # print(f"{Colors.YELLOW_PRINT_COLOR}Tools JSON: {tools_json}{Colors.RESET_COLOR}") + return tools_json + + def get_list_of_skills_as_json(self, list_of_skills: list["AbstractSkill"]) -> list[str]: + return list(map(pydantic_function_tool, list_of_skills)) + + def register_running_skill(self, name: str, instance: Any, subscription=None): + """ + Register a running skill with its subscription. + + Args: + name: Name of the skill (will be converted to lowercase) + instance: Instance of the running skill + subscription: Optional subscription associated with the skill + """ + name = name.lower() + self._running_skills[name] = (instance, subscription) + logger.info(f"Registered running skill: {name}") + + def unregister_running_skill(self, name: str): + """ + Unregister a running skill. + + Args: + name: Name of the skill to remove (will be converted to lowercase) + + Returns: + True if the skill was found and removed, False otherwise + """ + name = name.lower() + if name in self._running_skills: + del self._running_skills[name] + logger.info(f"Unregistered running skill: {name}") + return True + return False + + def get_running_skills(self): + """ + Get all running skills. + + Returns: + A dictionary of running skill names and their (instance, subscription) tuples + """ + return self._running_skills.copy() + + def terminate_skill(self, name: str): + """ + Terminate a running skill. + + Args: + name: Name of the skill to terminate (will be converted to lowercase) + + Returns: + A message indicating whether the skill was successfully terminated + """ + name = name.lower() + if name in self._running_skills: + instance, subscription = self._running_skills[name] + + try: + # Call the stop method if it exists + if hasattr(instance, "stop") and callable(instance.stop): + result = instance.stop() + logger.info(f"Stopped skill: {name}") + else: + logger.warning(f"Skill {name} does not have a stop method") + + # Also dispose the subscription if it exists + if ( + subscription is not None + and hasattr(subscription, "dispose") + and callable(subscription.dispose) + ): + subscription.dispose() + logger.info(f"Disposed subscription for skill: {name}") + elif subscription is not None: + logger.warning(f"Skill {name} has a subscription but it's not disposable") + + # unregister the skill + self.unregister_running_skill(name) + return f"Successfully terminated skill: {name}" + + except Exception as e: + error_msg = f"Error terminating skill {name}: {e}" + logger.error(error_msg) + # Even on error, try to unregister the skill + self.unregister_running_skill(name) + return error_msg + else: + return f"No running skill found with name: {name}" + + +# endregion SkillLibrary + +# region AbstractSkill + + +class AbstractSkill(BaseModel): + def __init__(self, *args, **kwargs): + print("Initializing AbstractSkill Class") + super().__init__(*args, **kwargs) + self._instances = {} + self._list_of_skills = [] # Initialize the list of skills + print(f"Instances: {self._instances}") + + def clone(self) -> "AbstractSkill": + return AbstractSkill() + + def register_as_running(self, name: str, skill_library: SkillLibrary, subscription=None): + """ + Register this skill as running in the skill library. + + Args: + name: Name of the skill (will be converted to lowercase) + skill_library: The skill library to register with + subscription: Optional subscription associated with the skill + """ + skill_library.register_running_skill(name, self, subscription) + + def unregister_as_running(self, name: str, skill_library: SkillLibrary): + """ + Unregister this skill from the skill library. + + Args: + name: Name of the skill to remove (will be converted to lowercase) + skill_library: The skill library to unregister from + """ + skill_library.unregister_running_skill(name) + + # ==== Tools ==== + def get_tools(self) -> Any: + tools_json = self.get_list_of_skills_as_json(list_of_skills=self._list_of_skills) + # print(f"Tools JSON: {tools_json}") + return tools_json + + def get_list_of_skills_as_json(self, list_of_skills: list["AbstractSkill"]) -> list[str]: + return list(map(pydantic_function_tool, list_of_skills)) + + +# endregion AbstractSkill + +# region Abstract Robot Skill + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.robot.robot import Robot +else: + Robot = "Robot" + + +class AbstractRobotSkill(AbstractSkill): + _robot: Robot = None + + def __init__(self, *args, robot: Optional[Robot] = None, **kwargs): + super().__init__(*args, **kwargs) + self._robot = robot + print( + f"{Colors.BLUE_PRINT_COLOR}Robot Skill Initialized with Robot: {robot}{Colors.RESET_COLOR}" + ) + + def set_robot(self, robot: Robot) -> None: + """Set the robot reference for this skills instance. + + Args: + robot: The robot instance to associate with these skills. + """ + self._robot = robot + + def __call__(self): + if self._robot is None: + raise RuntimeError( + f"{Colors.RED_PRINT_COLOR}" + f"No Robot instance provided to Robot Skill: {self.__class__.__name__}" + f"{Colors.RESET_COLOR}" + ) + else: + print( + f"{Colors.BLUE_PRINT_COLOR}Robot Instance provided to Robot Skill: {self.__class__.__name__}{Colors.RESET_COLOR}" + ) + + +# endregion Abstract Robot Skill diff --git a/build/lib/dimos/skills/speak.py b/build/lib/dimos/skills/speak.py new file mode 100644 index 0000000000..e73b9e792a --- /dev/null +++ b/build/lib/dimos/skills/speak.py @@ -0,0 +1,166 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.skills.skills import AbstractSkill +from pydantic import Field +from reactivex import Subject +from typing import Optional, Any, List +import time +import threading +import queue +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.skills.speak") + +# Global lock to prevent multiple simultaneous audio playbacks +_audio_device_lock = threading.RLock() + +# Global queue for sequential audio processing +_audio_queue = queue.Queue() +_queue_processor_thread = None +_queue_running = False + + +def _process_audio_queue(): + """Background thread to process audio requests sequentially""" + global _queue_running + + while _queue_running: + try: + # Get the next queued audio task with a timeout + task = _audio_queue.get(timeout=1.0) + if task is None: # Sentinel value to stop the thread + break + + # Execute the task (which is a function to be called) + task() + _audio_queue.task_done() + + except queue.Empty: + # No tasks in queue, just continue waiting + continue + except Exception as e: + logger.error(f"Error in audio queue processor: {e}") + # Continue processing other tasks + + +def start_audio_queue_processor(): + """Start the background thread for processing audio requests""" + global _queue_processor_thread, _queue_running + + if _queue_processor_thread is None or not _queue_processor_thread.is_alive(): + _queue_running = True + _queue_processor_thread = threading.Thread( + target=_process_audio_queue, daemon=True, name="AudioQueueProcessor" + ) + _queue_processor_thread.start() + logger.info("Started audio queue processor thread") + + +# Start the queue processor when module is imported +start_audio_queue_processor() + + +class Speak(AbstractSkill): + """Speak text out loud to humans nearby or to other robots.""" + + text: str = Field(..., description="Text to speak") + + def __init__(self, tts_node: Optional[Any] = None, **data): + super().__init__(**data) + self._tts_node = tts_node + self._audio_complete = threading.Event() + self._subscription = None + self._subscriptions: List = [] # Track all subscriptions + + def __call__(self): + if not self._tts_node: + logger.error("No TTS node provided to Speak skill") + return "Error: No TTS node available" + + # Create a result queue to get the result back from the audio thread + result_queue = queue.Queue(1) + + # Define the speech task to run in the audio queue + def speak_task(): + try: + # Using a lock to ensure exclusive access to audio device + with _audio_device_lock: + text_subject = Subject() + self._audio_complete.clear() + self._subscriptions = [] + + # This function will be called when audio processing is complete + def on_complete(): + logger.info(f"TTS audio playback completed for: {self.text}") + self._audio_complete.set() + + # This function will be called if there's an error + def on_error(error): + logger.error(f"Error in TTS processing: {error}") + self._audio_complete.set() + + # Connect the Subject to the TTS node and keep the subscription + self._tts_node.consume_text(text_subject) + + # Subscribe to the audio output to know when it's done + self._subscription = self._tts_node.emit_text().subscribe( + on_next=lambda text: logger.debug(f"TTS processing: {text}"), + on_completed=on_complete, + on_error=on_error, + ) + self._subscriptions.append(self._subscription) + + # Emit the text to the Subject + text_subject.on_next(self.text) + text_subject.on_completed() # Signal that we're done sending text + + # Wait for audio playback to complete with a timeout + # Using a dynamic timeout based on text length + timeout = max(5, len(self.text) * 0.1) + logger.debug(f"Waiting for TTS completion with timeout {timeout:.1f}s") + + if not self._audio_complete.wait(timeout=timeout): + logger.warning(f"TTS timeout reached for: {self.text}") + else: + # Add a small delay after audio completes to ensure buffers are fully flushed + time.sleep(0.3) + + # Clean up all subscriptions + for sub in self._subscriptions: + if sub: + sub.dispose() + self._subscriptions = [] + + # Successfully completed + result_queue.put(f"Spoke: {self.text} successfully") + except Exception as e: + logger.error(f"Error in speak task: {e}") + result_queue.put(f"Error speaking text: {str(e)}") + + # Add our speech task to the global queue for sequential processing + display_text = self.text[:50] + "..." if len(self.text) > 50 else self.text + logger.info(f"Queueing speech task: '{display_text}'") + _audio_queue.put(speak_task) + + # Wait for the result with a timeout + try: + # Use a longer timeout than the audio playback itself + text_len_timeout = len(self.text) * 0.15 # 150ms per character + max_timeout = max(10, text_len_timeout) # At least 10 seconds + + return result_queue.get(timeout=max_timeout) + except queue.Empty: + logger.error("Timed out waiting for speech task to complete") + return f"Error: Timed out while speaking: {self.text}" diff --git a/build/lib/dimos/skills/unitree/__init__.py b/build/lib/dimos/skills/unitree/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/build/lib/dimos/skills/unitree/__init__.py @@ -0,0 +1 @@ + diff --git a/build/lib/dimos/skills/unitree/unitree_speak.py b/build/lib/dimos/skills/unitree/unitree_speak.py new file mode 100644 index 0000000000..05004398f9 --- /dev/null +++ b/build/lib/dimos/skills/unitree/unitree_speak.py @@ -0,0 +1,280 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.skills.skills import AbstractRobotSkill +from pydantic import Field +import time +import tempfile +import os +import json +import base64 +import hashlib +import soundfile as sf +import numpy as np +from openai import OpenAI +from dimos.utils.logging_config import setup_logger +from go2_webrtc_driver.constants import RTC_TOPIC + +logger = setup_logger("dimos.skills.unitree.unitree_speak") + +# Audio API constants (from go2_webrtc_driver) +AUDIO_API = { + "GET_AUDIO_LIST": 1001, + "SELECT_START_PLAY": 1002, + "PAUSE": 1003, + "UNSUSPEND": 1004, + "SET_PLAY_MODE": 1007, + "UPLOAD_AUDIO_FILE": 2001, + "ENTER_MEGAPHONE": 4001, + "EXIT_MEGAPHONE": 4002, + "UPLOAD_MEGAPHONE": 4003, +} + +PLAY_MODES = {"NO_CYCLE": "no_cycle", "SINGLE_CYCLE": "single_cycle", "LIST_LOOP": "list_loop"} + + +class UnitreeSpeak(AbstractRobotSkill): + """Speak text out loud through the robot's speakers using WebRTC audio upload.""" + + text: str = Field(..., description="Text to speak") + voice: str = Field( + default="echo", description="Voice to use (alloy, echo, fable, onyx, nova, shimmer)" + ) + speed: float = Field(default=1.2, description="Speech speed (0.25 to 4.0)") + use_megaphone: bool = Field( + default=False, description="Use megaphone mode for lower latency (experimental)" + ) + + def __init__(self, **data): + super().__init__(**data) + self._openai_client = None + + def _get_openai_client(self): + if self._openai_client is None: + self._openai_client = OpenAI() + return self._openai_client + + def _generate_audio(self, text: str) -> bytes: + try: + client = self._get_openai_client() + response = client.audio.speech.create( + model="tts-1", voice=self.voice, input=text, speed=self.speed, response_format="mp3" + ) + return response.content + except Exception as e: + logger.error(f"Error generating audio: {e}") + raise + + def _webrtc_request(self, api_id: int, parameter: dict = None): + if parameter is None: + parameter = {} + + request_data = {"api_id": api_id, "parameter": json.dumps(parameter) if parameter else "{}"} + + return self._robot.webrtc_connection.publish_request( + RTC_TOPIC["AUDIO_HUB_REQ"], request_data + ) + + def _upload_audio_to_robot(self, audio_data: bytes, filename: str) -> str: + try: + file_md5 = hashlib.md5(audio_data).hexdigest() + b64_data = base64.b64encode(audio_data).decode("utf-8") + + chunk_size = 61440 + chunks = [b64_data[i : i + chunk_size] for i in range(0, len(b64_data), chunk_size)] + total_chunks = len(chunks) + + logger.info(f"Uploading audio '{filename}' in {total_chunks} chunks (optimized)") + + for i, chunk in enumerate(chunks, 1): + parameter = { + "file_name": filename, + "file_type": "wav", + "file_size": len(audio_data), + "current_block_index": i, + "total_block_number": total_chunks, + "block_content": chunk, + "current_block_size": len(chunk), + "file_md5": file_md5, + "create_time": int(time.time() * 1000), + } + + logger.debug(f"Sending chunk {i}/{total_chunks}") + response = self._webrtc_request(AUDIO_API["UPLOAD_AUDIO_FILE"], parameter) + + logger.info(f"Audio upload completed for '{filename}'") + + list_response = self._webrtc_request(AUDIO_API["GET_AUDIO_LIST"], {}) + + if list_response and "data" in list_response: + data_str = list_response.get("data", {}).get("data", "{}") + audio_list = json.loads(data_str).get("audio_list", []) + + for audio in audio_list: + if audio.get("CUSTOM_NAME") == filename: + return audio.get("UNIQUE_ID") + + logger.warning( + f"Could not find uploaded audio '{filename}' in list, using filename as UUID" + ) + return filename + + except Exception as e: + logger.error(f"Error uploading audio to robot: {e}") + raise + + def _play_audio_on_robot(self, uuid: str): + try: + self._webrtc_request(AUDIO_API["SET_PLAY_MODE"], {"play_mode": PLAY_MODES["NO_CYCLE"]}) + time.sleep(0.1) + + parameter = {"unique_id": uuid} + + logger.info(f"Playing audio with UUID: {uuid}") + self._webrtc_request(AUDIO_API["SELECT_START_PLAY"], parameter) + + except Exception as e: + logger.error(f"Error playing audio on robot: {e}") + raise + + def _stop_audio_playback(self): + try: + logger.debug("Stopping audio playback") + self._webrtc_request(AUDIO_API["PAUSE"], {}) + except Exception as e: + logger.warning(f"Error stopping audio playback: {e}") + + def _upload_and_play_megaphone(self, audio_data: bytes, duration: float): + try: + logger.debug("Entering megaphone mode") + self._webrtc_request(AUDIO_API["ENTER_MEGAPHONE"], {}) + + time.sleep(0.2) + + b64_data = base64.b64encode(audio_data).decode("utf-8") + + chunk_size = 4096 + chunks = [b64_data[i : i + chunk_size] for i in range(0, len(b64_data), chunk_size)] + total_chunks = len(chunks) + + logger.info(f"Uploading megaphone audio in {total_chunks} chunks") + + for i, chunk in enumerate(chunks, 1): + parameter = { + "current_block_size": len(chunk), + "block_content": chunk, + "current_block_index": i, + "total_block_number": total_chunks, + } + + logger.debug(f"Sending megaphone chunk {i}/{total_chunks}") + self._webrtc_request(AUDIO_API["UPLOAD_MEGAPHONE"], parameter) + + if i < total_chunks: + time.sleep(0.05) + + logger.info("Megaphone audio upload completed, waiting for playback") + + time.sleep(duration + 1.0) + + except Exception as e: + logger.error(f"Error in megaphone mode: {e}") + try: + self._webrtc_request(AUDIO_API["EXIT_MEGAPHONE"], {}) + except: + pass + raise + finally: + try: + logger.debug("Exiting megaphone mode") + self._webrtc_request(AUDIO_API["EXIT_MEGAPHONE"], {}) + time.sleep(0.1) + except Exception as e: + logger.warning(f"Error exiting megaphone mode: {e}") + + def __call__(self): + super().__call__() + + if not self._robot: + logger.error("No robot instance provided to UnitreeSpeak skill") + return "Error: No robot instance available" + + try: + display_text = self.text[:50] + "..." if len(self.text) > 50 else self.text + logger.info(f"Speaking: '{display_text}'") + + logger.debug("Generating audio with OpenAI TTS") + audio_data = self._generate_audio(self.text) + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_mp3: + tmp_mp3.write(audio_data) + tmp_mp3_path = tmp_mp3.name + + try: + audio_array, sample_rate = sf.read(tmp_mp3_path) + + if audio_array.ndim > 1: + audio_array = np.mean(audio_array, axis=1) + + target_sample_rate = 22050 + if sample_rate != target_sample_rate: + logger.debug(f"Resampling from {sample_rate}Hz to {target_sample_rate}Hz") + old_length = len(audio_array) + new_length = int(old_length * target_sample_rate / sample_rate) + old_indices = np.arange(old_length) + new_indices = np.linspace(0, old_length - 1, new_length) + audio_array = np.interp(new_indices, old_indices, audio_array) + sample_rate = target_sample_rate + + audio_array = audio_array / np.max(np.abs(audio_array)) + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav: + sf.write(tmp_wav.name, audio_array, sample_rate, format="WAV", subtype="PCM_16") + tmp_wav.seek(0) + wav_data = open(tmp_wav.name, "rb").read() + os.unlink(tmp_wav.name) + + logger.info( + f"Audio size: {len(wav_data) / 1024:.1f}KB, duration: {len(audio_array) / sample_rate:.1f}s" + ) + + finally: + os.unlink(tmp_mp3_path) + + if self.use_megaphone: + logger.debug("Using megaphone mode for lower latency") + duration = len(audio_array) / sample_rate + self._upload_and_play_megaphone(wav_data, duration) + + return f"Spoke: '{display_text}' on robot successfully (megaphone mode)" + else: + filename = f"speak_{int(time.time() * 1000)}" + + logger.debug("Uploading audio to robot") + uuid = self._upload_audio_to_robot(wav_data, filename) + + logger.debug("Playing audio on robot") + self._play_audio_on_robot(uuid) + + duration = len(audio_array) / sample_rate + logger.debug(f"Waiting {duration:.1f}s for playback to complete") + # time.sleep(duration + 0.2) + + # self._stop_audio_playback() + + return f"Spoke: '{display_text}' on robot successfully" + + except Exception as e: + logger.error(f"Error in speak skill: {e}") + return f"Error speaking text: {str(e)}" diff --git a/build/lib/dimos/skills/visual_navigation_skills.py b/build/lib/dimos/skills/visual_navigation_skills.py new file mode 100644 index 0000000000..96e21eb92d --- /dev/null +++ b/build/lib/dimos/skills/visual_navigation_skills.py @@ -0,0 +1,148 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Visual navigation skills for robot interaction. + +This module provides skills for visual navigation, including following humans +and navigating to specific objects using computer vision. +""" + +import time +import logging +import threading +from typing import Optional, Tuple + +from dimos.skills.skills import AbstractRobotSkill +from dimos.utils.logging_config import setup_logger +from dimos.perception.visual_servoing import VisualServoing +from pydantic import Field +from dimos.types.vector import Vector + +logger = setup_logger("dimos.skills.visual_navigation", level=logging.DEBUG) + + +class FollowHuman(AbstractRobotSkill): + """ + A skill that makes the robot follow a human using visual servoing continuously. + + This skill uses the robot's person tracking stream to follow a human + while maintaining a specified distance. It will keep following the human + until the timeout is reached or the skill is stopped. Don't use this skill + if you want to navigate to a specific person, use NavigateTo instead. + """ + + distance: float = Field( + 1.5, description="Desired distance to maintain from the person in meters" + ) + timeout: float = Field(20.0, description="Maximum time to follow the person in seconds") + point: Optional[Tuple[int, int]] = Field( + None, description="Optional point to start tracking (x,y pixel coordinates)" + ) + + def __init__(self, robot=None, **data): + super().__init__(robot=robot, **data) + self._stop_event = threading.Event() + self._visual_servoing = None + + def __call__(self): + """ + Start following a human using visual servoing. + + Returns: + bool: True if successful, False otherwise + """ + super().__call__() + + if ( + not hasattr(self._robot, "person_tracking_stream") + or self._robot.person_tracking_stream is None + ): + logger.error("Robot does not have a person tracking stream") + return False + + # Stop any existing operation + self.stop() + self._stop_event.clear() + + success = False + + try: + # Initialize visual servoing + self._visual_servoing = VisualServoing( + tracking_stream=self._robot.person_tracking_stream + ) + + logger.warning(f"Following human for {self.timeout} seconds...") + start_time = time.time() + + # Start tracking + track_success = self._visual_servoing.start_tracking( + point=self.point, desired_distance=self.distance + ) + + if not track_success: + logger.error("Failed to start tracking") + return False + + # Main follow loop + while ( + self._visual_servoing.running + and time.time() - start_time < self.timeout + and not self._stop_event.is_set() + ): + output = self._visual_servoing.updateTracking() + x_vel = output.get("linear_vel") + z_vel = output.get("angular_vel") + logger.debug(f"Following human: x_vel: {x_vel}, z_vel: {z_vel}") + self._robot.move(Vector(x_vel, 0, z_vel)) + time.sleep(0.05) + + # If we completed the full timeout duration, consider it success + if time.time() - start_time >= self.timeout: + success = True + logger.info("Human following completed successfully") + elif self._stop_event.is_set(): + logger.info("Human following stopped externally") + else: + logger.info("Human following stopped due to tracking loss") + + return success + + except Exception as e: + logger.error(f"Error in follow human: {e}") + return False + finally: + # Clean up + if self._visual_servoing: + self._visual_servoing.stop_tracking() + self._visual_servoing = None + + def stop(self): + """ + Stop the human following process. + + Returns: + bool: True if stopped, False if it wasn't running + """ + if self._visual_servoing is not None: + logger.info("Stopping FollowHuman skill") + self._stop_event.set() + + # Clean up visual servoing if it exists + self._visual_servoing.stop_tracking() + self._visual_servoing = None + + return True + return False diff --git a/build/lib/dimos/stream/__init__.py b/build/lib/dimos/stream/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/stream/audio/__init__.py b/build/lib/dimos/stream/audio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/stream/audio/base.py b/build/lib/dimos/stream/audio/base.py new file mode 100644 index 0000000000..a22e6606d6 --- /dev/null +++ b/build/lib/dimos/stream/audio/base.py @@ -0,0 +1,114 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from reactivex import Observable +import numpy as np + + +class AbstractAudioEmitter(ABC): + """Base class for components that emit audio.""" + + @abstractmethod + def emit_audio(self) -> Observable: + """Create an observable that emits audio frames. + + Returns: + Observable emitting audio frames + """ + pass + + +class AbstractAudioConsumer(ABC): + """Base class for components that consume audio.""" + + @abstractmethod + def consume_audio(self, audio_observable: Observable) -> "AbstractAudioConsumer": + """Set the audio observable to consume. + + Args: + audio_observable: Observable emitting audio frames + + Returns: + Self for method chaining + """ + pass + + +class AbstractAudioTransform(AbstractAudioConsumer, AbstractAudioEmitter): + """Base class for components that both consume and emit audio. + + This represents a transform in an audio processing pipeline. + """ + + pass + + +class AudioEvent: + """Class to represent an audio frame event with metadata.""" + + def __init__(self, data: np.ndarray, sample_rate: int, timestamp: float, channels: int = 1): + """ + Initialize an AudioEvent. + + Args: + data: Audio data as numpy array + sample_rate: Audio sample rate in Hz + timestamp: Unix timestamp when the audio was captured + channels: Number of audio channels + """ + self.data = data + self.sample_rate = sample_rate + self.timestamp = timestamp + self.channels = channels + self.dtype = data.dtype + self.shape = data.shape + + def to_float32(self) -> "AudioEvent": + """Convert audio data to float32 format normalized to [-1.0, 1.0].""" + if self.data.dtype == np.float32: + return self + + new_data = self.data.astype(np.float32) + if self.data.dtype == np.int16: + new_data /= 32768.0 + + return AudioEvent( + data=new_data, + sample_rate=self.sample_rate, + timestamp=self.timestamp, + channels=self.channels, + ) + + def to_int16(self) -> "AudioEvent": + """Convert audio data to int16 format.""" + if self.data.dtype == np.int16: + return self + + new_data = self.data + if self.data.dtype == np.float32: + new_data = (new_data * 32767).astype(np.int16) + + return AudioEvent( + data=new_data, + sample_rate=self.sample_rate, + timestamp=self.timestamp, + channels=self.channels, + ) + + def __repr__(self) -> str: + return ( + f"AudioEvent(shape={self.shape}, dtype={self.dtype}, " + f"sample_rate={self.sample_rate}, channels={self.channels})" + ) diff --git a/build/lib/dimos/stream/audio/node_key_recorder.py b/build/lib/dimos/stream/audio/node_key_recorder.py new file mode 100644 index 0000000000..6494dcbef9 --- /dev/null +++ b/build/lib/dimos/stream/audio/node_key_recorder.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +import numpy as np +import time +import threading +import sys +import select +from reactivex import Observable +from reactivex.subject import Subject, ReplaySubject + +from dimos.stream.audio.base import AbstractAudioTransform, AudioEvent + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.audio.key_recorder") + + +class KeyRecorder(AbstractAudioTransform): + """ + Audio recorder that captures audio events and combines them. + Press a key to toggle recording on/off. + """ + + def __init__( + self, + max_recording_time: float = 120.0, + always_subscribe: bool = False, + ): + """ + Initialize KeyRecorder. + + Args: + max_recording_time: Maximum recording time in seconds + always_subscribe: If True, subscribe to audio source continuously, + If False, only subscribe when recording (more efficient + but some audio devices may need time to initialize) + """ + self.max_recording_time = max_recording_time + self.always_subscribe = always_subscribe + + self._audio_buffer = [] + self._is_recording = False + self._recording_start_time = 0 + self._sample_rate = None # Will be updated from incoming audio + self._channels = None # Will be set from first event + + self._audio_observable = None + self._subscription = None + self._output_subject = Subject() # For record-time passthrough + self._recording_subject = ReplaySubject(1) # For full completed recordings + + # Start a thread to monitor for input + self._running = True + self._input_thread = threading.Thread(target=self._input_monitor, daemon=True) + self._input_thread.start() + + logger.info("Started audio recorder (press any key to start/stop recording)") + + def consume_audio(self, audio_observable: Observable) -> "KeyRecorder": + """ + Set the audio observable to use when recording. + If always_subscribe is True, subscribes immediately. + Otherwise, subscribes only when recording starts. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self._audio_observable = audio_observable + + # If configured to always subscribe, do it now + if self.always_subscribe and not self._subscription: + self._subscription = audio_observable.subscribe( + on_next=self._process_audio_event, + on_error=self._handle_error, + on_completed=self._handle_completion, + ) + logger.debug("Subscribed to audio source (always_subscribe=True)") + + return self + + def emit_audio(self) -> Observable: + """ + Create an observable that emits audio events in real-time (pass-through). + + Returns: + Observable emitting AudioEvent objects in real-time + """ + return self._output_subject + + def emit_recording(self) -> Observable: + """ + Create an observable that emits combined audio recordings when recording stops. + + Returns: + Observable emitting AudioEvent objects with complete recordings + """ + return self._recording_subject + + def stop(self): + """Stop recording and clean up resources.""" + logger.info("Stopping audio recorder") + + # If recording is in progress, stop it first + if self._is_recording: + self._stop_recording() + + # Always clean up subscription on full stop + if self._subscription: + self._subscription.dispose() + self._subscription = None + + # Stop input monitoring thread + self._running = False + if self._input_thread.is_alive(): + self._input_thread.join(1.0) + + def _input_monitor(self): + """Monitor for key presses to toggle recording.""" + logger.info("Press Enter to start/stop recording...") + + while self._running: + # Check if there's input available + if select.select([sys.stdin], [], [], 0.1)[0]: + sys.stdin.readline() + + if self._is_recording: + self._stop_recording() + else: + self._start_recording() + + # Sleep a bit to reduce CPU usage + time.sleep(0.1) + + def _start_recording(self): + """Start recording audio and subscribe to the audio source if not always subscribed.""" + if not self._audio_observable: + logger.error("Cannot start recording: No audio source has been set") + return + + # Subscribe to the observable if not using always_subscribe + if not self._subscription: + self._subscription = self._audio_observable.subscribe( + on_next=self._process_audio_event, + on_error=self._handle_error, + on_completed=self._handle_completion, + ) + logger.debug("Subscribed to audio source for recording") + + self._is_recording = True + self._recording_start_time = time.time() + self._audio_buffer = [] + logger.info("Recording... (press Enter to stop)") + + def _stop_recording(self): + """Stop recording, unsubscribe from audio source if not always subscribed, and emit the combined audio event.""" + self._is_recording = False + recording_duration = time.time() - self._recording_start_time + + # Unsubscribe from the audio source if not using always_subscribe + if not self.always_subscribe and self._subscription: + self._subscription.dispose() + self._subscription = None + logger.debug("Unsubscribed from audio source after recording") + + logger.info(f"Recording stopped after {recording_duration:.2f} seconds") + + # Combine all audio events into one + if len(self._audio_buffer) > 0: + combined_audio = self._combine_audio_events(self._audio_buffer) + self._recording_subject.on_next(combined_audio) + else: + logger.warning("No audio was recorded") + + def _process_audio_event(self, audio_event): + """Process incoming audio events.""" + + # Only buffer if recording + if not self._is_recording: + return + + # Pass through audio events in real-time + self._output_subject.on_next(audio_event) + + # First audio event - determine channel count/sample rate + if self._channels is None: + self._channels = audio_event.channels + self._sample_rate = audio_event.sample_rate + logger.info(f"Setting channel count to {self._channels}") + + # Add to buffer + self._audio_buffer.append(audio_event) + + # Check if we've exceeded max recording time + if time.time() - self._recording_start_time > self.max_recording_time: + logger.warning(f"Max recording time ({self.max_recording_time}s) reached") + self._stop_recording() + + def _combine_audio_events(self, audio_events: List[AudioEvent]) -> AudioEvent: + """Combine multiple audio events into a single event.""" + if not audio_events: + logger.warning("Attempted to combine empty audio events list") + return None + + # Filter out any empty events that might cause broadcasting errors + valid_events = [ + event + for event in audio_events + if event is not None + and (hasattr(event, "data") and event.data is not None and event.data.size > 0) + ] + + if not valid_events: + logger.warning("No valid audio events to combine") + return None + + first_event = valid_events[0] + channels = first_event.channels + dtype = first_event.data.dtype + + # Calculate total samples only from valid events + total_samples = sum(event.data.shape[0] for event in valid_events) + + # Safety check - if somehow we got no samples + if total_samples <= 0: + logger.warning(f"Combined audio would have {total_samples} samples - aborting") + return None + + # For multichannel audio, data shape could be (samples,) or (samples, channels) + if len(first_event.data.shape) == 1: + # 1D audio data (mono) + combined_data = np.zeros(total_samples, dtype=dtype) + + # Copy data + offset = 0 + for event in valid_events: + samples = event.data.shape[0] + if samples > 0: # Extra safety check + combined_data[offset : offset + samples] = event.data + offset += samples + else: + # Multichannel audio data (stereo or more) + combined_data = np.zeros((total_samples, channels), dtype=dtype) + + # Copy data + offset = 0 + for event in valid_events: + samples = event.data.shape[0] + if samples > 0 and offset + samples <= total_samples: # Safety check + try: + combined_data[offset : offset + samples] = event.data + offset += samples + except ValueError as e: + logger.error( + f"Error combining audio events: {e}. " + f"Event shape: {event.data.shape}, " + f"Combined shape: {combined_data.shape}, " + f"Offset: {offset}, Samples: {samples}" + ) + # Continue with next event instead of failing completely + + # Create new audio event with the combined data + if combined_data.size > 0: + return AudioEvent( + data=combined_data, + sample_rate=self._sample_rate, + timestamp=valid_events[0].timestamp, + channels=channels, + ) + else: + logger.warning("Failed to create valid combined audio event") + return None + + def _handle_error(self, error): + """Handle errors from the observable.""" + logger.error(f"Error in audio observable: {error}") + + def _handle_completion(self): + """Handle completion of the observable.""" + logger.info("Audio observable completed") + self.stop() + + +if __name__ == "__main__": + from dimos.stream.audio.node_microphone import ( + SounddeviceAudioSource, + ) + from dimos.stream.audio.node_output import SounddeviceAudioOutput + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.node_normalizer import AudioNormalizer + from dimos.stream.audio.utils import keepalive + + # Create microphone source, recorder, and audio output + mic = SounddeviceAudioSource() + + # my audio device needs time to init, so for smoother ux we constantly listen + recorder = KeyRecorder(always_subscribe=True) + + normalizer = AudioNormalizer() + speaker = SounddeviceAudioOutput() + + # Connect the components + normalizer.consume_audio(mic.emit_audio()) + recorder.consume_audio(normalizer.emit_audio()) + # recorder.consume_audio(mic.emit_audio()) + + # Monitor microphone input levels (real-time pass-through) + monitor(recorder.emit_audio()) + + # Connect the recorder output to the speakers to hear recordings when completed + playback_speaker = SounddeviceAudioOutput() + playback_speaker.consume_audio(recorder.emit_recording()) + + # TODO: we should be able to run normalizer post hoc on the recording as well, + # it's not working, this needs a review + # + # normalizer.consume_audio(recorder.emit_recording()) + # playback_speaker.consume_audio(normalizer.emit_audio()) + + keepalive() diff --git a/build/lib/dimos/stream/audio/node_microphone.py b/build/lib/dimos/stream/audio/node_microphone.py new file mode 100644 index 0000000000..bdb9b32180 --- /dev/null +++ b/build/lib/dimos/stream/audio/node_microphone.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.audio.base import ( + AbstractAudioEmitter, + AudioEvent, +) + +import numpy as np +from typing import Optional, List, Dict, Any +from reactivex import Observable, create, disposable +import time +import sounddevice as sd + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.audio.node_microphone") + + +class SounddeviceAudioSource(AbstractAudioEmitter): + """Audio source implementation using the sounddevice library.""" + + def __init__( + self, + device_index: Optional[int] = None, + sample_rate: int = 16000, + channels: int = 1, + block_size: int = 1024, + dtype: np.dtype = np.float32, + ): + """ + Initialize SounddeviceAudioSource. + + Args: + device_index: Audio device index (None for default) + sample_rate: Audio sample rate in Hz + channels: Number of audio channels (1=mono, 2=stereo) + block_size: Number of samples per audio frame + dtype: Data type for audio samples (np.float32 or np.int16) + """ + self.device_index = device_index + self.sample_rate = sample_rate + self.channels = channels + self.block_size = block_size + self.dtype = dtype + + self._stream = None + self._running = False + + def emit_audio(self) -> Observable: + """ + Create an observable that emits audio frames. + + Returns: + Observable emitting AudioEvent objects + """ + + def on_subscribe(observer, scheduler): + # Callback function to process audio data + def audio_callback(indata, frames, time_info, status): + if status: + logger.warning(f"Audio callback status: {status}") + + # Create audio event + audio_event = AudioEvent( + data=indata.copy(), + sample_rate=self.sample_rate, + timestamp=time.time(), + channels=self.channels, + ) + + observer.on_next(audio_event) + + # Start the audio stream + try: + self._stream = sd.InputStream( + device=self.device_index, + samplerate=self.sample_rate, + channels=self.channels, + blocksize=self.block_size, + dtype=self.dtype, + callback=audio_callback, + ) + self._stream.start() + self._running = True + + logger.info( + f"Started audio capture: {self.sample_rate}Hz, " + f"{self.channels} channels, {self.block_size} samples per frame" + ) + + except Exception as e: + logger.error(f"Error starting audio stream: {e}") + observer.on_error(e) + + # Return a disposable to clean up resources + def dispose(): + logger.info("Stopping audio capture") + self._running = False + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + def get_available_devices(self) -> List[Dict[str, Any]]: + """Get a list of available audio input devices.""" + return sd.query_devices() + + +if __name__ == "__main__": + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.utils import keepalive + + monitor(SounddeviceAudioSource().emit_audio()) + keepalive() diff --git a/build/lib/dimos/stream/audio/node_normalizer.py b/build/lib/dimos/stream/audio/node_normalizer.py new file mode 100644 index 0000000000..db9557a5b1 --- /dev/null +++ b/build/lib/dimos/stream/audio/node_normalizer.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import numpy as np +from reactivex import Observable, create, disposable + +from dimos.utils.logging_config import setup_logger +from dimos.stream.audio.volume import ( + calculate_rms_volume, + calculate_peak_volume, +) +from dimos.stream.audio.base import ( + AbstractAudioTransform, + AudioEvent, +) + + +logger = setup_logger("dimos.stream.audio.node_normalizer") + + +class AudioNormalizer(AbstractAudioTransform): + """ + Audio normalizer that remembers max volume and rescales audio to normalize it. + + This class applies dynamic normalization to audio frames. It keeps track of + the max volume encountered and uses that to normalize the audio to a target level. + """ + + def __init__( + self, + target_level: float = 1.0, + min_volume_threshold: float = 0.01, + max_gain: float = 10.0, + decay_factor: float = 0.999, + adapt_speed: float = 0.05, + volume_func: Callable[[np.ndarray], float] = calculate_peak_volume, + ): + """ + Initialize AudioNormalizer. + + Args: + target_level: Target normalization level (0.0 to 1.0) + min_volume_threshold: Minimum volume to apply normalization + max_gain: Maximum allowed gain to prevent excessive amplification + decay_factor: Decay factor for max volume (0.0-1.0, higher = slower decay) + adapt_speed: How quickly to adapt to new volume levels (0.0-1.0) + volume_func: Function to calculate volume (default: peak volume) + """ + self.target_level = target_level + self.min_volume_threshold = min_volume_threshold + self.max_gain = max_gain + self.decay_factor = decay_factor + self.adapt_speed = adapt_speed + self.volume_func = volume_func + + # Internal state + self.max_volume = 0.0 + self.current_gain = 1.0 + self.audio_observable = None + + def _normalize_audio(self, audio_event: AudioEvent) -> AudioEvent: + """ + Normalize audio data based on tracked max volume. + + Args: + audio_event: Input audio event + + Returns: + Normalized audio event + """ + # Convert to float32 for processing if needed + if audio_event.data.dtype != np.float32: + audio_event = audio_event.to_float32() + + # Calculate current volume using provided function + current_volume = self.volume_func(audio_event.data) + + # Update max volume with decay + self.max_volume = max(current_volume, self.max_volume * self.decay_factor) + + # Calculate ideal gain + if self.max_volume > self.min_volume_threshold: + ideal_gain = self.target_level / self.max_volume + else: + ideal_gain = 1.0 # No normalization needed for very quiet audio + + # Limit gain to max_gain + ideal_gain = min(ideal_gain, self.max_gain) + + # Smoothly adapt current gain towards ideal gain + self.current_gain = ( + 1 - self.adapt_speed + ) * self.current_gain + self.adapt_speed * ideal_gain + + # Apply gain to audio data + normalized_data = audio_event.data * self.current_gain + + # Clip to prevent distortion (values should stay within -1.0 to 1.0) + normalized_data = np.clip(normalized_data, -1.0, 1.0) + + # Create new audio event with normalized data + return AudioEvent( + data=normalized_data, + sample_rate=audio_event.sample_rate, + timestamp=audio_event.timestamp, + channels=audio_event.channels, + ) + + def consume_audio(self, audio_observable: Observable) -> "AudioNormalizer": + """ + Set the audio source observable to consume. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self.audio_observable = audio_observable + return self + + def emit_audio(self) -> Observable: + """ + Create an observable that emits normalized audio frames. + + Returns: + Observable emitting normalized AudioEvent objects + """ + if self.audio_observable is None: + raise ValueError("No audio source provided. Call consume_audio() first.") + + def on_subscribe(observer, scheduler): + # Subscribe to the audio observable + audio_subscription = self.audio_observable.subscribe( + on_next=lambda event: observer.on_next(self._normalize_audio(event)), + on_error=lambda error: observer.on_error(error), + on_completed=lambda: observer.on_completed(), + ) + + logger.info( + f"Started audio normalizer with target level: {self.target_level}, max gain: {self.max_gain}" + ) + + # Return a disposable to clean up resources + def dispose(): + logger.info("Stopping audio normalizer") + audio_subscription.dispose() + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + +if __name__ == "__main__": + import sys + from dimos.stream.audio.node_microphone import ( + SounddeviceAudioSource, + ) + from dimos.stream.audio.node_simulated import SimulatedAudioSource + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.node_output import SounddeviceAudioOutput + from dimos.stream.audio.utils import keepalive + + # Parse command line arguments + volume_method = "peak" # Default to peak + use_mic = False # Default to microphone input + target_level = 1 # Default target level + + # Process arguments + for arg in sys.argv[1:]: + if arg == "rms": + volume_method = "rms" + elif arg == "peak": + volume_method = "peak" + elif arg == "mic": + use_mic = True + elif arg.startswith("level="): + try: + target_level = float(arg.split("=")[1]) + except ValueError: + print(f"Invalid target level: {arg}") + sys.exit(1) + + # Create appropriate audio source + if use_mic: + audio_source = SounddeviceAudioSource() + print("Using microphone input") + else: + audio_source = SimulatedAudioSource(volume_oscillation=True) + print("Using simulated audio source") + + # Select volume function + volume_func = calculate_rms_volume if volume_method == "rms" else calculate_peak_volume + + # Create normalizer + normalizer = AudioNormalizer(target_level=target_level, volume_func=volume_func) + + # Connect the audio source to the normalizer + normalizer.consume_audio(audio_source.emit_audio()) + + print(f"Using {volume_method} volume method with target level {target_level}") + SounddeviceAudioOutput().consume_audio(normalizer.emit_audio()) + + # Monitor the normalized audio + monitor(normalizer.emit_audio()) + keepalive() diff --git a/build/lib/dimos/stream/audio/node_output.py b/build/lib/dimos/stream/audio/node_output.py new file mode 100644 index 0000000000..ee2e2c5ec2 --- /dev/null +++ b/build/lib/dimos/stream/audio/node_output.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List, Dict, Any +import numpy as np +import sounddevice as sd +from reactivex import Observable + +from dimos.utils.logging_config import setup_logger +from dimos.stream.audio.base import ( + AbstractAudioTransform, +) + +logger = setup_logger("dimos.stream.audio.node_output") + + +class SounddeviceAudioOutput(AbstractAudioTransform): + """ + Audio output implementation using the sounddevice library. + + This class implements AbstractAudioTransform so it can both play audio and + optionally pass audio events through to other components (for example, to + record audio while playing it, or to visualize the waveform while playing). + """ + + def __init__( + self, + device_index: Optional[int] = None, + sample_rate: int = 16000, + channels: int = 1, + block_size: int = 1024, + dtype: np.dtype = np.float32, + ): + """ + Initialize SounddeviceAudioOutput. + + Args: + device_index: Audio device index (None for default) + sample_rate: Audio sample rate in Hz + channels: Number of audio channels (1=mono, 2=stereo) + block_size: Number of samples per audio frame + dtype: Data type for audio samples (np.float32 or np.int16) + """ + self.device_index = device_index + self.sample_rate = sample_rate + self.channels = channels + self.block_size = block_size + self.dtype = dtype + + self._stream = None + self._running = False + self._subscription = None + self.audio_observable = None + + def consume_audio(self, audio_observable: Observable) -> "SounddeviceAudioOutput": + """ + Subscribe to an audio observable and play the audio through the speakers. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self.audio_observable = audio_observable + + # Create and start the output stream + try: + self._stream = sd.OutputStream( + device=self.device_index, + samplerate=self.sample_rate, + channels=self.channels, + blocksize=self.block_size, + dtype=self.dtype, + ) + self._stream.start() + self._running = True + + logger.info( + f"Started audio output: {self.sample_rate}Hz, " + f"{self.channels} channels, {self.block_size} samples per frame" + ) + + except Exception as e: + logger.error(f"Error starting audio output stream: {e}") + raise e + + # Subscribe to the observable + self._subscription = audio_observable.subscribe( + on_next=self._play_audio_event, + on_error=self._handle_error, + on_completed=self._handle_completion, + ) + + return self + + def emit_audio(self) -> Observable: + """ + Pass through the audio observable to allow chaining with other components. + + Returns: + The same Observable that was provided to consume_audio + """ + if self.audio_observable is None: + raise ValueError("No audio source provided. Call consume_audio() first.") + + return self.audio_observable + + def stop(self): + """Stop audio output and clean up resources.""" + logger.info("Stopping audio output") + self._running = False + + if self._subscription: + self._subscription.dispose() + self._subscription = None + + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + + def _play_audio_event(self, audio_event): + """Play audio from an AudioEvent.""" + if not self._running or not self._stream: + return + + try: + # Ensure data type matches our stream + if audio_event.dtype != self.dtype: + if self.dtype == np.float32: + audio_event = audio_event.to_float32() + elif self.dtype == np.int16: + audio_event = audio_event.to_int16() + + # Write audio data to the stream + self._stream.write(audio_event.data) + except Exception as e: + logger.error(f"Error playing audio: {e}") + + def _handle_error(self, error): + """Handle errors from the observable.""" + logger.error(f"Error in audio observable: {error}") + + def _handle_completion(self): + """Handle completion of the observable.""" + logger.info("Audio observable completed") + self._running = False + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + + def get_available_devices(self) -> List[Dict[str, Any]]: + """Get a list of available audio output devices.""" + return sd.query_devices() + + +if __name__ == "__main__": + from dimos.stream.audio.node_microphone import ( + SounddeviceAudioSource, + ) + from dimos.stream.audio.node_normalizer import AudioNormalizer + from dimos.stream.audio.utils import keepalive + + # Create microphone source, normalizer and audio output + mic = SounddeviceAudioSource() + normalizer = AudioNormalizer() + speaker = SounddeviceAudioOutput() + + # Connect the components in a pipeline + normalizer.consume_audio(mic.emit_audio()) + speaker.consume_audio(normalizer.emit_audio()) + + keepalive() diff --git a/build/lib/dimos/stream/audio/node_simulated.py b/build/lib/dimos/stream/audio/node_simulated.py new file mode 100644 index 0000000000..c9aff9a32d --- /dev/null +++ b/build/lib/dimos/stream/audio/node_simulated.py @@ -0,0 +1,221 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.audio.abstract import ( + AbstractAudioEmitter, + AudioEvent, +) +import numpy as np +from reactivex import Observable, create, disposable +import threading +import time + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.stream.audio.node_simulated") + + +class SimulatedAudioSource(AbstractAudioEmitter): + """Audio source that generates simulated audio for testing.""" + + def __init__( + self, + sample_rate: int = 16000, + frame_length: int = 1024, + channels: int = 1, + dtype: np.dtype = np.float32, + frequency: float = 440.0, # A4 note + waveform: str = "sine", # Type of waveform + modulation_rate: float = 0.5, # Modulation rate in Hz + volume_oscillation: bool = True, # Enable sinusoidal volume changes + volume_oscillation_rate: float = 0.2, # Volume oscillation rate in Hz + ): + """ + Initialize SimulatedAudioSource. + + Args: + sample_rate: Audio sample rate in Hz + frame_length: Number of samples per frame + channels: Number of audio channels + dtype: Data type for audio samples + frequency: Frequency of the sine wave in Hz + waveform: Type of waveform ("sine", "square", "triangle", "sawtooth") + modulation_rate: Frequency modulation rate in Hz + volume_oscillation: Whether to oscillate volume sinusoidally + volume_oscillation_rate: Rate of volume oscillation in Hz + """ + self.sample_rate = sample_rate + self.frame_length = frame_length + self.channels = channels + self.dtype = dtype + self.frequency = frequency + self.waveform = waveform.lower() + self.modulation_rate = modulation_rate + self.volume_oscillation = volume_oscillation + self.volume_oscillation_rate = volume_oscillation_rate + self.phase = 0.0 + self.volume_phase = 0.0 + + self._running = False + self._thread = None + + def _generate_sine_wave(self, time_points: np.ndarray) -> np.ndarray: + """Generate a waveform based on selected type.""" + # Generate base time points with phase + t = time_points + self.phase + + # Add frequency modulation for more interesting sounds + if self.modulation_rate > 0: + # Modulate frequency between 0.5x and 1.5x the base frequency + freq_mod = self.frequency * (1.0 + 0.5 * np.sin(2 * np.pi * self.modulation_rate * t)) + else: + freq_mod = np.ones_like(t) * self.frequency + + # Create phase argument for oscillators + phase_arg = 2 * np.pi * np.cumsum(freq_mod / self.sample_rate) + + # Generate waveform based on selection + if self.waveform == "sine": + wave = np.sin(phase_arg) + elif self.waveform == "square": + wave = np.sign(np.sin(phase_arg)) + elif self.waveform == "triangle": + wave = ( + 2 * np.abs(2 * (phase_arg / (2 * np.pi) - np.floor(phase_arg / (2 * np.pi) + 0.5))) + - 1 + ) + elif self.waveform == "sawtooth": + wave = 2 * (phase_arg / (2 * np.pi) - np.floor(0.5 + phase_arg / (2 * np.pi))) + else: + # Default to sine wave + wave = np.sin(phase_arg) + + # Apply sinusoidal volume oscillation if enabled + if self.volume_oscillation: + # Current time points for volume calculation + vol_t = t + self.volume_phase + + # Volume oscillates between 0.0 and 1.0 using a sine wave (complete silence to full volume) + volume_factor = 0.5 + 0.5 * np.sin(2 * np.pi * self.volume_oscillation_rate * vol_t) + + # Apply the volume factor + wave *= volume_factor * 0.7 + + # Update volume phase for next frame + self.volume_phase += ( + time_points[-1] - time_points[0] + (time_points[1] - time_points[0]) + ) + + # Update phase for next frame + self.phase += time_points[-1] - time_points[0] + (time_points[1] - time_points[0]) + + # Add a second channel if needed + if self.channels == 2: + wave = np.column_stack((wave, wave)) + elif self.channels > 2: + wave = np.tile(wave.reshape(-1, 1), (1, self.channels)) + + # Convert to int16 if needed + if self.dtype == np.int16: + wave = (wave * 32767).astype(np.int16) + + return wave + + def _audio_thread(self, observer, interval: float): + """Thread function for simulated audio generation.""" + try: + sample_index = 0 + self._running = True + + while self._running: + # Calculate time points for this frame + time_points = ( + np.arange(sample_index, sample_index + self.frame_length) / self.sample_rate + ) + + # Generate audio data + audio_data = self._generate_sine_wave(time_points) + + # Create audio event + audio_event = AudioEvent( + data=audio_data, + sample_rate=self.sample_rate, + timestamp=time.time(), + channels=self.channels, + ) + + observer.on_next(audio_event) + + # Update sample index for next frame + sample_index += self.frame_length + + # Sleep to simulate real-time audio + time.sleep(interval) + + except Exception as e: + logger.error(f"Error in simulated audio thread: {e}") + observer.on_error(e) + finally: + self._running = False + observer.on_completed() + + def emit_audio(self, fps: int = 30) -> Observable: + """ + Create an observable that emits simulated audio frames. + + Args: + fps: Frames per second to emit + + Returns: + Observable emitting AudioEvent objects + """ + + def on_subscribe(observer, scheduler): + # Calculate interval based on fps + interval = 1.0 / fps + + # Start the audio generation thread + self._thread = threading.Thread( + target=self._audio_thread, args=(observer, interval), daemon=True + ) + self._thread.start() + + logger.info( + f"Started simulated audio source: {self.sample_rate}Hz, " + f"{self.channels} channels, {self.frame_length} samples per frame" + ) + + # Return a disposable to clean up + def dispose(): + logger.info("Stopping simulated audio") + self._running = False + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + +if __name__ == "__main__": + from dimos.stream.audio.utils import keepalive + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.node_output import SounddeviceAudioOutput + + source = SimulatedAudioSource() + speaker = SounddeviceAudioOutput() + speaker.consume_audio(source.emit_audio()) + monitor(speaker.emit_audio()) + + keepalive() diff --git a/build/lib/dimos/stream/audio/node_volume_monitor.py b/build/lib/dimos/stream/audio/node_volume_monitor.py new file mode 100644 index 0000000000..6510667307 --- /dev/null +++ b/build/lib/dimos/stream/audio/node_volume_monitor.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable +from reactivex import Observable, create, disposable + +from dimos.stream.audio.base import AudioEvent, AbstractAudioConsumer +from dimos.stream.audio.text.base import AbstractTextEmitter +from dimos.stream.audio.text.node_stdout import TextPrinterNode +from dimos.stream.audio.volume import calculate_peak_volume +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.stream.audio.node_volume_monitor") + + +class VolumeMonitorNode(AbstractAudioConsumer, AbstractTextEmitter): + """ + A node that monitors audio volume and emits text descriptions. + """ + + def __init__( + self, + threshold: float = 0.01, + bar_length: int = 50, + volume_func: Callable = calculate_peak_volume, + ): + """ + Initialize VolumeMonitorNode. + + Args: + threshold: Threshold for considering audio as active + bar_length: Length of the volume bar in characters + volume_func: Function to calculate volume (defaults to peak volume) + """ + self.threshold = threshold + self.bar_length = bar_length + self.volume_func = volume_func + self.func_name = volume_func.__name__.replace("calculate_", "") + self.audio_observable = None + + def create_volume_text(self, volume: float) -> str: + """ + Create a text representation of the volume level. + + Args: + volume: Volume level between 0.0 and 1.0 + + Returns: + String representation of the volume + """ + # Calculate number of filled segments + filled = int(volume * self.bar_length) + + # Create the bar + bar = "█" * filled + "░" * (self.bar_length - filled) + + # Determine if we're above threshold + active = volume >= self.threshold + + # Format the text + percentage = int(volume * 100) + activity = "active" if active else "silent" + return f"{bar} {percentage:3d}% {activity}" + + def consume_audio(self, audio_observable: Observable) -> "VolumeMonitorNode": + """ + Set the audio source observable to consume. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self.audio_observable = audio_observable + return self + + def emit_text(self) -> Observable: + """ + Create an observable that emits volume text descriptions. + + Returns: + Observable emitting text descriptions of audio volume + """ + if self.audio_observable is None: + raise ValueError("No audio source provided. Call consume_audio() first.") + + def on_subscribe(observer, scheduler): + logger.info(f"Starting volume monitor (method: {self.func_name})") + + # Subscribe to the audio source + def on_audio_event(event: AudioEvent): + try: + # Calculate volume + volume = self.volume_func(event.data) + + # Create text representation + text = self.create_volume_text(volume) + + # Emit the text + observer.on_next(text) + except Exception as e: + logger.error(f"Error processing audio event: {e}") + observer.on_error(e) + + # Set up subscription to audio source + subscription = self.audio_observable.subscribe( + on_next=on_audio_event, + on_error=lambda e: observer.on_error(e), + on_completed=lambda: observer.on_completed(), + ) + + # Return a disposable to clean up resources + def dispose(): + logger.info("Stopping volume monitor") + subscription.dispose() + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + +def monitor( + audio_source: Observable, + threshold: float = 0.01, + bar_length: int = 50, + volume_func: Callable = calculate_peak_volume, +) -> VolumeMonitorNode: + """ + Create a volume monitor node connected to a text output node. + + Args: + audio_source: The audio source to monitor + threshold: Threshold for considering audio as active + bar_length: Length of the volume bar in characters + volume_func: Function to calculate volume + + Returns: + The configured volume monitor node + """ + # Create the volume monitor node with specified parameters + volume_monitor = VolumeMonitorNode( + threshold=threshold, bar_length=bar_length, volume_func=volume_func + ) + + # Connect the volume monitor to the audio source + volume_monitor.consume_audio(audio_source) + + # Create and connect the text printer node + text_printer = TextPrinterNode() + text_printer.consume_text(volume_monitor.emit_text()) + + # Return the volume monitor node + return volume_monitor + + +if __name__ == "__main__": + from utils import keepalive + from audio.node_simulated import SimulatedAudioSource + + # Use the monitor function to create and connect the nodes + volume_monitor = monitor(SimulatedAudioSource().emit_audio()) + + keepalive() diff --git a/build/lib/dimos/stream/audio/pipelines.py b/build/lib/dimos/stream/audio/pipelines.py new file mode 100644 index 0000000000..ee2ae43316 --- /dev/null +++ b/build/lib/dimos/stream/audio/pipelines.py @@ -0,0 +1,52 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.audio.node_microphone import SounddeviceAudioSource +from dimos.stream.audio.node_normalizer import AudioNormalizer +from dimos.stream.audio.node_volume_monitor import monitor +from dimos.stream.audio.node_key_recorder import KeyRecorder +from dimos.stream.audio.node_output import SounddeviceAudioOutput +from dimos.stream.audio.stt.node_whisper import WhisperNode +from dimos.stream.audio.tts.node_openai import OpenAITTSNode, Voice +from dimos.stream.audio.text.node_stdout import TextPrinterNode + + +def stt(): + # Create microphone source, recorder, and audio output + mic = SounddeviceAudioSource() + normalizer = AudioNormalizer() + recorder = KeyRecorder(always_subscribe=True) + whisper_node = WhisperNode() # Assign to global variable + + # Connect audio processing pipeline + normalizer.consume_audio(mic.emit_audio()) + recorder.consume_audio(normalizer.emit_audio()) + monitor(recorder.emit_audio()) + whisper_node.consume_audio(recorder.emit_recording()) + + user_text_printer = TextPrinterNode(prefix="USER: ") + user_text_printer.consume_text(whisper_node.emit_text()) + + return whisper_node + + +def tts(): + tts_node = OpenAITTSNode(speed=1.2, voice=Voice.ONYX) + agent_text_printer = TextPrinterNode(prefix="AGENT: ") + agent_text_printer.consume_text(tts_node.emit_text()) + + response_output = SounddeviceAudioOutput(sample_rate=24000) + response_output.consume_audio(tts_node.emit_audio()) + + return tts_node diff --git a/build/lib/dimos/stream/audio/utils.py b/build/lib/dimos/stream/audio/utils.py new file mode 100644 index 0000000000..712086ffd6 --- /dev/null +++ b/build/lib/dimos/stream/audio/utils.py @@ -0,0 +1,26 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + + +def keepalive(): + try: + # Keep the program running + print("Press Ctrl+C to exit") + print("-" * 60) + while True: + time.sleep(0.1) + except KeyboardInterrupt: + print("\nStopping pipeline") diff --git a/build/lib/dimos/stream/audio/volume.py b/build/lib/dimos/stream/audio/volume.py new file mode 100644 index 0000000000..f2e50ab72c --- /dev/null +++ b/build/lib/dimos/stream/audio/volume.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def calculate_rms_volume(audio_data: np.ndarray) -> float: + """ + Calculate RMS (Root Mean Square) volume of audio data. + + Args: + audio_data: Audio data as numpy array + + Returns: + RMS volume as a float between 0.0 and 1.0 + """ + # For multi-channel audio, calculate RMS across all channels + if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: + # Flatten all channels + audio_data = audio_data.flatten() + + # Calculate RMS + rms = np.sqrt(np.mean(np.square(audio_data))) + + # For int16 data, normalize to [0, 1] + if audio_data.dtype == np.int16: + rms = rms / 32768.0 + + return rms + + +def calculate_peak_volume(audio_data: np.ndarray) -> float: + """ + Calculate peak volume of audio data. + + Args: + audio_data: Audio data as numpy array + + Returns: + Peak volume as a float between 0.0 and 1.0 + """ + # For multi-channel audio, find max across all channels + if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: + # Flatten all channels + audio_data = audio_data.flatten() + + # Find absolute peak value + peak = np.max(np.abs(audio_data)) + + # For int16 data, normalize to [0, 1] + if audio_data.dtype == np.int16: + peak = peak / 32768.0 + + return peak + + +if __name__ == "__main__": + # Example usage + import time + from .node_simulated import SimulatedAudioSource + + # Create a simulated audio source + audio_source = SimulatedAudioSource() + + # Create observable and subscribe to get a single frame + audio_observable = audio_source.capture_audio_as_observable() + + def process_frame(frame): + # Calculate and print both RMS and peak volumes + rms_vol = calculate_rms_volume(frame.data) + peak_vol = calculate_peak_volume(frame.data) + + print(f"RMS Volume: {rms_vol:.4f}") + print(f"Peak Volume: {peak_vol:.4f}") + print(f"Ratio (Peak/RMS): {peak_vol / rms_vol:.2f}") + + # Set a flag to track when processing is complete + processed = {"done": False} + + def process_frame_wrapper(frame): + # Process the frame + process_frame(frame) + # Mark as processed + processed["done"] = True + + # Subscribe to get a single frame and process it + subscription = audio_observable.subscribe( + on_next=process_frame_wrapper, on_completed=lambda: print("Completed") + ) + + # Wait for frame processing to complete + while not processed["done"]: + time.sleep(0.01) + + # Now dispose the subscription from the main thread, not from within the callback + subscription.dispose() diff --git a/build/lib/dimos/stream/data_provider.py b/build/lib/dimos/stream/data_provider.py new file mode 100644 index 0000000000..73e1ba0f20 --- /dev/null +++ b/build/lib/dimos/stream/data_provider.py @@ -0,0 +1,183 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from reactivex import Subject, Observable +from reactivex.subject import Subject +from reactivex.scheduler import ThreadPoolScheduler +import multiprocessing +import logging + +import reactivex as rx +from reactivex import operators as ops + +logging.basicConfig(level=logging.INFO) + +# Create a thread pool scheduler for concurrent processing +pool_scheduler = ThreadPoolScheduler(multiprocessing.cpu_count()) + + +class AbstractDataProvider(ABC): + """Abstract base class for data providers using ReactiveX.""" + + def __init__(self, dev_name: str = "NA"): + self.dev_name = dev_name + self._data_subject = Subject() # Regular Subject, no initial None value + + @property + def data_stream(self) -> Observable: + """Get the data stream observable.""" + return self._data_subject + + def push_data(self, data): + """Push new data to the stream.""" + self._data_subject.on_next(data) + + def dispose(self): + """Cleanup resources.""" + self._data_subject.dispose() + + +class ROSDataProvider(AbstractDataProvider): + """ReactiveX data provider for ROS topics.""" + + def __init__(self, dev_name: str = "ros_provider"): + super().__init__(dev_name) + self.logger = logging.getLogger(dev_name) + + def push_data(self, data): + """Push new data to the stream.""" + print(f"ROSDataProvider pushing data of type: {type(data)}") + super().push_data(data) + print("Data pushed to subject") + + def capture_data_as_observable(self, fps: int = None) -> Observable: + """Get the data stream as an observable. + + Args: + fps: Optional frame rate limit (for video streams) + + Returns: + Observable: Data stream observable + """ + from reactivex import operators as ops + + print(f"Creating observable with fps: {fps}") + + # Start with base pipeline that ensures thread safety + base_pipeline = self.data_stream.pipe( + # Ensure emissions are handled on thread pool + ops.observe_on(pool_scheduler), + # Add debug logging to track data flow + ops.do_action( + on_next=lambda x: print(f"Got frame in pipeline: {type(x)}"), + on_error=lambda e: print(f"Pipeline error: {e}"), + on_completed=lambda: print("Pipeline completed"), + ), + ) + + # If fps is specified, add rate limiting + if fps and fps > 0: + print(f"Adding rate limiting at {fps} FPS") + return base_pipeline.pipe( + # Use scheduler for time-based operations + ops.sample(1.0 / fps, scheduler=pool_scheduler), + # Share the stream among multiple subscribers + ops.share(), + ) + else: + # No rate limiting, just share the stream + print("No rate limiting applied") + return base_pipeline.pipe(ops.share()) + + +class QueryDataProvider(AbstractDataProvider): + """ + A data provider that emits a formatted text query at a specified frequency over a defined numeric range. + + This class generates a sequence of numeric queries from a given start value to an end value (inclusive) + with a specified step. Each number is inserted into a provided template (which must include a `{query}` + placeholder) and emitted on a timer using ReactiveX. + + Attributes: + dev_name (str): The name of the data provider. + logger (logging.Logger): Logger instance for logging messages. + """ + + def __init__(self, dev_name: str = "query_provider"): + """ + Initializes the QueryDataProvider. + + Args: + dev_name (str): The name of the data provider. Defaults to "query_provider". + """ + super().__init__(dev_name) + self.logger = logging.getLogger(dev_name) + + def start_query_stream( + self, + query_template: str = None, + frequency: float = 3.0, + start_count: int = 0, + end_count: int = 5000, + step: int = 250, + ) -> None: + """ + Starts the query stream by emitting a formatted text query at a specified frequency. + + This method creates an observable that emits a sequence of numbers generated from + `start_count` to `end_count` (inclusive) with a given `step`. Each number is then formatted + using the `query_template`. The formatted query is pushed to the internal data stream. + + Args: + query_template (str): The template string for formatting queries. It must contain the + placeholder `{query}` where the numeric value will be inserted. If None, a default + template is used. + frequency (float): The frequency (in seconds) at which queries are emitted. Defaults to 3.0. + start_count (int): The starting number for query generation. Defaults to 0. + end_count (int): The ending number for query generation (inclusive). Defaults to 5000. + step (int): The increment between consecutive query numbers. Defaults to 250. + """ + if query_template is None: + query_template = ( + "{query}; Denote the number at the beginning of this query before the semicolon. " + "Only provide the number, without any other text in your response. " + "If the number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. " + "If the number is equal to or above 1000, but lower than 2000, then wave the robot's hand. " + "If the number is equal to or above 2000, then clear debris. " + "IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!" + ) + + # Generate the sequence of numeric queries. + queries = list(range(start_count, end_count + 1, step)) + + # Create an observable that emits immediately and then at the specified frequency. + timer = rx.timer(0, frequency) + query_source = rx.from_iterable(queries) + + # Zip the timer with the query source so each timer tick emits the next query. + query_stream = timer.pipe( + ops.zip(query_source), + ops.map(lambda pair: query_template.format(query=pair[1])), + ops.observe_on(pool_scheduler), + # ops.do_action( + # on_next=lambda q: self.logger.info(f"Emitting query: {q}"), + # on_error=lambda e: self.logger.error(f"Query stream error: {e}"), + # on_completed=lambda: self.logger.info("Query stream completed") + # ), + ops.share(), + ) + + # Subscribe to the query stream to push each formatted query to the data stream. + query_stream.subscribe(lambda q: self.push_data(q)) diff --git a/build/lib/dimos/stream/frame_processor.py b/build/lib/dimos/stream/frame_processor.py new file mode 100644 index 0000000000..b07a09118b --- /dev/null +++ b/build/lib/dimos/stream/frame_processor.py @@ -0,0 +1,300 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import os +from reactivex import Observable +from reactivex import operators as ops +from typing import Tuple, Optional + + +# TODO: Reorganize, filenaming - Consider merger with VideoOperators class +class FrameProcessor: + def __init__(self, output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=False): + """Initializes the FrameProcessor. + + Sets up the output directory for frame storage and optionally cleans up + existing JPG files. + + Args: + output_dir: Directory path for storing processed frames. + Defaults to '{os.getcwd()}/assets/output/frames'. + delete_on_init: If True, deletes all existing JPG files in output_dir. + Defaults to False. + + Raises: + OSError: If directory creation fails or if file deletion fails. + PermissionError: If lacking permissions for directory/file operations. + """ + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + if delete_on_init: + try: + jpg_files = [f for f in os.listdir(self.output_dir) if f.lower().endswith(".jpg")] + for file in jpg_files: + file_path = os.path.join(self.output_dir, file) + os.remove(file_path) + print(f"Cleaned up {len(jpg_files)} existing JPG files from {self.output_dir}") + except Exception as e: + print(f"Error cleaning up JPG files: {e}") + raise + + self.image_count = 1 + # TODO: Add randomness to jpg folder storage naming. + # Will overwrite between sessions. + + def to_grayscale(self, frame): + if frame is None: + print("Received None frame for grayscale conversion.") + return None + return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + def edge_detection(self, frame): + return cv2.Canny(frame, 100, 200) + + def resize(self, frame, scale=0.5): + return cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) + + def export_to_jpeg(self, frame, save_limit=100, loop=False, suffix=""): + if frame is None: + print("Error: Attempted to save a None image.") + return None + + # Check if the image has an acceptable number of channels + if len(frame.shape) == 3 and frame.shape[2] not in [1, 3, 4]: + print(f"Error: Frame with shape {frame.shape} has unsupported number of channels.") + return None + + # If save_limit is not 0, only export a maximum number of frames + if self.image_count > save_limit and save_limit != 0: + if loop: + self.image_count = 1 + else: + return frame + + filepath = os.path.join(self.output_dir, f"{self.image_count}_{suffix}.jpg") + cv2.imwrite(filepath, frame) + self.image_count += 1 + return frame + + def compute_optical_flow( + self, + acc: Tuple[np.ndarray, np.ndarray, Optional[float]], + current_frame: np.ndarray, + compute_relevancy: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, Optional[float]]: + """Computes optical flow between consecutive frames. + + Uses the Farneback algorithm to compute dense optical flow between the + previous and current frame. Optionally calculates a relevancy score + based on the mean magnitude of motion vectors. + + Args: + acc: Accumulator tuple containing: + prev_frame: Previous video frame (np.ndarray) + prev_flow: Previous optical flow (np.ndarray) + prev_relevancy: Previous relevancy score (float or None) + current_frame: Current video frame as BGR image (np.ndarray) + compute_relevancy: If True, calculates mean magnitude of flow vectors. + Defaults to True. + + Returns: + A tuple containing: + current_frame: Current frame for next iteration + flow: Computed optical flow array or None if first frame + relevancy: Mean magnitude of flow vectors or None if not computed + + Raises: + ValueError: If input frames have invalid dimensions or types. + TypeError: If acc is not a tuple of correct types. + """ + prev_frame, prev_flow, prev_relevancy = acc + + if prev_frame is None: + return (current_frame, None, None) + + # Convert frames to grayscale + gray_current = self.to_grayscale(current_frame) + gray_prev = self.to_grayscale(prev_frame) + + # Compute optical flow + flow = cv2.calcOpticalFlowFarneback(gray_prev, gray_current, None, 0.5, 3, 15, 3, 5, 1.2, 0) + + # Relevancy calulation (average magnitude of flow vectors) + relevancy = None + if compute_relevancy: + mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + relevancy = np.mean(mag) + + # Return the current frame as the new previous frame and the processed optical flow, with relevancy score + return (current_frame, flow, relevancy) + + def visualize_flow(self, flow): + if flow is None: + return None + hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) + hsv[..., 1] = 255 + mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + return rgb + + # ============================== + + def process_stream_edge_detection(self, frame_stream): + return frame_stream.pipe( + ops.map(self.edge_detection), + ) + + def process_stream_resize(self, frame_stream): + return frame_stream.pipe( + ops.map(self.resize), + ) + + def process_stream_to_greyscale(self, frame_stream): + return frame_stream.pipe( + ops.map(self.to_grayscale), + ) + + def process_stream_optical_flow(self, frame_stream: Observable) -> Observable: + """Processes video stream to compute and visualize optical flow. + + Computes optical flow between consecutive frames and generates a color-coded + visualization where hue represents flow direction and intensity represents + flow magnitude. This method optimizes performance by disabling relevancy + computation. + + Args: + frame_stream: An Observable emitting video frames as numpy arrays. + Each frame should be in BGR format with shape (height, width, 3). + + Returns: + An Observable emitting visualized optical flow frames as BGR images + (np.ndarray). Hue indicates flow direction, intensity shows magnitude. + + Raises: + TypeError: If frame_stream is not an Observable. + ValueError: If frames have invalid dimensions or format. + + Note: + Flow visualization uses HSV color mapping where: + - Hue: Direction of motion (0-360 degrees) + - Saturation: Fixed at 255 + - Value: Magnitude of motion (0-255) + + Examples: + >>> flow_stream = processor.process_stream_optical_flow(frame_stream) + >>> flow_stream.subscribe(lambda flow: cv2.imshow('Flow', flow)) + """ + return frame_stream.pipe( + ops.scan( + lambda acc, frame: self.compute_optical_flow(acc, frame, compute_relevancy=False), + (None, None, None), + ), + ops.map(lambda result: result[1]), # Extract flow component + ops.filter(lambda flow: flow is not None), + ops.map(self.visualize_flow), + ) + + def process_stream_optical_flow_with_relevancy(self, frame_stream: Observable) -> Observable: + """Processes video stream to compute optical flow with movement relevancy. + + Applies optical flow computation to each frame and returns both the + visualized flow and a relevancy score indicating the amount of movement. + The relevancy score is calculated as the mean magnitude of flow vectors. + This method includes relevancy computation for motion detection. + + Args: + frame_stream: An Observable emitting video frames as numpy arrays. + Each frame should be in BGR format with shape (height, width, 3). + + Returns: + An Observable emitting tuples of (visualized_flow, relevancy_score): + visualized_flow: np.ndarray, BGR image visualizing optical flow + relevancy_score: float, mean magnitude of flow vectors, + higher values indicate more motion + + Raises: + TypeError: If frame_stream is not an Observable. + ValueError: If frames have invalid dimensions or format. + + Examples: + >>> flow_stream = processor.process_stream_optical_flow_with_relevancy( + ... frame_stream + ... ) + >>> flow_stream.subscribe( + ... lambda result: print(f"Motion score: {result[1]}") + ... ) + + Note: + Relevancy scores are computed using mean magnitude of flow vectors. + Higher scores indicate more movement in the frame. + """ + return frame_stream.pipe( + ops.scan( + lambda acc, frame: self.compute_optical_flow(acc, frame, compute_relevancy=True), + (None, None, None), + ), + # Result is (current_frame, flow, relevancy) + ops.filter(lambda result: result[1] is not None), # Filter out None flows + ops.map( + lambda result: ( + self.visualize_flow(result[1]), # Visualized flow + result[2], # Relevancy score + ) + ), + ops.filter(lambda result: result[0] is not None), # Ensure valid visualization + ) + + def process_stream_with_jpeg_export( + self, frame_stream: Observable, suffix: str = "", loop: bool = False + ) -> Observable: + """Processes stream by saving frames as JPEGs while passing them through. + + Saves each frame from the stream as a JPEG file and passes the frame + downstream unmodified. Files are saved sequentially with optional suffix + in the configured output directory (self.output_dir). If loop is True, + it will cycle back and overwrite images starting from the first one + after reaching the save_limit. + + Args: + frame_stream: An Observable emitting video frames as numpy arrays. + Each frame should be in BGR format with shape (height, width, 3). + suffix: Optional string to append to filename before index. + Defaults to empty string. Example: "optical" -> "optical_1.jpg" + loop: If True, reset the image counter to 1 after reaching + save_limit, effectively looping the saves. Defaults to False. + + Returns: + An Observable emitting the same frames that were saved. Returns None + for frames that could not be saved due to format issues or save_limit + (unless loop is True). + + Raises: + TypeError: If frame_stream is not an Observable. + ValueError: If frames have invalid format or output directory + is not writable. + OSError: If there are file system permission issues. + + Note: + Frames are saved as '{suffix}_{index}.jpg' where index + increments for each saved frame. Saving stops after reaching + the configured save_limit (default: 100) unless loop is True. + """ + return frame_stream.pipe( + ops.map(lambda frame: self.export_to_jpeg(frame, suffix=suffix, loop=loop)), + ) diff --git a/build/lib/dimos/stream/ros_video_provider.py b/build/lib/dimos/stream/ros_video_provider.py new file mode 100644 index 0000000000..7ca6fa4aa7 --- /dev/null +++ b/build/lib/dimos/stream/ros_video_provider.py @@ -0,0 +1,112 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROS-based video provider module. + +This module provides a video frame provider that receives frames from ROS (Robot Operating System) +and makes them available as an Observable stream. +""" + +from reactivex import Subject, Observable +from reactivex import operators as ops +from reactivex.scheduler import ThreadPoolScheduler +import logging +import time +from typing import Optional +import numpy as np + +from dimos.stream.video_provider import AbstractVideoProvider + +logging.basicConfig(level=logging.INFO) + + +class ROSVideoProvider(AbstractVideoProvider): + """Video provider that uses a Subject to broadcast frames pushed by ROS. + + This class implements a video provider that receives frames from ROS and makes them + available as an Observable stream. It uses ReactiveX's Subject to broadcast frames. + + Attributes: + logger: Logger instance for this provider. + _subject: ReactiveX Subject that broadcasts frames. + _last_frame_time: Timestamp of the last received frame. + """ + + def __init__( + self, dev_name: str = "ros_video", pool_scheduler: Optional[ThreadPoolScheduler] = None + ): + """Initialize the ROS video provider. + + Args: + dev_name: A string identifying this provider. + pool_scheduler: Optional ThreadPoolScheduler for multithreading. + """ + super().__init__(dev_name, pool_scheduler) + self.logger = logging.getLogger(dev_name) + self._subject = Subject() + self._last_frame_time = None + self.logger.info("ROSVideoProvider initialized") + + def push_data(self, frame: np.ndarray) -> None: + """Push a new frame into the provider. + + Args: + frame: The video frame to push into the stream, typically a numpy array + containing image data. + + Raises: + Exception: If there's an error pushing the frame. + """ + try: + current_time = time.time() + if self._last_frame_time: + frame_interval = current_time - self._last_frame_time + self.logger.debug( + f"Frame interval: {frame_interval:.3f}s ({1 / frame_interval:.1f} FPS)" + ) + self._last_frame_time = current_time + + self.logger.debug(f"Pushing frame type: {type(frame)}") + self._subject.on_next(frame) + self.logger.debug("Frame pushed") + except Exception as e: + self.logger.error(f"Push error: {e}") + raise + + def capture_video_as_observable(self, fps: int = 30) -> Observable: + """Return an observable of video frames. + + Args: + fps: Frames per second rate limit (default: 30; ignored for now). + + Returns: + Observable: An observable stream of video frames (numpy.ndarray objects), + with each emission containing a single video frame. The frames are + multicast to all subscribers. + + Note: + The fps parameter is currently not enforced. See implementation note below. + """ + self.logger.info(f"Creating observable with {fps} FPS rate limiting") + # TODO: Implement rate limiting using ops.throttle_with_timeout() or + # ops.sample() to restrict emissions to one frame per (1/fps) seconds. + # Example: ops.sample(1.0/fps) + return self._subject.pipe( + # Ensure subscription work happens on the thread pool + ops.subscribe_on(self.pool_scheduler), + # Ensure observer callbacks execute on the thread pool + ops.observe_on(self.pool_scheduler), + # Make the stream hot/multicast so multiple subscribers get the same frames + ops.share(), + ) diff --git a/build/lib/dimos/stream/rtsp_video_provider.py b/build/lib/dimos/stream/rtsp_video_provider.py new file mode 100644 index 0000000000..5926c4f676 --- /dev/null +++ b/build/lib/dimos/stream/rtsp_video_provider.py @@ -0,0 +1,380 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RTSP video provider using ffmpeg for robust stream handling.""" + +import subprocess +import threading +import time +from typing import Optional + +import ffmpeg # ffmpeg-python wrapper +import numpy as np +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.observable import Observable +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.utils.logging_config import setup_logger + +# Assuming AbstractVideoProvider and exceptions are in the sibling file +from .video_provider import AbstractVideoProvider, VideoFrameError, VideoSourceError + +logger = setup_logger("dimos.stream.rtsp_video_provider") + + +class RtspVideoProvider(AbstractVideoProvider): + """Video provider implementation for capturing RTSP streams using ffmpeg. + + This provider uses the ffmpeg-python library to interact with ffmpeg, + providing more robust handling of various RTSP streams compared to OpenCV's + built-in VideoCapture for RTSP. + """ + + def __init__( + self, dev_name: str, rtsp_url: str, pool_scheduler: Optional[ThreadPoolScheduler] = None + ) -> None: + """Initializes the RTSP video provider. + + Args: + dev_name: The name of the device or stream (for identification). + rtsp_url: The URL of the RTSP stream (e.g., "rtsp://user:pass@ip:port/path"). + pool_scheduler: The scheduler for thread pool operations. Defaults to global scheduler. + """ + super().__init__(dev_name, pool_scheduler) + self.rtsp_url = rtsp_url + # Holds the currently active ffmpeg process Popen object + self._ffmpeg_process: Optional[subprocess.Popen] = None + # Lock to protect access to the ffmpeg process object + self._lock = threading.Lock() + + def _get_stream_info(self) -> dict: + """Probes the RTSP stream to get video dimensions and FPS using ffprobe.""" + logger.info(f"({self.dev_name}) Probing RTSP stream.") + try: + # Probe the stream without the problematic timeout argument + probe = ffmpeg.probe(self.rtsp_url) + except ffmpeg.Error as e: + stderr = e.stderr.decode("utf8", errors="ignore") if e.stderr else "No stderr" + msg = f"({self.dev_name}) Failed to probe RTSP stream {self.rtsp_url}: {stderr}" + logger.error(msg) + raise VideoSourceError(msg) from e + except Exception as e: + msg = f"({self.dev_name}) Unexpected error during probing {self.rtsp_url}: {e}" + logger.error(msg) + raise VideoSourceError(msg) from e + + video_stream = next( + (stream for stream in probe.get("streams", []) if stream.get("codec_type") == "video"), + None, + ) + + if video_stream is None: + msg = f"({self.dev_name}) No video stream found in {self.rtsp_url}" + logger.error(msg) + raise VideoSourceError(msg) + + width = video_stream.get("width") + height = video_stream.get("height") + fps_str = video_stream.get("avg_frame_rate", "0/1") + + if not width or not height: + msg = f"({self.dev_name}) Could not determine resolution for {self.rtsp_url}. Stream info: {video_stream}" + logger.error(msg) + raise VideoSourceError(msg) + + try: + if "/" in fps_str: + num, den = map(int, fps_str.split("/")) + fps = float(num) / den if den != 0 else 30.0 + else: + fps = float(fps_str) + if fps <= 0: + logger.warning( + f"({self.dev_name}) Invalid avg_frame_rate '{fps_str}', defaulting FPS to 30." + ) + fps = 30.0 + except ValueError: + logger.warning( + f"({self.dev_name}) Could not parse FPS '{fps_str}', defaulting FPS to 30." + ) + fps = 30.0 + + logger.info(f"({self.dev_name}) Stream info: {width}x{height} @ {fps:.2f} FPS") + return {"width": width, "height": height, "fps": fps} + + def _start_ffmpeg_process(self, width: int, height: int) -> subprocess.Popen: + """Starts the ffmpeg process to capture and decode the stream.""" + logger.info(f"({self.dev_name}) Starting ffmpeg process for rtsp stream.") + try: + # Configure ffmpeg input: prefer TCP, set timeout, reduce buffering/delay + input_options = { + "rtsp_transport": "tcp", + "stimeout": "5000000", # 5 seconds timeout for RTSP server responses + "fflags": "nobuffer", # Reduce input buffering + "flags": "low_delay", # Reduce decoding delay + # 'timeout': '10000000' # Removed: This was misinterpreted as listen timeout + } + process = ( + ffmpeg.input(self.rtsp_url, **input_options) + .output("pipe:", format="rawvideo", pix_fmt="bgr24") # Output raw BGR frames + .global_args("-loglevel", "warning") # Reduce ffmpeg log spam, use 'error' for less + .run_async(pipe_stdout=True, pipe_stderr=True) # Capture stdout and stderr + ) + logger.info(f"({self.dev_name}) ffmpeg process started (PID: {process.pid})") + return process + except ffmpeg.Error as e: + stderr = e.stderr.decode("utf8", errors="ignore") if e.stderr else "No stderr" + msg = f"({self.dev_name}) Failed to start ffmpeg for {self.rtsp_url}: {stderr}" + logger.error(msg) + raise VideoSourceError(msg) from e + except Exception as e: # Catch other errors like ffmpeg executable not found + msg = f"({self.dev_name}) An unexpected error occurred starting ffmpeg: {e}" + logger.error(msg) + raise VideoSourceError(msg) from e + + def capture_video_as_observable(self, fps: int = 0) -> Observable: + """Creates an observable from the RTSP stream using ffmpeg. + + The observable attempts to reconnect if the stream drops. + + Args: + fps: This argument is currently ignored. The provider attempts + to use the stream's native frame rate. Future versions might + allow specifying an output FPS via ffmpeg filters. + + Returns: + Observable: An observable emitting video frames as NumPy arrays (BGR format). + + Raises: + VideoSourceError: If the stream cannot be initially probed or the + ffmpeg process fails to start. + VideoFrameError: If there's an error reading or processing frames. + """ + if fps != 0: + logger.warning( + f"({self.dev_name}) The 'fps' argument ({fps}) is currently ignored. Using stream native FPS." + ) + + def emit_frames(observer, scheduler): + """Function executed by rx.create to emit frames.""" + process: Optional[subprocess.Popen] = None + # Event to signal the processing loop should stop (e.g., on dispose) + should_stop = threading.Event() + + def cleanup_process(): + """Safely terminate the ffmpeg process if it's running.""" + nonlocal process + logger.debug(f"({self.dev_name}) Cleanup requested.") + # Use lock to ensure thread safety when accessing/modifying process + with self._lock: + # Check if the process exists and is still running + if process and process.poll() is None: + logger.info( + f"({self.dev_name}) Terminating ffmpeg process (PID: {process.pid})." + ) + try: + process.terminate() # Ask ffmpeg to exit gracefully + process.wait(timeout=1.0) # Wait up to 1 second + except subprocess.TimeoutExpired: + logger.warning( + f"({self.dev_name}) ffmpeg (PID: {process.pid}) did not terminate gracefully, killing." + ) + process.kill() # Force kill if it didn't exit + except Exception as e: + logger.error(f"({self.dev_name}) Error during ffmpeg termination: {e}") + finally: + # Ensure we clear the process variable even if wait/kill fails + process = None + # Also clear the shared class attribute if this was the active process + if self._ffmpeg_process and self._ffmpeg_process.pid == process.pid: + self._ffmpeg_process = None + elif process and process.poll() is not None: + # Process exists but already terminated + logger.debug( + f"({self.dev_name}) ffmpeg process (PID: {process.pid}) already terminated (exit code: {process.poll()})." + ) + process = None # Clear the variable + # Clear shared attribute if it matches + if self._ffmpeg_process and self._ffmpeg_process.pid == process.pid: + self._ffmpeg_process = None + else: + # Process variable is already None or doesn't match _ffmpeg_process + logger.debug( + f"({self.dev_name}) No active ffmpeg process found needing termination in cleanup." + ) + + try: + # 1. Probe the stream to get essential info (width, height) + stream_info = self._get_stream_info() + width = stream_info["width"] + height = stream_info["height"] + # Calculate expected bytes per frame (BGR format = 3 bytes per pixel) + frame_size = width * height * 3 + + # 2. Main loop: Start ffmpeg and read frames. Retry on failure. + while not should_stop.is_set(): + try: + # Start or reuse the ffmpeg process + with self._lock: + # Check if another thread/subscription already started the process + if self._ffmpeg_process and self._ffmpeg_process.poll() is None: + logger.warning( + f"({self.dev_name}) ffmpeg process (PID: {self._ffmpeg_process.pid}) seems to be already running. Reusing." + ) + process = self._ffmpeg_process + else: + # Start a new ffmpeg process + process = self._start_ffmpeg_process(width, height) + # Store the new process handle in the shared class attribute + self._ffmpeg_process = process + + # 3. Frame reading loop + while not should_stop.is_set(): + # Read exactly one frame's worth of bytes + in_bytes = process.stdout.read(frame_size) + + if len(in_bytes) == 0: + # End of stream or process terminated unexpectedly + logger.warning( + f"({self.dev_name}) ffmpeg stdout returned 0 bytes. EOF or process terminated." + ) + process.wait(timeout=0.5) # Allow stderr to flush + stderr_data = process.stderr.read().decode("utf8", errors="ignore") + exit_code = process.poll() + logger.warning( + f"({self.dev_name}) ffmpeg process (PID: {process.pid}) exited with code {exit_code}. Stderr: {stderr_data}" + ) + # Break inner loop to trigger cleanup and potential restart + with self._lock: + # Clear the shared process handle if it matches the one that just exited + if ( + self._ffmpeg_process + and self._ffmpeg_process.pid == process.pid + ): + self._ffmpeg_process = None + process = None # Clear local process variable + break # Exit frame reading loop + + elif len(in_bytes) != frame_size: + # Received incomplete frame data - indicates a problem + msg = f"({self.dev_name}) Incomplete frame read. Expected {frame_size}, got {len(in_bytes)}. Stopping." + logger.error(msg) + observer.on_error(VideoFrameError(msg)) + should_stop.set() # Signal outer loop to stop + break # Exit frame reading loop + + # Convert the raw bytes to a NumPy array (height, width, channels) + frame = np.frombuffer(in_bytes, np.uint8).reshape((height, width, 3)) + # Emit the frame to subscribers + observer.on_next(frame) + + # 4. Handle ffmpeg process exit/crash (if not stopping deliberately) + if not should_stop.is_set() and process is None: + logger.info( + f"({self.dev_name}) ffmpeg process ended. Attempting reconnection in 5 seconds..." + ) + # Wait for a few seconds before trying to restart + time.sleep(5) + # Continue to the next iteration of the outer loop to restart + + except (VideoSourceError, ffmpeg.Error) as e: + # Errors during ffmpeg process start or severe runtime errors + logger.error( + f"({self.dev_name}) Unrecoverable ffmpeg error: {e}. Stopping emission." + ) + observer.on_error(e) + should_stop.set() # Stop retrying + except Exception as e: + # Catch other unexpected errors during frame reading/processing + logger.error( + f"({self.dev_name}) Unexpected error processing stream: {e}", + exc_info=True, + ) + observer.on_error(VideoFrameError(f"Frame processing failed: {e}")) + should_stop.set() # Stop retrying + + # 5. Loop finished (likely due to should_stop being set) + logger.info(f"({self.dev_name}) Frame emission loop stopped.") + observer.on_completed() + + except VideoSourceError as e: + # Handle errors during the initial probing phase + logger.error(f"({self.dev_name}) Failed initial setup: {e}") + observer.on_error(e) + except Exception as e: + # Catch-all for unexpected errors before the main loop starts + logger.error(f"({self.dev_name}) Unexpected setup error: {e}", exc_info=True) + observer.on_error(VideoSourceError(f"Setup failed: {e}")) + finally: + # Crucial: Ensure the ffmpeg process is terminated when the observable + # is completed, errored, or disposed. + logger.debug(f"({self.dev_name}) Entering finally block in emit_frames.") + cleanup_process() + + # Return a Disposable that, when called (by unsubscribe/dispose), + # signals the loop to stop. The finally block handles the actual cleanup. + return Disposable(should_stop.set) + + # Create the observable using rx.create, applying scheduling and sharing + return rx.create(emit_frames).pipe( + ops.subscribe_on(self.pool_scheduler), # Run the emit_frames logic on the pool + # ops.observe_on(self.pool_scheduler), # Optional: Switch thread for downstream operators + ops.share(), # Ensure multiple subscribers share the same ffmpeg process + ) + + def dispose_all(self) -> None: + """Disposes of all managed resources, including terminating the ffmpeg process.""" + logger.info(f"({self.dev_name}) dispose_all called.") + # Terminate the ffmpeg process using the same locked logic as cleanup + with self._lock: + process = self._ffmpeg_process # Get the current process handle + if process and process.poll() is None: + logger.info( + f"({self.dev_name}) Terminating ffmpeg process (PID: {process.pid}) via dispose_all." + ) + try: + process.terminate() + process.wait(timeout=1.0) + except subprocess.TimeoutExpired: + logger.warning( + f"({self.dev_name}) ffmpeg process (PID: {process.pid}) kill required in dispose_all." + ) + process.kill() + except Exception as e: + logger.error( + f"({self.dev_name}) Error during ffmpeg termination in dispose_all: {e}" + ) + finally: + self._ffmpeg_process = None # Clear handle after attempting termination + elif process: # Process exists but already terminated + logger.debug( + f"({self.dev_name}) ffmpeg process (PID: {process.pid}) already terminated in dispose_all." + ) + self._ffmpeg_process = None + else: + logger.debug( + f"({self.dev_name}) No active ffmpeg process found during dispose_all." + ) + + # Call the parent class's dispose_all to handle Rx Disposables + super().dispose_all() + + def __del__(self) -> None: + """Destructor attempts to clean up resources if not explicitly disposed.""" + # Logging in __del__ is generally discouraged due to interpreter state issues, + # but can be helpful for debugging resource leaks. Use print for robustness here if needed. + # print(f"DEBUG: ({self.dev_name}) __del__ called.") + self.dispose_all() diff --git a/build/lib/dimos/stream/stream_merger.py b/build/lib/dimos/stream/stream_merger.py new file mode 100644 index 0000000000..6f854b2d80 --- /dev/null +++ b/build/lib/dimos/stream/stream_merger.py @@ -0,0 +1,45 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, TypeVar, Tuple +from reactivex import Observable +from reactivex import operators as ops + +T = TypeVar("T") +Q = TypeVar("Q") + + +def create_stream_merger( + data_input_stream: Observable[T], text_query_stream: Observable[Q] +) -> Observable[Tuple[Q, List[T]]]: + """ + Creates a merged stream that combines the latest value from data_input_stream + with each value from text_query_stream. + + Args: + data_input_stream: Observable stream of data values + text_query_stream: Observable stream of query values + + Returns: + Observable that emits tuples of (query, latest_data) + """ + # Encompass any data items as a list for safe evaluation + safe_data_stream = data_input_stream.pipe( + # We don't modify the data, just pass it through in a list + # This avoids any boolean evaluation of arrays + ops.map(lambda x: [x]) + ) + + # Use safe_data_stream instead of raw data_input_stream + return text_query_stream.pipe(ops.with_latest_from(safe_data_stream)) diff --git a/build/lib/dimos/stream/video_operators.py b/build/lib/dimos/stream/video_operators.py new file mode 100644 index 0000000000..78ba7518a1 --- /dev/null +++ b/build/lib/dimos/stream/video_operators.py @@ -0,0 +1,622 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +import cv2 +import numpy as np +from reactivex import Observable, Observer, create +from reactivex import operators as ops +from typing import Any, Callable, Tuple, Optional + +import zmq +import base64 +from enum import Enum + +from dimos.stream.frame_processor import FrameProcessor + + +class VideoOperators: + """Collection of video processing operators for reactive video streams.""" + + @staticmethod + def with_fps_sampling( + fps: int = 25, *, sample_interval: Optional[timedelta] = None, use_latest: bool = True + ) -> Callable[[Observable], Observable]: + """Creates an operator that samples frames at a specified rate. + + Creates a transformation operator that samples frames either by taking + the latest frame or the first frame in each interval. Provides frame + rate control through time-based selection. + + Args: + fps: Desired frames per second, defaults to 25 FPS. + Ignored if sample_interval is provided. + sample_interval: Optional explicit interval between samples. + If provided, overrides the fps parameter. + use_latest: If True, uses the latest frame in interval. + If False, uses the first frame. Defaults to True. + + Returns: + A function that transforms an Observable[np.ndarray] stream to a sampled + Observable[np.ndarray] stream with controlled frame rate. + + Raises: + ValueError: If fps is not positive or sample_interval is negative. + TypeError: If sample_interval is provided but not a timedelta object. + + Examples: + Sample latest frame at 30 FPS (good for real-time): + >>> video_stream.pipe( + ... VideoOperators.with_fps_sampling(fps=30) + ... ) + + Sample first frame with custom interval (good for consistent timing): + >>> video_stream.pipe( + ... VideoOperators.with_fps_sampling( + ... sample_interval=timedelta(milliseconds=40), + ... use_latest=False + ... ) + ... ) + + Note: + This operator helps manage high-speed video streams through time-based + frame selection. It reduces the frame rate by selecting frames at + specified intervals. + + When use_latest=True: + - Uses sampling to select the most recent frame at fixed intervals + - Discards intermediate frames, keeping only the latest + - Best for real-time video where latest frame is most relevant + - Uses ops.sample internally + + When use_latest=False: + - Uses throttling to select the first frame in each interval + - Ignores subsequent frames until next interval + - Best for scenarios where you want consistent frame timing + - Uses ops.throttle_first internally + + This is an approropriate solution for managing video frame rates and + memory usage in many scenarios. + """ + if sample_interval is None: + if fps <= 0: + raise ValueError("FPS must be positive") + sample_interval = timedelta(microseconds=int(1_000_000 / fps)) + + def _operator(source: Observable) -> Observable: + return source.pipe( + ops.sample(sample_interval) if use_latest else ops.throttle_first(sample_interval) + ) + + return _operator + + @staticmethod + def with_jpeg_export( + frame_processor: "FrameProcessor", + save_limit: int = 100, + suffix: str = "", + loop: bool = False, + ) -> Callable[[Observable], Observable]: + """Creates an operator that saves video frames as JPEG files. + + Creates a transformation operator that saves each frame from the video + stream as a JPEG file while passing the frame through unchanged. + + Args: + frame_processor: FrameProcessor instance that handles the JPEG export + operations and maintains file count. + save_limit: Maximum number of frames to save before stopping. + Defaults to 100. Set to 0 for unlimited saves. + suffix: Optional string to append to filename before index. + Example: "raw" creates "1_raw.jpg". + Defaults to empty string. + loop: If True, when save_limit is reached, the files saved are + loopbacked and overwritten with the most recent frame. + Defaults to False. + Returns: + A function that transforms an Observable of frames into another + Observable of the same frames, with side effect of saving JPEGs. + + Raises: + ValueError: If save_limit is negative. + TypeError: If frame_processor is not a FrameProcessor instance. + + Example: + >>> video_stream.pipe( + ... VideoOperators.with_jpeg_export(processor, suffix="raw") + ... ) + """ + + def _operator(source: Observable) -> Observable: + return source.pipe( + ops.map( + lambda frame: frame_processor.export_to_jpeg(frame, save_limit, loop, suffix) + ) + ) + + return _operator + + @staticmethod + def with_optical_flow_filtering(threshold: float = 1.0) -> Callable[[Observable], Observable]: + """Creates an operator that filters optical flow frames by relevancy score. + + Filters a stream of optical flow results (frame, relevancy_score) tuples, + passing through only frames that meet the relevancy threshold. + + Args: + threshold: Minimum relevancy score required for frames to pass through. + Defaults to 1.0. Higher values mean more motion required. + + Returns: + A function that transforms an Observable of (frame, score) tuples + into an Observable of frames that meet the threshold. + + Raises: + ValueError: If threshold is negative. + TypeError: If input stream items are not (frame, float) tuples. + + Examples: + Basic filtering: + >>> optical_flow_stream.pipe( + ... VideoOperators.with_optical_flow_filtering(threshold=1.0) + ... ) + + With custom threshold: + >>> optical_flow_stream.pipe( + ... VideoOperators.with_optical_flow_filtering(threshold=2.5) + ... ) + + Note: + Input stream should contain tuples of (frame, relevancy_score) where + frame is a numpy array and relevancy_score is a float or None. + None scores are filtered out. + """ + return lambda source: source.pipe( + ops.filter(lambda result: result[1] is not None), + ops.filter(lambda result: result[1] > threshold), + ops.map(lambda result: result[0]), + ) + + @staticmethod + def with_edge_detection( + frame_processor: "FrameProcessor", + ) -> Callable[[Observable], Observable]: + return lambda source: source.pipe( + ops.map(lambda frame: frame_processor.edge_detection(frame)) + ) + + @staticmethod + def with_optical_flow( + frame_processor: "FrameProcessor", + ) -> Callable[[Observable], Observable]: + return lambda source: source.pipe( + ops.scan( + lambda acc, frame: frame_processor.compute_optical_flow( + acc, frame, compute_relevancy=False + ), + (None, None, None), + ), + ops.map(lambda result: result[1]), # Extract flow component + ops.filter(lambda flow: flow is not None), + ops.map(frame_processor.visualize_flow), + ) + + @staticmethod + def with_zmq_socket( + socket: zmq.Socket, scheduler: Optional[Any] = None + ) -> Callable[[Observable], Observable]: + def send_frame(frame, socket): + _, img_encoded = cv2.imencode(".jpg", frame) + socket.send(img_encoded.tobytes()) + # print(f"Frame received: {frame.shape}") + + # Use a default scheduler if none is provided + if scheduler is None: + from reactivex.scheduler import ThreadPoolScheduler + + scheduler = ThreadPoolScheduler(1) # Single-threaded pool for isolation + + return lambda source: source.pipe( + ops.observe_on(scheduler), # Ensure this part runs on its own thread + ops.do_action(lambda frame: send_frame(frame, socket)), + ) + + @staticmethod + def encode_image() -> Callable[[Observable], Observable]: + """ + Operator to encode an image to JPEG format and convert it to a Base64 string. + + Returns: + A function that transforms an Observable of images into an Observable + of tuples containing the Base64 string of the encoded image and its dimensions. + """ + + def _operator(source: Observable) -> Observable: + def _encode_image(image: np.ndarray) -> Tuple[str, Tuple[int, int]]: + try: + width, height = image.shape[:2] + _, buffer = cv2.imencode(".jpg", image) + if buffer is None: + raise ValueError("Failed to encode image") + base64_image = base64.b64encode(buffer).decode("utf-8") + return base64_image, (width, height) + except Exception as e: + raise e + + return source.pipe(ops.map(_encode_image)) + + return _operator + + +from reactivex.disposable import Disposable +from reactivex import Observable +from threading import Lock + + +class Operators: + @staticmethod + def exhaust_lock(process_item): + """ + For each incoming item, call `process_item(item)` to get an Observable. + - If we're busy processing the previous one, skip new items. + - Use a lock to ensure concurrency safety across threads. + """ + + def _exhaust_lock(source: Observable) -> Observable: + def _subscribe(observer, scheduler=None): + in_flight = False + lock = Lock() + upstream_done = False + + upstream_disp = None + active_inner_disp = None + + def dispose_all(): + if upstream_disp: + upstream_disp.dispose() + if active_inner_disp: + active_inner_disp.dispose() + + def on_next(value): + nonlocal in_flight, active_inner_disp + lock.acquire() + try: + if not in_flight: + in_flight = True + print("Processing new item.") + else: + print("Skipping item, already processing.") + return + finally: + lock.release() + + # We only get here if we grabbed the in_flight slot + try: + inner_source = process_item(value) + except Exception as ex: + observer.on_error(ex) + return + + def inner_on_next(ivalue): + observer.on_next(ivalue) + + def inner_on_error(err): + nonlocal in_flight + with lock: + in_flight = False + observer.on_error(err) + + def inner_on_completed(): + nonlocal in_flight + with lock: + in_flight = False + if upstream_done: + observer.on_completed() + + # Subscribe to the inner observable + nonlocal active_inner_disp + active_inner_disp = inner_source.subscribe( + on_next=inner_on_next, + on_error=inner_on_error, + on_completed=inner_on_completed, + scheduler=scheduler, + ) + + def on_error(err): + dispose_all() + observer.on_error(err) + + def on_completed(): + nonlocal upstream_done + with lock: + upstream_done = True + # If we're not busy, we can end now + if not in_flight: + observer.on_completed() + + upstream_disp = source.subscribe( + on_next, on_error, on_completed, scheduler=scheduler + ) + return dispose_all + + return create(_subscribe) + + return _exhaust_lock + + @staticmethod + def exhaust_lock_per_instance(process_item, lock: Lock): + """ + - For each item from upstream, call process_item(item) -> Observable. + - If a frame arrives while one is "in flight", discard it. + - 'lock' ensures we safely check/modify the 'in_flight' state in a multithreaded environment. + """ + + def _exhaust_lock(source: Observable) -> Observable: + def _subscribe(observer, scheduler=None): + in_flight = False + upstream_done = False + + upstream_disp = None + active_inner_disp = None + + def dispose_all(): + if upstream_disp: + upstream_disp.dispose() + if active_inner_disp: + active_inner_disp.dispose() + + def on_next(value): + nonlocal in_flight, active_inner_disp + with lock: + # If not busy, claim the slot + if not in_flight: + in_flight = True + print("\033[34mProcessing new item.\033[0m") + else: + # Already processing => drop + print("\033[34mSkipping item, already processing.\033[0m") + return + + # We only get here if we acquired the slot + try: + inner_source = process_item(value) + except Exception as ex: + observer.on_error(ex) + return + + def inner_on_next(ivalue): + observer.on_next(ivalue) + + def inner_on_error(err): + nonlocal in_flight + with lock: + in_flight = False + print("\033[34mError in inner on error.\033[0m") + observer.on_error(err) + + def inner_on_completed(): + nonlocal in_flight + with lock: + in_flight = False + print("\033[34mInner on completed.\033[0m") + if upstream_done: + observer.on_completed() + + # Subscribe to the inner Observable + nonlocal active_inner_disp + active_inner_disp = inner_source.subscribe( + on_next=inner_on_next, + on_error=inner_on_error, + on_completed=inner_on_completed, + scheduler=scheduler, + ) + + def on_error(e): + dispose_all() + observer.on_error(e) + + def on_completed(): + nonlocal upstream_done + with lock: + upstream_done = True + print("\033[34mOn completed.\033[0m") + if not in_flight: + observer.on_completed() + + upstream_disp = source.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + scheduler=scheduler, + ) + + return Disposable(dispose_all) + + return create(_subscribe) + + return _exhaust_lock + + @staticmethod + def exhaust_map(project): + def _exhaust_map(source: Observable): + def subscribe(observer, scheduler=None): + is_processing = False + + def on_next(item): + nonlocal is_processing + if not is_processing: + is_processing = True + print("\033[35mProcessing item.\033[0m") + try: + inner_observable = project(item) # Create the inner observable + inner_observable.subscribe( + on_next=observer.on_next, + on_error=observer.on_error, + on_completed=lambda: set_not_processing(), + scheduler=scheduler, + ) + except Exception as e: + observer.on_error(e) + else: + print("\033[35mSkipping item, already processing.\033[0m") + + def set_not_processing(): + nonlocal is_processing + is_processing = False + print("\033[35mItem processed.\033[0m") + + return source.subscribe( + on_next=on_next, + on_error=observer.on_error, + on_completed=observer.on_completed, + scheduler=scheduler, + ) + + return create(subscribe) + + return _exhaust_map + + @staticmethod + def with_lock(lock: Lock): + def operator(source: Observable): + def subscribe(observer, scheduler=None): + def on_next(item): + if not lock.locked(): # Check if the lock is free + if lock.acquire(blocking=False): # Non-blocking acquire + try: + print("\033[32mAcquired lock, processing item.\033[0m") + observer.on_next(item) + finally: # Ensure lock release even if observer.on_next throws + lock.release() + else: + print("\033[34mLock busy, skipping item.\033[0m") + else: + print("\033[34mLock busy, skipping item.\033[0m") + + def on_error(error): + observer.on_error(error) + + def on_completed(): + observer.on_completed() + + return source.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + scheduler=scheduler, + ) + + return Observable(subscribe) + + return operator + + @staticmethod + def with_lock_check(lock: Lock): # Renamed for clarity + def operator(source: Observable): + def subscribe(observer, scheduler=None): + def on_next(item): + if not lock.locked(): # Check if the lock is held WITHOUT acquiring + print(f"\033[32mLock is free, processing item: {item}\033[0m") + observer.on_next(item) + else: + print(f"\033[34mLock is busy, skipping item: {item}\033[0m") + # observer.on_completed() + + def on_error(error): + observer.on_error(error) + + def on_completed(): + observer.on_completed() + + return source.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + scheduler=scheduler, + ) + + return Observable(subscribe) + + return operator + + # PrintColor enum for standardized color formatting + class PrintColor(Enum): + RED = "\033[31m" + GREEN = "\033[32m" + BLUE = "\033[34m" + YELLOW = "\033[33m" + MAGENTA = "\033[35m" + CYAN = "\033[36m" + WHITE = "\033[37m" + RESET = "\033[0m" + + @staticmethod + def print_emission( + id: str, + dev_name: str = "NA", + counts: dict = None, + color: "Operators.PrintColor" = None, + enabled: bool = True, + ): + """ + Creates an operator that prints the emission with optional counts for debugging. + + Args: + id: Identifier for the emission point (e.g., 'A', 'B') + dev_name: Device or component name for context + counts: External dictionary to track emission count across operators. If None, will not print emission count. + color: Color for the printed output from PrintColor enum (default is RED) + enabled: Whether to print the emission count (default is True) + Returns: + An operator that counts and prints emissions without modifying the stream + """ + # If enabled is false, return the source unchanged + if not enabled: + return lambda source: source + + # Use RED as default if no color provided + if color is None: + color = Operators.PrintColor.RED + + def _operator(source: Observable) -> Observable: + def _subscribe(observer: Observer, scheduler=None): + def on_next(value): + if counts is not None: + # Initialize count if necessary + if id not in counts: + counts[id] = 0 + + # Increment and print + counts[id] += 1 + print( + f"{color.value}({dev_name} - {id}) Emission Count - {counts[id]} {datetime.now()}{Operators.PrintColor.RESET.value}" + ) + else: + print( + f"{color.value}({dev_name} - {id}) Emitted - {datetime.now()}{Operators.PrintColor.RESET.value}" + ) + + # Pass value through unchanged + observer.on_next(value) + + return source.subscribe( + on_next=on_next, + on_error=observer.on_error, + on_completed=observer.on_completed, + scheduler=scheduler, + ) + + return create(_subscribe) + + return _operator diff --git a/build/lib/dimos/stream/video_provider.py b/build/lib/dimos/stream/video_provider.py new file mode 100644 index 0000000000..050905a024 --- /dev/null +++ b/build/lib/dimos/stream/video_provider.py @@ -0,0 +1,235 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Video provider module for capturing and streaming video frames. + +This module provides classes for capturing video from various sources and +exposing them as reactive observables. It handles resource management, +frame rate control, and thread safety. +""" + +# Standard library imports +import logging +import os +import time +from abc import ABC, abstractmethod +from threading import Lock +from typing import Optional + +# Third-party imports +import cv2 +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import CompositeDisposable +from reactivex.observable import Observable +from reactivex.scheduler import ThreadPoolScheduler + +# Local imports +from dimos.utils.threadpool import get_scheduler + +# Note: Logging configuration should ideally be in the application initialization, +# not in a module. Keeping it for now but with a more restricted scope. +logger = logging.getLogger(__name__) + + +# Specific exception classes +class VideoSourceError(Exception): + """Raised when there's an issue with the video source.""" + + pass + + +class VideoFrameError(Exception): + """Raised when there's an issue with frame acquisition.""" + + pass + + +class AbstractVideoProvider(ABC): + """Abstract base class for video providers managing video capture resources.""" + + def __init__( + self, dev_name: str = "NA", pool_scheduler: Optional[ThreadPoolScheduler] = None + ) -> None: + """Initializes the video provider with a device name. + + Args: + dev_name: The name of the device. Defaults to "NA". + pool_scheduler: The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + """ + self.dev_name = dev_name + self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() + self.disposables = CompositeDisposable() + + @abstractmethod + def capture_video_as_observable(self, fps: int = 30) -> Observable: + """Create an observable from video capture. + + Args: + fps: Frames per second to emit. Defaults to 30fps. + + Returns: + Observable: An observable emitting frames at the specified rate. + + Raises: + VideoSourceError: If the video source cannot be opened. + VideoFrameError: If frames cannot be read properly. + """ + pass + + def dispose_all(self) -> None: + """Disposes of all active subscriptions managed by this provider.""" + if self.disposables: + self.disposables.dispose() + else: + logger.info("No disposables to dispose.") + + def __del__(self) -> None: + """Destructor to ensure resources are cleaned up if not explicitly disposed.""" + self.dispose_all() + + +class VideoProvider(AbstractVideoProvider): + """Video provider implementation for capturing video as an observable.""" + + def __init__( + self, + dev_name: str, + video_source: str = f"{os.getcwd()}/assets/video-f30-480p.mp4", + pool_scheduler: Optional[ThreadPoolScheduler] = None, + ) -> None: + """Initializes the video provider with a device name and video source. + + Args: + dev_name: The name of the device. + video_source: The path to the video source. Defaults to a sample video. + pool_scheduler: The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + """ + super().__init__(dev_name, pool_scheduler) + self.video_source = video_source + self.cap = None + self.lock = Lock() + + def _initialize_capture(self) -> None: + """Initializes the video capture object if not already initialized. + + Raises: + VideoSourceError: If the video source cannot be opened. + """ + if self.cap is None or not self.cap.isOpened(): + # Release previous capture if it exists but is closed + if self.cap: + self.cap.release() + logger.info("Released previous capture") + + # Attempt to open new capture + self.cap = cv2.VideoCapture(self.video_source) + if self.cap is None or not self.cap.isOpened(): + error_msg = f"Failed to open video source: {self.video_source}" + logger.error(error_msg) + raise VideoSourceError(error_msg) + + logger.info(f"Opened new capture: {self.video_source}") + + def capture_video_as_observable(self, realtime: bool = True, fps: int = 30) -> Observable: + """Creates an observable from video capture. + + Creates an observable that emits frames at specified FPS or the video's + native FPS, with proper resource management and error handling. + + Args: + realtime: If True, use the video's native FPS. Defaults to True. + fps: Frames per second to emit. Defaults to 30fps. Only used if + realtime is False or the video's native FPS is not available. + + Returns: + Observable: An observable emitting frames at the configured rate. + + Raises: + VideoSourceError: If the video source cannot be opened. + VideoFrameError: If frames cannot be read properly. + """ + + def emit_frames(observer, scheduler): + try: + self._initialize_capture() + + # Determine the FPS to use based on configuration and availability + local_fps: float = fps + if realtime: + native_fps: float = self.cap.get(cv2.CAP_PROP_FPS) + if native_fps > 0: + local_fps = native_fps + else: + logger.warning("Native FPS not available, defaulting to specified FPS") + + frame_interval: float = 1.0 / local_fps + frame_time: float = time.monotonic() + + while self.cap.isOpened(): + # Thread-safe access to video capture + with self.lock: + ret, frame = self.cap.read() + + if not ret: + # Loop video when we reach the end + logger.warning("End of video reached, restarting playback") + with self.lock: + self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + continue + + # Control frame rate to match target FPS + now: float = time.monotonic() + next_frame_time: float = frame_time + frame_interval + sleep_time: float = next_frame_time - now + + if sleep_time > 0: + time.sleep(sleep_time) + + observer.on_next(frame) + frame_time = next_frame_time + + except VideoSourceError as e: + logger.error(f"Video source error: {e}") + observer.on_error(e) + except Exception as e: + logger.error(f"Unexpected error during frame emission: {e}") + observer.on_error(VideoFrameError(f"Frame acquisition failed: {e}")) + finally: + # Clean up resources regardless of success or failure + with self.lock: + if self.cap and self.cap.isOpened(): + self.cap.release() + logger.info("Capture released") + observer.on_completed() + + return rx.create(emit_frames).pipe( + ops.subscribe_on(self.pool_scheduler), + ops.observe_on(self.pool_scheduler), + ops.share(), # Share the stream among multiple subscribers + ) + + def dispose_all(self) -> None: + """Disposes of all resources including video capture.""" + with self.lock: + if self.cap and self.cap.isOpened(): + self.cap.release() + logger.info("Capture released in dispose_all") + super().dispose_all() + + def __del__(self) -> None: + """Destructor to ensure resources are cleaned up if not explicitly disposed.""" + self.dispose_all() diff --git a/build/lib/dimos/stream/video_providers/__init__.py b/build/lib/dimos/stream/video_providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/stream/video_providers/unitree.py b/build/lib/dimos/stream/video_providers/unitree.py new file mode 100644 index 0000000000..e1a7587146 --- /dev/null +++ b/build/lib/dimos/stream/video_providers/unitree.py @@ -0,0 +1,167 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.video_provider import AbstractVideoProvider + +from queue import Queue +from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod +from aiortc import MediaStreamTrack +import asyncio +from reactivex import Observable, create, operators as ops +import logging +import threading +import time + + +class UnitreeVideoProvider(AbstractVideoProvider): + def __init__( + self, + dev_name: str = "UnitreeGo2", + connection_method: WebRTCConnectionMethod = WebRTCConnectionMethod.LocalSTA, + serial_number: str = None, + ip: str = None, + ): + """Initialize the Unitree video stream with WebRTC connection. + + Args: + dev_name: Name of the device + connection_method: WebRTC connection method (LocalSTA, LocalAP, Remote) + serial_number: Serial number of the robot (required for LocalSTA with serial) + ip: IP address of the robot (required for LocalSTA with IP) + """ + super().__init__(dev_name) + self.frame_queue = Queue() + self.loop = None + self.asyncio_thread = None + + # Initialize WebRTC connection based on method + if connection_method == WebRTCConnectionMethod.LocalSTA: + if serial_number: + self.conn = Go2WebRTCConnection(connection_method, serialNumber=serial_number) + elif ip: + self.conn = Go2WebRTCConnection(connection_method, ip=ip) + else: + raise ValueError( + "Either serial_number or ip must be provided for LocalSTA connection" + ) + elif connection_method == WebRTCConnectionMethod.LocalAP: + self.conn = Go2WebRTCConnection(connection_method) + else: + raise ValueError("Unsupported connection method") + + async def _recv_camera_stream(self, track: MediaStreamTrack): + """Receive video frames from WebRTC and put them in the queue.""" + while True: + frame = await track.recv() + # Convert the frame to a NumPy array in BGR format + img = frame.to_ndarray(format="bgr24") + self.frame_queue.put(img) + + def _run_asyncio_loop(self, loop): + """Run the asyncio event loop in a separate thread.""" + asyncio.set_event_loop(loop) + + async def setup(): + try: + await self.conn.connect() + self.conn.video.switchVideoChannel(True) + self.conn.video.add_track_callback(self._recv_camera_stream) + + await self.conn.datachannel.switchToNormalMode() + # await self.conn.datachannel.sendDamp() + + # await asyncio.sleep(5) + + # await self.conn.datachannel.sendDamp() + # await asyncio.sleep(5) + # await self.conn.datachannel.sendStandUp() + # await asyncio.sleep(5) + + # Wiggle the robot + # await self.conn.datachannel.switchToNormalMode() + # await self.conn.datachannel.sendWiggle() + # await asyncio.sleep(3) + + # Stretch the robot + # await self.conn.datachannel.sendStretch() + # await asyncio.sleep(3) + + except Exception as e: + logging.error(f"Error in WebRTC connection: {e}") + raise + + loop.run_until_complete(setup()) + loop.run_forever() + + def capture_video_as_observable(self, fps: int = 30) -> Observable: + """Create an observable that emits video frames at the specified FPS. + + Args: + fps: Frames per second to emit (default: 30) + + Returns: + Observable emitting video frames + """ + frame_interval = 1.0 / fps + + def emit_frames(observer, scheduler): + try: + # Start asyncio loop if not already running + if not self.loop: + self.loop = asyncio.new_event_loop() + self.asyncio_thread = threading.Thread( + target=self._run_asyncio_loop, args=(self.loop,) + ) + self.asyncio_thread.start() + + frame_time = time.monotonic() + + while True: + if not self.frame_queue.empty(): + frame = self.frame_queue.get() + + # Control frame rate + now = time.monotonic() + next_frame_time = frame_time + frame_interval + sleep_time = next_frame_time - now + + if sleep_time > 0: + time.sleep(sleep_time) + + observer.on_next(frame) + frame_time = next_frame_time + else: + time.sleep(0.001) # Small sleep to prevent CPU overuse + + except Exception as e: + logging.error(f"Error during frame emission: {e}") + observer.on_error(e) + finally: + if self.loop: + self.loop.call_soon_threadsafe(self.loop.stop) + if self.asyncio_thread: + self.asyncio_thread.join() + observer.on_completed() + + return create(emit_frames).pipe( + ops.share() # Share the stream among multiple subscribers + ) + + def dispose_all(self): + """Clean up resources.""" + if self.loop: + self.loop.call_soon_threadsafe(self.loop.stop) + if self.asyncio_thread: + self.asyncio_thread.join() + super().dispose_all() diff --git a/build/lib/dimos/stream/videostream.py b/build/lib/dimos/stream/videostream.py new file mode 100644 index 0000000000..ee63261ae6 --- /dev/null +++ b/build/lib/dimos/stream/videostream.py @@ -0,0 +1,41 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 + + +class VideoStream: + def __init__(self, source=0): + """ + Initialize the video stream from a camera source. + + Args: + source (int or str): Camera index or video file path. + """ + self.capture = cv2.VideoCapture(source) + if not self.capture.isOpened(): + raise ValueError(f"Unable to open video source {source}") + + def __iter__(self): + return self + + def __next__(self): + ret, frame = self.capture.read() + if not ret: + self.capture.release() + raise StopIteration + return frame + + def release(self): + self.capture.release() diff --git a/build/lib/dimos/types/__init__.py b/build/lib/dimos/types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/types/constants.py b/build/lib/dimos/types/constants.py new file mode 100644 index 0000000000..91841e8bef --- /dev/null +++ b/build/lib/dimos/types/constants.py @@ -0,0 +1,24 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Colors: + GREEN_PRINT_COLOR: str = "\033[32m" + YELLOW_PRINT_COLOR: str = "\033[33m" + RED_PRINT_COLOR: str = "\033[31m" + BLUE_PRINT_COLOR: str = "\033[34m" + MAGENTA_PRINT_COLOR: str = "\033[35m" + CYAN_PRINT_COLOR: str = "\033[36m" + WHITE_PRINT_COLOR: str = "\033[37m" + RESET_COLOR: str = "\033[0m" diff --git a/build/lib/dimos/types/costmap.py b/build/lib/dimos/types/costmap.py new file mode 100644 index 0000000000..2d9b1c433e --- /dev/null +++ b/build/lib/dimos/types/costmap.py @@ -0,0 +1,534 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import pickle +import math +import numpy as np +from typing import Optional +from scipy import ndimage +from dimos.types.ros_polyfill import OccupancyGrid +from scipy.ndimage import binary_dilation +from dimos.types.vector import Vector, VectorLike, x, y, to_vector +import open3d as o3d +from matplotlib.path import Path +from PIL import Image +import cv2 + +DTYPE2STR = { + np.float32: "f32", + np.float64: "f64", + np.int32: "i32", + np.int8: "i8", +} + +STR2DTYPE = {v: k for k, v in DTYPE2STR.items()} + + +class CostValues: + """Standard cost values for occupancy grid cells.""" + + FREE = 0 # Free space + UNKNOWN = -1 # Unknown space + OCCUPIED = 100 # Occupied/lethal space + + +def encode_ndarray(arr: np.ndarray, compress: bool = False): + arr_c = np.ascontiguousarray(arr) + payload = arr_c.tobytes() + b64 = base64.b64encode(payload).decode("ascii") + + return { + "type": "grid", + "shape": arr_c.shape, + "dtype": DTYPE2STR[arr_c.dtype.type], + "data": b64, + } + + +class Costmap: + """Class to hold ROS OccupancyGrid data.""" + + def __init__( + self, + grid: np.ndarray, + origin: VectorLike, + origin_theta: float = 0, + resolution: float = 0.05, + ): + """Initialize Costmap with its core attributes.""" + self.grid = grid + self.resolution = resolution + self.origin = to_vector(origin).to_2d() + self.origin_theta = origin_theta + self.width = self.grid.shape[1] + self.height = self.grid.shape[0] + + def serialize(self) -> tuple: + """Serialize the Costmap instance to a tuple.""" + return { + "type": "costmap", + "grid": encode_ndarray(self.grid), + "origin": self.origin.serialize(), + "resolution": self.resolution, + "origin_theta": self.origin_theta, + } + + @classmethod + def from_msg(cls, costmap_msg: OccupancyGrid) -> "Costmap": + """Create a Costmap instance from a ROS OccupancyGrid message.""" + if costmap_msg is None: + raise Exception("need costmap msg") + + # Extract info from the message + width = costmap_msg.info.width + height = costmap_msg.info.height + resolution = costmap_msg.info.resolution + + # Get origin position as a vector-like object + origin = ( + costmap_msg.info.origin.position.x, + costmap_msg.info.origin.position.y, + ) + + # Calculate orientation from quaternion + qx = costmap_msg.info.origin.orientation.x + qy = costmap_msg.info.origin.orientation.y + qz = costmap_msg.info.origin.orientation.z + qw = costmap_msg.info.origin.orientation.w + origin_theta = math.atan2(2.0 * (qw * qz + qx * qy), 1.0 - 2.0 * (qy * qy + qz * qz)) + + # Convert to numpy array + data = np.array(costmap_msg.data, dtype=np.int8) + grid = data.reshape((height, width)) + + return cls( + grid=grid, + resolution=resolution, + origin=origin, + origin_theta=origin_theta, + ) + + def save_pickle(self, pickle_path: str): + """Save costmap to a pickle file. + + Args: + pickle_path: Path to save the pickle file + """ + with open(pickle_path, "wb") as f: + pickle.dump(self, f) + + @classmethod + def from_pickle(cls, pickle_path: str) -> "Costmap": + """Load costmap from a pickle file containing either a Costmap object or constructor arguments.""" + with open(pickle_path, "rb") as f: + data = pickle.load(f) + + # Check if data is already a Costmap object + if isinstance(data, cls): + return data + else: + # Assume it's constructor arguments + costmap = cls(*data) + return costmap + + @classmethod + def create_empty( + cls, width: int = 100, height: int = 100, resolution: float = 0.1 + ) -> "Costmap": + """Create an empty costmap with specified dimensions.""" + return cls( + grid=np.zeros((height, width), dtype=np.int8), + resolution=resolution, + origin=(0.0, 0.0), + origin_theta=0.0, + ) + + def world_to_grid(self, point: VectorLike) -> Vector: + """Convert world coordinates to grid coordinates. + + Args: + point: A vector-like object containing X,Y coordinates + + Returns: + Tuple of (grid_x, grid_y) as integers + """ + return (to_vector(point) - self.origin) / self.resolution + + def grid_to_world(self, grid_point: VectorLike) -> Vector: + return to_vector(grid_point) * self.resolution + self.origin + + def is_occupied(self, point: VectorLike, threshold: int = 50) -> bool: + """Check if a position in world coordinates is occupied. + + Args: + point: Vector-like object containing X,Y coordinates + threshold: Cost threshold above which a cell is considered occupied (0-100) + + Returns: + True if position is occupied or out of bounds, False otherwise + """ + grid_point = self.world_to_grid(point) + grid_x, grid_y = int(grid_point.x), int(grid_point.y) + if 0 <= grid_x < self.width and 0 <= grid_y < self.height: + # Consider unknown (-1) as unoccupied for navigation purposes + value = self.grid[grid_y, grid_x] + return value >= threshold + return True # Consider out-of-bounds as occupied + + def get_value(self, point: VectorLike) -> Optional[int]: + point = self.world_to_grid(point) + + if 0 <= point.x < self.width and 0 <= point.y < self.height: + return int(self.grid[int(point.y), int(point.x)]) + return None + + def set_value(self, point: VectorLike, value: int = 0) -> bool: + point = self.world_to_grid(point) + + if 0 <= point.x < self.width and 0 <= point.y < self.height: + self.grid[int(point.y), int(point.x)] = value + return value + return False + + def smudge( + self, + kernel_size: int = 7, + iterations: int = 25, + decay_factor: float = 0.9, + threshold: int = 90, + preserve_unknown: bool = False, + ) -> "Costmap": + """ + Creates a new costmap with expanded obstacles (smudged). + + Args: + kernel_size: Size of the convolution kernel for dilation (must be odd) + iterations: Number of dilation iterations + decay_factor: Factor to reduce cost as distance increases (0.0-1.0) + threshold: Minimum cost value to consider as an obstacle for expansion + preserve_unknown: Whether to keep unknown (-1) cells as unknown + + Returns: + A new Costmap instance with expanded obstacles + """ + # Make sure kernel size is odd + if kernel_size % 2 == 0: + kernel_size += 1 + + # Create a copy of the grid for processing + grid_copy = self.grid.copy() + + # Create a mask of unknown cells if needed + unknown_mask = None + if preserve_unknown: + unknown_mask = grid_copy == -1 + # Temporarily replace unknown cells with 0 for processing + # This allows smudging to go over unknown areas + grid_copy[unknown_mask] = 0 + + # Create a mask of cells that are above the threshold + obstacle_mask = grid_copy >= threshold + + # Create a binary map of obstacles + binary_map = obstacle_mask.astype(np.uint8) * 100 + + # Create a circular kernel for dilation (instead of square) + y, x = np.ogrid[ + -kernel_size // 2 : kernel_size // 2 + 1, + -kernel_size // 2 : kernel_size // 2 + 1, + ] + kernel = (x * x + y * y <= (kernel_size // 2) * (kernel_size // 2)).astype(np.uint8) + + # Create distance map using dilation + # Each iteration adds one 'ring' of cells around obstacles + dilated_map = binary_map.copy() + + # Store each layer of dilation with decreasing values + layers = [] + + # First layer is the original obstacle cells + layers.append(binary_map.copy()) + + for i in range(iterations): + # Dilate the binary map + dilated = ndimage.binary_dilation( + dilated_map > 0, structure=kernel, iterations=1 + ).astype(np.uint8) + + # Calculate the new layer (cells that were just added in this iteration) + new_layer = (dilated - (dilated_map > 0).astype(np.uint8)) * 100 + + # Apply decay factor based on distance from obstacle + new_layer = new_layer * (decay_factor ** (i + 1)) + + # Add to layers list + layers.append(new_layer) + + # Update dilated map for next iteration + dilated_map = dilated * 100 + + # Combine all layers to create a distance-based cost map + smudged_map = np.zeros_like(grid_copy) + for layer in layers: + # For each cell, keep the maximum value across all layers + smudged_map = np.maximum(smudged_map, layer) + + # Preserve original obstacles + smudged_map[obstacle_mask] = grid_copy[obstacle_mask] + + # When preserve_unknown is true, restore all original unknown cells + # This overlays unknown cells on top of the smudged map + if preserve_unknown and unknown_mask is not None: + smudged_map[unknown_mask] = -1 + + # Ensure cost values are in valid range (0-100) except for unknown (-1) + if preserve_unknown: + valid_cells = ~unknown_mask + smudged_map[valid_cells] = np.clip(smudged_map[valid_cells], 0, 100) + else: + smudged_map = np.clip(smudged_map, 0, 100) + + # Create a new costmap with the smudged grid + return Costmap( + grid=smudged_map.astype(np.int8), + resolution=self.resolution, + origin=self.origin, + origin_theta=self.origin_theta, + ) + + def subsample(self, subsample_factor: int = 2) -> "Costmap": + """ + Create a subsampled (lower resolution) version of the costmap. + + Args: + subsample_factor: Factor by which to reduce resolution (e.g., 2 = half resolution, 4 = quarter resolution) + + Returns: + New Costmap instance with reduced resolution + """ + if subsample_factor <= 1: + return self # No subsampling needed + + # Calculate new grid dimensions + new_height = self.height // subsample_factor + new_width = self.width // subsample_factor + + # Create new grid by subsampling + subsampled_grid = np.zeros((new_height, new_width), dtype=self.grid.dtype) + + # Sample every subsample_factor-th point + for i in range(new_height): + for j in range(new_width): + orig_i = i * subsample_factor + orig_j = j * subsample_factor + + # Take a small neighborhood and use the most conservative value + # (prioritize occupied > unknown > free for safety) + neighborhood = self.grid[ + orig_i : min(orig_i + subsample_factor, self.height), + orig_j : min(orig_j + subsample_factor, self.width), + ] + + # Priority: Occupied (100) > Unknown (-1) > Free (0) + if np.any(neighborhood == CostValues.OCCUPIED): + subsampled_grid[i, j] = CostValues.OCCUPIED + elif np.any(neighborhood == CostValues.UNKNOWN): + subsampled_grid[i, j] = CostValues.UNKNOWN + else: + subsampled_grid[i, j] = CostValues.FREE + + # Create new costmap with adjusted resolution and origin + new_resolution = self.resolution * subsample_factor + + return Costmap( + grid=subsampled_grid, + resolution=new_resolution, + origin=self.origin, # Origin stays the same + ) + + @property + def total_cells(self) -> int: + return self.width * self.height + + @property + def occupied_cells(self) -> int: + return np.sum(self.grid >= 0.1) + + @property + def unknown_cells(self) -> int: + return np.sum(self.grid == -1) + + @property + def free_cells(self) -> int: + return self.total_cells - self.occupied_cells - self.unknown_cells + + @property + def free_percent(self) -> float: + return (self.free_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 + + @property + def occupied_percent(self) -> float: + return (self.occupied_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 + + @property + def unknown_percent(self) -> float: + return (self.unknown_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 + + def __str__(self) -> str: + """ + Create a string representation of the Costmap. + + Returns: + A formatted string with key costmap information + """ + + cell_info = [ + "▦ Costmap", + f"{self.width}x{self.height}", + f"({self.width * self.resolution:.1f}x{self.height * self.resolution:.1f}m @", + f"{1 / self.resolution:.0f}cm res)", + f"Origin: ({x(self.origin):.2f}, {y(self.origin):.2f})", + f"▣ {self.occupied_percent:.1f}%", + f"□ {self.free_percent:.1f}%", + f"◌ {self.unknown_percent:.1f}%", + ] + + return " ".join(cell_info) + + def costmap_to_image(self, image_path: str) -> None: + """ + Convert costmap to JPEG image with ROS-style coloring. + Free space: light grey, Obstacles: black, Unknown: dark gray + + Args: + image_path: Path to save the JPEG image + """ + # Create image array (height, width, 3 for RGB) + img_array = np.zeros((self.height, self.width, 3), dtype=np.uint8) + + # Apply ROS-style coloring based on costmap values + for i in range(self.height): + for j in range(self.width): + value = self.grid[i, j] + if value == CostValues.FREE: # Free space = light grey (205, 205, 205) + img_array[i, j] = [205, 205, 205] + elif value == CostValues.UNKNOWN: # Unknown = dark gray (128, 128, 128) + img_array[i, j] = [128, 128, 128] + elif value >= CostValues.OCCUPIED: # Occupied/obstacles = black (0, 0, 0) + img_array[i, j] = [0, 0, 0] + else: # Any other values (low cost) = light grey + img_array[i, j] = [205, 205, 205] + + # Flip vertically to match ROS convention (origin at bottom-left) + img_array = np.flipud(img_array) + + # Create PIL image and save as JPEG + img = Image.fromarray(img_array, "RGB") + img.save(image_path, "JPEG", quality=95) + print(f"Costmap image saved to: {image_path}") + + +def _inflate_lethal(costmap: np.ndarray, radius: int, lethal_val: int = 100) -> np.ndarray: + """Return *costmap* with lethal cells dilated by *radius* grid steps (circular).""" + if radius <= 0 or not np.any(costmap == lethal_val): + return costmap + + mask = costmap == lethal_val + dilated = mask.copy() + for dy in range(-radius, radius + 1): + for dx in range(-radius, radius + 1): + if dx * dx + dy * dy > radius * radius or (dx == 0 and dy == 0): + continue + dilated |= np.roll(mask, shift=(dy, dx), axis=(0, 1)) + + out = costmap.copy() + out[dilated] = lethal_val + return out + + +def pointcloud_to_costmap( + pcd: o3d.geometry.PointCloud, + *, + resolution: float = 0.05, + ground_z: float = 0.0, + obs_min_height: float = 0.15, + max_height: Optional[float] = 0.5, + inflate_radius_m: Optional[float] = None, + default_unknown: int = -1, + cost_free: int = 0, + cost_lethal: int = 100, +) -> tuple[np.ndarray, np.ndarray]: + """Rasterise *pcd* into a 2-D int8 cost-map with optional obstacle inflation. + + Grid origin is **aligned** to the `resolution` lattice so that when + `resolution == voxel_size` every voxel centroid lands squarely inside a cell + (no alternating blank lines). + """ + + pts = np.asarray(pcd.points, dtype=np.float32) + if pts.size == 0: + return np.full((1, 1), default_unknown, np.int8), np.zeros(2, np.float32) + + # 0. Ceiling filter -------------------------------------------------------- + if max_height is not None: + pts = pts[pts[:, 2] <= max_height] + if pts.size == 0: + return np.full((1, 1), default_unknown, np.int8), np.zeros(2, np.float32) + + # 1. Bounding box & aligned origin --------------------------------------- + xy_min = pts[:, :2].min(axis=0) + xy_max = pts[:, :2].max(axis=0) + + # Align origin to the resolution grid (anchor = 0,0) + origin = np.floor(xy_min / resolution) * resolution + + # Grid dimensions (inclusive) ------------------------------------------- + Nx, Ny = (np.ceil((xy_max - origin) / resolution).astype(int) + 1).tolist() + + # 2. Bin points ------------------------------------------------------------ + idx_xy = np.floor((pts[:, :2] - origin) / resolution).astype(np.int32) + np.clip(idx_xy[:, 0], 0, Nx - 1, out=idx_xy[:, 0]) + np.clip(idx_xy[:, 1], 0, Ny - 1, out=idx_xy[:, 1]) + + lin = idx_xy[:, 1] * Nx + idx_xy[:, 0] + z_max = np.full(Nx * Ny, -np.inf, np.float32) + np.maximum.at(z_max, lin, pts[:, 2]) + z_max = z_max.reshape(Ny, Nx) + + # 3. Cost rules ----------------------------------------------------------- + costmap = np.full_like(z_max, default_unknown, np.int8) + known = z_max != -np.inf + costmap[known] = cost_free + + lethal = z_max >= (ground_z + obs_min_height) + costmap[lethal] = cost_lethal + + # 4. Optional inflation ---------------------------------------------------- + if inflate_radius_m and inflate_radius_m > 0: + cells = int(np.ceil(inflate_radius_m / resolution)) + costmap = _inflate_lethal(costmap, cells, lethal_val=cost_lethal) + + return costmap, origin.astype(np.float32) + + +if __name__ == "__main__": + costmap = Costmap.from_pickle("costmapMsg.pickle") + print(costmap) + + # Create a smudged version of the costmap for better planning + smudged_costmap = costmap.smudge( + kernel_size=10, iterations=10, threshold=80, preserve_unknown=False + ) + + print(costmap) diff --git a/build/lib/dimos/types/label.py b/build/lib/dimos/types/label.py new file mode 100644 index 0000000000..ce037aed7a --- /dev/null +++ b/build/lib/dimos/types/label.py @@ -0,0 +1,39 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any + + +class LabelType: + def __init__(self, labels: Dict[str, Any], metadata: Any = None): + """ + Initializes a standardized label type. + + Args: + labels (Dict[str, Any]): A dictionary of labels with descriptions. + metadata (Any, optional): Additional metadata related to the labels. + """ + self.labels = labels + self.metadata = metadata + + def get_label_descriptions(self): + """Return a list of label descriptions.""" + return [desc["description"] for desc in self.labels.values()] + + def save_to_json(self, filepath: str): + """Save the labels to a JSON file.""" + import json + + with open(filepath, "w") as f: + json.dump(self.labels, f, indent=4) diff --git a/build/lib/dimos/types/manipulation.py b/build/lib/dimos/types/manipulation.py new file mode 100644 index 0000000000..d61d73a7ed --- /dev/null +++ b/build/lib/dimos/types/manipulation.py @@ -0,0 +1,155 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Dict, List, Optional, Any, Union, TypedDict, Tuple, Literal +from dataclasses import dataclass, field, fields +from abc import ABC, abstractmethod +import uuid +import numpy as np +import time +from dimos.types.vector import Vector + + +class ConstraintType(Enum): + """Types of manipulation constraints.""" + + TRANSLATION = "translation" + ROTATION = "rotation" + FORCE = "force" + + +@dataclass +class AbstractConstraint(ABC): + """Base class for all manipulation constraints.""" + + description: str = "" + id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) + + +@dataclass +class TranslationConstraint(AbstractConstraint): + """Constraint parameters for translational movement along a single axis.""" + + translation_axis: Literal["x", "y", "z"] = None # Axis to translate along + reference_point: Optional[Vector] = None + bounds_min: Optional[Vector] = None # For bounded translation + bounds_max: Optional[Vector] = None # For bounded translation + target_point: Optional[Vector] = None # For relative positioning + + +@dataclass +class RotationConstraint(AbstractConstraint): + """Constraint parameters for rotational movement around a single axis.""" + + rotation_axis: Literal["roll", "pitch", "yaw"] = None # Axis to rotate around + start_angle: Optional[Vector] = None # Angle values applied to the specified rotation axis + end_angle: Optional[Vector] = None # Angle values applied to the specified rotation axis + pivot_point: Optional[Vector] = None # Point of rotation + secondary_pivot_point: Optional[Vector] = None # For double point rotations + + +@dataclass +class ForceConstraint(AbstractConstraint): + """Constraint parameters for force application.""" + + max_force: float = 0.0 # Maximum force in newtons + min_force: float = 0.0 # Minimum force in newtons + force_direction: Optional[Vector] = None # Direction of force application + + +class ObjectData(TypedDict, total=False): + """Data about an object in the manipulation scene.""" + + object_id: int # Unique ID for the object + bbox: List[float] # Bounding box [x1, y1, x2, y2] + depth: float # Depth in meters from Metric3d + confidence: float # Detection confidence + class_id: int # Class ID from the detector + label: str # Semantic label (e.g., 'cup', 'table') + movement_tolerance: float # (0.0 = immovable, 1.0 = freely movable) + segmentation_mask: np.ndarray # Binary mask of the object's pixels + position: Dict[str, float] # 3D position {x, y, z} + rotation: Dict[str, float] # 3D rotation {roll, pitch, yaw} + size: Dict[str, float] # Object dimensions {width, height} + + +class ManipulationMetadata(TypedDict, total=False): + """Typed metadata for manipulation constraints.""" + + timestamp: float + objects: Dict[str, ObjectData] + + +@dataclass +class ManipulationTaskConstraint: + """Set of constraints for a specific manipulation action.""" + + constraints: List[AbstractConstraint] = field(default_factory=list) + + def add_constraint(self, constraint: AbstractConstraint): + """Add a constraint to this set.""" + if constraint not in self.constraints: + self.constraints.append(constraint) + + def get_constraints(self) -> List[AbstractConstraint]: + """Get all constraints in this set.""" + return self.constraints + + +@dataclass +class ManipulationTask: + """Complete definition of a manipulation task.""" + + description: str + target_object: str # Semantic label of target object + target_point: Optional[Tuple[float, float]] = ( + None # (X,Y) point in pixel-space of the point to manipulate on target object + ) + metadata: ManipulationMetadata = field(default_factory=dict) + timestamp: float = field(default_factory=time.time) + task_id: str = "" + result: Optional[Dict[str, Any]] = None # Any result data from the task execution + constraints: Union[List[AbstractConstraint], ManipulationTaskConstraint, AbstractConstraint] = ( + field(default_factory=list) + ) + + def add_constraint(self, constraint: AbstractConstraint): + """Add a constraint to this manipulation task.""" + # If constraints is a ManipulationTaskConstraint object + if isinstance(self.constraints, ManipulationTaskConstraint): + self.constraints.add_constraint(constraint) + return + + # If constraints is a single AbstractConstraint, convert to list + if isinstance(self.constraints, AbstractConstraint): + self.constraints = [self.constraints, constraint] + return + + # If constraints is a list, append to it + # This will also handle empty lists (the default case) + self.constraints.append(constraint) + + def get_constraints(self) -> List[AbstractConstraint]: + """Get all constraints in this manipulation task.""" + # If constraints is a ManipulationTaskConstraint object + if isinstance(self.constraints, ManipulationTaskConstraint): + return self.constraints.get_constraints() + + # If constraints is a single AbstractConstraint, return as list + if isinstance(self.constraints, AbstractConstraint): + return [self.constraints] + + # If constraints is a list (including empty list), return it + return self.constraints diff --git a/build/lib/dimos/types/path.py b/build/lib/dimos/types/path.py new file mode 100644 index 0000000000..c87658182f --- /dev/null +++ b/build/lib/dimos/types/path.py @@ -0,0 +1,414 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import List, Union, Tuple, Iterator, TypeVar +from dimos.types.vector import Vector + +T = TypeVar("T", bound="Path") + + +class Path: + """A class representing a path as a sequence of points.""" + + def __init__( + self, + points: Union[List[Vector], List[np.ndarray], List[Tuple], np.ndarray, None] = None, + ): + """Initialize a path from a list of points. + + Args: + points: List of Vector objects, numpy arrays, tuples, or a 2D numpy array where each row is a point. + If None, creates an empty path. + + Examples: + Path([Vector(1, 2), Vector(3, 4)]) # from Vector objects + Path([(1, 2), (3, 4)]) # from tuples + Path(np.array([[1, 2], [3, 4]])) # from 2D numpy array + """ + if points is None: + self._points = np.zeros((0, 0), dtype=float) + return + + if isinstance(points, np.ndarray) and points.ndim == 2: + # If already a 2D numpy array, use it directly + self._points = points.astype(float) + else: + # Convert various input types to numpy array + converted = [] + for p in points: + if isinstance(p, Vector): + converted.append(p.data) + else: + converted.append(p) + self._points = np.array(converted, dtype=float) + + def serialize(self) -> Tuple: + """Serialize the vector to a tuple.""" + return { + "type": "path", + "points": self._points.tolist(), + } + + @property + def points(self) -> np.ndarray: + """Get the path points as a numpy array.""" + return self._points + + def as_vectors(self) -> List[Vector]: + """Get the path points as Vector objects.""" + return [Vector(p) for p in self._points] + + def append(self, point: Union[Vector, np.ndarray, Tuple]) -> None: + """Append a point to the path. + + Args: + point: A Vector, numpy array, or tuple representing a point + """ + if isinstance(point, Vector): + point_data = point.data + else: + point_data = np.array(point, dtype=float) + + if len(self._points) == 0: + # If empty, create with correct dimensionality + self._points = np.array([point_data]) + else: + self._points = np.vstack((self._points, point_data)) + + def extend(self, points: Union[List[Vector], List[np.ndarray], List[Tuple], "Path"]) -> None: + """Extend the path with more points. + + Args: + points: List of points or another Path object + """ + if isinstance(points, Path): + if len(self._points) == 0: + self._points = points.points.copy() + else: + self._points = np.vstack((self._points, points.points)) + else: + for point in points: + self.append(point) + + def insert(self, index: int, point: Union[Vector, np.ndarray, Tuple]) -> None: + """Insert a point at a specific index. + + Args: + index: The index at which to insert the point + point: A Vector, numpy array, or tuple representing a point + """ + if isinstance(point, Vector): + point_data = point.data + else: + point_data = np.array(point, dtype=float) + + if len(self._points) == 0: + self._points = np.array([point_data]) + else: + self._points = np.insert(self._points, index, point_data, axis=0) + + def remove(self, index: int) -> np.ndarray: + """Remove and return the point at the given index. + + Args: + index: The index of the point to remove + + Returns: + The removed point as a numpy array + """ + point = self._points[index].copy() + self._points = np.delete(self._points, index, axis=0) + return point + + def clear(self) -> None: + """Remove all points from the path.""" + self._points = np.zeros( + (0, self._points.shape[1] if len(self._points) > 0 else 0), dtype=float + ) + + def length(self) -> float: + """Calculate the total length of the path. + + Returns: + The sum of the distances between consecutive points + """ + if len(self._points) < 2: + return 0.0 + + # Efficient vector calculation of consecutive point distances + diff = self._points[1:] - self._points[:-1] + segment_lengths = np.sqrt(np.sum(diff * diff, axis=1)) + return float(np.sum(segment_lengths)) + + def resample(self: T, point_spacing: float) -> T: + """Resample the path with approximately uniform spacing between points. + + Args: + point_spacing: The desired distance between consecutive points + + Returns: + A new Path object with resampled points + """ + if len(self._points) < 2 or point_spacing <= 0: + return self.__class__(self._points.copy()) + + resampled_points = [self._points[0].copy()] + accumulated_distance = 0.0 + + for i in range(1, len(self._points)): + current_point = self._points[i] + prev_point = self._points[i - 1] + segment_vector = current_point - prev_point + segment_length = np.linalg.norm(segment_vector) + + if segment_length < 1e-10: + continue + + direction = segment_vector / segment_length + + # Add points along this segment until we reach the end + while accumulated_distance + segment_length >= point_spacing: + # How far along this segment the next point should be + dist_along_segment = point_spacing - accumulated_distance + if dist_along_segment < 0: + break + + # Create the new point + new_point = prev_point + direction * dist_along_segment + resampled_points.append(new_point) + + # Update for next iteration + accumulated_distance = 0 + segment_length -= dist_along_segment + prev_point = new_point + + # Update the accumulated distance for the next segment + accumulated_distance += segment_length + + # Add the last point if it's not already there + if len(self._points) > 1: + last_point = self._points[-1] + if not np.array_equal(resampled_points[-1], last_point): + resampled_points.append(last_point.copy()) + + return self.__class__(np.array(resampled_points)) + + def simplify(self: T, tolerance: float) -> T: + """Simplify the path using the Ramer-Douglas-Peucker algorithm. + + Args: + tolerance: The maximum distance a point can deviate from the simplified path + + Returns: + A new simplified Path object + """ + if len(self._points) <= 2: + return self.__class__(self._points.copy()) + + # Implementation of Ramer-Douglas-Peucker algorithm + def rdp(points, epsilon, start, end): + if end <= start + 1: + return [start] + + # Find point with max distance from line + line_vec = points[end] - points[start] + line_length = np.linalg.norm(line_vec) + + if line_length < 1e-10: # If start and end points are the same + # Distance from next point to start point + max_dist = np.linalg.norm(points[start + 1] - points[start]) + max_idx = start + 1 + else: + max_dist = 0 + max_idx = start + + for i in range(start + 1, end): + # Distance from point to line + p_vec = points[i] - points[start] + + # Project p_vec onto line_vec + proj_scalar = np.dot(p_vec, line_vec) / (line_length * line_length) + proj = points[start] + proj_scalar * line_vec + + # Calculate perpendicular distance + dist = np.linalg.norm(points[i] - proj) + + if dist > max_dist: + max_dist = dist + max_idx = i + + # Recursive call + result = [] + if max_dist > epsilon: + result_left = rdp(points, epsilon, start, max_idx) + result_right = rdp(points, epsilon, max_idx, end) + result = result_left + result_right[1:] + else: + result = [start, end] + + return result + + indices = rdp(self._points, tolerance, 0, len(self._points) - 1) + indices.append(len(self._points) - 1) # Make sure the last point is included + indices = sorted(set(indices)) # Remove duplicates and sort + + return self.__class__(self._points[indices]) + + def smooth(self: T, weight: float = 0.5, iterations: int = 1) -> T: + """Smooth the path using a moving average filter. + + Args: + weight: How much to weight the neighboring points (0-1) + iterations: Number of smoothing passes to apply + + Returns: + A new smoothed Path object + """ + if len(self._points) <= 2 or weight <= 0 or iterations <= 0: + return self.__class__(self._points.copy()) + + smoothed_points = self._points.copy() + + for _ in range(iterations): + new_points = np.zeros_like(smoothed_points) + new_points[0] = smoothed_points[0] # Keep first point unchanged + + # Apply weighted average to middle points + for i in range(1, len(smoothed_points) - 1): + neighbor_avg = 0.5 * (smoothed_points[i - 1] + smoothed_points[i + 1]) + new_points[i] = (1 - weight) * smoothed_points[i] + weight * neighbor_avg + + new_points[-1] = smoothed_points[-1] # Keep last point unchanged + smoothed_points = new_points + + return self.__class__(smoothed_points) + + def nearest_point_index(self, point: Union[Vector, np.ndarray, Tuple]) -> int: + """Find the index of the closest point on the path to the given point. + + Args: + point: The reference point + + Returns: + Index of the closest point on the path + """ + if len(self._points) == 0: + raise ValueError("Cannot find nearest point in an empty path") + + if isinstance(point, Vector): + point_data = point.data + else: + point_data = np.array(point, dtype=float) + + # Calculate squared distances to all points + diff = self._points - point_data + sq_distances = np.sum(diff * diff, axis=1) + + # Return index of minimum distance + return int(np.argmin(sq_distances)) + + def reverse(self: T) -> T: + """Reverse the path direction. + + Returns: + A new Path object with points in reverse order + """ + return self.__class__(self._points[::-1].copy()) + + def __len__(self) -> int: + """Return the number of points in the path.""" + return len(self._points) + + def __getitem__(self, idx) -> Union[np.ndarray, "Path"]: + """Get a point or slice of points from the path.""" + if isinstance(idx, slice): + return self.__class__(self._points[idx]) + return self._points[idx].copy() + + def get_vector(self, idx: int) -> Vector: + """Get a point at the given index as a Vector object.""" + return Vector(self._points[idx]) + + def last(self) -> Vector: + """Get the first point in the path as a Vector object.""" + if len(self._points) == 0: + return None + return Vector(self._points[-1]) + + def head(self) -> Vector: + """Get the first point in the path as a Vector object.""" + if len(self._points) == 0: + return None + return Vector(self._points[0]) + + def tail(self) -> "Path": + """Get all points except the first point as a new Path object.""" + if len(self._points) <= 1: + return None + return self.__class__(self._points[1:].copy()) + + def __iter__(self) -> Iterator[np.ndarray]: + """Iterate over the points in the path.""" + for point in self._points: + yield point.copy() + + def __repr__(self) -> str: + """String representation of the path.""" + return f"↶ Path ({len(self._points)} Points)" + + def ipush(self, point: Union[Vector, np.ndarray, Tuple]) -> "Path": + """Return a new Path with `point` appended.""" + if isinstance(point, Vector): + p = point.data + else: + p = np.asarray(point, dtype=float) + + if len(self._points) == 0: + new_pts = p.reshape(1, -1) + else: + new_pts = np.vstack((self._points, p)) + return self.__class__(new_pts) + + def iclip_tail(self, max_len: int) -> "Path": + """Return a new Path containing only the last `max_len` points.""" + if max_len < 0: + raise ValueError("max_len must be ≥ 0") + return self.__class__(self._points[-max_len:]) + + def __add__(self, point): + """path + vec -> path.pushed(vec)""" + return self.pushed(point) + + +if __name__ == "__main__": + # Test vectors in various directions + print( + Path( + [ + Vector(1, 0), # Right + Vector(1, 1), # Up-Right + Vector(0, 1), # Up + Vector(-1, 1), # Up-Left + Vector(-1, 0), # Left + Vector(-1, -1), # Down-Left + Vector(0, -1), # Down + Vector(1, -1), # Down-Right + Vector(0.5, 0.5), # Up-Right (shorter) + Vector(-3, 4), # Up-Left (longer) + ] + ) + ) + + print(Path()) diff --git a/build/lib/dimos/types/pose.py b/build/lib/dimos/types/pose.py new file mode 100644 index 0000000000..455f22c189 --- /dev/null +++ b/build/lib/dimos/types/pose.py @@ -0,0 +1,149 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypeVar, Union, Sequence +import numpy as np +from plum import dispatch +import math + +from dimos.types.vector import Vector, to_vector, to_numpy, VectorLike + + +T = TypeVar("T", bound="Pose") + +PoseLike = Union["Pose", VectorLike, Sequence[VectorLike]] + + +def yaw_to_matrix(yaw: float) -> np.ndarray: + """2-D yaw (about Z) to a 3×3 rotation matrix.""" + c, s = math.cos(yaw), math.sin(yaw) + return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) + + +class Pose(Vector): + """A pose in 3D space, consisting of a position vector and a rotation vector. + + Pose inherits from Vector and behaves like a vector for the position component. + The rotation vector is stored separately and can be accessed via the rot property. + """ + + _rot: Vector = None + + @dispatch + def __init__(self, pos: VectorLike): + self._data = to_numpy(pos) + self._rot = None + + @dispatch + def __init__(self, pos: VectorLike, rot: VectorLike): + self._data = to_numpy(pos) + self._rot = to_vector(rot) + + def __repr__(self) -> str: + return f"Pose({self.pos.__repr__()}, {self.rot.__repr__()})" + + def __str__(self) -> str: + return self.__repr__() + + def is_zero(self) -> bool: + return super().is_zero() and self.rot.is_zero() + + def __bool__(self) -> bool: + return not self.is_zero() + + def serialize(self): + """Serialize the pose to a dictionary.""" + return {"type": "pose", "pos": self.to_list(), "rot": self.rot.to_list()} + + def vector_to(self, target: Vector) -> Vector: + direction = target - self.pos.to_2d() + + cos_y = math.cos(-self.yaw) + sin_y = math.sin(-self.yaw) + + x = cos_y * direction.x - sin_y * direction.y + y = sin_y * direction.x + cos_y * direction.y + + return Vector(x, y) + + def __eq__(self, other) -> bool: + """Check if two poses are equal using numpy's allclose for floating point comparison.""" + if not isinstance(other, Pose): + return False + return np.allclose(self.pos._data, other.pos._data) and np.allclose( + self.rot._data, other.rot._data + ) + + @property + def rot(self) -> Vector: + if self._rot: + return self._rot + else: + return Vector(0, 0, 0) + + @property + def pos(self) -> Vector: + """Get the position vector (self).""" + return to_vector(self._data) + + def __add__(self: T, other) -> T: + """Override Vector's __add__ to handle Pose objects specially. + + When adding two Pose objects, both position and rotation components are added. + """ + if isinstance(other, Pose): + # Add both position and rotation components + result = super().__add__(other) + result._rot = self.rot + other.rot + return result + else: + # For other types, just use Vector's addition + return Pose(super().__add__(other), self.rot) + + @property + def yaw(self) -> float: + """Get the yaw (rotation around Z-axis) from the rotation vector.""" + return self.rot.z + + def __sub__(self: T, other) -> T: + """Override Vector's __sub__ to handle Pose objects specially. + + When subtracting two Pose objects, both position and rotation components are subtracted. + """ + if isinstance(other, Pose): + # Subtract both position and rotation components + result = super().__sub__(other) + result._rot = self.rot - other.rot + return result + else: + # For other types, just use Vector's subtraction + return super().__sub__(other) + + def __mul__(self: T, scalar: float) -> T: + return Pose(self.pos * scalar, self.rot) + + +@dispatch +def to_pose(pos: Pose) -> Pose: + return pos + + +@dispatch +def to_pose(pos: VectorLike) -> Pose: + return Pose(pos) + + +@dispatch +def to_pose(pos: Sequence[VectorLike]) -> Pose: + return Pose(*pos) diff --git a/build/lib/dimos/types/robot_capabilities.py b/build/lib/dimos/types/robot_capabilities.py new file mode 100644 index 0000000000..8c9a7fcd41 --- /dev/null +++ b/build/lib/dimos/types/robot_capabilities.py @@ -0,0 +1,27 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Robot capabilities module for defining robot functionality.""" + +from enum import Enum, auto + + +class RobotCapability(Enum): + """Enum defining possible robot capabilities.""" + + MANIPULATION = auto() + VISION = auto() + AUDIO = auto() + SPEECH = auto() + LOCOMOTION = auto() diff --git a/build/lib/dimos/types/robot_location.py b/build/lib/dimos/types/robot_location.py new file mode 100644 index 0000000000..c69d131a04 --- /dev/null +++ b/build/lib/dimos/types/robot_location.py @@ -0,0 +1,130 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RobotLocation type definition for storing and managing robot location data. +""" + +from dataclasses import dataclass, field +from typing import Dict, Any, Optional, Tuple +import time +import uuid + + +@dataclass +class RobotLocation: + """ + Represents a named location in the robot's spatial memory. + + This class stores the position, rotation, and descriptive metadata for + locations that the robot can remember and navigate to. + + Attributes: + name: Human-readable name of the location (e.g., "kitchen", "office") + position: 3D position coordinates (x, y, z) + rotation: 3D rotation angles in radians (roll, pitch, yaw) + frame_id: ID of the associated video frame if available + timestamp: Time when the location was recorded + location_id: Unique identifier for this location + metadata: Additional metadata for the location + """ + + name: str + position: Tuple[float, float, float] + rotation: Tuple[float, float, float] + frame_id: Optional[str] = None + timestamp: float = field(default_factory=time.time) + location_id: str = field(default_factory=lambda: f"loc_{uuid.uuid4().hex[:8]}") + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """Validate and normalize the position and rotation tuples.""" + # Ensure position is a tuple of 3 floats + if len(self.position) == 2: + self.position = (self.position[0], self.position[1], 0.0) + else: + self.position = tuple(float(x) for x in self.position) + + # Ensure rotation is a tuple of 3 floats + if len(self.rotation) == 1: + self.rotation = (0.0, 0.0, self.rotation[0]) + else: + self.rotation = tuple(float(x) for x in self.rotation) + + def to_vector_metadata(self) -> Dict[str, Any]: + """ + Convert the location to metadata format for storing in a vector database. + + Returns: + Dictionary with metadata fields compatible with vector DB storage + """ + return { + "pos_x": float(self.position[0]), + "pos_y": float(self.position[1]), + "pos_z": float(self.position[2]), + "rot_x": float(self.rotation[0]), + "rot_y": float(self.rotation[1]), + "rot_z": float(self.rotation[2]), + "timestamp": self.timestamp, + "location_id": self.location_id, + "frame_id": self.frame_id, + "location_name": self.name, + "description": self.name, # Makes it searchable by text + } + + @classmethod + def from_vector_metadata(cls, metadata: Dict[str, Any]) -> "RobotLocation": + """ + Create a RobotLocation object from vector database metadata. + + Args: + metadata: Dictionary with metadata from vector database + + Returns: + RobotLocation object + """ + return cls( + name=metadata.get("location_name", "unknown"), + position=( + metadata.get("pos_x", 0.0), + metadata.get("pos_y", 0.0), + metadata.get("pos_z", 0.0), + ), + rotation=( + metadata.get("rot_x", 0.0), + metadata.get("rot_y", 0.0), + metadata.get("rot_z", 0.0), + ), + frame_id=metadata.get("frame_id"), + timestamp=metadata.get("timestamp", time.time()), + location_id=metadata.get("location_id", f"loc_{uuid.uuid4().hex[:8]}"), + metadata={ + k: v + for k, v in metadata.items() + if k + not in [ + "pos_x", + "pos_y", + "pos_z", + "rot_x", + "rot_y", + "rot_z", + "timestamp", + "location_id", + "frame_id", + "location_name", + "description", + ] + }, + ) diff --git a/build/lib/dimos/types/ros_polyfill.py b/build/lib/dimos/types/ros_polyfill.py new file mode 100644 index 0000000000..b5c2bc1d64 --- /dev/null +++ b/build/lib/dimos/types/ros_polyfill.py @@ -0,0 +1,103 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from geometry_msgs.msg import Vector3 +except ImportError: + + class Vector3: + def __init__(self, x: float = 0.0, y: float = 0.0, z: float = 0.0): + self.x = float(x) + self.y = float(y) + self.z = float(z) + + def __repr__(self) -> str: + return f"Vector3(x={self.x}, y={self.y}, z={self.z})" + + +try: + from nav_msgs.msg import OccupancyGrid, Odometry + from geometry_msgs.msg import Pose, Point, Quaternion, Twist + from std_msgs.msg import Header +except ImportError: + + class Header: + def __init__(self): + self.stamp = None + self.frame_id = "" + + class Point: + def __init__(self, x: float = 0.0, y: float = 0.0, z: float = 0.0): + self.x = float(x) + self.y = float(y) + self.z = float(z) + + def __repr__(self) -> str: + return f"Point(x={self.x}, y={self.y}, z={self.z})" + + class Quaternion: + def __init__(self, x: float = 0.0, y: float = 0.0, z: float = 0.0, w: float = 1.0): + self.x = float(x) + self.y = float(y) + self.z = float(z) + self.w = float(w) + + def __repr__(self) -> str: + return f"Quaternion(x={self.x}, y={self.y}, z={self.z}, w={self.w})" + + class Pose: + def __init__(self): + self.position = Point() + self.orientation = Quaternion() + + def __repr__(self) -> str: + return f"Pose(position={self.position}, orientation={self.orientation})" + + class MapMetaData: + def __init__(self): + self.map_load_time = None + self.resolution = 0.05 + self.width = 0 + self.height = 0 + self.origin = Pose() + + def __repr__(self) -> str: + return f"MapMetaData(resolution={self.resolution}, width={self.width}, height={self.height}, origin={self.origin})" + + class Twist: + def __init__(self): + self.linear = Vector3() + self.angular = Vector3() + + def __repr__(self) -> str: + return f"Twist(linear={self.linear}, angular={self.angular})" + + class OccupancyGrid: + def __init__(self): + self.header = Header() + self.info = MapMetaData() + self.data = [] + + def __repr__(self) -> str: + return f"OccupancyGrid(info={self.info}, data_length={len(self.data)})" + + class Odometry: + def __init__(self): + self.header = Header() + self.child_frame_id = "" + self.pose = Pose() + self.twist = Twist() + + def __repr__(self) -> str: + return f"Odometry(pose={self.pose}, twist={self.twist})" diff --git a/build/lib/dimos/types/sample.py b/build/lib/dimos/types/sample.py new file mode 100644 index 0000000000..5665f7a640 --- /dev/null +++ b/build/lib/dimos/types/sample.py @@ -0,0 +1,572 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from collections import OrderedDict +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Literal, Sequence, Union, get_origin + +import numpy as np +from datasets import Dataset +from gymnasium import spaces +from jsonref import replace_refs +from pydantic import BaseModel, ConfigDict, ValidationError +from pydantic.fields import FieldInfo +from pydantic_core import from_json +from typing_extensions import Annotated + +from mbodied.data.utils import to_features +from mbodied.utils.import_utils import smart_import + +Flattenable = Annotated[Literal["dict", "np", "pt", "list"], "Numpy, PyTorch, list, or dict"] + + +class Sample(BaseModel): + """A base model class for serializing, recording, and manipulating arbitray data. + + It was designed to be extensible, flexible, yet strongly typed. In addition to + supporting any json API out of the box, it can be used to represent + arbitrary action and observation spaces in robotics and integrates seemlessly with H5, Gym, Arrow, + PyTorch, DSPY, numpy, and HuggingFace. + + Methods: + schema: Get a simplified json schema of your data. + to: Convert the Sample instance to a different container type: + - + default_value: Get the default value for the Sample instance. + unflatten: Unflatten a one-dimensional array or dictionary into a Sample instance. + flatten: Flatten the Sample instance into a one-dimensional array or dictionary. + space_for: Default Gym space generation for a given value. + init_from: Initialize a Sample instance from a given value. + from_space: Generate a Sample instance from a Gym space. + pack_from: Pack a list of samples into a single sample with lists for attributes. + unpack: Unpack the packed Sample object into a list of Sample objects or dictionaries. + dict: Return the Sample object as a dictionary with None values excluded. + model_field_info: Get the FieldInfo for a given attribute key. + space: Return the corresponding Gym space for the Sample instance based on its instance attributes. + random_sample: Generate a random Sample instance based on its instance attributes. + + Examples: + >>> sample = Sample(x=1, y=2, z={"a": 3, "b": 4}, extra_field=5) + >>> flat_list = sample.flatten() + >>> print(flat_list) + [1, 2, 3, 4, 5] + >>> schema = sample.schema() + {'type': 'object', 'properties': {'x': {'type': 'number'}, 'y': {'type': 'number'}, 'z': {'type': 'object', 'properties': {'a': {'type': 'number'}, 'b': {'type': 'number'}}}, 'extra_field': {'type': 'number'}}} + >>> unflattened_sample = Sample.unflatten(flat_list, schema) + >>> print(unflattened_sample) + Sample(x=1, y=2, z={'a': 3, 'b': 4}, extra_field=5) + """ + + __doc__ = "A base model class for serializing, recording, and manipulating arbitray data." + + model_config: ConfigDict = ConfigDict( + use_enum_values=False, + from_attributes=True, + validate_assignment=False, + extra="allow", + arbitrary_types_allowed=True, + ) + + def __init__(self, datum=None, **data): + """Accepts an arbitrary datum as well as keyword arguments.""" + if datum is not None: + if isinstance(datum, Sample): + data.update(datum.dict()) + elif isinstance(datum, dict): + data.update(datum) + else: + data["datum"] = datum + super().__init__(**data) + + def __hash__(self) -> int: + """Return a hash of the Sample instance.""" + return hash(tuple(self.dict().values())) + + def __str__(self) -> str: + """Return a string representation of the Sample instance.""" + return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.dict().items() if v is not None])})" + + def dict(self, exclude_none=True, exclude: set[str] = None) -> Dict[str, Any]: + """Return the Sample object as a dictionary with None values excluded. + + Args: + exclude_none (bool, optional): Whether to exclude None values. Defaults to True. + exclude (set[str], optional): Set of attribute names to exclude. Defaults to None. + + Returns: + Dict[str, Any]: Dictionary representation of the Sample object. + """ + return self.model_dump(exclude_none=exclude_none, exclude=exclude) + + @classmethod + def unflatten(cls, one_d_array_or_dict, schema=None) -> "Sample": + """Unflatten a one-dimensional array or dictionary into a Sample instance. + + If a dictionary is provided, its keys are ignored. + + Args: + one_d_array_or_dict: A one-dimensional array or dictionary to unflatten. + schema: A dictionary representing the JSON schema. Defaults to using the class's schema. + + Returns: + Sample: The unflattened Sample instance. + + Examples: + >>> sample = Sample(x=1, y=2, z={"a": 3, "b": 4}, extra_field=5) + >>> flat_list = sample.flatten() + >>> print(flat_list) + [1, 2, 3, 4, 5] + >>> Sample.unflatten(flat_list, sample.schema()) + Sample(x=1, y=2, z={'a': 3, 'b': 4}, extra_field=5) + """ + if schema is None: + schema = cls().schema() + + # Convert input to list if it's not already + if isinstance(one_d_array_or_dict, dict): + flat_data = list(one_d_array_or_dict.values()) + else: + flat_data = list(one_d_array_or_dict) + + def unflatten_recursive(schema_part, index=0): + if schema_part["type"] == "object": + result = {} + for prop, prop_schema in schema_part["properties"].items(): + value, index = unflatten_recursive(prop_schema, index) + result[prop] = value + return result, index + elif schema_part["type"] == "array": + items = [] + for _ in range(schema_part.get("maxItems", len(flat_data) - index)): + value, index = unflatten_recursive(schema_part["items"], index) + items.append(value) + return items, index + else: # Assuming it's a primitive type + return flat_data[index], index + 1 + + unflattened_dict, _ = unflatten_recursive(schema) + return cls(**unflattened_dict) + + def flatten( + self, + output_type: Flattenable = "dict", + non_numerical: Literal["ignore", "forbid", "allow"] = "allow", + ) -> Dict[str, Any] | np.ndarray | "torch.Tensor" | List: + accumulator = {} if output_type == "dict" else [] + + def flatten_recursive(obj, path=""): + if isinstance(obj, Sample): + for k, v in obj.dict().items(): + flatten_recursive(v, path + k + "/") + elif isinstance(obj, dict): + for k, v in obj.items(): + flatten_recursive(v, path + k + "/") + elif isinstance(obj, list | tuple): + for i, item in enumerate(obj): + flatten_recursive(item, path + str(i) + "/") + elif hasattr(obj, "__len__") and not isinstance(obj, str): + flat_list = obj.flatten().tolist() + if output_type == "dict": + # Convert to list for dict storage + accumulator[path[:-1]] = flat_list + else: + accumulator.extend(flat_list) + else: + if non_numerical == "ignore" and not isinstance(obj, int | float | bool): + return + final_key = path[:-1] # Remove trailing slash + if output_type == "dict": + accumulator[final_key] = obj + else: + accumulator.append(obj) + + flatten_recursive(self) + accumulator = accumulator.values() if output_type == "dict" else accumulator + if non_numerical == "forbid" and any( + not isinstance(v, int | float | bool) for v in accumulator + ): + raise ValueError("Non-numerical values found in flattened data.") + if output_type == "np": + return np.array(accumulator) + if output_type == "pt": + torch = smart_import("torch") + return torch.tensor(accumulator) + return accumulator + + @staticmethod + def obj_to_schema(value: Any) -> Dict: + """Generates a simplified JSON schema from a dictionary. + + Args: + value (Any): An object to generate a schema for. + + Returns: + dict: A simplified JSON schema representing the structure of the dictionary. + """ + if isinstance(value, dict): + return { + "type": "object", + "properties": {k: Sample.obj_to_schema(v) for k, v in value.items()}, + } + if isinstance(value, list | tuple | np.ndarray): + if len(value) > 0: + return {"type": "array", "items": Sample.obj_to_schema(value[0])} + return {"type": "array", "items": {}} + if isinstance(value, str): + return {"type": "string"} + if isinstance(value, int | np.integer): + return {"type": "integer"} + if isinstance(value, float | np.floating): + return {"type": "number"} + if isinstance(value, bool): + return {"type": "boolean"} + return {} + + def schema(self, resolve_refs: bool = True, include_descriptions=False) -> Dict: + """Returns a simplified json schema. + + Removing additionalProperties, + selecting the first type in anyOf, and converting numpy schema to the desired type. + Optionally resolves references. + + Args: + resolve_refs (bool): Whether to resolve references in the schema. Defaults to True. + include_descriptions (bool): Whether to include descriptions in the schema. Defaults to False. + + Returns: + dict: A simplified JSON schema. + """ + schema = self.model_json_schema() + if "additionalProperties" in schema: + del schema["additionalProperties"] + + if resolve_refs: + schema = replace_refs(schema) + + if not include_descriptions and "description" in schema: + del schema["description"] + + properties = schema.get("properties", {}) + for key, value in self.dict().items(): + if key not in properties: + properties[key] = Sample.obj_to_schema(value) + if isinstance(value, Sample): + properties[key] = value.schema( + resolve_refs=resolve_refs, include_descriptions=include_descriptions + ) + else: + properties[key] = Sample.obj_to_schema(value) + return schema + + @classmethod + def read(cls, data: Any) -> "Sample": + """Read a Sample instance from a JSON string or dictionary or path. + + Args: + data (Any): The JSON string or dictionary to read. + + Returns: + Sample: The read Sample instance. + """ + if isinstance(data, str): + try: + data = cls.model_validate(from_json(data)) + except Exception as e: + logging.info(f"Error reading data: {e}. Attempting to read as JSON.") + if isinstance(data, str): + if Path(data).exists(): + if hasattr(cls, "open"): + data = cls.open(data) + else: + data = Path(data).read_text() + data = json.loads(data) + else: + data = json.load(data) + + if isinstance(data, dict): + return cls(**data) + return cls(data) + + def to(self, container: Any) -> Any: + """Convert the Sample instance to a different container type. + + Args: + container (Any): The container type to convert to. Supported types are + 'dict', 'list', 'np', 'pt' (pytorch), 'space' (gym.space), + 'schema', 'json', 'hf' (datasets.Dataset) and any subtype of Sample. + + Returns: + Any: The converted container. + """ + if isinstance(container, Sample) and not issubclass(container, Sample): + return container(**self.dict()) + if isinstance(container, type) and issubclass(container, Sample): + return container.unflatten(self.flatten()) + + if container == "dict": + return self.dict() + if container == "list": + return self.flatten(output_type="list") + if container == "np": + return self.flatten(output_type="np") + if container == "pt": + return self.flatten(output_type="pt") + if container == "space": + return self.space() + if container == "schema": + return self.schema() + if container == "json": + return self.model_dump_json() + if container == "hf": + return Dataset.from_dict(self.dict()) + if container == "features": + return to_features(self.dict()) + raise ValueError(f"Unsupported container type: {container}") + + @classmethod + def default_value(cls) -> "Sample": + """Get the default value for the Sample instance. + + Returns: + Sample: The default value for the Sample instance. + """ + return cls() + + @classmethod + def space_for( + cls, + value: Any, + max_text_length: int = 1000, + info: Annotated = None, + ) -> spaces.Space: + """Default Gym space generation for a given value. + + Only used for subclasses that do not override the space method. + """ + if isinstance(value, Enum) or get_origin(value) == Literal: + return spaces.Discrete(len(value.__args__)) + if isinstance(value, bool): + return spaces.Discrete(2) + if isinstance(value, dict | Sample): + if isinstance(value, Sample): + value = value.dict() + return spaces.Dict( + {k: Sample.space_for(v, max_text_length, info) for k, v in value.items()}, + ) + if isinstance(value, str): + return spaces.Text(max_length=max_text_length) + if isinstance(value, int | float | list | tuple | np.ndarray): + shape = None + le = None + ge = None + dtype = None + if info is not None: + shape = info.metadata_lookup.get("shape") + le = info.metadata_lookup.get("le") + ge = info.metadata_lookup.get("ge") + dtype = info.metadata_lookup.get("dtype") + logging.debug( + "Generating space for value: %s, shape: %s, le: %s, ge: %s, dtype: %s", + value, + shape, + le, + ge, + dtype, + ) + try: + value = np.asfarray(value) + shape = shape or value.shape + dtype = dtype or value.dtype + le = le or -np.inf + ge = ge or np.inf + return spaces.Box(low=le, high=ge, shape=shape, dtype=dtype) + except Exception as e: + logging.info(f"Could not convert value {value} to numpy array: {e}") + if len(value) > 0 and isinstance(value[0], dict | Sample): + return spaces.Tuple( + [spaces.Dict(cls.space_for(v, max_text_length, info)) for v in value], + ) + return spaces.Tuple( + [cls.space_for(value[0], max_text_length, info) for value in value[:1]], + ) + raise ValueError(f"Unsupported object {value} of type: {type(value)} for space generation") + + @classmethod + def init_from(cls, d: Any, pack=False) -> "Sample": + if isinstance(d, spaces.Space): + return cls.from_space(d) + if isinstance(d, Union[Sequence, np.ndarray]): # noqa: UP007 + if pack: + return cls.pack_from(d) + return cls.unflatten(d) + if isinstance(d, dict): + try: + return cls.model_validate(d) + except ValidationError as e: + logging.info(f" Unable to validate {d} as {cls} {e}. Attempting to unflatten.") + + try: + return cls.unflatten(d) + except Exception as e: + logging.info(f" Unable to unflatten {d} as {cls} {e}. Attempting to read.") + return cls.read(d) + return cls(d) + + @classmethod + def from_flat_dict(cls, flat_dict: Dict[str, Any], schema: Dict = None) -> "Sample": + """Initialize a Sample instance from a flattened dictionary.""" + """ + Reconstructs the original JSON object from a flattened dictionary using the provided schema. + + Args: + flat_dict (dict): A flattened dictionary with keys like "key1.nestedkey1". + schema (dict): A dictionary representing the JSON schema. + + Returns: + dict: The reconstructed JSON object. + """ + schema = schema or replace_refs(cls.model_json_schema()) + reconstructed = {} + + for flat_key, value in flat_dict.items(): + keys = flat_key.split(".") + current = reconstructed + for key in keys[:-1]: + if key not in current: + current[key] = {} + current = current[key] + current[keys[-1]] = value + + return reconstructed + + @classmethod + def from_space(cls, space: spaces.Space) -> "Sample": + """Generate a Sample instance from a Gym space.""" + sampled = space.sample() + if isinstance(sampled, dict | OrderedDict): + return cls(**sampled) + if hasattr(sampled, "__len__") and not isinstance(sampled, str): + sampled = np.asarray(sampled) + if len(sampled.shape) > 0 and isinstance(sampled[0], dict | Sample): + return cls.pack_from(sampled) + return cls(sampled) + + @classmethod + def pack_from(cls, samples: List[Union["Sample", Dict]]) -> "Sample": + """Pack a list of samples into a single sample with lists for attributes. + + Args: + samples (List[Union[Sample, Dict]]): List of samples or dictionaries. + + Returns: + Sample: Packed sample with lists for attributes. + """ + if samples is None or len(samples) == 0: + return cls() + + first_sample = samples[0] + if isinstance(first_sample, dict): + attributes = list(first_sample.keys()) + elif hasattr(first_sample, "__dict__"): + attributes = list(first_sample.__dict__.keys()) + else: + attributes = ["item" + str(i) for i in range(len(samples))] + + aggregated = {attr: [] for attr in attributes} + for sample in samples: + for attr in attributes: + # Handle both Sample instances and dictionaries + if isinstance(sample, dict): + aggregated[attr].append(sample.get(attr, None)) + else: + aggregated[attr].append(getattr(sample, attr, None)) + return cls(**aggregated) + + def unpack(self, to_dicts=False) -> List[Union["Sample", Dict]]: + """Unpack the packed Sample object into a list of Sample objects or dictionaries.""" + attributes = list(self.model_extra.keys()) + list(self.model_fields.keys()) + attributes = [attr for attr in attributes if getattr(self, attr) is not None] + if not attributes or getattr(self, attributes[0]) is None: + return [] + + # Ensure all attributes are lists and have the same length + list_sizes = { + len(getattr(self, attr)) for attr in attributes if isinstance(getattr(self, attr), list) + } + if len(list_sizes) != 1: + raise ValueError("Not all attribute lists have the same length.") + list_size = list_sizes.pop() + + if to_dicts: + return [{key: getattr(self, key)[i] for key in attributes} for i in range(list_size)] + + return [ + self.__class__(**{key: getattr(self, key)[i] for key in attributes}) + for i in range(list_size) + ] + + @classmethod + def default_space(cls) -> spaces.Dict: + """Return the Gym space for the Sample class based on its class attributes.""" + return cls().space() + + @classmethod + def default_sample(cls, output_type="Sample") -> Union["Sample", Dict[str, Any]]: + """Generate a default Sample instance from its class attributes. Useful for padding. + + This is the "no-op" instance and should be overriden as needed. + """ + if output_type == "Sample": + return cls() + return cls().dict() + + def model_field_info(self, key: str) -> FieldInfo: + """Get the FieldInfo for a given attribute key.""" + if self.model_extra and self.model_extra.get(key) is not None: + info = FieldInfo(metadata=self.model_extra[key]) + if self.model_fields.get(key) is not None: + info = FieldInfo(metadata=self.model_fields[key]) + + if info and hasattr(info, "annotation"): + return info.annotation + return None + + def space(self) -> spaces.Dict: + """Return the corresponding Gym space for the Sample instance based on its instance attributes. Omits None values. + + Override this method in subclasses to customize the space generation. + """ + space_dict = {} + for key, value in self.dict().items(): + logging.debug("Generating space for key: '%s', value: %s", key, value) + info = self.model_field_info(key) + value = getattr(self, key) if hasattr(self, key) else value # noqa: PLW2901 + space_dict[key] = ( + value.space() if isinstance(value, Sample) else self.space_for(value, info=info) + ) + return spaces.Dict(space_dict) + + def random_sample(self) -> "Sample": + """Generate a random Sample instance based on its instance attributes. Omits None values. + + Override this method in subclasses to customize the sample generation. + """ + return self.__class__.model_validate(self.space().sample()) + + +if __name__ == "__main__": + sample = Sample(x=1, y=2, z={"a": 3, "b": 4}, extra_field=5) diff --git a/build/lib/dimos/types/segmentation.py b/build/lib/dimos/types/segmentation.py new file mode 100644 index 0000000000..5995f302f9 --- /dev/null +++ b/build/lib/dimos/types/segmentation.py @@ -0,0 +1,44 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Any +import numpy as np + + +class SegmentationType: + def __init__(self, masks: List[np.ndarray], metadata: Any = None): + """ + Initializes a standardized segmentation type. + + Args: + masks (List[np.ndarray]): A list of binary masks for segmentation. + metadata (Any, optional): Additional metadata related to the segmentations. + """ + self.masks = masks + self.metadata = metadata + + def combine_masks(self): + """Combine all masks into a single mask.""" + combined_mask = np.zeros_like(self.masks[0]) + for mask in self.masks: + combined_mask = np.logical_or(combined_mask, mask) + return combined_mask + + def save_masks(self, directory: str): + """Save each mask to a separate file.""" + import os + + os.makedirs(directory, exist_ok=True) + for i, mask in enumerate(self.masks): + np.save(os.path.join(directory, f"mask_{i}.npy"), mask) diff --git a/build/lib/dimos/types/test_pose.py b/build/lib/dimos/types/test_pose.py new file mode 100644 index 0000000000..e95133e035 --- /dev/null +++ b/build/lib/dimos/types/test_pose.py @@ -0,0 +1,323 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import math +from dimos.types.pose import Pose, to_pose +from dimos.types.vector import Vector + + +def test_pose_default_init(): + """Test that default initialization of Pose() has zero vectors for pos and rot.""" + pose = Pose() + + # Check that pos is a zero vector + assert isinstance(pose.pos, Vector) + assert pose.pos.is_zero() + assert pose.pos.x == 0.0 + assert pose.pos.y == 0.0 + assert pose.pos.z == 0.0 + + # Check that rot is a zero vector + assert isinstance(pose.rot, Vector) + assert pose.rot.is_zero() + assert pose.rot.x == 0.0 + assert pose.rot.y == 0.0 + assert pose.rot.z == 0.0 + + assert pose.is_zero() + + assert not pose + + +def test_pose_vector_init(): + """Test initialization with custom vectors.""" + pos = Vector(1.0, 2.0, 3.0) + rot = Vector(4.0, 5.0, 6.0) + + pose = Pose(pos, rot) + + # Check pos vector + assert pose.pos == pos + assert pose.pos.x == 1.0 + assert pose.pos.y == 2.0 + assert pose.pos.z == 3.0 + + # Check rot vector + assert pose.rot == rot + assert pose.rot.x == 4.0 + assert pose.rot.y == 5.0 + assert pose.rot.z == 6.0 + + # even if pos has the same xyz as pos vector + # it shouldn't accept equality comparisons + # as both are not the same type + assert not pose == pos + + +def test_pose_partial_init(): + """Test initialization with only one custom vector.""" + pos = Vector(1.0, 2.0, 3.0) + assert pos + + # Only specify pos + pose1 = Pose(pos) + assert pose1.pos == pos + assert pose1.pos.x == 1.0 + assert pose1.pos.y == 2.0 + assert pose1.pos.z == 3.0 + assert not pose1.pos.is_zero() + + assert isinstance(pose1.rot, Vector) + assert pose1.rot.is_zero() + assert pose1.rot.x == 0.0 + assert pose1.rot.y == 0.0 + assert pose1.rot.z == 0.0 + + +def test_pose_equality(): + """Test equality comparison between positions.""" + pos1 = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) + pos2 = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) + pos3 = Pose(Vector(1.0, 2.0, 3.0), Vector(7.0, 8.0, 9.0)) + pos4 = Pose(Vector(7.0, 8.0, 9.0), Vector(4.0, 5.0, 6.0)) + + # Same pos and rot values should be equal + assert pos1 == pos2 + + # Different rot values should not be equal + assert pos1 != pos3 + + # Different pos values should not be equal + assert pos1 != pos4 + + # Pose should not equal a vector even if values match + assert pos1 != Vector(1.0, 2.0, 3.0) + + +def test_pose_vector_operations(): + """Test that Pose inherits Vector operations.""" + pos1 = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) + pos2 = Pose(Vector(2.0, 3.0, 4.0), Vector(7.0, 8.0, 9.0)) + + # Addition should work on both position and rotation components + sum_pos = pos1 + pos2 + assert isinstance(sum_pos, Pose) + assert sum_pos.x == 3.0 + assert sum_pos.y == 5.0 + assert sum_pos.z == 7.0 + # Rotation should be added as well + assert sum_pos.rot.x == 11.0 # 4.0 + 7.0 + assert sum_pos.rot.y == 13.0 # 5.0 + 8.0 + assert sum_pos.rot.z == 15.0 # 6.0 + 9.0 + + # Subtraction should work on both position and rotation components + diff_pos = pos2 - pos1 + assert isinstance(diff_pos, Pose) + assert diff_pos.x == 1.0 + assert diff_pos.y == 1.0 + assert diff_pos.z == 1.0 + # Rotation should be subtracted as well + assert diff_pos.rot.x == 3.0 # 7.0 - 4.0 + assert diff_pos.rot.y == 3.0 # 8.0 - 5.0 + assert diff_pos.rot.z == 3.0 # 9.0 - 6.0 + + # Scalar multiplication + scaled_pos = pos1 * 2.0 + assert isinstance(scaled_pos, Pose) + assert scaled_pos.x == 2.0 + assert scaled_pos.y == 4.0 + assert scaled_pos.z == 6.0 + assert scaled_pos.rot == pos1.rot # Rotation not affected by scalar multiplication + + # Adding a Vector to a Pose (only affects position component) + vec = Vector(5.0, 6.0, 7.0) + pos_plus_vec = pos1 + vec + assert isinstance(pos_plus_vec, Pose) + assert pos_plus_vec.x == 6.0 + assert pos_plus_vec.y == 8.0 + assert pos_plus_vec.z == 10.0 + assert pos_plus_vec.rot == pos1.rot # Rotation unchanged + + +def test_pose_serialization(): + """Test pose serialization.""" + pos = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) + serialized = pos.serialize() + + assert serialized["type"] == "pose" + assert serialized["pos"] == [1.0, 2.0, 3.0] + assert serialized["rot"] == [4.0, 5.0, 6.0] + + +def test_pose_initialization_with_arrays(): + """Test initialization with numpy arrays, lists and tuples.""" + # Test with numpy arrays + np_pos = np.array([1.0, 2.0, 3.0]) + np_rot = np.array([4.0, 5.0, 6.0]) + + pos1 = Pose(np_pos, np_rot) + + assert pos1.x == 1.0 + assert pos1.y == 2.0 + assert pos1.z == 3.0 + assert pos1.rot.x == 4.0 + assert pos1.rot.y == 5.0 + assert pos1.rot.z == 6.0 + + # Test with lists + list_pos = [7.0, 8.0, 9.0] + list_rot = [10.0, 11.0, 12.0] + pos2 = Pose(list_pos, list_rot) + + assert pos2.x == 7.0 + assert pos2.y == 8.0 + assert pos2.z == 9.0 + assert pos2.rot.x == 10.0 + assert pos2.rot.y == 11.0 + assert pos2.rot.z == 12.0 + + # Test with tuples + tuple_pos = (13.0, 14.0, 15.0) + tuple_rot = (16.0, 17.0, 18.0) + pos3 = Pose(tuple_pos, tuple_rot) + + assert pos3.x == 13.0 + assert pos3.y == 14.0 + assert pos3.z == 15.0 + assert pos3.rot.x == 16.0 + assert pos3.rot.y == 17.0 + assert pos3.rot.z == 18.0 + + +def test_to_pose_with_pose(): + """Test to_pose with Pose input.""" + # Create a pose + original_pos = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) + + # Convert using to_pose + converted_pos = to_pose(original_pos) + + # Should return the exact same object + assert converted_pos is original_pos + assert converted_pos == original_pos + + # Check values + assert converted_pos.x == 1.0 + assert converted_pos.y == 2.0 + assert converted_pos.z == 3.0 + assert converted_pos.rot.x == 4.0 + assert converted_pos.rot.y == 5.0 + assert converted_pos.rot.z == 6.0 + + +def test_to_pose_with_vector(): + """Test to_pose with Vector input.""" + # Create a vector + vec = Vector(1.0, 2.0, 3.0) + + # Convert using to_pose + pos = to_pose(vec) + + # Should return a Pose with the vector as position and zero rotation + assert isinstance(pos, Pose) + assert pos.pos == vec + assert pos.x == 1.0 + assert pos.y == 2.0 + assert pos.z == 3.0 + + # Rotation should be zero + assert isinstance(pos.rot, Vector) + assert pos.rot.is_zero() + assert pos.rot.x == 0.0 + assert pos.rot.y == 0.0 + assert pos.rot.z == 0.0 + + +def test_to_pose_with_vectorlike(): + """Test to_pose with VectorLike inputs (arrays, lists, tuples).""" + # Test with numpy arrays + np_arr = np.array([1.0, 2.0, 3.0]) + pos1 = to_pose(np_arr) + + assert isinstance(pos1, Pose) + assert pos1.x == 1.0 + assert pos1.y == 2.0 + assert pos1.z == 3.0 + assert pos1.rot.is_zero() + + # Test with lists + list_val = [4.0, 5.0, 6.0] + pos2 = to_pose(list_val) + + assert isinstance(pos2, Pose) + assert pos2.x == 4.0 + assert pos2.y == 5.0 + assert pos2.z == 6.0 + assert pos2.rot.is_zero() + + # Test with tuples + tuple_val = (7.0, 8.0, 9.0) + pos3 = to_pose(tuple_val) + + assert isinstance(pos3, Pose) + assert pos3.x == 7.0 + assert pos3.y == 8.0 + assert pos3.z == 9.0 + assert pos3.rot.is_zero() + + +def test_to_pose_with_sequence(): + """Test to_pose with Sequence of VectorLike inputs.""" + # Test with sequence of two vectors + pos_vec = Vector(1.0, 2.0, 3.0) + rot_vec = Vector(4.0, 5.0, 6.0) + pos1 = to_pose([pos_vec, rot_vec]) + + assert isinstance(pos1, Pose) + assert pos1.pos == pos_vec + assert pos1.rot == rot_vec + assert pos1.x == 1.0 + assert pos1.y == 2.0 + assert pos1.z == 3.0 + assert pos1.rot.x == 4.0 + assert pos1.rot.y == 5.0 + assert pos1.rot.z == 6.0 + + # Test with sequence of lists + pos2 = to_pose([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]) + + assert isinstance(pos2, Pose) + assert pos2.x == 7.0 + assert pos2.y == 8.0 + assert pos2.z == 9.0 + assert pos2.rot.x == 10.0 + assert pos2.rot.y == 11.0 + assert pos2.rot.z == 12.0 + + # Test with mixed sequence (tuple and numpy array) + pos3 = to_pose([(13.0, 14.0, 15.0), np.array([16.0, 17.0, 18.0])]) + + assert isinstance(pos3, Pose) + assert pos3.x == 13.0 + assert pos3.y == 14.0 + assert pos3.z == 15.0 + assert pos3.rot.x == 16.0 + assert pos3.rot.y == 17.0 + assert pos3.rot.z == 18.0 + + +def test_vector_transform(): + robot_pose = Pose(Vector(4.0, 2.0, 0.5), Vector(0.0, 0.0, math.pi / 2)) + target = Vector(1.0, 3.0, 0.0) + print(robot_pose.vector_to(target)) diff --git a/build/lib/dimos/types/test_timestamped.py b/build/lib/dimos/types/test_timestamped.py new file mode 100644 index 0000000000..bf7962371e --- /dev/null +++ b/build/lib/dimos/types/test_timestamped.py @@ -0,0 +1,26 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from dimos.types.timestamped import Timestamped + + +def test_timestamped_dt_method(): + ts = 1751075203.4120464 + timestamped = Timestamped(ts) + dt = timestamped.dt() + assert isinstance(dt, datetime) + assert abs(dt.timestamp() - ts) < 1e-6 + assert dt.tzinfo is not None, "datetime should be timezone-aware" diff --git a/build/lib/dimos/types/test_vector.py b/build/lib/dimos/types/test_vector.py new file mode 100644 index 0000000000..6a93d37afd --- /dev/null +++ b/build/lib/dimos/types/test_vector.py @@ -0,0 +1,384 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from dimos.types.vector import Vector + + +def test_vector_default_init(): + """Test that default initialization of Vector() has x,y,z components all zero.""" + v = Vector() + assert v.x == 0.0 + assert v.y == 0.0 + assert v.z == 0.0 + assert v.dim == 0 + assert len(v.data) == 0 + assert v.to_list() == [] + assert v.is_zero() == True # Empty vector should be considered zero + + +def test_vector_specific_init(): + """Test initialization with specific values.""" + # 2D vector + v1 = Vector(1.0, 2.0) + assert v1.x == 1.0 + assert v1.y == 2.0 + assert v1.z == 0.0 + assert v1.dim == 2 + + # 3D vector + v2 = Vector(3.0, 4.0, 5.0) + assert v2.x == 3.0 + assert v2.y == 4.0 + assert v2.z == 5.0 + assert v2.dim == 3 + + # From list + v3 = Vector([6.0, 7.0, 8.0]) + assert v3.x == 6.0 + assert v3.y == 7.0 + assert v3.z == 8.0 + assert v3.dim == 3 + + # From numpy array + v4 = Vector(np.array([9.0, 10.0, 11.0])) + assert v4.x == 9.0 + assert v4.y == 10.0 + assert v4.z == 11.0 + assert v4.dim == 3 + + +def test_vector_addition(): + """Test vector addition.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 5.0, 6.0) + + v_add = v1 + v2 + assert v_add.x == 5.0 + assert v_add.y == 7.0 + assert v_add.z == 9.0 + + +def test_vector_subtraction(): + """Test vector subtraction.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 5.0, 6.0) + + v_sub = v2 - v1 + assert v_sub.x == 3.0 + assert v_sub.y == 3.0 + assert v_sub.z == 3.0 + + +def test_vector_scalar_multiplication(): + """Test vector multiplication by a scalar.""" + v1 = Vector(1.0, 2.0, 3.0) + + v_mul = v1 * 2.0 + assert v_mul.x == 2.0 + assert v_mul.y == 4.0 + assert v_mul.z == 6.0 + + # Test right multiplication + v_rmul = 2.0 * v1 + assert v_rmul.x == 2.0 + assert v_rmul.y == 4.0 + assert v_rmul.z == 6.0 + + +def test_vector_scalar_division(): + """Test vector division by a scalar.""" + v2 = Vector(4.0, 5.0, 6.0) + + v_div = v2 / 2.0 + assert v_div.x == 2.0 + assert v_div.y == 2.5 + assert v_div.z == 3.0 + + +def test_vector_dot_product(): + """Test vector dot product.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 5.0, 6.0) + + dot = v1.dot(v2) + assert dot == 32.0 + + +def test_vector_length(): + """Test vector length calculation.""" + # 2D vector with length 5 + v1 = Vector(3.0, 4.0) + assert v1.length() == 5.0 + + # 3D vector + v2 = Vector(2.0, 3.0, 6.0) + assert v2.length() == pytest.approx(7.0, 0.001) + + # Test length_squared + assert v1.length_squared() == 25.0 + assert v2.length_squared() == 49.0 + + +def test_vector_normalize(): + """Test vector normalization.""" + v = Vector(2.0, 3.0, 6.0) + assert v.is_zero() == False + + v_norm = v.normalize() + length = v.length() + expected_x = 2.0 / length + expected_y = 3.0 / length + expected_z = 6.0 / length + + assert np.isclose(v_norm.x, expected_x) + assert np.isclose(v_norm.y, expected_y) + assert np.isclose(v_norm.z, expected_z) + assert np.isclose(v_norm.length(), 1.0) + assert v_norm.is_zero() == False + + # Test normalizing a zero vector + v_zero = Vector(0.0, 0.0, 0.0) + assert v_zero.is_zero() == True + v_zero_norm = v_zero.normalize() + assert v_zero_norm.x == 0.0 + assert v_zero_norm.y == 0.0 + assert v_zero_norm.z == 0.0 + assert v_zero_norm.is_zero() == True + + +def test_vector_to_2d(): + """Test conversion to 2D vector.""" + v = Vector(2.0, 3.0, 6.0) + + v_2d = v.to_2d() + assert v_2d.x == 2.0 + assert v_2d.y == 3.0 + assert v_2d.z == 0.0 + assert v_2d.dim == 2 + + # Already 2D vector + v2 = Vector(4.0, 5.0) + v2_2d = v2.to_2d() + assert v2_2d.x == 4.0 + assert v2_2d.y == 5.0 + assert v2_2d.dim == 2 + + +def test_vector_distance(): + """Test distance calculations between vectors.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 6.0, 8.0) + + # Distance + dist = v1.distance(v2) + expected_dist = np.sqrt(9.0 + 16.0 + 25.0) # sqrt((4-1)² + (6-2)² + (8-3)²) + assert dist == pytest.approx(expected_dist) + + # Distance squared + dist_sq = v1.distance_squared(v2) + assert dist_sq == 50.0 # 9 + 16 + 25 + + +def test_vector_cross_product(): + """Test vector cross product.""" + v1 = Vector(1.0, 0.0, 0.0) # Unit x vector + v2 = Vector(0.0, 1.0, 0.0) # Unit y vector + + # v1 × v2 should be unit z vector + cross = v1.cross(v2) + assert cross.x == 0.0 + assert cross.y == 0.0 + assert cross.z == 1.0 + + # Test with more complex vectors + a = Vector(2.0, 3.0, 4.0) + b = Vector(5.0, 6.0, 7.0) + c = a.cross(b) + + # Cross product manually calculated: + # (3*7-4*6, 4*5-2*7, 2*6-3*5) + assert c.x == -3.0 + assert c.y == 6.0 + assert c.z == -3.0 + + # Test with 2D vectors (should raise error) + v_2d = Vector(1.0, 2.0) + with pytest.raises(ValueError): + v_2d.cross(v2) + + +def test_vector_zeros(): + """Test Vector.zeros class method.""" + # 3D zero vector + v_zeros = Vector.zeros(3) + assert v_zeros.x == 0.0 + assert v_zeros.y == 0.0 + assert v_zeros.z == 0.0 + assert v_zeros.dim == 3 + assert v_zeros.is_zero() == True + + # 2D zero vector + v_zeros_2d = Vector.zeros(2) + assert v_zeros_2d.x == 0.0 + assert v_zeros_2d.y == 0.0 + assert v_zeros_2d.z == 0.0 + assert v_zeros_2d.dim == 2 + assert v_zeros_2d.is_zero() == True + + +def test_vector_ones(): + """Test Vector.ones class method.""" + # 3D ones vector + v_ones = Vector.ones(3) + assert v_ones.x == 1.0 + assert v_ones.y == 1.0 + assert v_ones.z == 1.0 + assert v_ones.dim == 3 + + # 2D ones vector + v_ones_2d = Vector.ones(2) + assert v_ones_2d.x == 1.0 + assert v_ones_2d.y == 1.0 + assert v_ones_2d.z == 0.0 + assert v_ones_2d.dim == 2 + + +def test_vector_conversion_methods(): + """Test vector conversion methods (to_list, to_tuple, to_numpy).""" + v = Vector(1.0, 2.0, 3.0) + + # to_list + assert v.to_list() == [1.0, 2.0, 3.0] + + # to_tuple + assert v.to_tuple() == (1.0, 2.0, 3.0) + + # to_numpy + np_array = v.to_numpy() + assert isinstance(np_array, np.ndarray) + assert np.array_equal(np_array, np.array([1.0, 2.0, 3.0])) + + +def test_vector_equality(): + """Test vector equality.""" + v1 = Vector(1, 2, 3) + v2 = Vector(1, 2, 3) + v3 = Vector(4, 5, 6) + + assert v1 == v2 + assert v1 != v3 + assert v1 != Vector(1, 2) # Different dimensions + assert v1 != Vector(1.1, 2, 3) # Different values + assert v1 != [1, 2, 3] + + +def test_vector_is_zero(): + """Test is_zero method for vectors.""" + # Default empty vector + v0 = Vector() + assert v0.is_zero() == True + + # Explicit zero vector + v1 = Vector(0.0, 0.0, 0.0) + assert v1.is_zero() == True + + # Zero vector with different dimensions + v2 = Vector(0.0, 0.0) + assert v2.is_zero() == True + + # Non-zero vectors + v3 = Vector(1.0, 0.0, 0.0) + assert v3.is_zero() == False + + v4 = Vector(0.0, 2.0, 0.0) + assert v4.is_zero() == False + + v5 = Vector(0.0, 0.0, 3.0) + assert v5.is_zero() == False + + # Almost zero (within tolerance) + v6 = Vector(1e-10, 1e-10, 1e-10) + assert v6.is_zero() == True + + # Almost zero (outside tolerance) + v7 = Vector(1e-6, 1e-6, 1e-6) + assert v7.is_zero() == False + + +def test_vector_bool_conversion(): + """Test boolean conversion of vectors.""" + # Zero vectors should be False + v0 = Vector() + assert bool(v0) == False + + v1 = Vector(0.0, 0.0, 0.0) + assert bool(v1) == False + + # Almost zero vectors should be False + v2 = Vector(1e-10, 1e-10, 1e-10) + assert bool(v2) == False + + # Non-zero vectors should be True + v3 = Vector(1.0, 0.0, 0.0) + assert bool(v3) == True + + v4 = Vector(0.0, 2.0, 0.0) + assert bool(v4) == True + + v5 = Vector(0.0, 0.0, 3.0) + assert bool(v5) == True + + # Direct use in if statements + if v0: + assert False, "Zero vector should be False in boolean context" + else: + pass # Expected path + + if v3: + pass # Expected path + else: + assert False, "Non-zero vector should be True in boolean context" + + +def test_vector_add(): + """Test vector addition operator.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 5.0, 6.0) + + # Using __add__ method + v_add = v1.__add__(v2) + assert v_add.x == 5.0 + assert v_add.y == 7.0 + assert v_add.z == 9.0 + + # Using + operator + v_add_op = v1 + v2 + assert v_add_op.x == 5.0 + assert v_add_op.y == 7.0 + assert v_add_op.z == 9.0 + + # Adding zero vector should return original vector + v_zero = Vector.zeros(3) + assert (v1 + v_zero) == v1 + + +def test_vector_add_dim_mismatch(): + """Test vector addition operator.""" + v1 = Vector(1.0, 2.0) + v2 = Vector(4.0, 5.0, 6.0) + + # Using + operator + v_add_op = v1 + v2 diff --git a/build/lib/dimos/types/timestamped.py b/build/lib/dimos/types/timestamped.py new file mode 100644 index 0000000000..189bf7eaec --- /dev/null +++ b/build/lib/dimos/types/timestamped.py @@ -0,0 +1,55 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timezone +from typing import Generic, Iterable, Tuple, TypedDict, TypeVar, Union + +# any class that carries a timestamp should inherit from this +# this allows us to work with timeseries in consistent way, allign messages, replay etc +# aditional functionality will come to this class soon + + +class RosStamp(TypedDict): + sec: int + nanosec: int + + +EpochLike = Union[int, float, datetime, RosStamp] + + +def to_timestamp(ts: EpochLike) -> float: + """Convert EpochLike to a timestamp in seconds.""" + if isinstance(ts, datetime): + return ts.timestamp() + if isinstance(ts, (int, float)): + return float(ts) + if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: + return ts["sec"] + ts["nanosec"] / 1e9 + raise TypeError("unsupported timestamp type") + + +class Timestamped: + ts: float + + def __init__(self, ts: float): + self.ts = ts + + def dt(self) -> datetime: + return datetime.fromtimestamp(self.ts, tz=timezone.utc).astimezone() + + def ros_timestamp(self) -> dict[str, int]: + """Convert timestamp to ROS-style dictionary.""" + sec = int(self.ts) + nanosec = int((self.ts - sec) * 1_000_000_000) + return [sec, nanosec] diff --git a/build/lib/dimos/types/vector.py b/build/lib/dimos/types/vector.py new file mode 100644 index 0000000000..d980e28105 --- /dev/null +++ b/build/lib/dimos/types/vector.py @@ -0,0 +1,460 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple, TypeVar, Union, Sequence + +import numpy as np +from dimos.types.ros_polyfill import Vector3 + +T = TypeVar("T", bound="Vector") + +# Vector-like types that can be converted to/from Vector +VectorLike = Union[Sequence[Union[int, float]], Vector3, "Vector", np.ndarray] + + +class Vector: + """A wrapper around numpy arrays for vector operations with intuitive syntax.""" + + def __init__(self, *args: VectorLike): + """Initialize a vector from components or another iterable. + + Examples: + Vector(1, 2) # 2D vector + Vector(1, 2, 3) # 3D vector + Vector([1, 2, 3]) # From list + Vector(np.array([1, 2, 3])) # From numpy array + """ + if len(args) == 1 and hasattr(args[0], "__iter__"): + self._data = np.array(args[0], dtype=float) + + elif len(args) == 1: + self._data = np.array([args[0].x, args[0].y, args[0].z], dtype=float) + + else: + self._data = np.array(args, dtype=float) + + @property + def yaw(self) -> float: + return self.x + + @property + def tuple(self) -> Tuple[float, ...]: + """Tuple representation of the vector.""" + return tuple(self._data) + + @property + def x(self) -> float: + """X component of the vector.""" + return self._data[0] if len(self._data) > 0 else 0.0 + + @property + def y(self) -> float: + """Y component of the vector.""" + return self._data[1] if len(self._data) > 1 else 0.0 + + @property + def z(self) -> float: + """Z component of the vector.""" + return self._data[2] if len(self._data) > 2 else 0.0 + + @property + def dim(self) -> int: + """Dimensionality of the vector.""" + return len(self._data) + + @property + def data(self) -> np.ndarray: + """Get the underlying numpy array.""" + return self._data + + def __getitem__(self, idx): + return self._data[idx] + + def __repr__(self) -> str: + return f"Vector({self.data})" + + def __str__(self) -> str: + if self.dim < 2: + return self.__repr__() + + def getArrow(): + repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] + + if self.x == 0 and self.y == 0: + return "·" + + # Calculate angle in radians and convert to directional index + angle = np.arctan2(self.y, self.x) + # Map angle to 0-7 index (8 directions) with proper orientation + dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) + # Get directional arrow symbol + return repr[dir_index] + + return f"{getArrow()} Vector {self.__repr__()}" + + def serialize(self) -> Tuple: + """Serialize the vector to a tuple.""" + return {"type": "vector", "c": self._data.tolist()} + + def __eq__(self, other) -> bool: + """Check if two vectors are equal using numpy's allclose for floating point comparison.""" + if not isinstance(other, Vector): + return False + if len(self._data) != len(other._data): + return False + return np.allclose(self._data, other._data) + + def __add__(self: T, other: VectorLike) -> T: + other = to_vector(other) + if self.dim != other.dim: + max_dim = max(self.dim, other.dim) + return self.pad(max_dim) + other.pad(max_dim) + return self.__class__(self._data + other._data) + + def __sub__(self: T, other: VectorLike) -> T: + other = to_vector(other) + if self.dim != other.dim: + max_dim = max(self.dim, other.dim) + return self.pad(max_dim) - other.pad(max_dim) + return self.__class__(self._data - other._data) + + def __mul__(self: T, scalar: float) -> T: + return self.__class__(self._data * scalar) + + def __rmul__(self: T, scalar: float) -> T: + return self.__mul__(scalar) + + def __truediv__(self: T, scalar: float) -> T: + return self.__class__(self._data / scalar) + + def __neg__(self: T) -> T: + return self.__class__(-self._data) + + def dot(self, other: VectorLike) -> float: + """Compute dot product.""" + other = to_vector(other) + return float(np.dot(self._data, other._data)) + + def cross(self: T, other: VectorLike) -> T: + """Compute cross product (3D vectors only).""" + if self.dim != 3: + raise ValueError("Cross product is only defined for 3D vectors") + + other = to_vector(other) + if other.dim != 3: + raise ValueError("Cross product requires two 3D vectors") + + return self.__class__(np.cross(self._data, other._data)) + + def length(self) -> float: + """Compute the Euclidean length (magnitude) of the vector.""" + return float(np.linalg.norm(self._data)) + + def length_squared(self) -> float: + """Compute the squared length of the vector (faster than length()).""" + return float(np.sum(self._data * self._data)) + + def normalize(self: T) -> T: + """Return a normalized unit vector in the same direction.""" + length = self.length() + if length < 1e-10: # Avoid division by near-zero + return self.__class__(np.zeros_like(self._data)) + return self.__class__(self._data / length) + + def to_2d(self: T) -> T: + """Convert a vector to a 2D vector by taking only the x and y components.""" + return self.__class__(self._data[:2]) + + def pad(self: T, dim: int) -> T: + """Pad a vector with zeros to reach the specified dimension. + + If vector already has dimension >= dim, it is returned unchanged. + """ + if self.dim >= dim: + return self + + padded = np.zeros(dim, dtype=float) + padded[: len(self._data)] = self._data + return self.__class__(padded) + + def distance(self, other: VectorLike) -> float: + """Compute Euclidean distance to another vector.""" + other = to_vector(other) + return float(np.linalg.norm(self._data - other._data)) + + def distance_squared(self, other: VectorLike) -> float: + """Compute squared Euclidean distance to another vector (faster than distance()).""" + other = to_vector(other) + diff = self._data - other._data + return float(np.sum(diff * diff)) + + def angle(self, other: VectorLike) -> float: + """Compute the angle (in radians) between this vector and another.""" + other = to_vector(other) + if self.length() < 1e-10 or other.length() < 1e-10: + return 0.0 + + cos_angle = np.clip( + np.dot(self._data, other._data) + / (np.linalg.norm(self._data) * np.linalg.norm(other._data)), + -1.0, + 1.0, + ) + return float(np.arccos(cos_angle)) + + def project(self: T, onto: VectorLike) -> T: + """Project this vector onto another vector.""" + onto = to_vector(onto) + onto_length_sq = np.sum(onto._data * onto._data) + if onto_length_sq < 1e-10: + return self.__class__(np.zeros_like(self._data)) + + scalar_projection = np.dot(self._data, onto._data) / onto_length_sq + return self.__class__(scalar_projection * onto._data) + + # this is here to test ros_observable_topic + # doesn't happen irl afaik that we want a vector from ros message + @classmethod + def from_msg(cls: type[T], msg) -> T: + return cls(*msg) + + @classmethod + def zeros(cls: type[T], dim: int) -> T: + """Create a zero vector of given dimension.""" + return cls(np.zeros(dim)) + + @classmethod + def ones(cls: type[T], dim: int) -> T: + """Create a vector of ones with given dimension.""" + return cls(np.ones(dim)) + + @classmethod + def unit_x(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the x direction.""" + v = np.zeros(dim) + v[0] = 1.0 + return cls(v) + + @classmethod + def unit_y(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the y direction.""" + v = np.zeros(dim) + v[1] = 1.0 + return cls(v) + + @classmethod + def unit_z(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the z direction.""" + v = np.zeros(dim) + if dim > 2: + v[2] = 1.0 + return cls(v) + + def to_list(self) -> List[float]: + """Convert the vector to a list.""" + return self._data.tolist() + + def to_tuple(self) -> Tuple[float, ...]: + """Convert the vector to a tuple.""" + return tuple(self._data) + + def to_numpy(self) -> np.ndarray: + """Convert the vector to a numpy array.""" + return self._data + + def is_zero(self) -> bool: + """Check if this is a zero vector (all components are zero). + + Returns: + True if all components are zero, False otherwise + """ + return np.allclose(self._data, 0.0) + + def __bool__(self) -> bool: + """Boolean conversion for Vector. + + A Vector is considered False if it's a zero vector (all components are zero), + and True otherwise. + + Returns: + False if vector is zero, True otherwise + """ + return not self.is_zero() + + +def to_numpy(value: VectorLike) -> np.ndarray: + """Convert a vector-compatible value to a numpy array. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Numpy array representation + """ + if isinstance(value, Vector3): + return np.array([value.x, value.y, value.z], dtype=float) + if isinstance(value, Vector): + return value.data + elif isinstance(value, np.ndarray): + return value + else: + return np.array(value, dtype=float) + + +def to_vector(value: VectorLike) -> Vector: + """Convert a vector-compatible value to a Vector object. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Vector object + """ + if isinstance(value, Vector): + return value + else: + return Vector(value) + + +def to_tuple(value: VectorLike) -> Tuple[float, ...]: + """Convert a vector-compatible value to a tuple. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Tuple of floats + """ + if isinstance(value, Vector3): + return tuple([value.x, value.y, value.z]) + if isinstance(value, Vector): + return tuple(value.data) + elif isinstance(value, np.ndarray): + return tuple(value.tolist()) + elif isinstance(value, tuple): + return value + else: + return tuple(value) + + +def to_list(value: VectorLike) -> List[float]: + """Convert a vector-compatible value to a list. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + List of floats + """ + if isinstance(value, Vector): + return value.data.tolist() + elif isinstance(value, np.ndarray): + return value.tolist() + elif isinstance(value, list): + return value + else: + return list(value) + + +# Helper functions to check dimensionality +def is_2d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 2D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 2D + """ + if isinstance(value, Vector3): + return False + elif isinstance(value, Vector): + return len(value) == 2 + elif isinstance(value, np.ndarray): + return value.shape[-1] == 2 or value.size == 2 + else: + return len(value) == 2 + + +def is_3d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 3D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 3D + """ + if isinstance(value, Vector): + return len(value) == 3 + elif isinstance(value, Vector3): + return True + elif isinstance(value, np.ndarray): + return value.shape[-1] == 3 or value.size == 3 + else: + return len(value) == 3 + + +# Extraction functions for XYZ components +def x(value: VectorLike) -> float: + """Get the X component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + X component as a float + """ + if isinstance(value, Vector): + return value.x + elif isinstance(value, Vector3): + return value.x + else: + return float(to_numpy(value)[0]) + + +def y(value: VectorLike) -> float: + """Get the Y component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Y component as a float + """ + if isinstance(value, Vector): + return value.y + elif isinstance(value, Vector3): + return value.y + else: + arr = to_numpy(value) + return float(arr[1]) if len(arr) > 1 else 0.0 + + +def z(value: VectorLike) -> float: + """Get the Z component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Z component as a float + """ + if isinstance(value, Vector): + return value.z + elif isinstance(value, Vector3): + return value.z + else: + arr = to_numpy(value) + return float(arr[2]) if len(arr) > 2 else 0.0 diff --git a/build/lib/dimos/web/__init__.py b/build/lib/dimos/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/web/dimos_interface/__init__.py b/build/lib/dimos/web/dimos_interface/__init__.py new file mode 100644 index 0000000000..5ca28b30e5 --- /dev/null +++ b/build/lib/dimos/web/dimos_interface/__init__.py @@ -0,0 +1,7 @@ +""" +Dimensional Interface package +""" + +from .api.server import FastAPIServer + +__all__ = ["FastAPIServer"] diff --git a/build/lib/dimos/web/dimos_interface/api/__init__.py b/build/lib/dimos/web/dimos_interface/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/build/lib/dimos/web/dimos_interface/api/server.py b/build/lib/dimos/web/dimos_interface/api/server.py new file mode 100644 index 0000000000..bcc590ab46 --- /dev/null +++ b/build/lib/dimos/web/dimos_interface/api/server.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Working FastAPI/Uvicorn Impl. + +# Notes: Do not use simultaneously with Flask, this includes imports. +# Workers are not yet setup, as this requires a much more intricate +# reorganization. There appears to be possible signalling issues when +# opening up streams on multiple windows/reloading which will need to +# be fixed. Also note, Chrome only supports 6 simultaneous web streams, +# and its advised to test threading/worker performance with another +# browser like Safari. + +# Fast Api & Uvicorn +import cv2 +from dimos.web.edge_io import EdgeIO +from fastapi import FastAPI, Request, Form, HTTPException, UploadFile, File +from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse +from sse_starlette.sse import EventSourceResponse +from fastapi.templating import Jinja2Templates +import uvicorn +from threading import Lock +from pathlib import Path +from queue import Queue, Empty +import asyncio + +from reactivex.disposable import SingleAssignmentDisposable +from reactivex import operators as ops +import reactivex as rx +from fastapi.middleware.cors import CORSMiddleware + +# For audio processing +import io +import time +import numpy as np +import ffmpeg +import soundfile as sf +from dimos.stream.audio.base import AudioEvent + +# TODO: Resolve threading, start/stop stream functionality. + + +class FastAPIServer(EdgeIO): + def __init__( + self, + dev_name="FastAPI Server", + edge_type="Bidirectional", + host="0.0.0.0", + port=5555, + text_streams=None, + audio_subject=None, + **streams, + ): + print("Starting FastAPIServer initialization...") # Debug print + super().__init__(dev_name, edge_type) + self.app = FastAPI() + + # Add CORS middleware with more permissive settings for development + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # More permissive for development + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], + ) + + self.port = port + self.host = host + BASE_DIR = Path(__file__).resolve().parent + self.templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) + self.streams = streams + self.active_streams = {} + self.stream_locks = {key: Lock() for key in self.streams} + self.stream_queues = {} + self.stream_disposables = {} + + # Initialize text streams + self.text_streams = text_streams or {} + self.text_queues = {} + self.text_disposables = {} + self.text_clients = set() + + # Create a Subject for text queries + self.query_subject = rx.subject.Subject() + self.query_stream = self.query_subject.pipe(ops.share()) + self.audio_subject = audio_subject + + for key in self.streams: + if self.streams[key] is not None: + self.active_streams[key] = self.streams[key].pipe( + ops.map(self.process_frame_fastapi), ops.share() + ) + + # Set up text stream subscriptions + for key, stream in self.text_streams.items(): + if stream is not None: + self.text_queues[key] = Queue(maxsize=100) + disposable = stream.subscribe( + lambda text, k=key: self.text_queues[k].put(text) if text is not None else None, + lambda e, k=key: self.text_queues[k].put(None), + lambda k=key: self.text_queues[k].put(None), + ) + self.text_disposables[key] = disposable + self.disposables.add(disposable) + + print("Setting up routes...") # Debug print + self.setup_routes() + print("FastAPIServer initialization complete") # Debug print + + def process_frame_fastapi(self, frame): + """Convert frame to JPEG format for streaming.""" + _, buffer = cv2.imencode(".jpg", frame) + return buffer.tobytes() + + def stream_generator(self, key): + """Generate frames for a given video stream.""" + + def generate(): + if key not in self.stream_queues: + self.stream_queues[key] = Queue(maxsize=10) + + frame_queue = self.stream_queues[key] + + # Clear any existing disposable for this stream + if key in self.stream_disposables: + self.stream_disposables[key].dispose() + + disposable = SingleAssignmentDisposable() + self.stream_disposables[key] = disposable + self.disposables.add(disposable) + + if key in self.active_streams: + with self.stream_locks[key]: + # Clear the queue before starting new subscription + while not frame_queue.empty(): + try: + frame_queue.get_nowait() + except Empty: + break + + disposable.disposable = self.active_streams[key].subscribe( + lambda frame: frame_queue.put(frame) if frame is not None else None, + lambda e: frame_queue.put(None), + lambda: frame_queue.put(None), + ) + + try: + while True: + try: + frame = frame_queue.get(timeout=1) + if frame is None: + break + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") + except Empty: + # Instead of breaking, continue waiting for new frames + continue + finally: + if key in self.stream_disposables: + self.stream_disposables[key].dispose() + + return generate + + def create_video_feed_route(self, key): + """Create a video feed route for a specific stream.""" + + async def video_feed(): + return StreamingResponse( + self.stream_generator(key)(), media_type="multipart/x-mixed-replace; boundary=frame" + ) + + return video_feed + + async def text_stream_generator(self, key): + """Generate SSE events for text stream.""" + client_id = id(object()) + self.text_clients.add(client_id) + + try: + while True: + if key not in self.text_queues: + yield {"event": "ping", "data": ""} + await asyncio.sleep(0.1) + continue + + try: + text = self.text_queues[key].get_nowait() + if text is not None: + yield {"event": "message", "id": key, "data": text} + else: + break + except Empty: + yield {"event": "ping", "data": ""} + await asyncio.sleep(0.1) + finally: + self.text_clients.remove(client_id) + + @staticmethod + def _decode_audio(raw: bytes) -> tuple[np.ndarray, int]: + """Convert the webm/opus blob sent by the browser into mono 16-kHz PCM.""" + try: + # Use ffmpeg to convert to 16-kHz mono 16-bit PCM WAV in memory + out, _ = ( + ffmpeg.input("pipe:0") + .output( + "pipe:1", + format="wav", + acodec="pcm_s16le", + ac=1, + ar="16000", + loglevel="quiet", + ) + .run(input=raw, capture_stdout=True, capture_stderr=True) + ) + # Load with soundfile (returns float32 by default) + audio, sr = sf.read(io.BytesIO(out), dtype="float32") + # Ensure 1-D array (mono) + if audio.ndim > 1: + audio = audio[:, 0] + return np.array(audio), sr + except Exception as exc: + print(f"ffmpeg decoding failed: {exc}") + return None, None + + def setup_routes(self): + """Set up FastAPI routes.""" + + @self.app.get("/streams") + async def get_streams(): + """Get list of available video streams""" + return {"streams": list(self.streams.keys())} + + @self.app.get("/text_streams") + async def get_text_streams(): + """Get list of available text streams""" + return {"streams": list(self.text_streams.keys())} + + @self.app.get("/", response_class=HTMLResponse) + async def index(request: Request): + stream_keys = list(self.streams.keys()) + text_stream_keys = list(self.text_streams.keys()) + return self.templates.TemplateResponse( + "index_fastapi.html", + { + "request": request, + "stream_keys": stream_keys, + "text_stream_keys": text_stream_keys, + "has_voice": self.audio_subject is not None, + }, + ) + + @self.app.post("/submit_query") + async def submit_query(query: str = Form(...)): + # Using Form directly as a dependency ensures proper form handling + try: + if query: + # Emit the query through our Subject + self.query_subject.on_next(query) + return JSONResponse({"success": True, "message": "Query received"}) + return JSONResponse({"success": False, "message": "No query provided"}) + except Exception as e: + # Ensure we always return valid JSON even on error + return JSONResponse( + status_code=500, + content={"success": False, "message": f"Server error: {str(e)}"}, + ) + + @self.app.post("/upload_audio") + async def upload_audio(file: UploadFile = File(...)): + """Handle audio upload from the browser.""" + if self.audio_subject is None: + return JSONResponse( + status_code=400, + content={"success": False, "message": "Voice input not configured"}, + ) + + try: + data = await file.read() + audio_np, sr = self._decode_audio(data) + if audio_np is None: + return JSONResponse( + status_code=400, + content={"success": False, "message": "Unable to decode audio"}, + ) + + event = AudioEvent( + data=audio_np, + sample_rate=sr, + timestamp=time.time(), + channels=1 if audio_np.ndim == 1 else audio_np.shape[1], + ) + + # Push to reactive stream + self.audio_subject.on_next(event) + print(f"Received audio – {event.data.shape[0] / sr:.2f} s, {sr} Hz") + return {"success": True} + except Exception as e: + print(f"Failed to process uploaded audio: {e}") + return JSONResponse(status_code=500, content={"success": False, "message": str(e)}) + + # Unitree API endpoints + @self.app.get("/unitree/status") + async def unitree_status(): + """Check the status of the Unitree API server""" + return JSONResponse({"status": "online", "service": "unitree"}) + + @self.app.post("/unitree/command") + async def unitree_command(request: Request): + """Process commands sent from the terminal frontend""" + try: + data = await request.json() + command_text = data.get("command", "") + + # Emit the command through the query_subject + self.query_subject.on_next(command_text) + + response = { + "success": True, + "command": command_text, + "result": f"Processed command: {command_text}", + } + + return JSONResponse(response) + except Exception as e: + print(f"Error processing command: {str(e)}") + return JSONResponse( + status_code=500, + content={"success": False, "message": f"Error processing command: {str(e)}"}, + ) + + @self.app.get("/text_stream/{key}") + async def text_stream(key: str): + if key not in self.text_streams: + raise HTTPException(status_code=404, detail=f"Text stream '{key}' not found") + return EventSourceResponse(self.text_stream_generator(key)) + + for key in self.streams: + self.app.get(f"/video_feed/{key}")(self.create_video_feed_route(key)) + + def run(self): + """Run the FastAPI server.""" + uvicorn.run( + self.app, host=self.host, port=self.port + ) # TODO: Translate structure to enable in-built workers' + + +if __name__ == "__main__": + server = FastAPIServer() + server.run() diff --git a/build/lib/dimos/web/edge_io.py b/build/lib/dimos/web/edge_io.py new file mode 100644 index 0000000000..8511df2ce3 --- /dev/null +++ b/build/lib/dimos/web/edge_io.py @@ -0,0 +1,26 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from reactivex.disposable import CompositeDisposable + + +class EdgeIO: + def __init__(self, dev_name: str = "NA", edge_type: str = "Base"): + self.dev_name = dev_name + self.edge_type = edge_type + self.disposables = CompositeDisposable() + + def dispose_all(self): + """Disposes of all active subscriptions managed by this agent.""" + self.disposables.dispose() diff --git a/build/lib/dimos/web/fastapi_server.py b/build/lib/dimos/web/fastapi_server.py new file mode 100644 index 0000000000..7dcd0f6d73 --- /dev/null +++ b/build/lib/dimos/web/fastapi_server.py @@ -0,0 +1,224 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Working FastAPI/Uvicorn Impl. + +# Notes: Do not use simultaneously with Flask, this includes imports. +# Workers are not yet setup, as this requires a much more intricate +# reorganization. There appears to be possible signalling issues when +# opening up streams on multiple windows/reloading which will need to +# be fixed. Also note, Chrome only supports 6 simultaneous web streams, +# and its advised to test threading/worker performance with another +# browser like Safari. + +# Fast Api & Uvicorn +import cv2 +from dimos.web.edge_io import EdgeIO +from fastapi import FastAPI, Request, Form, HTTPException +from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse +from sse_starlette.sse import EventSourceResponse +from fastapi.templating import Jinja2Templates +import uvicorn +from threading import Lock +from pathlib import Path +from queue import Queue, Empty +import asyncio + +from reactivex.disposable import SingleAssignmentDisposable +from reactivex import operators as ops +import reactivex as rx + +# TODO: Resolve threading, start/stop stream functionality. + + +class FastAPIServer(EdgeIO): + def __init__( + self, + dev_name="FastAPI Server", + edge_type="Bidirectional", + host="0.0.0.0", + port=5555, + text_streams=None, + **streams, + ): + super().__init__(dev_name, edge_type) + self.app = FastAPI() + self.port = port + self.host = host + BASE_DIR = Path(__file__).resolve().parent + self.templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) + self.streams = streams + self.active_streams = {} + self.stream_locks = {key: Lock() for key in self.streams} + self.stream_queues = {} + self.stream_disposables = {} + + # Initialize text streams + self.text_streams = text_streams or {} + self.text_queues = {} + self.text_disposables = {} + self.text_clients = set() + + # Create a Subject for text queries + self.query_subject = rx.subject.Subject() + self.query_stream = self.query_subject.pipe(ops.share()) + + for key in self.streams: + if self.streams[key] is not None: + self.active_streams[key] = self.streams[key].pipe( + ops.map(self.process_frame_fastapi), ops.share() + ) + + # Set up text stream subscriptions + for key, stream in self.text_streams.items(): + if stream is not None: + self.text_queues[key] = Queue(maxsize=100) + disposable = stream.subscribe( + lambda text, k=key: self.text_queues[k].put(text) if text is not None else None, + lambda e, k=key: self.text_queues[k].put(None), + lambda k=key: self.text_queues[k].put(None), + ) + self.text_disposables[key] = disposable + self.disposables.add(disposable) + + self.setup_routes() + + def process_frame_fastapi(self, frame): + """Convert frame to JPEG format for streaming.""" + _, buffer = cv2.imencode(".jpg", frame) + return buffer.tobytes() + + def stream_generator(self, key): + """Generate frames for a given video stream.""" + + def generate(): + if key not in self.stream_queues: + self.stream_queues[key] = Queue(maxsize=10) + + frame_queue = self.stream_queues[key] + + # Clear any existing disposable for this stream + if key in self.stream_disposables: + self.stream_disposables[key].dispose() + + disposable = SingleAssignmentDisposable() + self.stream_disposables[key] = disposable + self.disposables.add(disposable) + + if key in self.active_streams: + with self.stream_locks[key]: + # Clear the queue before starting new subscription + while not frame_queue.empty(): + try: + frame_queue.get_nowait() + except Empty: + break + + disposable.disposable = self.active_streams[key].subscribe( + lambda frame: frame_queue.put(frame) if frame is not None else None, + lambda e: frame_queue.put(None), + lambda: frame_queue.put(None), + ) + + try: + while True: + try: + frame = frame_queue.get(timeout=1) + if frame is None: + break + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") + except Empty: + # Instead of breaking, continue waiting for new frames + continue + finally: + if key in self.stream_disposables: + self.stream_disposables[key].dispose() + + return generate + + def create_video_feed_route(self, key): + """Create a video feed route for a specific stream.""" + + async def video_feed(): + return StreamingResponse( + self.stream_generator(key)(), media_type="multipart/x-mixed-replace; boundary=frame" + ) + + return video_feed + + async def text_stream_generator(self, key): + """Generate SSE events for text stream.""" + client_id = id(object()) + self.text_clients.add(client_id) + + try: + while True: + if key in self.text_queues: + try: + text = self.text_queues[key].get(timeout=1) + if text is not None: + yield {"event": "message", "id": key, "data": text} + except Empty: + # Send a keep-alive comment + yield {"event": "ping", "data": ""} + await asyncio.sleep(0.1) + finally: + self.text_clients.remove(client_id) + + def setup_routes(self): + """Set up FastAPI routes.""" + + @self.app.get("/", response_class=HTMLResponse) + async def index(request: Request): + stream_keys = list(self.streams.keys()) + text_stream_keys = list(self.text_streams.keys()) + return self.templates.TemplateResponse( + "index_fastapi.html", + { + "request": request, + "stream_keys": stream_keys, + "text_stream_keys": text_stream_keys, + }, + ) + + @self.app.post("/submit_query") + async def submit_query(query: str = Form(...)): + # Using Form directly as a dependency ensures proper form handling + try: + if query: + # Emit the query through our Subject + self.query_subject.on_next(query) + return JSONResponse({"success": True, "message": "Query received"}) + return JSONResponse({"success": False, "message": "No query provided"}) + except Exception as e: + # Ensure we always return valid JSON even on error + return JSONResponse( + status_code=500, + content={"success": False, "message": f"Server error: {str(e)}"}, + ) + + @self.app.get("/text_stream/{key}") + async def text_stream(key: str): + if key not in self.text_streams: + raise HTTPException(status_code=404, detail=f"Text stream '{key}' not found") + return EventSourceResponse(self.text_stream_generator(key)) + + for key in self.streams: + self.app.get(f"/video_feed/{key}")(self.create_video_feed_route(key)) + + def run(self): + """Run the FastAPI server.""" + uvicorn.run( + self.app, host=self.host, port=self.port + ) # TODO: Translate structure to enable in-built workers' diff --git a/build/lib/dimos/web/flask_server.py b/build/lib/dimos/web/flask_server.py new file mode 100644 index 0000000000..01d79f63cd --- /dev/null +++ b/build/lib/dimos/web/flask_server.py @@ -0,0 +1,95 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from flask import Flask, Response, render_template +import cv2 +from reactivex import operators as ops +from reactivex.disposable import SingleAssignmentDisposable +from queue import Queue + +from dimos.web.edge_io import EdgeIO + + +class FlaskServer(EdgeIO): + def __init__(self, dev_name="Flask Server", edge_type="Bidirectional", port=5555, **streams): + super().__init__(dev_name, edge_type) + self.app = Flask(__name__) + self.port = port + self.streams = streams + self.active_streams = {} + + # Initialize shared stream references with ref_count + for key in self.streams: + if self.streams[key] is not None: + # Apply share and ref_count to manage subscriptions + self.active_streams[key] = self.streams[key].pipe( + ops.map(self.process_frame_flask), ops.share() + ) + + self.setup_routes() + + def process_frame_flask(self, frame): + """Convert frame to JPEG format for streaming.""" + _, buffer = cv2.imencode(".jpg", frame) + return buffer.tobytes() + + def setup_routes(self): + @self.app.route("/") + def index(): + stream_keys = list(self.streams.keys()) # Get the keys from the streams dictionary + return render_template("index_flask.html", stream_keys=stream_keys) + + # Function to create a streaming response + def stream_generator(key): + def generate(): + frame_queue = Queue() + disposable = SingleAssignmentDisposable() + + # Subscribe to the shared, ref-counted stream + if key in self.active_streams: + disposable.disposable = self.active_streams[key].subscribe( + lambda frame: frame_queue.put(frame) if frame is not None else None, + lambda e: frame_queue.put(None), + lambda: frame_queue.put(None), + ) + + try: + while True: + frame = frame_queue.get() + if frame is None: + break + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") + finally: + disposable.dispose() + + return generate + + def make_response_generator(key): + def response_generator(): + return Response( + stream_generator(key)(), mimetype="multipart/x-mixed-replace; boundary=frame" + ) + + return response_generator + + # Dynamically adding routes using add_url_rule + for key in self.streams: + endpoint = f"video_feed_{key}" + self.app.add_url_rule( + f"/video_feed/{key}", endpoint, view_func=make_response_generator(key) + ) + + def run(self, host="0.0.0.0", port=5555, threaded=True): + self.port = port + self.app.run(host=host, port=self.port, debug=False, threaded=threaded) diff --git a/build/lib/dimos/web/robot_web_interface.py b/build/lib/dimos/web/robot_web_interface.py new file mode 100644 index 0000000000..33847c0056 --- /dev/null +++ b/build/lib/dimos/web/robot_web_interface.py @@ -0,0 +1,35 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Robot Web Interface wrapper for DIMOS. +Provides a clean interface to the dimensional-interface FastAPI server. +""" + +from dimos.web.dimos_interface.api.server import FastAPIServer + + +class RobotWebInterface(FastAPIServer): + """Wrapper class for the dimos-interface FastAPI server.""" + + def __init__(self, port=5555, text_streams=None, audio_subject=None, **streams): + super().__init__( + dev_name="Robot Web Interface", + edge_type="Bidirectional", + host="0.0.0.0", + port=port, + text_streams=text_streams, + audio_subject=audio_subject, + **streams, + ) diff --git a/build/lib/tests/__init__.py b/build/lib/tests/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/build/lib/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/build/lib/tests/agent_manip_flow_fastapi_test.py b/build/lib/tests/agent_manip_flow_fastapi_test.py new file mode 100644 index 0000000000..c7dec66f74 --- /dev/null +++ b/build/lib/tests/agent_manip_flow_fastapi_test.py @@ -0,0 +1,153 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module initializes and manages the video processing pipeline integrated with a web server. +It handles video capture, frame processing, and exposes the processed video streams via HTTP endpoints. +""" + +import tests.test_header +import os + +# ----- + +# Standard library imports +import multiprocessing +from dotenv import load_dotenv + +# Third-party imports +from fastapi import FastAPI +from reactivex import operators as ops +from reactivex.disposable import CompositeDisposable +from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler, ImmediateScheduler + +# Local application imports +from dimos.agents.agent import OpenAIAgent +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.video_operators import VideoOperators as vops +from dimos.stream.video_provider import VideoProvider +from dimos.web.fastapi_server import FastAPIServer + +# Load environment variables +load_dotenv() + + +def main(): + """ + Initializes and runs the video processing pipeline with web server output. + + This function orchestrates a video processing system that handles capture, processing, + and visualization of video streams. It demonstrates parallel processing capabilities + and various video manipulation techniques across multiple stages including capture + and processing at different frame rates, edge detection, and optical flow analysis. + + Raises: + RuntimeError: If video sources are unavailable or processing fails. + """ + disposables = CompositeDisposable() + + processor = FrameProcessor( + output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True + ) + + optimal_thread_count = multiprocessing.cpu_count() # Gets number of CPU cores + thread_pool_scheduler = ThreadPoolScheduler(optimal_thread_count) + + VIDEO_SOURCES = [ + f"{os.getcwd()}/assets/ldru.mp4", + f"{os.getcwd()}/assets/ldru_480p.mp4", + f"{os.getcwd()}/assets/trimmed_video_480p.mov", + f"{os.getcwd()}/assets/video-f30-480p.mp4", + "rtsp://192.168.50.207:8080/h264.sdp", + "rtsp://10.0.0.106:8080/h264.sdp", + ] + + VIDEO_SOURCE_INDEX = 3 + VIDEO_SOURCE_INDEX_2 = 2 + + my_video_provider = VideoProvider("Video File", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX]) + my_video_provider_2 = VideoProvider( + "Video File 2", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX_2] + ) + + video_stream_obs = my_video_provider.capture_video_as_observable(fps=120).pipe( + ops.subscribe_on(thread_pool_scheduler), + # Move downstream operations to thread pool for parallel processing + # Disabled: Evaluating performance impact + # ops.observe_on(thread_pool_scheduler), + vops.with_jpeg_export(processor, suffix="raw"), + vops.with_fps_sampling(fps=30), + vops.with_jpeg_export(processor, suffix="raw_slowed"), + ) + + video_stream_obs_2 = my_video_provider_2.capture_video_as_observable(fps=120).pipe( + ops.subscribe_on(thread_pool_scheduler), + # Move downstream operations to thread pool for parallel processing + # Disabled: Evaluating performance impact + # ops.observe_on(thread_pool_scheduler), + vops.with_jpeg_export(processor, suffix="raw_2"), + vops.with_fps_sampling(fps=30), + vops.with_jpeg_export(processor, suffix="raw_2_slowed"), + ) + + edge_detection_stream_obs = processor.process_stream_edge_detection(video_stream_obs).pipe( + vops.with_jpeg_export(processor, suffix="edge"), + ) + + optical_flow_relevancy_stream_obs = processor.process_stream_optical_flow_with_relevancy( + video_stream_obs + ) + + optical_flow_stream_obs = optical_flow_relevancy_stream_obs.pipe( + ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), + vops.with_optical_flow_filtering(threshold=2.0), + ops.do_action(lambda _: print(f"Optical Flow Passed Threshold.")), + vops.with_jpeg_export(processor, suffix="optical"), + ) + + # + # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== + # + + # Agent 1 + # my_agent = OpenAIAgent( + # "Agent 1", + # query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.") + # my_agent.subscribe_to_image_processing(slowed_video_stream_obs) + # disposables.add(my_agent.disposables) + + # # Agent 2 + # my_agent_two = OpenAIAgent( + # "Agent 2", + # query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.") + # my_agent_two.subscribe_to_image_processing(optical_flow_stream_obs) + # disposables.add(my_agent_two.disposables) + + # + # ====== Create and start the FastAPI server ====== + # + + # Will be visible at http://[host]:[port]/video_feed/[key] + streams = { + "video_one": video_stream_obs, + "video_two": video_stream_obs_2, + "edge_detection": edge_detection_stream_obs, + "optical_flow": optical_flow_stream_obs, + } + fast_api_server = FastAPIServer(port=5555, **streams) + fast_api_server.run() + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/agent_manip_flow_flask_test.py b/build/lib/tests/agent_manip_flow_flask_test.py new file mode 100644 index 0000000000..2356eb74ae --- /dev/null +++ b/build/lib/tests/agent_manip_flow_flask_test.py @@ -0,0 +1,195 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module initializes and manages the video processing pipeline integrated with a web server. +It handles video capture, frame processing, and exposes the processed video streams via HTTP endpoints. +""" + +import tests.test_header +import os + +# ----- + +# Standard library imports +import multiprocessing +from dotenv import load_dotenv + +# Third-party imports +from flask import Flask +from reactivex import operators as ops +from reactivex import of, interval, zip +from reactivex.disposable import CompositeDisposable +from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler, ImmediateScheduler + +# Local application imports +from dimos.agents.agent import PromptBuilder, OpenAIAgent +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.video_operators import VideoOperators as vops +from dimos.stream.video_provider import VideoProvider +from dimos.web.flask_server import FlaskServer + +# Load environment variables +load_dotenv() + +app = Flask(__name__) + + +def main(): + """ + Initializes and runs the video processing pipeline with web server output. + + This function orchestrates a video processing system that handles capture, processing, + and visualization of video streams. It demonstrates parallel processing capabilities + and various video manipulation techniques across multiple stages including capture + and processing at different frame rates, edge detection, and optical flow analysis. + + Raises: + RuntimeError: If video sources are unavailable or processing fails. + """ + disposables = CompositeDisposable() + + processor = FrameProcessor( + output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True + ) + + optimal_thread_count = multiprocessing.cpu_count() # Gets number of CPU cores + thread_pool_scheduler = ThreadPoolScheduler(optimal_thread_count) + + VIDEO_SOURCES = [ + f"{os.getcwd()}/assets/ldru.mp4", + f"{os.getcwd()}/assets/ldru_480p.mp4", + f"{os.getcwd()}/assets/trimmed_video_480p.mov", + f"{os.getcwd()}/assets/video-f30-480p.mp4", + f"{os.getcwd()}/assets/video.mov", + "rtsp://192.168.50.207:8080/h264.sdp", + "rtsp://10.0.0.106:8080/h264.sdp", + f"{os.getcwd()}/assets/people_1080p_24fps.mp4", + ] + + VIDEO_SOURCE_INDEX = 4 + + my_video_provider = VideoProvider("Video File", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX]) + + video_stream_obs = my_video_provider.capture_video_as_observable(fps=120).pipe( + ops.subscribe_on(thread_pool_scheduler), + # Move downstream operations to thread pool for parallel processing + # Disabled: Evaluating performance impact + # ops.observe_on(thread_pool_scheduler), + # vops.with_jpeg_export(processor, suffix="raw"), + vops.with_fps_sampling(fps=30), + # vops.with_jpeg_export(processor, suffix="raw_slowed"), + ) + + edge_detection_stream_obs = processor.process_stream_edge_detection(video_stream_obs).pipe( + # vops.with_jpeg_export(processor, suffix="edge"), + ) + + optical_flow_relevancy_stream_obs = processor.process_stream_optical_flow(video_stream_obs) + + optical_flow_stream_obs = optical_flow_relevancy_stream_obs.pipe( + # ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), + # vops.with_optical_flow_filtering(threshold=2.0), + # ops.do_action(lambda _: print(f"Optical Flow Passed Threshold.")), + # vops.with_jpeg_export(processor, suffix="optical") + ) + + # + # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== + # + + # Observable that emits every 2 seconds + secondly_emission = interval(2, scheduler=thread_pool_scheduler).pipe( + ops.map(lambda x: f"Second {x + 1}"), + # ops.take(30) + ) + + # Agent 1 + my_agent = OpenAIAgent( + "Agent 1", + query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.", + json_mode=False, + ) + + # Create an agent for each subset of questions that it would be theroized to handle. + # Set std. template/blueprints, and devs will add to that likely. + + ai_1_obs = video_stream_obs.pipe( + # vops.with_fps_sampling(fps=30), + # ops.throttle_first(1), + vops.with_jpeg_export(processor, suffix="open_ai_agent_1"), + ops.take(30), + ops.replay(buffer_size=30, scheduler=thread_pool_scheduler), + ) + ai_1_obs.connect() + + ai_1_repeat_obs = ai_1_obs.pipe(ops.repeat()) + + my_agent.subscribe_to_image_processing(ai_1_obs) + disposables.add(my_agent.disposables) + + # Agent 2 + my_agent_two = OpenAIAgent( + "Agent 2", + query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.", + max_input_tokens_per_request=1000, + max_output_tokens_per_request=300, + json_mode=False, + model_name="gpt-4o-2024-08-06", + ) + + ai_2_obs = optical_flow_stream_obs.pipe( + # vops.with_fps_sampling(fps=30), + # ops.throttle_first(1), + vops.with_jpeg_export(processor, suffix="open_ai_agent_2"), + ops.take(30), + ops.replay(buffer_size=30, scheduler=thread_pool_scheduler), + ) + ai_2_obs.connect() + + ai_2_repeat_obs = ai_2_obs.pipe(ops.repeat()) + + # Combine emissions using zip + ai_1_secondly_repeating_obs = zip(secondly_emission, ai_1_repeat_obs).pipe( + # ops.do_action(lambda s: print(f"AI 1 - Emission Count: {s[0]}")), + ops.map(lambda r: r[1]), + ) + + # Combine emissions using zip + ai_2_secondly_repeating_obs = zip(secondly_emission, ai_2_repeat_obs).pipe( + # ops.do_action(lambda s: print(f"AI 2 - Emission Count: {s[0]}")), + ops.map(lambda r: r[1]), + ) + + my_agent_two.subscribe_to_image_processing(ai_2_obs) + disposables.add(my_agent_two.disposables) + + # + # ====== Create and start the Flask server ====== + # + + # Will be visible at http://[host]:[port]/video_feed/[key] + flask_server = FlaskServer( + # video_one=video_stream_obs, + # edge_detection=edge_detection_stream_obs, + # optical_flow=optical_flow_stream_obs, + OpenAIAgent_1=ai_1_secondly_repeating_obs, + OpenAIAgent_2=ai_2_secondly_repeating_obs, + ) + + flask_server.run(threaded=True) + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/agent_memory_test.py b/build/lib/tests/agent_memory_test.py new file mode 100644 index 0000000000..b662af18bd --- /dev/null +++ b/build/lib/tests/agent_memory_test.py @@ -0,0 +1,61 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +# ----- + +from dotenv import load_dotenv +import os + +load_dotenv() + +from dimos.agents.memory.chroma_impl import OpenAISemanticMemory + +agent_memory = OpenAISemanticMemory() +print("Initialization done.") + +agent_memory.add_vector("id0", "Food") +agent_memory.add_vector("id1", "Cat") +agent_memory.add_vector("id2", "Mouse") +agent_memory.add_vector("id3", "Bike") +agent_memory.add_vector("id4", "Dog") +agent_memory.add_vector("id5", "Tricycle") +agent_memory.add_vector("id6", "Car") +agent_memory.add_vector("id7", "Horse") +agent_memory.add_vector("id8", "Vehicle") +agent_memory.add_vector("id6", "Red") +agent_memory.add_vector("id7", "Orange") +agent_memory.add_vector("id8", "Yellow") +print("Adding vectors done.") + +print(agent_memory.get_vector("id1")) +print("Done retrieving sample vector.") + +results = agent_memory.query("Colors") +print(results) +print("Done querying agent memory (basic).") + +results = agent_memory.query("Colors", similarity_threshold=0.2) +print(results) +print("Done querying agent memory (similarity_threshold=0.2).") + +results = agent_memory.query("Colors", n_results=2) +print(results) +print("Done querying agent memory (n_results=2).") + +results = agent_memory.query("Colors", n_results=19, similarity_threshold=0.45) +print(results) +print("Done querying agent memory (n_results=19, similarity_threshold=0.45).") diff --git a/build/lib/tests/colmap_test.py b/build/lib/tests/colmap_test.py new file mode 100644 index 0000000000..e1f451a7dc --- /dev/null +++ b/build/lib/tests/colmap_test.py @@ -0,0 +1,25 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os +import sys + +# ----- + +# Now try to import +from dimos.environment.colmap_environment import COLMAPEnvironment + +env = COLMAPEnvironment() +env.initialize_from_video("data/IMG_1525.MOV", "data/frames") diff --git a/build/lib/tests/run.py b/build/lib/tests/run.py new file mode 100644 index 0000000000..9ae6f81398 --- /dev/null +++ b/build/lib/tests/run.py @@ -0,0 +1,361 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +import time +from dotenv import load_dotenv +from dimos.agents.cerebras_agent import CerebrasAgent +from dimos.agents.claude_agent import ClaudeAgent +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 + +# from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.websocket_vis.server import WebsocketVis +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.observe import Observe +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal, Explore +from dimos.skills.visual_navigation_skills import FollowHuman +import reactivex as rx +import reactivex.operators as ops +from dimos.stream.audio.pipelines import tts, stt +import threading +import json +from dimos.types.vector import Vector +from dimos.skills.unitree.unitree_speak import UnitreeSpeak + +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.utils.reactive import backpressure +import asyncio +import atexit +import signal +import sys +import warnings +import logging + +# Filter out known WebRTC warnings that don't affect functionality +warnings.filterwarnings("ignore", message="coroutine.*was never awaited") +warnings.filterwarnings("ignore", message=".*RTCSctpTransport.*") + +# Set up logging to reduce asyncio noise +logging.getLogger("asyncio").setLevel(logging.ERROR) + +# Load API key from environment +load_dotenv() + +# Allow command line arguments to control spatial memory parameters +import argparse + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Run the robot with optional spatial memory parameters" + ) + parser.add_argument( + "--new-memory", action="store_true", help="Create a new spatial memory from scratch" + ) + parser.add_argument( + "--spatial-memory-dir", type=str, help="Directory for storing spatial memory data" + ) + return parser.parse_args() + + +args = parse_arguments() + +# Initialize robot with spatial memory parameters - using WebRTC mode instead of "ai" +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + mode="normal", +) + + +# Add graceful shutdown handling to prevent WebRTC task destruction errors +def cleanup_robot(): + print("Cleaning up robot connection...") + try: + # Make cleanup non-blocking to avoid hangs + def quick_cleanup(): + try: + robot.liedown() + except: + pass + + # Run cleanup in a separate thread with timeout + cleanup_thread = threading.Thread(target=quick_cleanup) + cleanup_thread.daemon = True + cleanup_thread.start() + cleanup_thread.join(timeout=3.0) # Max 3 seconds for cleanup + + # Force stop the robot's WebRTC connection + try: + robot.stop() + except: + pass + + except Exception as e: + print(f"Error during cleanup: {e}") + # Continue anyway + + +atexit.register(cleanup_robot) + + +def signal_handler(signum, frame): + print("Received shutdown signal, cleaning up...") + try: + cleanup_robot() + except: + pass + # Force exit if cleanup hangs + os._exit(0) + + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +# Initialize WebSocket visualization +websocket_vis = WebsocketVis() +websocket_vis.start() +websocket_vis.connect(robot.global_planner.vis_stream()) + + +def msg_handler(msgtype, data): + if msgtype == "click": + print(f"Received click at position: {data['position']}") + + try: + print("Setting goal...") + + # Instead of disabling visualization, make it timeout-safe + original_vis = robot.global_planner.vis + + def safe_vis(name, drawable): + """Visualization wrapper that won't block on timeouts""" + try: + # Use a separate thread for visualization to avoid blocking + def vis_update(): + try: + original_vis(name, drawable) + except Exception as e: + print(f"Visualization update failed (non-critical): {e}") + + vis_thread = threading.Thread(target=vis_update) + vis_thread.daemon = True + vis_thread.start() + # Don't wait for completion - let it run asynchronously + except Exception as e: + print(f"Visualization setup failed (non-critical): {e}") + + robot.global_planner.vis = safe_vis + robot.global_planner.set_goal(Vector(data["position"])) + robot.global_planner.vis = original_vis + + print("Goal set successfully") + except Exception as e: + print(f"Error setting goal: {e}") + import traceback + + traceback.print_exc() + + +def threaded_msg_handler(msgtype, data): + print(f"Processing message: {msgtype}") + + # Create a dedicated event loop for goal setting to avoid conflicts + def run_with_dedicated_loop(): + try: + # Use asyncio.run which creates and manages its own event loop + # This won't conflict with the robot's or websocket's event loops + async def async_msg_handler(): + msg_handler(msgtype, data) + + asyncio.run(async_msg_handler()) + print("Goal setting completed successfully") + except Exception as e: + print(f"Error in goal setting thread: {e}") + import traceback + + traceback.print_exc() + + thread = threading.Thread(target=run_with_dedicated_loop) + thread.daemon = True + thread.start() + + +websocket_vis.msg_handler = threaded_msg_handler + + +def newmap(msg): + return ["costmap", robot.map.costmap.smudge()] + + +websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) +websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) +local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) +audio_subject = rx.subject.Subject() + +# Initialize object detection stream +min_confidence = 0.6 +class_filter = None # No class filtering + +# Create video stream from robot's camera +video_stream = backpressure(robot.get_video_stream()) # WebRTC doesn't use ROS video stream + +# # Initialize ObjectDetectionStream with robot +object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + class_filter=class_filter, + get_pose=robot.get_pose, + video_stream=video_stream, + draw_masks=True, +) + +# # Create visualization stream for web interface +viz_stream = backpressure(object_detector.get_stream()).pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), +) + +# # Get the formatted detection stream +formatted_detection_stream = object_detector.get_formatted_stream().pipe( + ops.filter(lambda x: x is not None) +) + + +# Create a direct mapping that combines detection data with locations +def combine_with_locations(object_detections): + # Get locations from spatial memory + try: + spatial_memory = robot.get_spatial_memory() + if spatial_memory is None: + # If spatial memory is disabled, just return the object detections + return object_detections + + locations = spatial_memory.get_robot_locations() + + # Format the locations section + locations_text = "\n\nSaved Robot Locations:\n" + if locations: + for loc in locations: + locations_text += f"- {loc.name}: Position ({loc.position[0]:.2f}, {loc.position[1]:.2f}, {loc.position[2]:.2f}), " + locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" + else: + locations_text += "None\n" + + # Simply concatenate the strings + return object_detections + locations_text + except Exception as e: + print(f"Error adding locations: {e}") + return object_detections + + +# Create the combined stream with a simple pipe operation +enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) + +streams = { + "unitree_video": robot.get_video_stream(), # Changed from get_ros_video_stream to get_video_stream for WebRTC + "local_planner_viz": local_planner_viz_stream, + "object_detection": viz_stream, # Uncommented object detection +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface( + port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams +) + +stt_node = stt() +stt_node.consume_audio(audio_subject.pipe(ops.share())) + +# Read system query from prompt.txt file +with open( + os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets/agent/prompt.txt"), "r" +) as f: + system_query = f.read() + +# Create a ClaudeAgent instance +agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=stt_node.emit_text(), + # input_query_stream=web_interface.query_stream, + input_data_stream=enhanced_data_stream, + skills=robot.get_skills(), + system_query=system_query, + model_name="claude-3-5-haiku-latest", + thinking_budget_tokens=0, + max_output_tokens_per_request=8192, + # model_name="llama-4-scout-17b-16e-instruct", +) + +# tts_node = tts() +# tts_node.consume_text(agent.get_response_observable()) + +robot_skills = robot.get_skills() +robot_skills.add(ObserveStream) +robot_skills.add(Observe) +robot_skills.add(KillSkill) +robot_skills.add(NavigateWithText) +# robot_skills.add(FollowHuman) # TODO: broken +robot_skills.add(GetPose) +robot_skills.add(UnitreeSpeak) # Re-enable Speak skill +robot_skills.add(NavigateToGoal) +robot_skills.add(Explore) + +robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("Observe", robot=robot, agent=agent) +robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +robot_skills.create_instance("NavigateWithText", robot=robot) +# robot_skills.create_instance("FollowHuman", robot=robot) +robot_skills.create_instance("GetPose", robot=robot) +robot_skills.create_instance("NavigateToGoal", robot=robot) +robot_skills.create_instance("Explore", robot=robot) +robot_skills.create_instance("UnitreeSpeak", robot=robot) # Now only needs robot instance + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +print("ObserveStream and Kill skills registered and ready for use") +print("Created memory.txt file") + +# Start web interface in a separate thread to avoid blocking +web_thread = threading.Thread(target=web_interface.run) +web_thread.daemon = True +web_thread.start() + +try: + while True: + # Main loop - can add robot movement or other logic here + time.sleep(0.01) + +except KeyboardInterrupt: + print("Stopping robot") + robot.liedown() +except Exception as e: + print(f"Unexpected error in main loop: {e}") + import traceback + + traceback.print_exc() +finally: + print("Cleaning up...") + cleanup_robot() diff --git a/build/lib/tests/run_go2_ros.py b/build/lib/tests/run_go2_ros.py new file mode 100644 index 0000000000..6bba1c1797 --- /dev/null +++ b/build/lib/tests/run_go2_ros.py @@ -0,0 +1,178 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +import os +import time + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2, WebRTCConnectionMethod +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl + + +def get_env_var(var_name, default=None, required=False): + """Get environment variable with validation.""" + value = os.getenv(var_name, default) + if value == "": + value = default + if required and not value: + raise ValueError(f"{var_name} environment variable is required") + return value + + +if __name__ == "__main__": + # Get configuration from environment variables + robot_ip = get_env_var("ROBOT_IP") + connection_method = get_env_var("CONNECTION_METHOD", "LocalSTA") + serial_number = get_env_var("SERIAL_NUMBER", None) + output_dir = get_env_var("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + print(f"Ensuring output directory exists: {output_dir}") + + use_ros = True + use_webrtc = False + # Convert connection method string to enum + connection_method = getattr(WebRTCConnectionMethod, connection_method) + + print("Initializing UnitreeGo2...") + print(f"Configuration:") + print(f" IP: {robot_ip}") + print(f" Connection Method: {connection_method}") + print(f" Serial Number: {serial_number if serial_number else 'Not provided'}") + print(f" Output Directory: {output_dir}") + + if use_ros: + ros_control = UnitreeROSControl(node_name="unitree_go2", use_raw=True) + else: + ros_control = None + + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + serial_number=serial_number, + output_dir=output_dir, + ros_control=ros_control, + use_ros=use_ros, + use_webrtc=use_webrtc, + ) + time.sleep(5) + try: + # Start perception + print("\nStarting perception system...") + + # Get the processed stream + processed_stream = robot.get_ros_video_stream(fps=30) + + # Create frame counter for unique filenames + frame_count = 0 + + # Create a subscriber to handle the frames + def handle_frame(frame): + global frame_count + frame_count += 1 + + try: + # Save frame to output directory if desired for debugging frame streaming + # MAKE SURE TO CHANGE OUTPUT DIR depending on if running in ROS or local + # frame_path = os.path.join(output_dir, f"frame_{frame_count:04d}.jpg") + # success = cv2.imwrite(frame_path, frame) + # print(f"Frame #{frame_count} {'saved successfully' if success else 'failed to save'} to {frame_path}") + pass + + except Exception as e: + print(f"Error in handle_frame: {e}") + import traceback + + print(traceback.format_exc()) + + def handle_error(error): + print(f"Error in stream: {error}") + + def handle_completion(): + print("Stream completed") + + # Subscribe to the stream + print("Creating subscription...") + try: + subscription = processed_stream.subscribe( + on_next=handle_frame, + on_error=lambda e: print(f"Subscription error: {e}"), + on_completed=lambda: print("Subscription completed"), + ) + print("Subscription created successfully") + except Exception as e: + print(f"Error creating subscription: {e}") + + time.sleep(5) + + # First put the robot in a good starting state + print("Running recovery stand...") + robot.webrtc_req(api_id=1006) # RecoveryStand + + # Queue 20 WebRTC requests back-to-back + print("\n🤖 QUEUEING WEBRTC COMMANDS BACK-TO-BACK FOR TESTING UnitreeGo2🤖\n") + + # Dance 1 + robot.webrtc_req(api_id=1033) + print("Queued: WiggleHips (1033)") + + robot.reverse(distance=0.2, speed=0.5) + print("Queued: Reverse 0.5m at 0.5m/s") + + # Wiggle Hips + robot.webrtc_req(api_id=1033) + print("Queued: WiggleHips (1033)") + + robot.move(distance=0.2, speed=0.5) + print("Queued: Move forward 1.0m at 0.5m/s") + + robot.webrtc_req(api_id=1017) + print("Queued: Stretch (1017)") + + robot.move(distance=0.2, speed=0.5) + print("Queued: Move forward 1.0m at 0.5m/s") + + robot.webrtc_req(api_id=1017) + print("Queued: Stretch (1017)") + + robot.reverse(distance=0.2, speed=0.5) + print("Queued: Reverse 0.5m at 0.5m/s") + + robot.webrtc_req(api_id=1017) + print("Queued: Stretch (1017)") + robot.spin(degrees=-90.0, speed=45.0) + print("Queued: Spin right 90 degrees at 45 degrees/s") + + robot.spin(degrees=90.0, speed=45.0) + print("Queued: Spin left 90 degrees at 45 degrees/s") + + # To prevent termination + while True: + time.sleep(0.1) + + except KeyboardInterrupt: + print("\nStopping perception...") + if "subscription" in locals(): + subscription.dispose() + except Exception as e: + print(f"Error in main loop: {e}") + finally: + # Cleanup + print("Cleaning up resources...") + if "subscription" in locals(): + subscription.dispose() + del robot + print("Cleanup complete.") diff --git a/build/lib/tests/run_navigation_only.py b/build/lib/tests/run_navigation_only.py new file mode 100644 index 0000000000..2995750e2b --- /dev/null +++ b/build/lib/tests/run_navigation_only.py @@ -0,0 +1,191 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dotenv import load_dotenv +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream +from dimos.web.websocket_vis.server import WebsocketVis +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.types.vector import Vector +import reactivex.operators as ops +import time +import threading +import asyncio +import atexit +import signal +import sys +import warnings +import logging +# logging.basicConfig(level=logging.DEBUG) + +# Filter out known WebRTC warnings that don't affect functionality +warnings.filterwarnings("ignore", message="coroutine.*was never awaited") +warnings.filterwarnings("ignore", message=".*RTCSctpTransport.*") + +# Set up logging to reduce asyncio noise +logging.getLogger("asyncio").setLevel(logging.ERROR) + +load_dotenv() +robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="normal", enable_perception=False) + + +# Add graceful shutdown handling to prevent WebRTC task destruction errors +def cleanup_robot(): + print("Cleaning up robot connection...") + try: + # Make cleanup non-blocking to avoid hangs + def quick_cleanup(): + try: + robot.liedown() + except: + pass + + # Run cleanup in a separate thread with timeout + cleanup_thread = threading.Thread(target=quick_cleanup) + cleanup_thread.daemon = True + cleanup_thread.start() + cleanup_thread.join(timeout=3.0) # Max 3 seconds for cleanup + + # Force stop the robot's WebRTC connection + try: + robot.stop() + except: + pass + + except Exception as e: + print(f"Error during cleanup: {e}") + # Continue anyway + + +atexit.register(cleanup_robot) + + +def signal_handler(signum, frame): + print("Received shutdown signal, cleaning up...") + try: + cleanup_robot() + except: + pass + # Force exit if cleanup hangs + os._exit(0) + + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +websocket_vis = WebsocketVis() +websocket_vis.start() +websocket_vis.connect(robot.global_planner.vis_stream()) + + +def msg_handler(msgtype, data): + if msgtype == "click": + print(f"Received click at position: {data['position']}") + + try: + print("Setting goal...") + + # Instead of disabling visualization, make it timeout-safe + original_vis = robot.global_planner.vis + + def safe_vis(name, drawable): + """Visualization wrapper that won't block on timeouts""" + try: + # Use a separate thread for visualization to avoid blocking + def vis_update(): + try: + original_vis(name, drawable) + except Exception as e: + print(f"Visualization update failed (non-critical): {e}") + + vis_thread = threading.Thread(target=vis_update) + vis_thread.daemon = True + vis_thread.start() + # Don't wait for completion - let it run asynchronously + except Exception as e: + print(f"Visualization setup failed (non-critical): {e}") + + robot.global_planner.vis = safe_vis + robot.global_planner.set_goal(Vector(data["position"])) + robot.global_planner.vis = original_vis + + print("Goal set successfully") + except Exception as e: + print(f"Error setting goal: {e}") + import traceback + + traceback.print_exc() + + +def threaded_msg_handler(msgtype, data): + print(f"Processing message: {msgtype}") + + # Create a dedicated event loop for goal setting to avoid conflicts + def run_with_dedicated_loop(): + try: + # Use asyncio.run which creates and manages its own event loop + # This won't conflict with the robot's or websocket's event loops + async def async_msg_handler(): + msg_handler(msgtype, data) + + asyncio.run(async_msg_handler()) + print("Goal setting completed successfully") + except Exception as e: + print(f"Error in goal setting thread: {e}") + import traceback + + traceback.print_exc() + + thread = threading.Thread(target=run_with_dedicated_loop) + thread.daemon = True + thread.start() + + +websocket_vis.msg_handler = threaded_msg_handler + +print("standing up") +robot.standup() +print("robot is up") + + +def newmap(msg): + return ["costmap", robot.map.costmap.smudge()] + + +websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) +websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) + +local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) + +# Add RobotWebInterface with video stream +streams = {"unitree_video": robot.get_video_stream(), "local_planner_viz": local_planner_viz_stream} +web_interface = RobotWebInterface(port=5555, **streams) +web_interface.run() + +try: + while True: + # robot.move_vel(Vector(0.1, 0.1, 0.1)) + time.sleep(0.01) + +except KeyboardInterrupt: + print("Stopping robot") + robot.liedown() +except Exception as e: + print(f"Unexpected error in main loop: {e}") + import traceback + + traceback.print_exc() +finally: + print("Cleaning up...") + cleanup_robot() diff --git a/build/lib/tests/simple_agent_test.py b/build/lib/tests/simple_agent_test.py new file mode 100644 index 0000000000..2534eac31b --- /dev/null +++ b/build/lib/tests/simple_agent_test.py @@ -0,0 +1,39 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.agents.agent import OpenAIAgent +import os + +# Initialize robot +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() +) + +# Initialize agent +agent = OpenAIAgent( + dev_name="UnitreeExecutionAgent", + input_video_stream=robot.get_ros_video_stream(), + skills=robot.get_skills(), + system_query="Wiggle when you see a person! Jump when you see a person waving!", +) + +try: + input("Press ESC to exit...") +except KeyboardInterrupt: + print("\nExiting...") diff --git a/build/lib/tests/test_agent.py b/build/lib/tests/test_agent.py new file mode 100644 index 0000000000..e2c8f89f8e --- /dev/null +++ b/build/lib/tests/test_agent.py @@ -0,0 +1,60 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import tests.test_header + +# ----- + +from dotenv import load_dotenv + + +# Sanity check for dotenv +def test_dotenv(): + print("test_dotenv:") + load_dotenv() + openai_api_key = os.getenv("OPENAI_API_KEY") + print("\t\tOPENAI_API_KEY: ", openai_api_key) + + +# Sanity check for openai connection +def test_openai_connection(): + from openai import OpenAI + + client = OpenAI() + print("test_openai_connection:") + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + }, + ], + } + ], + max_tokens=300, + ) + print("\t\tOpenAI Response: ", response.choices[0]) + + +test_dotenv() +test_openai_connection() diff --git a/build/lib/tests/test_agent_alibaba.py b/build/lib/tests/test_agent_alibaba.py new file mode 100644 index 0000000000..9519387b7b --- /dev/null +++ b/build/lib/tests/test_agent_alibaba.py @@ -0,0 +1,59 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +import os +from dimos.agents.agent import OpenAIAgent +from openai import OpenAI +from dimos.stream.video_provider import VideoProvider +from dimos.utils.threadpool import get_scheduler +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills + +# Initialize video stream +video_stream = VideoProvider( + dev_name="VideoProvider", + # video_source=f"{os.getcwd()}/assets/framecount.mp4", + video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", + pool_scheduler=get_scheduler(), +).capture_video_as_observable(realtime=False, fps=1) + +# Specify the OpenAI client for Alibaba +qwen_client = OpenAI( + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key=os.getenv("ALIBABA_API_KEY"), +) + +# Initialize Unitree skills +myUnitreeSkills = MyUnitreeSkills() +myUnitreeSkills.initialize_skills() + +# Initialize agent +agent = OpenAIAgent( + dev_name="AlibabaExecutionAgent", + openai_client=qwen_client, + model_name="qwen2.5-vl-72b-instruct", + tokenizer=HuggingFaceTokenizer(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), + max_output_tokens_per_request=8192, + input_video_stream=video_stream, + # system_query="Tell me the number in the video. Find me the center of the number spotted, and print the coordinates to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", + system_query="Tell me about any objects seen. Print the coordinates for center of the objects seen to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", + skills=myUnitreeSkills, +) + +try: + input("Press ESC to exit...") +except KeyboardInterrupt: + print("\nExiting...") diff --git a/build/lib/tests/test_agent_ctransformers_gguf.py b/build/lib/tests/test_agent_ctransformers_gguf.py new file mode 100644 index 0000000000..6cd3405239 --- /dev/null +++ b/build/lib/tests/test_agent_ctransformers_gguf.py @@ -0,0 +1,44 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +from dimos.agents.agent_ctransformers_gguf import CTransformersGGUFAgent + +system_query = "You are a robot with the following functions. Move(), Reverse(), Left(), Right(), Stop(). Given the following user comands return the correct function." + +# Initialize agent +agent = CTransformersGGUFAgent( + dev_name="GGUF-Agent", + model_name="TheBloke/Llama-2-7B-GGUF", + model_file="llama-2-7b.Q4_K_M.gguf", + model_type="llama", + system_query=system_query, + gpu_layers=50, + max_input_tokens_per_request=250, + max_output_tokens_per_request=10, +) + +test_query = "User: Travel forward 10 meters" + +agent.run_observable_query(test_query).subscribe( + on_next=lambda response: print(f"One-off query response: {response}"), + on_error=lambda error: print(f"Error: {error}"), + on_completed=lambda: print("Query completed"), +) + +try: + input("Press ESC to exit...") +except KeyboardInterrupt: + print("\nExiting...") diff --git a/build/lib/tests/test_agent_huggingface_local.py b/build/lib/tests/test_agent_huggingface_local.py new file mode 100644 index 0000000000..4c4536a197 --- /dev/null +++ b/build/lib/tests/test_agent_huggingface_local.py @@ -0,0 +1,72 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.data_provider import QueryDataProvider +import tests.test_header + +import os +from dimos.stream.video_provider import VideoProvider +from dimos.utils.threadpool import get_scheduler +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.agents.agent_huggingface_local import HuggingFaceLocalAgent +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills + +# Initialize video stream +video_stream = VideoProvider( + dev_name="VideoProvider", + # video_source=f"{os.getcwd()}/assets/framecount.mp4", + video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", + pool_scheduler=get_scheduler(), +).capture_video_as_observable(realtime=False, fps=1) + +# Initialize Unitree skills +myUnitreeSkills = MyUnitreeSkills() +myUnitreeSkills.initialize_skills() + +# Initialize query stream +query_provider = QueryDataProvider() + +system_query = "You are a robot with the following functions. Move(), Reverse(), Left(), Right(), Stop(). Given the following user comands return ONLY the correct function." + +# Initialize agent +agent = HuggingFaceLocalAgent( + dev_name="HuggingFaceLLMAgent", + model_name="Qwen/Qwen2.5-3B", + agent_type="HF-LLM", + system_query=system_query, + input_query_stream=query_provider.data_stream, + process_all_inputs=False, + max_input_tokens_per_request=250, + max_output_tokens_per_request=20, + # output_dir=self.output_dir, + # skills=skills_instance, + # frame_processor=frame_processor, +) + +# Start the query stream. +# Queries will be pushed every 1 second, in a count from 100 to 5000. +# This will cause listening agents to consume the queries and respond +# to them via skill execution and provide 1-shot responses. +query_provider.start_query_stream( + query_template="{query}; User: travel forward by 10 meters", + frequency=10, + start_count=1, + end_count=10000, + step=1, +) + +try: + input("Press ESC to exit...") +except KeyboardInterrupt: + print("\nExiting...") diff --git a/build/lib/tests/test_agent_huggingface_local_jetson.py b/build/lib/tests/test_agent_huggingface_local_jetson.py new file mode 100644 index 0000000000..6d29b3903f --- /dev/null +++ b/build/lib/tests/test_agent_huggingface_local_jetson.py @@ -0,0 +1,73 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.data_provider import QueryDataProvider +import tests.test_header + +import os +from dimos.stream.video_provider import VideoProvider +from dimos.utils.threadpool import get_scheduler +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.agents.agent_huggingface_local import HuggingFaceLocalAgent +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills + +# Initialize video stream +video_stream = VideoProvider( + dev_name="VideoProvider", + # video_source=f"{os.getcwd()}/assets/framecount.mp4", + video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", + pool_scheduler=get_scheduler(), +).capture_video_as_observable(realtime=False, fps=1) + +# Initialize Unitree skills +myUnitreeSkills = MyUnitreeSkills() +myUnitreeSkills.initialize_skills() + +# Initialize query stream +query_provider = QueryDataProvider() + +system_query = "You are a helpful assistant." + +# Initialize agent +agent = HuggingFaceLocalAgent( + dev_name="HuggingFaceLLMAgent", + model_name="Qwen/Qwen2.5-0.5B", + # model_name="HuggingFaceTB/SmolLM2-135M", + agent_type="HF-LLM", + system_query=system_query, + input_query_stream=query_provider.data_stream, + process_all_inputs=False, + max_input_tokens_per_request=250, + max_output_tokens_per_request=20, + # output_dir=self.output_dir, + # skills=skills_instance, + # frame_processor=frame_processor, +) + +# Start the query stream. +# Queries will be pushed every 1 second, in a count from 100 to 5000. +# This will cause listening agents to consume the queries and respond +# to them via skill execution and provide 1-shot responses. +query_provider.start_query_stream( + query_template="{query}; User: Hello how are you!", + frequency=30, + start_count=1, + end_count=10000, + step=1, +) + +try: + input("Press ESC to exit...") +except KeyboardInterrupt: + print("\nExiting...") diff --git a/build/lib/tests/test_agent_huggingface_remote.py b/build/lib/tests/test_agent_huggingface_remote.py new file mode 100644 index 0000000000..7129523bf0 --- /dev/null +++ b/build/lib/tests/test_agent_huggingface_remote.py @@ -0,0 +1,64 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.data_provider import QueryDataProvider +import tests.test_header + +import os +from dimos.stream.video_provider import VideoProvider +from dimos.utils.threadpool import get_scheduler +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.agents.agent_huggingface_remote import HuggingFaceRemoteAgent +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills + +# Initialize video stream +# video_stream = VideoProvider( +# dev_name="VideoProvider", +# # video_source=f"{os.getcwd()}/assets/framecount.mp4", +# video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", +# pool_scheduler=get_scheduler(), +# ).capture_video_as_observable(realtime=False, fps=1) + +# Initialize Unitree skills +# myUnitreeSkills = MyUnitreeSkills() +# myUnitreeSkills.initialize_skills() + +# Initialize query stream +query_provider = QueryDataProvider() + +# Initialize agent +agent = HuggingFaceRemoteAgent( + dev_name="HuggingFaceRemoteAgent", + model_name="meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer=HuggingFaceTokenizer(model_name="meta-llama/Meta-Llama-3-8B-Instruct"), + max_output_tokens_per_request=8192, + input_query_stream=query_provider.data_stream, + # input_video_stream=video_stream, + system_query="You are a helpful assistant that can answer questions and help with tasks.", +) + +# Start the query stream. +# Queries will be pushed every 1 second, in a count from 100 to 5000. +query_provider.start_query_stream( + query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response.", + frequency=5, + start_count=1, + end_count=10000, + step=1, +) + +try: + input("Press ESC to exit...") +except KeyboardInterrupt: + print("\nExiting...") diff --git a/build/lib/tests/test_audio_agent.py b/build/lib/tests/test_audio_agent.py new file mode 100644 index 0000000000..6caf24b9eb --- /dev/null +++ b/build/lib/tests/test_audio_agent.py @@ -0,0 +1,39 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.audio.utils import keepalive +from dimos.stream.audio.pipelines import tts, stt +from dimos.utils.threadpool import get_scheduler +from dimos.agents.agent import OpenAIAgent + + +def main(): + stt_node = stt() + + agent = OpenAIAgent( + dev_name="UnitreeExecutionAgent", + input_query_stream=stt_node.emit_text(), + system_query="You are a helpful robot named daneel that does my bidding", + pool_scheduler=get_scheduler(), + ) + + tts_node = tts() + tts_node.consume_text(agent.get_response_observable()) + + # Keep the main thread alive + keepalive() + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_audio_robot_agent.py b/build/lib/tests/test_audio_robot_agent.py new file mode 100644 index 0000000000..411e4a56c1 --- /dev/null +++ b/build/lib/tests/test_audio_robot_agent.py @@ -0,0 +1,51 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.utils.threadpool import get_scheduler +import os +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.agents.agent import OpenAIAgent +from dimos.stream.audio.pipelines import tts, stt +from dimos.stream.audio.utils import keepalive + + +def main(): + stt_node = stt() + tts_node = tts() + + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + + # Initialize agent with main thread pool scheduler + agent = OpenAIAgent( + dev_name="UnitreeExecutionAgent", + input_query_stream=stt_node.emit_text(), + system_query="You are a helpful robot named daneel that does my bidding", + pool_scheduler=get_scheduler(), + skills=robot.get_skills(), + ) + + tts_node.consume_text(agent.get_response_observable()) + + # Keep the main thread alive + keepalive() + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_cerebras_unitree_ros.py b/build/lib/tests/test_cerebras_unitree_ros.py new file mode 100644 index 0000000000..cbb7c130db --- /dev/null +++ b/build/lib/tests/test_cerebras_unitree_ros.py @@ -0,0 +1,118 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from dimos.robot.robot import MockRobot +import tests.test_header + +import time +from dotenv import load_dotenv +from dimos.agents.cerebras_agent import CerebrasAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal +from dimos.skills.visual_navigation_skills import FollowHuman +import reactivex as rx +import reactivex.operators as ops +from dimos.stream.audio.pipelines import tts, stt +from dimos.web.websocket_vis.server import WebsocketVis +import threading +from dimos.types.vector import Vector +from dimos.skills.speak import Speak + +# Load API key from environment +load_dotenv() + +# robot = MockRobot() +robot_skills = MyUnitreeSkills() + +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=robot_skills, + mock_connection=False, + new_memory=True, +) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) + +streams = { + "unitree_video": robot.get_ros_video_stream(), +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface( + port=5555, + text_streams=text_streams, + **streams, +) + +stt_node = stt() + +# Create a CerebrasAgent instance +agent = CerebrasAgent( + dev_name="test_cerebras_agent", + input_query_stream=stt_node.emit_text(), + # input_query_stream=web_interface.query_stream, + skills=robot_skills, + system_query="""You are an agent controlling a virtual robot. When given a query, respond by using the appropriate tool calls if needed to execute commands on the robot. + +IMPORTANT INSTRUCTIONS: +1. Each tool call must include the exact function name and appropriate parameters +2. If a function needs parameters like 'distance' or 'angle', be sure to include them +3. If you're unsure which tool to use, choose the most appropriate one based on the user's query +4. Parse the user's instructions carefully to determine correct parameter values + +When you need to call a skill or tool, ALWAYS respond ONLY with a JSON object in this exact format: {"name": "SkillName", "arguments": {"arg1": "value1", "arg2": "value2"}} + +Example: If the user asks to spin right by 90 degrees, output ONLY the following: {"name": "SpinRight", "arguments": {"degrees": 90}}""", + model_name="llama-4-scout-17b-16e-instruct", +) + +tts_node = tts() +tts_node.consume_text(agent.get_response_observable()) + +robot_skills.add(ObserveStream) +robot_skills.add(KillSkill) +robot_skills.add(NavigateWithText) +robot_skills.add(FollowHuman) +robot_skills.add(GetPose) +robot_skills.add(Speak) +robot_skills.add(NavigateToGoal) +robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +robot_skills.create_instance("NavigateWithText", robot=robot) +robot_skills.create_instance("FollowHuman", robot=robot) +robot_skills.create_instance("GetPose", robot=robot) +robot_skills.create_instance("NavigateToGoal", robot=robot) + + +robot_skills.create_instance("Speak", tts_node=tts_node) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +# print(f"Registered skills: {', '.join([skill.__name__ for skill in robot_skills.skills])}") +print("Cerebras agent demo initialized. You can now interact with the agent via the web interface.") + +web_interface.run() diff --git a/build/lib/tests/test_claude_agent_query.py b/build/lib/tests/test_claude_agent_query.py new file mode 100644 index 0000000000..aabd85bc12 --- /dev/null +++ b/build/lib/tests/test_claude_agent_query.py @@ -0,0 +1,29 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +from dotenv import load_dotenv +from dimos.agents.claude_agent import ClaudeAgent + +# Load API key from environment +load_dotenv() + +# Create a ClaudeAgent instance +agent = ClaudeAgent(dev_name="test_agent", query="What is the capital of France?") + +# Use the stream_query method to get a response +response = agent.run_observable_query("What is the capital of France?").run() + +print(f"Response from Claude Agent: {response}") diff --git a/build/lib/tests/test_claude_agent_skills_query.py b/build/lib/tests/test_claude_agent_skills_query.py new file mode 100644 index 0000000000..1aaeb795f1 --- /dev/null +++ b/build/lib/tests/test_claude_agent_skills_query.py @@ -0,0 +1,135 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +import time +from dotenv import load_dotenv +from dimos.agents.claude_agent import ClaudeAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import Navigate, BuildSemanticMap, GetPose, NavigateToGoal +from dimos.skills.visual_navigation_skills import NavigateToObject, FollowHuman +import reactivex as rx +import reactivex.operators as ops +from dimos.stream.audio.pipelines import tts, stt +from dimos.web.websocket_vis.server import WebsocketVis +import threading +from dimos.types.vector import Vector +from dimos.skills.speak import Speak + +# Load API key from environment +load_dotenv() + +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + mock_connection=False, +) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) +local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) + +streams = { + "unitree_video": robot.get_ros_video_stream(), + "local_planner_viz": local_planner_viz_stream, +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + +stt_node = stt() + +# Create a ClaudeAgent instance +agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=stt_node.emit_text(), + # input_query_stream=web_interface.query_stream, + skills=robot.get_skills(), + system_query="""You are an agent controlling a virtual robot. When given a query, respond by using the appropriate tool calls if needed to execute commands on the robot. + +IMPORTANT INSTRUCTIONS: +1. Each tool call must include the exact function name and appropriate parameters +2. If a function needs parameters like 'distance' or 'angle', be sure to include them +3. If you're unsure which tool to use, choose the most appropriate one based on the user's query +4. Parse the user's instructions carefully to determine correct parameter values + +Example: If the user asks to move forward 1 meter, call the Move function with distance=1""", + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=2000, +) + +tts_node = tts() +# tts_node.consume_text(agent.get_response_observable()) + +robot_skills = robot.get_skills() +robot_skills.add(ObserveStream) +robot_skills.add(KillSkill) +robot_skills.add(Navigate) +robot_skills.add(BuildSemanticMap) +robot_skills.add(NavigateToObject) +robot_skills.add(FollowHuman) +robot_skills.add(GetPose) +robot_skills.add(Speak) +robot_skills.add(NavigateToGoal) +robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +robot_skills.create_instance("Navigate", robot=robot) +robot_skills.create_instance("BuildSemanticMap", robot=robot) +robot_skills.create_instance("NavigateToObject", robot=robot) +robot_skills.create_instance("FollowHuman", robot=robot) +robot_skills.create_instance("GetPose", robot=robot) +robot_skills.create_instance("NavigateToGoal", robot=robot) +robot_skills.create_instance("Speak", tts_node=tts_node) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +print("ObserveStream and Kill skills registered and ready for use") +print("Created memory.txt file") + +websocket_vis = WebsocketVis() +websocket_vis.start() +websocket_vis.connect(robot.global_planner.vis_stream()) + + +def msg_handler(msgtype, data): + if msgtype == "click": + target = Vector(data["position"]) + try: + robot.global_planner.set_goal(target) + except Exception as e: + print(f"Error setting goal: {e}") + return + + +def threaded_msg_handler(msgtype, data): + thread = threading.Thread(target=msg_handler, args=(msgtype, data)) + thread.daemon = True + thread.start() + + +websocket_vis.msg_handler = threaded_msg_handler + +web_interface.run() diff --git a/build/lib/tests/test_command_pose_unitree.py b/build/lib/tests/test_command_pose_unitree.py new file mode 100644 index 0000000000..22cf0e82ed --- /dev/null +++ b/build/lib/tests/test_command_pose_unitree.py @@ -0,0 +1,82 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +# Add the parent directory to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +import os +import time +import math + +# Initialize robot +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() +) + + +# Helper function to send pose commands continuously for a duration +def send_pose_for_duration(roll, pitch, yaw, duration, hz=10): + """Send the same pose command repeatedly at specified frequency for the given duration""" + start_time = time.time() + while time.time() - start_time < duration: + robot.pose_command(roll=roll, pitch=pitch, yaw=yaw) + time.sleep(1.0 / hz) # Sleep to achieve the desired frequency + + +# Test pose commands + +# First, make sure the robot is in a stable position +print("Setting default pose...") +send_pose_for_duration(0.0, 0.0, 0.0, 1) + +# Test roll angle (lean left/right) +print("Testing roll angle - lean right...") +send_pose_for_duration(0.5, 0.0, 0.0, 1.5) # Lean right + +print("Testing roll angle - lean left...") +send_pose_for_duration(-0.5, 0.0, 0.0, 1.5) # Lean left + +# Test pitch angle (lean forward/backward) +print("Testing pitch angle - lean forward...") +send_pose_for_duration(0.0, 0.5, 0.0, 1.5) # Lean forward + +print("Testing pitch angle - lean backward...") +send_pose_for_duration(0.0, -0.5, 0.0, 1.5) # Lean backward + +# Test yaw angle (rotate body without moving feet) +print("Testing yaw angle - rotate clockwise...") +send_pose_for_duration(0.0, 0.0, 0.5, 1.5) # Rotate body clockwise + +print("Testing yaw angle - rotate counterclockwise...") +send_pose_for_duration(0.0, 0.0, -0.5, 1.5) # Rotate body counterclockwise + +# Reset to default pose +print("Resetting to default pose...") +send_pose_for_duration(0.0, 0.0, 0.0, 2) + +print("Pose command test completed") + +# Keep the program running (optional) +print("Press Ctrl+C to exit") +try: + while True: + time.sleep(1) +except KeyboardInterrupt: + print("Test terminated by user") diff --git a/build/lib/tests/test_header.py b/build/lib/tests/test_header.py new file mode 100644 index 0000000000..48ea6dd509 --- /dev/null +++ b/build/lib/tests/test_header.py @@ -0,0 +1,58 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test utilities for identifying caller information and path setup. + +This module provides functionality to determine which file called the current +script and sets up the Python path to include the parent directory, allowing +tests to import from the main application. +""" + +import sys +import os +import inspect + +# Add the parent directory of 'tests' to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def get_caller_info(): + """Identify the filename of the caller in the stack. + + Examines the call stack to find the first non-internal file that called + this module. Skips the current file and Python internal files. + + Returns: + str: The basename of the caller's filename, or "unknown" if not found. + """ + current_file = os.path.abspath(__file__) + + # Look through the call stack to find the first file that's not this one + for frame in inspect.stack()[1:]: + filename = os.path.abspath(frame.filename) + # Skip this file and Python internals + if filename != current_file and " 0: + best_score = max(grasp.get("score", 0.0) for grasp in grasps) + print(f" Best grasp score: {best_score:.3f}") + last_grasp_count = current_count + last_update_time = current_time + else: + # Show periodic "still waiting" message + if current_time - last_update_time > 10.0: + print(f" Still waiting for grasps... ({time.strftime('%H:%M:%S')})") + last_update_time = current_time + + time.sleep(1.0) # Check every second + + except Exception as e: + print(f" Error in grasp monitor: {e}") + time.sleep(2.0) + + +def main(): + """Test point cloud filtering with grasp generation using ManipulationPipeline.""" + print(" Testing point cloud filtering + grasp generation with ManipulationPipeline...") + + # Configuration + min_confidence = 0.6 + web_port = 5555 + grasp_server_url = "ws://18.224.39.74:8000/ws/grasp" + + try: + # Initialize ZED camera stream + zed_stream = ZEDCameraStream(resolution=sl.RESOLUTION.HD1080, fps=10) + + # Get camera intrinsics + camera_intrinsics_dict = zed_stream.get_camera_info() + camera_intrinsics = [ + camera_intrinsics_dict["fx"], + camera_intrinsics_dict["fy"], + camera_intrinsics_dict["cx"], + camera_intrinsics_dict["cy"], + ] + + # Create the concurrent manipulation pipeline WITH grasp generation + pipeline = ManipulationPipeline( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + max_objects=10, + grasp_server_url=grasp_server_url, + enable_grasp_generation=True, # Enable grasp generation + ) + + # Create ZED stream + zed_frame_stream = zed_stream.create_stream().pipe(ops.share()) + + # Create concurrent processing streams + streams = pipeline.create_streams(zed_frame_stream) + detection_viz_stream = streams["detection_viz"] + pointcloud_viz_stream = streams["pointcloud_viz"] + grasps_stream = streams.get("grasps") # Get grasp stream if available + grasp_overlay_stream = streams.get("grasp_overlay") # Get grasp overlay stream if available + + except ImportError: + print("Error: ZED SDK not installed. Please install pyzed package.") + sys.exit(1) + except RuntimeError as e: + print(f"Error: Failed to open ZED camera: {e}") + sys.exit(1) + + try: + # Set up web interface with concurrent visualization streams + print("Initializing web interface...") + web_interface = RobotWebInterface( + port=web_port, + object_detection=detection_viz_stream, + pointcloud_stream=pointcloud_viz_stream, + grasp_overlay_stream=grasp_overlay_stream, + ) + + # Start grasp monitoring in background thread + grasp_monitor_thread = threading.Thread( + target=monitor_grasps, args=(pipeline,), daemon=True + ) + grasp_monitor_thread.start() + + print(f"\n Point Cloud + Grasp Generation Test Running:") + print(f" Web Interface: http://localhost:{web_port}") + print(f" Object Detection View: RGB with bounding boxes") + print(f" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print(f" Confidence threshold: {min_confidence}") + print(f" Grasp server: {grasp_server_url}") + print(f" Available streams: {list(streams.keys())}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Cleaning up resources...") + if "zed_stream" in locals(): + zed_stream.cleanup() + if "pipeline" in locals(): + pipeline.cleanup() + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_manipulation_pipeline_single_frame.py b/build/lib/tests/test_manipulation_pipeline_single_frame.py new file mode 100644 index 0000000000..fa7187f948 --- /dev/null +++ b/build/lib/tests/test_manipulation_pipeline_single_frame.py @@ -0,0 +1,248 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test manipulation processor with direct visualization and grasp data output.""" + +import os +import sys +import cv2 +import numpy as np +import time +import argparse +import matplotlib + +# Try to use TkAgg backend for live display, fallback to Agg if not available +try: + matplotlib.use("TkAgg") +except: + try: + matplotlib.use("Qt5Agg") + except: + matplotlib.use("Agg") # Fallback to non-interactive +import matplotlib.pyplot as plt +import open3d as o3d +from typing import Dict, List + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid +from dimos.manipulation.manip_aio_processer import ManipulationProcessor +from dimos.perception.pointcloud.utils import ( + load_camera_matrix_from_yaml, + visualize_pcd, + combine_object_pointclouds, +) +from dimos.utils.logging_config import setup_logger + +from dimos.perception.grasp_generation.utils import visualize_grasps_3d, create_grasp_overlay + +logger = setup_logger("test_pipeline_viz") + + +def load_first_frame(data_dir: str): + """Load first RGB-D frame and camera intrinsics.""" + # Load images + color_img = cv2.imread(os.path.join(data_dir, "color", "00000.png")) + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) + + depth_img = cv2.imread(os.path.join(data_dir, "depth", "00000.png"), cv2.IMREAD_ANYDEPTH) + if depth_img.dtype == np.uint16: + depth_img = depth_img.astype(np.float32) / 1000.0 + # Load intrinsics + camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, "color_camera_info.yaml")) + intrinsics = [ + camera_matrix[0, 0], + camera_matrix[1, 1], + camera_matrix[0, 2], + camera_matrix[1, 2], + ] + + return color_img, depth_img, intrinsics + + +def create_point_cloud(color_img, depth_img, intrinsics): + """Create Open3D point cloud.""" + fx, fy, cx, cy = intrinsics + height, width = depth_img.shape + + o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) + color_o3d = o3d.geometry.Image(color_img) + depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False + ) + + return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) + + +def run_processor(color_img, depth_img, intrinsics, grasp_server_url=None): + """Run processor and collect results.""" + processor_kwargs = { + "camera_intrinsics": intrinsics, + "enable_grasp_generation": True, + "enable_segmentation": True, + } + + if grasp_server_url: + processor_kwargs["grasp_server_url"] = grasp_server_url + + processor = ManipulationProcessor(**processor_kwargs) + + # Process frame without grasp generation + results = processor.process_frame(color_img, depth_img, generate_grasps=False) + + # Run grasp generation separately + grasps = processor.run_grasp_generation(results["all_objects"], results["full_pointcloud"]) + results["grasps"] = grasps + results["grasp_overlay"] = create_grasp_overlay(color_img, grasps, intrinsics) + + processor.cleanup() + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-dir", default="assets/rgbd_data") + parser.add_argument("--wait-time", type=float, default=5.0) + parser.add_argument( + "--grasp-server-url", + default="ws://18.224.39.74:8000/ws/grasp", + help="WebSocket URL for AnyGrasp server", + ) + args = parser.parse_args() + + # Load data + color_img, depth_img, intrinsics = load_first_frame(args.data_dir) + logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") + + # Run processor + results = run_processor(color_img, depth_img, intrinsics, args.grasp_server_url) + + # Print results summary + print(f"Processing time: {results.get('processing_time', 0):.3f}s") + print(f"Detection objects: {len(results.get('detected_objects', []))}") + print(f"All objects processed: {len(results.get('all_objects', []))}") + + # Print grasp summary + grasp_data = results["grasps"] + total_grasps = len(grasp_data) if isinstance(grasp_data, list) else 0 + best_score = max(grasp["score"] for grasp in grasp_data) if grasp_data else 0 + + print(f"AnyGrasp grasps: {total_grasps} total (best score: {best_score:.3f})") + + # Create visualizations + plot_configs = [] + if results["detection_viz"] is not None: + plot_configs.append(("detection_viz", "Object Detection")) + if results["segmentation_viz"] is not None: + plot_configs.append(("segmentation_viz", "Semantic Segmentation")) + if results["pointcloud_viz"] is not None: + plot_configs.append(("pointcloud_viz", "All Objects Point Cloud")) + if results["detected_pointcloud_viz"] is not None: + plot_configs.append(("detected_pointcloud_viz", "Detection Objects Point Cloud")) + if results["misc_pointcloud_viz"] is not None: + plot_configs.append(("misc_pointcloud_viz", "Misc/Background Points")) + if results["grasp_overlay"] is not None: + plot_configs.append(("grasp_overlay", "Grasp Overlay")) + + # Create subplot layout + num_plots = len(plot_configs) + if num_plots <= 3: + fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5)) + else: + rows = 2 + cols = (num_plots + 1) // 2 + fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) + + if num_plots == 1: + axes = [axes] + elif num_plots > 2: + axes = axes.flatten() + + # Plot each result + for i, (key, title) in enumerate(plot_configs): + axes[i].imshow(results[key]) + axes[i].set_title(title) + axes[i].axis("off") + + # Hide unused subplots + if num_plots > 3: + for i in range(num_plots, len(axes)): + axes[i].axis("off") + + plt.tight_layout() + plt.savefig("manipulation_results.png", dpi=150, bbox_inches="tight") + plt.show(block=True) + plt.close() + + point_clouds = [obj["point_cloud"] for obj in results["all_objects"]] + colors = [obj["color"] for obj in results["all_objects"]] + combined_pcd = combine_object_pointclouds(point_clouds, colors) + + # 3D Grasp visualization + if grasp_data: + # Convert grasp format to visualization format for 3D display + viz_grasps = [] + for grasp in grasp_data: + translation = grasp.get("translation", [0, 0, 0]) + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3).tolist())) + score = grasp.get("score", 0.0) + width = grasp.get("width", 0.08) + + viz_grasp = { + "translation": translation, + "rotation_matrix": rotation_matrix, + "width": width, + "score": score, + } + viz_grasps.append(viz_grasp) + + # Use unified 3D visualization + visualize_grasps_3d(combined_pcd, viz_grasps) + + # Visualize full point cloud + visualize_pcd( + results["full_pointcloud"], + window_name="Full Scene Point Cloud", + point_size=2.0, + show_coordinate_frame=True, + ) + + # Visualize all objects point cloud + visualize_pcd( + combined_pcd, + window_name="All Objects Point Cloud", + point_size=3.0, + show_coordinate_frame=True, + ) + + # Visualize misc clusters + visualize_clustered_point_clouds( + results["misc_clusters"], + window_name="Misc/Background Clusters (DBSCAN)", + point_size=3.0, + show_coordinate_frame=True, + ) + + # Visualize voxel grid + visualize_voxel_grid( + results["misc_voxel_grid"], + window_name="Misc/Background Voxel Grid", + show_coordinate_frame=True, + ) + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_manipulation_pipeline_single_frame_lcm.py b/build/lib/tests/test_manipulation_pipeline_single_frame_lcm.py new file mode 100644 index 0000000000..62898816fa --- /dev/null +++ b/build/lib/tests/test_manipulation_pipeline_single_frame_lcm.py @@ -0,0 +1,431 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test manipulation processor with LCM topic subscription.""" + +import os +import sys +import cv2 +import numpy as np +import time +import argparse +import threading +import pickle +import matplotlib +import json +import copy + +# Try to use TkAgg backend for live display, fallback to Agg if not available +try: + matplotlib.use("TkAgg") +except: + try: + matplotlib.use("Qt5Agg") + except: + matplotlib.use("Agg") # Fallback to non-interactive +import matplotlib.pyplot as plt +import open3d as o3d +from typing import Dict, List, Optional + +# LCM imports +import lcm +from lcm_msgs.sensor_msgs import Image as LCMImage +from lcm_msgs.sensor_msgs import CameraInfo as LCMCameraInfo + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid +from dimos.manipulation.manip_aio_processer import ManipulationProcessor +from dimos.perception.grasp_generation.utils import visualize_grasps_3d +from dimos.perception.pointcloud.utils import visualize_pcd +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("test_pipeline_lcm") + + +class LCMDataCollector: + """Collects one message from each required LCM topic.""" + + def __init__(self, lcm_url: str = "udpm://239.255.76.67:7667?ttl=1"): + self.lcm = lcm.LCM(lcm_url) + + # Data storage + self.rgb_data: Optional[np.ndarray] = None + self.depth_data: Optional[np.ndarray] = None + self.camera_intrinsics: Optional[List[float]] = None + + # Synchronization + self.data_lock = threading.Lock() + self.data_ready_event = threading.Event() + + # Flags to track received messages + self.rgb_received = False + self.depth_received = False + self.camera_info_received = False + + # Subscribe to topics + self.lcm.subscribe("head_cam_rgb#sensor_msgs.Image", self._handle_rgb_message) + self.lcm.subscribe("head_cam_depth#sensor_msgs.Image", self._handle_depth_message) + self.lcm.subscribe("head_cam_info#sensor_msgs.CameraInfo", self._handle_camera_info_message) + + logger.info("LCM Data Collector initialized") + logger.info("Subscribed to topics:") + logger.info(" - head_cam_rgb#sensor_msgs.Image") + logger.info(" - head_cam_depth#sensor_msgs.Image") + logger.info(" - head_cam_info#sensor_msgs.CameraInfo") + + def _handle_rgb_message(self, channel: str, data: bytes): + """Handle RGB image message.""" + if self.rgb_received: + return # Already got one, ignore subsequent messages + + try: + msg = LCMImage.decode(data) + + # Convert message data to numpy array + if msg.encoding == "rgb8": + # RGB8 format: 3 bytes per pixel + rgb_array = np.frombuffer(msg.data[: msg.data_length], dtype=np.uint8) + rgb_image = rgb_array.reshape((msg.height, msg.width, 3)) + + with self.data_lock: + self.rgb_data = rgb_image + self.rgb_received = True + logger.info( + f"RGB message received: {msg.width}x{msg.height}, encoding: {msg.encoding}" + ) + self._check_all_data_received() + + else: + logger.warning(f"Unsupported RGB encoding: {msg.encoding}") + + except Exception as e: + logger.error(f"Error processing RGB message: {e}") + + def _handle_depth_message(self, channel: str, data: bytes): + """Handle depth image message.""" + if self.depth_received: + return # Already got one, ignore subsequent messages + + try: + msg = LCMImage.decode(data) + + # Convert message data to numpy array + if msg.encoding == "32FC1": + # 32FC1 format: 4 bytes (float32) per pixel + depth_array = np.frombuffer(msg.data[: msg.data_length], dtype=np.float32) + depth_image = depth_array.reshape((msg.height, msg.width)) + + with self.data_lock: + self.depth_data = depth_image + self.depth_received = True + logger.info( + f"Depth message received: {msg.width}x{msg.height}, encoding: {msg.encoding}" + ) + logger.info( + f"Depth range: {depth_image.min():.3f} - {depth_image.max():.3f} meters" + ) + self._check_all_data_received() + + else: + logger.warning(f"Unsupported depth encoding: {msg.encoding}") + + except Exception as e: + logger.error(f"Error processing depth message: {e}") + + def _handle_camera_info_message(self, channel: str, data: bytes): + """Handle camera info message.""" + if self.camera_info_received: + return # Already got one, ignore subsequent messages + + try: + msg = LCMCameraInfo.decode(data) + + # Extract intrinsics from K matrix: [fx, 0, cx, 0, fy, cy, 0, 0, 1] + K = msg.K + fx = K[0] # K[0,0] + fy = K[4] # K[1,1] + cx = K[2] # K[0,2] + cy = K[5] # K[1,2] + + intrinsics = [fx, fy, cx, cy] + + with self.data_lock: + self.camera_intrinsics = intrinsics + self.camera_info_received = True + logger.info(f"Camera info received: {msg.width}x{msg.height}") + logger.info(f"Intrinsics: fx={fx:.1f}, fy={fy:.1f}, cx={cx:.1f}, cy={cy:.1f}") + self._check_all_data_received() + + except Exception as e: + logger.error(f"Error processing camera info message: {e}") + + def _check_all_data_received(self): + """Check if all required data has been received.""" + if self.rgb_received and self.depth_received and self.camera_info_received: + logger.info("✅ All required data received!") + self.data_ready_event.set() + + def wait_for_data(self, timeout: float = 30.0) -> bool: + """Wait for all data to be received.""" + logger.info("Waiting for RGB, depth, and camera info messages...") + + # Start LCM handling in a separate thread + lcm_thread = threading.Thread(target=self._lcm_handle_loop, daemon=True) + lcm_thread.start() + + # Wait for data with timeout + return self.data_ready_event.wait(timeout) + + def _lcm_handle_loop(self): + """LCM message handling loop.""" + try: + while not self.data_ready_event.is_set(): + self.lcm.handle_timeout(100) # 100ms timeout + except Exception as e: + logger.error(f"Error in LCM handling loop: {e}") + + def get_data(self): + """Get the collected data.""" + with self.data_lock: + return self.rgb_data, self.depth_data, self.camera_intrinsics + + +def create_point_cloud(color_img, depth_img, intrinsics): + """Create Open3D point cloud.""" + fx, fy, cx, cy = intrinsics + height, width = depth_img.shape + + o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) + color_o3d = o3d.geometry.Image(color_img) + depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False + ) + + return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) + + +def run_processor(color_img, depth_img, intrinsics): + """Run processor and collect results.""" + # Create processor + processor = ManipulationProcessor( + camera_intrinsics=intrinsics, + grasp_server_url="ws://18.224.39.74:8000/ws/grasp", + enable_grasp_generation=False, + enable_segmentation=True, + ) + + # Process single frame directly + results = processor.process_frame(color_img, depth_img) + + # Debug: print available results + print(f"Available results: {list(results.keys())}") + + processor.cleanup() + + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lcm-url", default="udpm://239.255.76.67:7667?ttl=1", help="LCM URL for subscription" + ) + parser.add_argument( + "--timeout", type=float, default=30.0, help="Timeout in seconds to wait for messages" + ) + parser.add_argument( + "--save-images", action="store_true", help="Save received RGB and depth images to files" + ) + args = parser.parse_args() + + # Create data collector + collector = LCMDataCollector(args.lcm_url) + + # Wait for data + if not collector.wait_for_data(args.timeout): + logger.error(f"Timeout waiting for data after {args.timeout} seconds") + logger.error("Make sure Unity is running and publishing to the LCM topics") + return + + # Get the collected data + color_img, depth_img, intrinsics = collector.get_data() + + logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") + logger.info(f"Intrinsics: {intrinsics}") + + # Save images if requested + if args.save_images: + try: + cv2.imwrite("received_rgb.png", cv2.cvtColor(color_img, cv2.COLOR_RGB2BGR)) + # Save depth as 16-bit for visualization + depth_viz = (np.clip(depth_img * 1000, 0, 65535)).astype(np.uint16) + cv2.imwrite("received_depth.png", depth_viz) + logger.info("Saved received_rgb.png and received_depth.png") + except Exception as e: + logger.warning(f"Failed to save images: {e}") + + # Run processor + results = run_processor(color_img, depth_img, intrinsics) + + # Debug: Print what we received + print(f"\n✅ Processor Results:") + print(f" Available results: {list(results.keys())}") + print(f" Processing time: {results.get('processing_time', 0):.3f}s") + + # Show timing breakdown if available + if "timing_breakdown" in results: + breakdown = results["timing_breakdown"] + print(f" Timing breakdown:") + print(f" - Detection: {breakdown.get('detection', 0):.3f}s") + print(f" - Segmentation: {breakdown.get('segmentation', 0):.3f}s") + print(f" - Point cloud: {breakdown.get('pointcloud', 0):.3f}s") + print(f" - Misc extraction: {breakdown.get('misc_extraction', 0):.3f}s") + + # Print object information + detected_count = len(results.get("detected_objects", [])) + all_count = len(results.get("all_objects", [])) + + print(f" Detection objects: {detected_count}") + print(f" All objects processed: {all_count}") + + # Print misc clusters information + if "misc_clusters" in results and results["misc_clusters"]: + cluster_count = len(results["misc_clusters"]) + total_misc_points = sum( + len(np.asarray(cluster.points)) for cluster in results["misc_clusters"] + ) + print(f" Misc clusters: {cluster_count} clusters with {total_misc_points} total points") + else: + print(f" Misc clusters: None") + + # Print grasp summary + if "grasps" in results and results["grasps"]: + total_grasps = 0 + best_score = 0 + for grasp in results["grasps"]: + score = grasp.get("score", 0) + if score > best_score: + best_score = score + total_grasps += 1 + print(f" Grasps generated: {total_grasps} (best score: {best_score:.3f})") + else: + print(" Grasps: None generated") + + # Save results to pickle file + pickle_path = "manipulation_results.pkl" + print(f"\nSaving results to pickle file: {pickle_path}") + + def serialize_point_cloud(pcd): + """Convert Open3D PointCloud to serializable format.""" + if pcd is None: + return None + data = { + "points": np.asarray(pcd.points).tolist() if hasattr(pcd, "points") else [], + "colors": np.asarray(pcd.colors).tolist() + if hasattr(pcd, "colors") and pcd.colors + else [], + } + return data + + def serialize_voxel_grid(voxel_grid): + """Convert Open3D VoxelGrid to serializable format.""" + if voxel_grid is None: + return None + + # Extract voxel data + voxels = voxel_grid.get_voxels() + data = { + "voxel_size": voxel_grid.voxel_size, + "origin": np.asarray(voxel_grid.origin).tolist(), + "voxels": [ + ( + v.grid_index[0], + v.grid_index[1], + v.grid_index[2], + v.color[0], + v.color[1], + v.color[2], + ) + for v in voxels + ], + } + return data + + # Create a copy of results with non-picklable objects converted + pickle_data = { + "color_img": color_img, + "depth_img": depth_img, + "intrinsics": intrinsics, + "results": {}, + } + + # Convert and store all results, properly handling Open3D objects + for key, value in results.items(): + if key.endswith("_viz") or key in [ + "processing_time", + "timing_breakdown", + "detection2d_objects", + "segmentation2d_objects", + ]: + # These are already serializable + pickle_data["results"][key] = value + elif key == "full_pointcloud": + # Serialize PointCloud object + pickle_data["results"][key] = serialize_point_cloud(value) + print(f"Serialized {key}") + elif key == "misc_voxel_grid": + # Serialize VoxelGrid object + pickle_data["results"][key] = serialize_voxel_grid(value) + print(f"Serialized {key}") + elif key == "misc_clusters": + # List of PointCloud objects + if value: + serialized_clusters = [serialize_point_cloud(cluster) for cluster in value] + pickle_data["results"][key] = serialized_clusters + print(f"Serialized {key} ({len(serialized_clusters)} clusters)") + elif key == "detected_objects" or key == "all_objects": + # Objects with PointCloud attributes + serialized_objects = [] + for obj in value: + obj_dict = {k: v for k, v in obj.items() if k != "point_cloud"} + if "point_cloud" in obj: + obj_dict["point_cloud"] = serialize_point_cloud(obj.get("point_cloud")) + serialized_objects.append(obj_dict) + pickle_data["results"][key] = serialized_objects + print(f"Serialized {key} ({len(serialized_objects)} objects)") + else: + try: + # Try to pickle as is + pickle_data["results"][key] = value + print(f"Preserved {key} as is") + except (TypeError, ValueError): + print(f"Warning: Could not serialize {key}, skipping") + + with open(pickle_path, "wb") as f: + pickle.dump(pickle_data, f) + + print(f"Results saved successfully with all 3D data serialized!") + print(f"Pickled data keys: {list(pickle_data['results'].keys())}") + + # Visualization code has been moved to visualization_script.py + # The results have been pickled and can be loaded from there + print("\nVisualization code has been moved to visualization_script.py") + print("Run 'python visualization_script.py' to visualize the results") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_move_vel_unitree.py b/build/lib/tests/test_move_vel_unitree.py new file mode 100644 index 0000000000..fe4d09a8e1 --- /dev/null +++ b/build/lib/tests/test_move_vel_unitree.py @@ -0,0 +1,32 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +import os +import time + +# Initialize robot +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() +) + +# Move the robot forward +robot.move_vel(x=0.5, y=0, yaw=0, duration=5) + +while True: + time.sleep(1) diff --git a/build/lib/tests/test_navigate_to_object_robot.py b/build/lib/tests/test_navigate_to_object_robot.py new file mode 100644 index 0000000000..eb2767d6ca --- /dev/null +++ b/build/lib/tests/test_navigate_to_object_robot.py @@ -0,0 +1,137 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys +import argparse +import threading +from reactivex import Subject, operators as RxOps + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.logging_config import logger +from dimos.skills.navigation import Navigate +import tests.test_header + + +def parse_args(): + parser = argparse.ArgumentParser(description="Navigate to an object using Qwen vision.") + parser.add_argument( + "--object", + type=str, + default="chair", + help="Name of the object to navigate to (default: chair)", + ) + parser.add_argument( + "--distance", + type=float, + default=1.0, + help="Desired distance to maintain from object in meters (default: 0.8)", + ) + parser.add_argument( + "--timeout", + type=float, + default=60.0, + help="Maximum navigation time in seconds (default: 30.0)", + ) + return parser.parse_args() + + +def main(): + # Get command line arguments + args = parse_args() + object_name = args.object # Object to navigate to + distance = args.distance # Desired distance to object + timeout = args.timeout # Maximum navigation time + + print(f"Initializing Unitree Go2 robot for navigating to a {object_name}...") + + # Initialize the robot with ROS control and skills + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + + # Add and create instance of NavigateToObject skill + robot_skills = robot.get_skills() + robot_skills.add(Navigate) + robot_skills.create_instance("Navigate", robot=robot) + + # Set up tracking and visualization streams + object_tracking_stream = robot.object_tracking_stream + viz_stream = object_tracking_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x["viz_frame"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + + # The local planner visualization stream is created during robot initialization + local_planner_stream = robot.local_planner_viz_stream + + local_planner_stream = local_planner_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + + try: + # Set up web interface + logger.info("Initializing web interface") + streams = { + # "robot_video": video_stream, + "object_tracking": viz_stream, + "local_planner": local_planner_stream, + } + + web_interface = RobotWebInterface(port=5555, **streams) + + # Wait for camera and tracking to initialize + print("Waiting for camera and tracking to initialize...") + time.sleep(3) + + def navigate_to_object(): + try: + result = robot_skills.call( + "Navigate", robot=robot, query=object_name, timeout=timeout + ) + print(f"Navigation result: {result}") + except Exception as e: + print(f"Error during navigation: {e}") + + navigate_thread = threading.Thread(target=navigate_to_object, daemon=True) + navigate_thread.start() + + print( + f"Navigating to {object_name} with desired distance {distance}m and timeout {timeout}s..." + ) + print("Web interface available at http://localhost:5555") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nInterrupted by user") + except Exception as e: + print(f"Error during navigation test: {e}") + finally: + print("Test completed") + robot.cleanup() + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_navigation_skills.py b/build/lib/tests/test_navigation_skills.py new file mode 100644 index 0000000000..9a91d1aba5 --- /dev/null +++ b/build/lib/tests/test_navigation_skills.py @@ -0,0 +1,269 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple test script for semantic / spatial memory skills. + +This script is a simplified version that focuses only on making the workflow work. + +Usage: + # Build and query in one run: + python simple_navigation_test.py --query "kitchen" + + # Skip build and just query: + python simple_navigation_test.py --skip-build --query "kitchen" +""" + +import os +import sys +import time +import logging +import argparse +import threading +from reactivex import Subject, operators as RxOps +import os + +import tests.test_header + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.navigation import BuildSemanticMap, Navigate +from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface + +# Setup logging +logger = setup_logger("simple_navigation_test") + + +def parse_args(): + spatial_memory_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../assets/spatial_memory_vegas") + ) + + parser = argparse.ArgumentParser(description="Simple test for semantic map skills.") + parser.add_argument( + "--skip-build", + action="store_true", + help="Skip building the map and run navigation with existing semantic and visual memory", + ) + parser.add_argument( + "--query", type=str, default="kitchen", help="Text query for navigation (default: kitchen)" + ) + parser.add_argument( + "--db-path", + type=str, + default=os.path.join(spatial_memory_dir, "chromadb_data"), + help="Path to ChromaDB database", + ) + parser.add_argument("--justgo", type=str, help="Globally navigate to location") + parser.add_argument( + "--visual-memory-dir", + type=str, + default=spatial_memory_dir, + help="Directory for visual memory", + ) + parser.add_argument( + "--visual-memory-file", + type=str, + default="visual_memory.pkl", + help="Filename for visual memory", + ) + parser.add_argument( + "--port", type=int, default=5555, help="Port for web visualization interface" + ) + return parser.parse_args() + + +def build_map(robot, args): + logger.info("Starting to build spatial memory...") + + # Create the BuildSemanticMap skill + build_skill = BuildSemanticMap( + robot=robot, + db_path=args.db_path, + visual_memory_dir=args.visual_memory_dir, + visual_memory_file=args.visual_memory_file, + ) + + # Start the skill + build_skill() + + # Wait for user to press Ctrl+C + logger.info("Press Ctrl+C to stop mapping and proceed to navigation...") + + try: + while True: + time.sleep(0.5) + except KeyboardInterrupt: + logger.info("Stopping map building...") + + # Stop the skill + build_skill.stop() + logger.info("Map building complete.") + + +def query_map(robot, args): + logger.info(f"Querying spatial memory for: '{args.query}'") + + # Create the Navigate skill + nav_skill = Navigate( + robot=robot, + query=args.query, + db_path=args.db_path, + visual_memory_path=os.path.join(args.visual_memory_dir, args.visual_memory_file), + ) + + # Query the map + result = nav_skill() + + # Display the result + if isinstance(result, dict) and result.get("success", False): + position = result.get("position", (0, 0, 0)) + similarity = result.get("similarity", 0) + logger.info(f"Found '{args.query}' at position: {position}") + logger.info(f"Similarity score: {similarity:.4f}") + return position + + else: + logger.error(f"Navigation query failed: {result}") + return False + + +def setup_visualization(robot, port=5555): + """Set up visualization streams for the web interface""" + logger.info(f"Setting up visualization streams on port {port}") + + # Get video stream from robot + video_stream = robot.video_stream_ros.pipe( + RxOps.share(), + RxOps.map(lambda frame: frame), + RxOps.filter(lambda frame: frame is not None), + ) + + # Get local planner visualization stream + local_planner_stream = robot.local_planner_viz_stream.pipe( + RxOps.share(), + RxOps.map(lambda frame: frame), + RxOps.filter(lambda frame: frame is not None), + ) + + # Create web interface with streams + streams = {"robot_video": video_stream, "local_planner": local_planner_stream} + + web_interface = RobotWebInterface(port=port, **streams) + + return web_interface + + +def run_navigation(robot, target): + """Run navigation in a separate thread""" + logger.info(f"Starting navigation to target: {target}") + return robot.global_planner.set_goal(target) + + +def main(): + args = parse_args() + + # Ensure directories exist + if not args.justgo: + os.makedirs(args.db_path, exist_ok=True) + os.makedirs(args.visual_memory_dir, exist_ok=True) + + # Initialize robot + logger.info("Initializing robot...") + ros_control = UnitreeROSControl(node_name="simple_nav_test", mock_connection=False) + robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP"), skills=MyUnitreeSkills()) + + # Set up visualization + web_interface = None + try: + # Set up visualization first if the robot has video capabilities + if hasattr(robot, "video_stream_ros") and robot.video_stream_ros is not None: + web_interface = setup_visualization(robot, port=args.port) + # Start web interface in a separate thread + viz_thread = threading.Thread(target=web_interface.run, daemon=True) + viz_thread.start() + logger.info(f"Web visualization available at http://localhost:{args.port}") + # Wait a moment for the web interface to initialize + time.sleep(2) + + if args.justgo: + # Just go to the specified location + coords = list(map(float, args.justgo.split(","))) + logger.info(f"Navigating to coordinates: {coords}") + + # Run navigation + navigate_thread = threading.Thread( + target=lambda: run_navigation(robot, coords), daemon=True + ) + navigate_thread.start() + + # Wait for navigation to complete or user to interrupt + try: + while navigate_thread.is_alive(): + time.sleep(0.5) + logger.info("Navigation completed") + except KeyboardInterrupt: + logger.info("Navigation interrupted by user") + else: + # Build map if not skipped + if not args.skip_build: + build_map(robot, args) + + # Query the map + target = query_map(robot, args) + + if not target: + logger.error("No target found for navigation.") + return + + # Run navigation + navigate_thread = threading.Thread( + target=lambda: run_navigation(robot, target), daemon=True + ) + navigate_thread.start() + + # Wait for navigation to complete or user to interrupt + try: + while navigate_thread.is_alive(): + time.sleep(0.5) + logger.info("Navigation completed") + except KeyboardInterrupt: + logger.info("Navigation interrupted by user") + + # If web interface is running, keep the main thread alive + if web_interface: + logger.info( + "Navigation completed. Visualization still available. Press Ctrl+C to exit." + ) + try: + while True: + time.sleep(0.5) + except KeyboardInterrupt: + logger.info("Exiting...") + + finally: + # Clean up + logger.info("Cleaning up resources...") + try: + robot.cleanup() + except Exception as e: + logger.error(f"Error during cleanup: {e}") + + logger.info("Test completed successfully") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_object_detection_agent_data_query_stream.py b/build/lib/tests/test_object_detection_agent_data_query_stream.py new file mode 100644 index 0000000000..00e5625119 --- /dev/null +++ b/build/lib/tests/test_object_detection_agent_data_query_stream.py @@ -0,0 +1,191 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys +import argparse +import threading +from typing import List, Dict, Any +from reactivex import Subject, operators as ops + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.logging_config import logger +from dimos.stream.video_provider import VideoProvider +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.agents.claude_agent import ClaudeAgent + +from dotenv import load_dotenv + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Test ObjectDetectionStream for object detection and position estimation" + ) + parser.add_argument( + "--mode", + type=str, + default="webcam", + choices=["robot", "webcam"], + help='Mode to run: "robot" or "webcam" (default: webcam)', + ) + return parser.parse_args() + + +load_dotenv() + + +def main(): + # Get command line arguments + args = parse_args() + + # Set default parameters + min_confidence = 0.6 + class_filter = None # No class filtering + web_port = 5555 + + # Initialize detector + detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) + + # Initialize based on mode + if args.mode == "robot": + print("Initializing in robot mode...") + + # Get robot IP from environment + robot_ip = os.getenv("ROBOT_IP") + if not robot_ip: + print("Error: ROBOT_IP environment variable not set.") + sys.exit(1) + + # Initialize robot + robot = UnitreeGo2( + ip=robot_ip, + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + # Create video stream from robot's camera + video_stream = robot.video_stream_ros + + # Initialize ObjectDetectionStream with robot and transform function + object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + transform_to_map=robot.ros_control.transform_pose, + detector=detector, + video_stream=video_stream, + ) + + else: # webcam mode + print("Initializing in webcam mode...") + + # Define camera intrinsics for the webcam + # These are approximate values for a typical 640x480 webcam + width, height = 640, 480 + focal_length_mm = 3.67 # mm (typical webcam) + sensor_width_mm = 4.8 # mm (1/4" sensor) + + # Calculate focal length in pixels + focal_length_x_px = width * focal_length_mm / sensor_width_mm + focal_length_y_px = height * focal_length_mm / sensor_width_mm + + # Principal point (center of image) + cx, cy = width / 2, height / 2 + + # Camera intrinsics in [fx, fy, cx, cy] format + camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] + + # Initialize video provider and ObjectDetectionStream + video_provider = VideoProvider("test_camera", video_source=0) # Default camera + # Create video stream + video_stream = backpressure( + video_provider.capture_video_as_observable(realtime=True, fps=30) + ) + + object_detector = ObjectDetectionStream( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + detector=detector, + video_stream=video_stream, + ) + + # Set placeholder robot for cleanup + robot = None + + # Create visualization stream for web interface + viz_stream = object_detector.get_stream().pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), + ) + + # Create object data observable for Agent using the formatted stream + object_data_stream = object_detector.get_formatted_stream().pipe( + ops.share(), ops.filter(lambda x: x is not None) + ) + + # Create stop event for clean shutdown + stop_event = threading.Event() + + try: + # Set up web interface + print("Initializing web interface...") + web_interface = RobotWebInterface(port=web_port, object_detection=viz_stream) + + agent = ClaudeAgent( + dev_name="test_agent", + # input_query_stream=stt_node.emit_text(), + input_query_stream=web_interface.query_stream, + input_data_stream=object_data_stream, + system_query="Tell me what you see", + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=0, + ) + + # Print configuration information + print("\nObjectDetectionStream Test Running:") + print(f"Mode: {args.mode}") + print(f"Web Interface: http://localhost:{web_port}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + # Clean up resources + print("Cleaning up resources...") + stop_event.set() + + if args.mode == "robot" and robot: + robot.cleanup() + elif args.mode == "webcam": + if "video_provider" in locals(): + video_provider.dispose_all() + + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_object_detection_stream.py b/build/lib/tests/test_object_detection_stream.py new file mode 100644 index 0000000000..1cf8aeab01 --- /dev/null +++ b/build/lib/tests/test_object_detection_stream.py @@ -0,0 +1,240 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys +import argparse +import threading +from typing import List, Dict, Any +from reactivex import Subject, operators as ops + +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.logging_config import logger +from dimos.stream.video_provider import VideoProvider +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure +from dotenv import load_dotenv + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Test ObjectDetectionStream for object detection and position estimation" + ) + parser.add_argument( + "--mode", + type=str, + default="webcam", + choices=["robot", "webcam"], + help='Mode to run: "robot" or "webcam" (default: webcam)', + ) + return parser.parse_args() + + +load_dotenv() + + +class ResultPrinter: + def __init__(self, print_interval: float = 1.0): + """ + Initialize a result printer that limits console output frequency. + + Args: + print_interval: Minimum time between console prints in seconds + """ + self.print_interval = print_interval + self.last_print_time = 0 + + def print_results(self, objects: List[Dict[str, Any]]): + """Print object detection results to console with rate limiting.""" + current_time = time.time() + + # Only print results at the specified interval + if current_time - self.last_print_time >= self.print_interval: + self.last_print_time = current_time + + if not objects: + print("\n[No objects detected]") + return + + print("\n" + "=" * 50) + print(f"Detected {len(objects)} objects at {time.strftime('%H:%M:%S')}:") + print("=" * 50) + + for i, obj in enumerate(objects): + pos = obj["position"] + rot = obj["rotation"] + size = obj["size"] + + print( + f"{i + 1}. {obj['label']} (ID: {obj['object_id']}, Conf: {obj['confidence']:.2f})" + ) + print(f" Position: x={pos.x:.2f}, y={pos.y:.2f}, z={pos.z:.2f} m") + print(f" Rotation: yaw={rot.z:.2f} rad") + print(f" Size: width={size['width']:.2f}, height={size['height']:.2f} m") + print(f" Depth: {obj['depth']:.2f} m") + print("-" * 30) + + +def main(): + # Get command line arguments + args = parse_args() + + # Set up the result printer for console output + result_printer = ResultPrinter(print_interval=1.0) + + # Set default parameters + min_confidence = 0.6 + class_filter = None # No class filtering + web_port = 5555 + + # Initialize based on mode + if args.mode == "robot": + print("Initializing in robot mode...") + + # Get robot IP from environment + robot_ip = os.getenv("ROBOT_IP") + if not robot_ip: + print("Error: ROBOT_IP environment variable not set.") + sys.exit(1) + + # Initialize robot + robot = UnitreeGo2( + ip=robot_ip, + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + # Create video stream from robot's camera + video_stream = robot.video_stream_ros + + # Initialize ObjectDetectionStream with robot and transform function + object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + transform_to_map=robot.ros_control.transform_pose, + detector=detector, + video_stream=video_stream, + disable_depth=False, + ) + + else: # webcam mode + print("Initializing in webcam mode...") + + # Define camera intrinsics for the webcam + # These are approximate values for a typical 640x480 webcam + width, height = 640, 480 + focal_length_mm = 3.67 # mm (typical webcam) + sensor_width_mm = 4.8 # mm (1/4" sensor) + + # Calculate focal length in pixels + focal_length_x_px = width * focal_length_mm / sensor_width_mm + focal_length_y_px = height * focal_length_mm / sensor_width_mm + + # Principal point (center of image) + cx, cy = width / 2, height / 2 + + # Camera intrinsics in [fx, fy, cx, cy] format + camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] + + # Initialize video provider and ObjectDetectionStream + video_provider = VideoProvider("test_camera", video_source=0) # Default camera + # Create video stream + video_stream = backpressure( + video_provider.capture_video_as_observable(realtime=True, fps=30) + ) + + object_detector = ObjectDetectionStream( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + video_stream=video_stream, + disable_depth=False, + draw_masks=True, + ) + + # Set placeholder robot for cleanup + robot = None + + # Create visualization stream for web interface + viz_stream = object_detector.get_stream().pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), + ) + + # Create stop event for clean shutdown + stop_event = threading.Event() + + # Define subscription callback to print results + def on_next(result): + if stop_event.is_set(): + return + + # Print detected objects to console + if "objects" in result: + result_printer.print_results(result["objects"]) + + def on_error(error): + print(f"Error in detection stream: {error}") + stop_event.set() + + def on_completed(): + print("Detection stream completed") + stop_event.set() + + try: + # Subscribe to the detection stream + subscription = object_detector.get_stream().subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + # Set up web interface + print("Initializing web interface...") + web_interface = RobotWebInterface(port=web_port, object_detection=viz_stream) + + # Print configuration information + print("\nObjectDetectionStream Test Running:") + print(f"Mode: {args.mode}") + print(f"Web Interface: http://localhost:{web_port}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + # Clean up resources + print("Cleaning up resources...") + stop_event.set() + + if subscription: + subscription.dispose() + + if args.mode == "robot" and robot: + robot.cleanup() + elif args.mode == "webcam": + if "video_provider" in locals(): + video_provider.dispose_all() + + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_object_tracking_webcam.py b/build/lib/tests/test_object_tracking_webcam.py new file mode 100644 index 0000000000..a9d792d51b --- /dev/null +++ b/build/lib/tests/test_object_tracking_webcam.py @@ -0,0 +1,222 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import os +import sys +import queue +import threading +import tests.test_header + +from dimos.stream.video_provider import VideoProvider +from dimos.perception.object_tracker import ObjectTrackingStream + +# Global variables for bounding box selection +selecting_bbox = False +bbox_points = [] +current_bbox = None +tracker_initialized = False +object_size = 0.30 # Hardcoded object size in meters (adjust based on your tracking target) + + +def mouse_callback(event, x, y, flags, param): + global selecting_bbox, bbox_points, current_bbox, tracker_initialized, tracker_stream + + if event == cv2.EVENT_LBUTTONDOWN: + # Start bbox selection + selecting_bbox = True + bbox_points = [(x, y)] + current_bbox = None + tracker_initialized = False + + elif event == cv2.EVENT_MOUSEMOVE and selecting_bbox: + # Update current selection for visualization + current_bbox = [bbox_points[0][0], bbox_points[0][1], x, y] + + elif event == cv2.EVENT_LBUTTONUP: + # End bbox selection + selecting_bbox = False + if bbox_points: + bbox_points.append((x, y)) + x1, y1 = bbox_points[0] + x2, y2 = bbox_points[1] + # Ensure x1,y1 is top-left and x2,y2 is bottom-right + current_bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] + # Add the bbox to the tracking queue + if param.get("bbox_queue") and not tracker_initialized: + param["bbox_queue"].put((current_bbox, object_size)) + tracker_initialized = True + + +def main(): + global tracker_initialized + + # Create queues for thread communication + frame_queue = queue.Queue(maxsize=5) + bbox_queue = queue.Queue() + stop_event = threading.Event() + + # Logitech C920e camera parameters at 480p + # Convert physical parameters to pixel-based intrinsics + width, height = 640, 480 + focal_length_mm = 3.67 # mm + sensor_width_mm = 4.8 # mm (1/4" sensor) + sensor_height_mm = 3.6 # mm + + # Calculate focal length in pixels + focal_length_x_px = width * focal_length_mm / sensor_width_mm + focal_length_y_px = height * focal_length_mm / sensor_height_mm + + # Principal point (assuming center of image) + cx = width / 2 + cy = height / 2 + + # Final camera intrinsics in [fx, fy, cx, cy] format + camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] + + # Initialize video provider and object tracking stream + video_provider = VideoProvider("test_camera", video_source=0) + tracker_stream = ObjectTrackingStream( + camera_intrinsics=camera_intrinsics, + camera_pitch=0.0, # Adjust if your camera is tilted + camera_height=0.5, # Height of camera from ground in meters (adjust as needed) + ) + + # Create video stream + video_stream = video_provider.capture_video_as_observable(realtime=True, fps=30) + tracking_stream = tracker_stream.create_stream(video_stream) + + # Define callbacks for the tracking stream + def on_next(result): + if stop_event.is_set(): + return + + # Get the visualization frame + viz_frame = result["viz_frame"] + + # If we're selecting a bbox, draw the current selection + if selecting_bbox and current_bbox is not None: + x1, y1, x2, y2 = current_bbox + cv2.rectangle(viz_frame, (x1, y1), (x2, y2), (0, 255, 255), 2) + + # Add instructions + cv2.putText( + viz_frame, + "Click and drag to select object", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + viz_frame, + f"Object size: {object_size:.2f}m", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + + # Show tracking status + status = "Tracking" if tracker_initialized else "Not tracking" + cv2.putText( + viz_frame, + f"Status: {status}", + (10, 90), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (0, 255, 0) if tracker_initialized else (0, 0, 255), + 2, + ) + + # Put frame in queue for main thread to display + try: + frame_queue.put_nowait(viz_frame) + except queue.Full: + # Skip frame if queue is full + pass + + def on_error(error): + print(f"Error: {error}") + stop_event.set() + + def on_completed(): + print("Stream completed") + stop_event.set() + + # Start the subscription + subscription = None + + try: + # Subscribe to start processing in background thread + subscription = tracking_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + print("Object tracking started. Click and drag to select an object. Press 'q' to exit.") + + # Create window and set mouse callback + cv2.namedWindow("Object Tracker") + cv2.setMouseCallback("Object Tracker", mouse_callback, {"bbox_queue": bbox_queue}) + + # Main thread loop for displaying frames and handling bbox selection + while not stop_event.is_set(): + # Check if there's a new bbox to track + try: + new_bbox, size = bbox_queue.get_nowait() + print(f"New object selected: {new_bbox}, size: {size}m") + # Initialize tracker with the new bbox and size + tracker_stream.track(new_bbox, size=size) + except queue.Empty: + pass + + try: + # Get frame with timeout + viz_frame = frame_queue.get(timeout=1.0) + + # Display the frame + cv2.imshow("Object Tracker", viz_frame) + # Check for exit key + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + + except queue.Empty: + # No frame available, check if we should continue + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + continue + + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + video_provider.dispose_all() + tracker_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_object_tracking_with_qwen.py b/build/lib/tests/test_object_tracking_with_qwen.py new file mode 100644 index 0000000000..959565ae55 --- /dev/null +++ b/build/lib/tests/test_object_tracking_with_qwen.py @@ -0,0 +1,216 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import cv2 +import numpy as np +import queue +import threading +import json +from reactivex import Subject, operators as RxOps +from openai import OpenAI +import tests.test_header + +from dimos.stream.video_provider import VideoProvider +from dimos.perception.object_tracker import ObjectTrackingStream +from dimos.models.qwen.video_query import get_bbox_from_qwen +from dimos.utils.logging_config import logger + +# Global variables for tracking control +object_size = 0.30 # Hardcoded object size in meters (adjust based on your tracking target) +tracking_object_name = "object" # Will be updated by Qwen +object_name = "hairbrush" # Example object name for Qwen + +global tracker_initialized, detection_in_progress + +# Create queues for thread communication +frame_queue = queue.Queue(maxsize=5) +stop_event = threading.Event() + +# Logitech C920e camera parameters at 480p +width, height = 640, 480 +focal_length_mm = 3.67 # mm +sensor_width_mm = 4.8 # mm (1/4" sensor) +sensor_height_mm = 3.6 # mm + +# Calculate focal length in pixels +focal_length_x_px = width * focal_length_mm / sensor_width_mm +focal_length_y_px = height * focal_length_mm / sensor_height_mm +cx, cy = width / 2, height / 2 + +# Final camera intrinsics in [fx, fy, cx, cy] format +camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] + +# Initialize video provider and object tracking stream +video_provider = VideoProvider("webcam", video_source=0) +tracker_stream = ObjectTrackingStream( + camera_intrinsics=camera_intrinsics, camera_pitch=0.0, camera_height=0.5 +) + +# Create video streams +video_stream = video_provider.capture_video_as_observable(realtime=True, fps=10) +tracking_stream = tracker_stream.create_stream(video_stream) + +# Check if display is available +if "DISPLAY" not in os.environ: + raise RuntimeError( + "No display available. Please set DISPLAY environment variable or run in headless mode." + ) + + +# Define callbacks for the tracking stream +def on_next(result): + global tracker_initialized, detection_in_progress + if stop_event.is_set(): + return + + # Get the visualization frame + viz_frame = result["viz_frame"] + + # Add information to the visualization + cv2.putText( + viz_frame, + f"Tracking {tracking_object_name}", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + viz_frame, + f"Object size: {object_size:.2f}m", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + + # Show tracking status + status = "Tracking" if tracker_initialized else "Waiting for detection" + color = (0, 255, 0) if tracker_initialized else (0, 0, 255) + cv2.putText(viz_frame, f"Status: {status}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) + + # If detection is in progress, show a message + if detection_in_progress: + cv2.putText( + viz_frame, "Querying Qwen...", (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2 + ) + + # Put frame in queue for main thread to display + try: + frame_queue.put_nowait(viz_frame) + except queue.Full: + pass + + +def on_error(error): + print(f"Error: {error}") + stop_event.set() + + +def on_completed(): + print("Stream completed") + stop_event.set() + + +# Start the subscription +subscription = None + +try: + # Initialize global flags + tracker_initialized = False + detection_in_progress = False + # Subscribe to start processing in background thread + subscription = tracking_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + print("Object tracking with Qwen started. Press 'q' to exit.") + print("Waiting for initial object detection...") + + # Main thread loop for displaying frames and updating tracking + while not stop_event.is_set(): + # Check if we need to update tracking + + if not detection_in_progress: + detection_in_progress = True + print("Requesting object detection from Qwen...") + + print("detection_in_progress: ", detection_in_progress) + print("tracker_initialized: ", tracker_initialized) + + def detection_task(): + global detection_in_progress, tracker_initialized, tracking_object_name, object_size + try: + result = get_bbox_from_qwen(video_stream, object_name=object_name) + print(f"Got result from Qwen: {result}") + + if result: + bbox, size = result + print(f"Detected object at {bbox} with size {size}") + tracker_stream.track(bbox, size=size) + tracker_initialized = True + return + + print("No object detected by Qwen") + tracker_initialized = False + tracker_stream.stop_track() + + except Exception as e: + print(f"Error in update_tracking: {e}") + tracker_initialized = False + tracker_stream.stop_track() + finally: + detection_in_progress = False + + # Run detection task in a separate thread + threading.Thread(target=detection_task, daemon=True).start() + + try: + # Get frame with timeout + viz_frame = frame_queue.get(timeout=0.1) + + # Display the frame + cv2.imshow("Object Tracking with Qwen", viz_frame) + + # Check for exit key + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + + except queue.Empty: + # No frame available, check if we should continue + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + continue + +except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") +finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + video_provider.dispose_all() + tracker_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") diff --git a/build/lib/tests/test_observe_stream_skill.py b/build/lib/tests/test_observe_stream_skill.py new file mode 100644 index 0000000000..7f18789fb0 --- /dev/null +++ b/build/lib/tests/test_observe_stream_skill.py @@ -0,0 +1,131 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test for the monitor skill and kill skill. + +This script demonstrates how to use the monitor skill to periodically +send images from the robot's video stream to a Claude agent, and how +to use the kill skill to terminate the monitor skill. +""" + +import os +import time +import threading +from dotenv import load_dotenv +import reactivex as rx +from reactivex import operators as ops +import logging + +from dimos.agents.claude_agent import ClaudeAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.kill_skill import KillSkill +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.logging_config import setup_logger +import tests.test_header + +logger = setup_logger("tests.test_observe_stream_skill") + +load_dotenv() + + +def main(): + # Initialize the robot with mock connection for testing + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP", "192.168.123.161"), skills=MyUnitreeSkills(), mock_connection=True + ) + + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + streams = {"unitree_video": robot.get_ros_video_stream()} + text_streams = { + "agent_responses": agent_response_stream, + } + + web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + + agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=web_interface.query_stream, + skills=robot.get_skills(), + system_query="""You are an agent monitoring a robot's environment. + When you see an image, describe what you see and alert if you notice any people or important changes. + Be concise but thorough in your observations.""", + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=10000, + ) + + agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + + robot_skills = robot.get_skills() + + robot_skills.add(ObserveStream) + robot_skills.add(KillSkill) + + robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) + robot_skills.create_instance("KillSkill", skill_library=robot_skills) + + web_interface_thread = threading.Thread(target=web_interface.run) + web_interface_thread.daemon = True + web_interface_thread.start() + + logger.info("Starting monitor skill...") + + memory_file = os.path.join(agent.output_dir, "memory.txt") + with open(memory_file, "a") as f: + f.write( + "SKILL CALL: ObserveStream(timestep=10.0, query_text='What do you see in this image? Alert me if you see any people.', max_duration=120.0)" + ) + + result = robot_skills.call( + "ObserveStream", + timestep=10.0, # 20 seconds between monitoring queries + query_text="What do you see in this image? Alert me if you see any people.", + max_duration=120.0, + ) # Run for 120 seconds + logger.info(f"Monitor skill result: {result}") + + logger.info(f"Running skills: {robot_skills.get_running_skills().keys()}") + + try: + logger.info("Observer running. Will stop after 35 seconds...") + time.sleep(20.0) + + logger.info(f"Running skills before kill: {robot_skills.get_running_skills().keys()}") + logger.info("Killing the observer skill...") + + memory_file = os.path.join(agent.output_dir, "memory.txt") + with open(memory_file, "a") as f: + f.write("\n\nSKILL CALL: KillSkill(skill_name='observer')\n\n") + + kill_result = robot_skills.call("KillSkill", skill_name="observer") + logger.info(f"Kill skill result: {kill_result}") + + logger.info(f"Running skills after kill: {robot_skills.get_running_skills().keys()}") + + # Keep test running until user interrupts + while True: + time.sleep(1.0) + except KeyboardInterrupt: + logger.info("Test interrupted by user") + + logger.info("Test completed") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_person_following_robot.py b/build/lib/tests/test_person_following_robot.py new file mode 100644 index 0000000000..46f91cc7a3 --- /dev/null +++ b/build/lib/tests/test_person_following_robot.py @@ -0,0 +1,113 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys +from reactivex import operators as RxOps +import tests.test_header + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.logging_config import logger +from dimos.models.qwen.video_query import query_single_frame_observable + + +def main(): + # Hardcoded parameters + timeout = 60.0 # Maximum time to follow a person (seconds) + distance = 0.5 # Desired distance to maintain from target (meters) + + print("Initializing Unitree Go2 robot...") + + # Initialize the robot with ROS control and skills + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + + tracking_stream = robot.person_tracking_stream + viz_stream = tracking_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x["viz_frame"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + video_stream = robot.get_ros_video_stream() + + try: + # Set up web interface + logger.info("Initializing web interface") + streams = {"unitree_video": video_stream, "person_tracking": viz_stream} + + web_interface = RobotWebInterface(port=5555, **streams) + + # Wait for camera and tracking to initialize + print("Waiting for camera and tracking to initialize...") + time.sleep(5) + # Get initial point from Qwen + + max_retries = 5 + delay = 3 + + for attempt in range(max_retries): + try: + qwen_point = eval( + query_single_frame_observable( + video_stream, + "Look at this frame and point to the person shirt. Return ONLY their center coordinates as a tuple (x,y).", + ) + .pipe(RxOps.take(1)) + .run() + ) # Get first response and convert string tuple to actual tuple + logger.info(f"Found person at coordinates {qwen_point}") + break # If successful, break out of retry loop + except Exception as e: + if attempt < max_retries - 1: + logger.error( + f"Person not found. Attempt {attempt + 1}/{max_retries} failed. Retrying in {delay}s... Error: {e}" + ) + time.sleep(delay) + else: + logger.error(f"Person not found after {max_retries} attempts. Last error: {e}") + return + + # Start following human in a separate thread + import threading + + follow_thread = threading.Thread( + target=lambda: robot.follow_human(timeout=timeout, distance=distance, point=qwen_point), + daemon=True, + ) + follow_thread.start() + + print(f"Following human at point {qwen_point} for {timeout} seconds...") + print("Web interface available at http://localhost:5555") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nInterrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Test completed") + robot.cleanup() + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_person_following_webcam.py b/build/lib/tests/test_person_following_webcam.py new file mode 100644 index 0000000000..2108c4cf95 --- /dev/null +++ b/build/lib/tests/test_person_following_webcam.py @@ -0,0 +1,230 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import os +import sys +import queue +import threading +import tests.test_header + + +from dimos.stream.video_provider import VideoProvider +from dimos.perception.person_tracker import PersonTrackingStream +from dimos.perception.visual_servoing import VisualServoing + + +def main(): + # Create a queue for thread communication (limit to prevent memory issues) + frame_queue = queue.Queue(maxsize=5) + result_queue = queue.Queue(maxsize=5) # For tracking results + stop_event = threading.Event() + + # Logitech C920e camera parameters at 480p + # Convert physical parameters to intrinsics [fx, fy, cx, cy] + resolution = (640, 480) # 480p resolution + focal_length_mm = 3.67 # mm + sensor_size_mm = (4.8, 3.6) # mm (1/4" sensor) + + # Calculate focal length in pixels + fx = (resolution[0] * focal_length_mm) / sensor_size_mm[0] + fy = (resolution[1] * focal_length_mm) / sensor_size_mm[1] + + # Principal point (typically at image center) + cx = resolution[0] / 2 + cy = resolution[1] / 2 + + # Camera intrinsics in [fx, fy, cx, cy] format + camera_intrinsics = [fx, fy, cx, cy] + + # Camera mounted parameters + camera_pitch = np.deg2rad(-5) # negative for downward pitch + camera_height = 1.4 # meters + + # Initialize video provider and person tracking stream + video_provider = VideoProvider("test_camera", video_source=0) + person_tracker = PersonTrackingStream( + camera_intrinsics=camera_intrinsics, camera_pitch=camera_pitch, camera_height=camera_height + ) + + # Create streams + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=20) + person_tracking_stream = person_tracker.create_stream(video_stream) + + # Create visual servoing object + visual_servoing = VisualServoing( + tracking_stream=person_tracking_stream, + max_linear_speed=0.5, + max_angular_speed=0.75, + desired_distance=2.5, + ) + + # Track if we have selected a person to follow + selected_point = None + tracking_active = False + + # Define callbacks for the tracking stream + def on_next(result): + if stop_event.is_set(): + return + + # Get the visualization frame which already includes person detections + # with bounding boxes, tracking IDs, and distance/angle information + viz_frame = result["viz_frame"] + + # Store the result for the main thread to use with visual servoing + try: + result_queue.put_nowait(result) + except queue.Full: + # Skip if queue is full + pass + + # Put frame in queue for main thread to display (non-blocking) + try: + frame_queue.put_nowait(viz_frame) + except queue.Full: + # Skip frame if queue is full + pass + + def on_error(error): + print(f"Error: {error}") + stop_event.set() + + def on_completed(): + print("Stream completed") + stop_event.set() + + # Mouse callback for selecting a person to track + def mouse_callback(event, x, y, flags, param): + nonlocal selected_point, tracking_active + + if event == cv2.EVENT_LBUTTONDOWN: + # Store the clicked point + selected_point = (x, y) + tracking_active = False # Will be set to True if start_tracking succeeds + print(f"Selected point: {selected_point}") + + # Start the subscription + subscription = None + + try: + # Subscribe to start processing in background thread + subscription = person_tracking_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + print("Person tracking visualization started.") + print("Click on a person to start visual servoing. Press 'q' to exit.") + + # Set up mouse callback + cv2.namedWindow("Person Tracking") + cv2.setMouseCallback("Person Tracking", mouse_callback) + + # Main thread loop for displaying frames + while not stop_event.is_set(): + try: + # Get frame with timeout (allows checking stop_event periodically) + frame = frame_queue.get(timeout=1.0) + + # Call the visual servoing if we have a selected point + if selected_point is not None: + # If not actively tracking, try to start tracking + if not tracking_active: + tracking_active = visual_servoing.start_tracking(point=selected_point) + if not tracking_active: + print("Failed to start tracking") + selected_point = None + + # If tracking is active, update tracking + if tracking_active: + servoing_result = visual_servoing.updateTracking() + + # Display visual servoing output on the frame + linear_vel = servoing_result.get("linear_vel", 0.0) + angular_vel = servoing_result.get("angular_vel", 0.0) + running = visual_servoing.running + + status_color = ( + (0, 255, 0) if running else (0, 0, 255) + ) # Green if running, red if not + + # Add velocity text to frame + cv2.putText( + frame, + f"Linear: {linear_vel:.2f} m/s", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + cv2.putText( + frame, + f"Angular: {angular_vel:.2f} rad/s", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + cv2.putText( + frame, + f"Tracking: {'ON' if running else 'OFF'}", + (10, 90), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + + # If tracking is lost, reset selected_point and tracking_active + if not running: + selected_point = None + tracking_active = False + + # Display the frame in main thread + cv2.imshow("Person Tracking", frame) + + # Check for exit key + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + + except queue.Empty: + # No frame available, check if we should continue + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + continue + + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + visual_servoing.cleanup() + video_provider.dispose_all() + person_tracker.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_planning_agent_web_interface.py b/build/lib/tests/test_planning_agent_web_interface.py new file mode 100644 index 0000000000..1d1e3fcd87 --- /dev/null +++ b/build/lib/tests/test_planning_agent_web_interface.py @@ -0,0 +1,180 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Planning agent demo with FastAPI server and robot integration. + +Connects a planning agent, execution agent, and robot with a web interface. + +Environment Variables: + OPENAI_API_KEY: Required. OpenAI API key. + ROBOT_IP: Required. IP address of the robot. + CONN_TYPE: Required. Connection method to the robot. + ROS_OUTPUT_DIR: Optional. Directory for ROS output files. +""" + +import tests.test_header +import os +import sys + +# ----- + +from textwrap import dedent +import threading +import time +import reactivex as rx +import reactivex.operators as ops + +# Local application imports +from dimos.agents.agent import OpenAIAgent +from dimos.agents.planning_agent import PlanningAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.utils.logging_config import logger + +# from dimos.web.fastapi_server import FastAPIServer +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.threadpool import make_single_thread_scheduler + + +def main(): + # Get environment variables + robot_ip = os.getenv("ROBOT_IP") + if not robot_ip: + raise ValueError("ROBOT_IP environment variable is required") + connection_method = os.getenv("CONN_TYPE") or "webrtc" + output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) + + # Initialize components as None for proper cleanup + robot = None + web_interface = None + planner = None + executor = None + + try: + # Initialize robot + logger.info("Initializing Unitree Robot") + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir, + mock_connection=False, + skills=MyUnitreeSkills(), + ) + # Set up video stream + logger.info("Starting video stream") + video_stream = robot.get_ros_video_stream() + + # Initialize robot skills + logger.info("Initializing robot skills") + + # Create subjects for planner and executor responses + logger.info("Creating response streams") + planner_response_subject = rx.subject.Subject() + planner_response_stream = planner_response_subject.pipe(ops.share()) + + executor_response_subject = rx.subject.Subject() + executor_response_stream = executor_response_subject.pipe(ops.share()) + + # Web interface mode with FastAPI server + logger.info("Initializing FastAPI server") + streams = {"unitree_video": video_stream} + text_streams = { + "planner_responses": planner_response_stream, + "executor_responses": executor_response_stream, + } + + web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + + logger.info("Starting planning agent with web interface") + planner = PlanningAgent( + dev_name="TaskPlanner", + model_name="gpt-4o", + input_query_stream=web_interface.query_stream, + skills=robot.get_skills(), + ) + + # Get planner's response observable + logger.info("Setting up agent response streams") + planner_responses = planner.get_response_observable() + + # Connect planner to its subject + planner_responses.subscribe(lambda x: planner_response_subject.on_next(x)) + + planner_responses.subscribe( + on_next=lambda x: logger.info(f"Planner response: {x}"), + on_error=lambda e: logger.error(f"Planner error: {e}"), + on_completed=lambda: logger.info("Planner completed"), + ) + + # Initialize execution agent with robot skills + logger.info("Starting execution agent") + system_query = dedent( + """ + You are a robot execution agent that can execute tasks on a virtual + robot. The sole text you will be given is the task to execute. + You will be given a list of skills that you can use to execute the task. + ONLY OUTPUT THE SKILLS TO EXECUTE, NOTHING ELSE. + """ + ) + executor = OpenAIAgent( + dev_name="StepExecutor", + input_query_stream=planner_responses, + output_dir=output_dir, + skills=robot.get_skills(), + system_query=system_query, + pool_scheduler=make_single_thread_scheduler(), + ) + + # Get executor's response observable + executor_responses = executor.get_response_observable() + + # Subscribe to responses for logging + executor_responses.subscribe( + on_next=lambda x: logger.info(f"Executor response: {x}"), + on_error=lambda e: logger.error(f"Executor error: {e}"), + on_completed=lambda: logger.info("Executor completed"), + ) + + # Connect executor to its subject + executor_responses.subscribe(lambda x: executor_response_subject.on_next(x)) + + # Start web server (blocking call) + logger.info("Starting FastAPI server") + web_interface.run() + + except KeyboardInterrupt: + print("Stopping demo...") + except Exception as e: + logger.error(f"Error: {e}") + return 1 + finally: + # Clean up all components + logger.info("Cleaning up components") + if executor: + executor.dispose_all() + if planner: + planner.dispose_all() + if web_interface: + web_interface.dispose_all() + if robot: + robot.cleanup() + # Halt execution forever + while True: + time.sleep(1) + + +if __name__ == "__main__": + sys.exit(main()) + +# Example Task: Move the robot forward by 1 meter, then turn 90 degrees clockwise, then move backward by 1 meter, then turn a random angle counterclockwise, then repeat this sequence 5 times. diff --git a/build/lib/tests/test_planning_robot_agent.py b/build/lib/tests/test_planning_robot_agent.py new file mode 100644 index 0000000000..6e55e5de71 --- /dev/null +++ b/build/lib/tests/test_planning_robot_agent.py @@ -0,0 +1,177 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Planning agent demo with FastAPI server and robot integration. + +Connects a planning agent, execution agent, and robot with a web interface. + +Environment Variables: + OPENAI_API_KEY: Required. OpenAI API key. + ROBOT_IP: Required. IP address of the robot. + CONN_TYPE: Required. Connection method to the robot. + ROS_OUTPUT_DIR: Optional. Directory for ROS output files. + USE_TERMINAL: Optional. If set to "true", use terminal interface instead of web. +""" + +import tests.test_header +import os +import sys + +# ----- + +from textwrap import dedent +import threading +import time + +# Local application imports +from dimos.agents.agent import OpenAIAgent +from dimos.agents.planning_agent import PlanningAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.utils.logging_config import logger +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.threadpool import make_single_thread_scheduler + + +def main(): + # Get environment variables + robot_ip = os.getenv("ROBOT_IP") + if not robot_ip: + raise ValueError("ROBOT_IP environment variable is required") + connection_method = os.getenv("CONN_TYPE") or "webrtc" + output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) + use_terminal = os.getenv("USE_TERMINAL", "").lower() == "true" + + use_terminal = True + # Initialize components as None for proper cleanup + robot = None + web_interface = None + planner = None + executor = None + + try: + # Initialize robot + logger.info("Initializing Unitree Robot") + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir, + mock_connection=True, + ) + + # Set up video stream + logger.info("Starting video stream") + video_stream = robot.get_ros_video_stream() + + # Initialize robot skills + logger.info("Initializing robot skills") + skills_instance = MyUnitreeSkills(robot=robot) + + if use_terminal: + # Terminal mode - no web interface needed + logger.info("Starting planning agent in terminal mode") + planner = PlanningAgent( + dev_name="TaskPlanner", + model_name="gpt-4o", + use_terminal=True, + skills=skills_instance, + ) + else: + # Web interface mode + logger.info("Initializing FastAPI server") + streams = {"unitree_video": video_stream} + web_interface = RobotWebInterface(port=5555, **streams) + + logger.info("Starting planning agent with web interface") + planner = PlanningAgent( + dev_name="TaskPlanner", + model_name="gpt-4o", + input_query_stream=web_interface.query_stream, + skills=skills_instance, + ) + + # Get planner's response observable + logger.info("Setting up agent response streams") + planner_responses = planner.get_response_observable() + + # Initialize execution agent with robot skills + logger.info("Starting execution agent") + system_query = dedent( + """ + You are a robot execution agent that can execute tasks on a virtual + robot. You are given a task to execute and a list of skills that + you can use to execute the task. ONLY OUTPUT THE SKILLS TO EXECUTE, + NOTHING ELSE. + """ + ) + executor = OpenAIAgent( + dev_name="StepExecutor", + input_query_stream=planner_responses, + output_dir=output_dir, + skills=skills_instance, + system_query=system_query, + pool_scheduler=make_single_thread_scheduler(), + ) + + # Get executor's response observable + executor_responses = executor.get_response_observable() + + # Subscribe to responses for logging + executor_responses.subscribe( + on_next=lambda x: logger.info(f"Executor response: {x}"), + on_error=lambda e: logger.error(f"Executor error: {e}"), + on_completed=lambda: logger.info("Executor completed"), + ) + + if use_terminal: + # In terminal mode, just wait for the planning session to complete + logger.info("Waiting for planning session to complete") + while not planner.plan_confirmed: + pass + logger.info("Planning session completed") + else: + # Start web server (blocking call) + logger.info("Starting FastAPI server") + web_interface.run() + + # Keep the main thread alive + logger.error("NOTE: Keeping main thread alive") + while True: + time.sleep(1) + + except KeyboardInterrupt: + print("Stopping demo...") + except Exception as e: + logger.error(f"Error: {e}") + return 1 + finally: + # Clean up all components + logger.info("Cleaning up components") + if executor: + executor.dispose_all() + if planner: + planner.dispose_all() + if web_interface: + web_interface.dispose_all() + if robot: + robot.cleanup() + # Halt execution forever + while True: + time.sleep(1) + + +if __name__ == "__main__": + sys.exit(main()) + +# Example Task: Move the robot forward by 1 meter, then turn 90 degrees clockwise, then move backward by 1 meter, then turn a random angle counterclockwise, then repeat this sequence 5 times. diff --git a/build/lib/tests/test_pointcloud_filtering.py b/build/lib/tests/test_pointcloud_filtering.py new file mode 100644 index 0000000000..57a1cb5b00 --- /dev/null +++ b/build/lib/tests/test_pointcloud_filtering.py @@ -0,0 +1,105 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import time +import threading +from reactivex import operators as ops + +import tests.test_header + +from pyzed import sl +from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.utils.logging_config import logger +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline + + +def main(): + """Test point cloud filtering using the concurrent stream-based ManipulationPipeline.""" + print("Testing point cloud filtering with ManipulationPipeline...") + + # Configuration + min_confidence = 0.6 + web_port = 5555 + + try: + # Initialize ZED camera stream + zed_stream = ZEDCameraStream(resolution=sl.RESOLUTION.HD1080, fps=10) + + # Get camera intrinsics + camera_intrinsics_dict = zed_stream.get_camera_info() + camera_intrinsics = [ + camera_intrinsics_dict["fx"], + camera_intrinsics_dict["fy"], + camera_intrinsics_dict["cx"], + camera_intrinsics_dict["cy"], + ] + + # Create the concurrent manipulation pipeline + pipeline = ManipulationPipeline( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + max_objects=10, + ) + + # Create ZED stream + zed_frame_stream = zed_stream.create_stream().pipe(ops.share()) + + # Create concurrent processing streams + streams = pipeline.create_streams(zed_frame_stream) + detection_viz_stream = streams["detection_viz"] + pointcloud_viz_stream = streams["pointcloud_viz"] + + except ImportError: + print("Error: ZED SDK not installed. Please install pyzed package.") + sys.exit(1) + except RuntimeError as e: + print(f"Error: Failed to open ZED camera: {e}") + sys.exit(1) + + try: + # Set up web interface with concurrent visualization streams + print("Initializing web interface...") + web_interface = RobotWebInterface( + port=web_port, + object_detection=detection_viz_stream, + pointcloud_stream=pointcloud_viz_stream, + ) + + print(f"\nPoint Cloud Filtering Test Running:") + print(f"Web Interface: http://localhost:{web_port}") + print(f"Object Detection View: RGB with bounding boxes") + print(f"Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print(f"Confidence threshold: {min_confidence}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Cleaning up resources...") + if "zed_stream" in locals(): + zed_stream.cleanup() + if "pipeline" in locals(): + pipeline.cleanup() + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_qwen_image_query.py b/build/lib/tests/test_qwen_image_query.py new file mode 100644 index 0000000000..13feaf7eb3 --- /dev/null +++ b/build/lib/tests/test_qwen_image_query.py @@ -0,0 +1,49 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Qwen image query functionality.""" + +import os +from PIL import Image +from dimos.models.qwen.video_query import query_single_frame + + +def test_qwen_image_query(): + """Test querying Qwen with a single image.""" + # Skip if no API key + if not os.getenv("ALIBABA_API_KEY"): + print("ALIBABA_API_KEY not set") + return + + # Load test image + image_path = os.path.join(os.getcwd(), "assets", "test_spatial_memory", "frame_038.jpg") + image = Image.open(image_path) + + # Test basic object detection query + response = query_single_frame( + image=image, + query="What objects do you see in this image? Return as a comma-separated list.", + ) + print(response) + + # Test coordinate query + response = query_single_frame( + image=image, + query="Return the center coordinates of any person in the image as a tuple (x,y)", + ) + print(response) + + +if __name__ == "__main__": + test_qwen_image_query() diff --git a/build/lib/tests/test_robot.py b/build/lib/tests/test_robot.py new file mode 100644 index 0000000000..76289273f7 --- /dev/null +++ b/build/lib/tests/test_robot.py @@ -0,0 +1,86 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import threading +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.local_planner.local_planner import navigate_to_goal_local +from dimos.web.robot_web_interface import RobotWebInterface +from reactivex import operators as RxOps +import tests.test_header + + +def main(): + print("Initializing Unitree Go2 robot with local planner visualization...") + + # Initialize the robot with webrtc interface + robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") + + # Get the camera stream + video_stream = robot.get_video_stream() + + # The local planner visualization stream is created during robot initialization + local_planner_stream = robot.local_planner_viz_stream + + local_planner_stream = local_planner_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + + goal_following_thread = None + try: + # Set up web interface with both streams + streams = {"camera": video_stream, "local_planner": local_planner_stream} + + # Create and start the web interface + web_interface = RobotWebInterface(port=5555, **streams) + + # Wait for initialization + print("Waiting for camera and systems to initialize...") + time.sleep(2) + + # Start the goal following test in a separate thread + print("Starting navigation to local goal (2m ahead) in a separate thread...") + goal_following_thread = threading.Thread( + target=navigate_to_goal_local, + kwargs={"robot": robot, "goal_xy_robot": (3.0, 0.0), "distance": 0.0, "timeout": 300}, + daemon=True, + ) + goal_following_thread.start() + + print("Robot streams running") + print("Web interface available at http://localhost:5555") + print("Press Ctrl+C to exit") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nInterrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Cleaning up...") + # Make sure the robot stands down safely + try: + robot.liedown() + except: + pass + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_rtsp_video_provider.py b/build/lib/tests/test_rtsp_video_provider.py new file mode 100644 index 0000000000..e3824740a6 --- /dev/null +++ b/build/lib/tests/test_rtsp_video_provider.py @@ -0,0 +1,146 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.rtsp_video_provider import RtspVideoProvider +from dimos.web.robot_web_interface import RobotWebInterface +import tests.test_header + +import logging +import time + +import numpy as np +import reactivex as rx +from reactivex import operators as ops + +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.video_operators import VideoOperators as vops +from dimos.stream.video_provider import get_scheduler +from dimos.utils.logging_config import setup_logger + + +logger = setup_logger("tests.test_rtsp_video_provider") + +import sys +import os + +# Load environment variables from .env file +from dotenv import load_dotenv + +load_dotenv() + +# RTSP URL must be provided as a command-line argument or environment variable +RTSP_URL = os.environ.get("TEST_RTSP_URL", "") +if len(sys.argv) > 1: + RTSP_URL = sys.argv[1] # Allow overriding with command-line argument +elif RTSP_URL == "": + print("Please provide an RTSP URL for testing.") + print( + "You can set the TEST_RTSP_URL environment variable or pass it as a command-line argument." + ) + print("Example: python -m dimos.stream.rtsp_video_provider rtsp://...") + sys.exit(1) + +logger.info(f"Attempting to connect to provided RTSP URL.") +provider = RtspVideoProvider(dev_name="TestRtspCam", rtsp_url=RTSP_URL) + +logger.info("Creating observable...") +video_stream_observable = provider.capture_video_as_observable() + +logger.info("Subscribing to observable...") +frame_counter = 0 +start_time = time.monotonic() # Re-initialize start_time +last_log_time = start_time # Keep this for interval timing + +# Create a subject for ffmpeg responses +ffmpeg_response_subject = rx.subject.Subject() +ffmpeg_response_stream = ffmpeg_response_subject.pipe(ops.observe_on(get_scheduler()), ops.share()) + + +def process_frame(frame: np.ndarray): + """Callback function executed for each received frame.""" + global frame_counter, last_log_time, start_time # Add start_time to global + frame_counter += 1 + current_time = time.monotonic() + # Log stats periodically (e.g., every 5 seconds) + if current_time - last_log_time >= 5.0: + total_elapsed_time = current_time - start_time # Calculate total elapsed time + avg_fps = frame_counter / total_elapsed_time if total_elapsed_time > 0 else 0 + logger.info(f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}") + ffmpeg_response_subject.on_next( + f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}" + ) + last_log_time = current_time # Update log time for the next interval + + +def handle_error(error: Exception): + """Callback function executed if the observable stream errors.""" + logger.error(f"Stream error: {error}", exc_info=True) # Log with traceback + + +def handle_completion(): + """Callback function executed when the observable stream completes.""" + logger.info("Stream completed.") + + +# Subscribe to the observable stream +processor = FrameProcessor() +subscription = video_stream_observable.pipe( + # ops.subscribe_on(get_scheduler()), + ops.observe_on(get_scheduler()), + ops.share(), + vops.with_jpeg_export(processor, suffix="reolink_", save_limit=30, loop=True), +).subscribe(on_next=process_frame, on_error=handle_error, on_completed=handle_completion) + +streams = {"reolink_video": video_stream_observable} +text_streams = { + "ffmpeg_responses": ffmpeg_response_stream, +} + +web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + +web_interface.run() # This may block the main thread + +# TODO: Redo disposal / keep-alive loop + +# Keep the main thread alive to receive frames (e.g., for 60 seconds) +print("Stream running. Press Ctrl+C to stop...") +try: + # Keep running indefinitely until interrupted + while True: + time.sleep(1) + # Optional: Check if subscription is still active + # if not subscription.is_disposed: + # # logger.debug("Subscription active...") + # pass + # else: + # logger.warning("Subscription was disposed externally.") + # break + +except KeyboardInterrupt: + print("KeyboardInterrupt received. Shutting down...") +finally: + # Ensure resources are cleaned up regardless of how the loop exits + print("Disposing subscription...") + # subscription.dispose() + print("Disposing provider resources...") + provider.dispose_all() + print("Cleanup finished.") + +# Final check (optional, for debugging) +time.sleep(1) # Give background threads a moment +final_process = provider._ffmpeg_process +if final_process and final_process.poll() is None: + print(f"WARNING: ffmpeg process (PID: {final_process.pid}) may still be running after cleanup!") +else: + print("ffmpeg process appears terminated.") diff --git a/build/lib/tests/test_semantic_seg_robot.py b/build/lib/tests/test_semantic_seg_robot.py new file mode 100644 index 0000000000..eb5beb88e2 --- /dev/null +++ b/build/lib/tests/test_semantic_seg_robot.py @@ -0,0 +1,151 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import os +import sys +import queue +import threading + +# Add the parent directory to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.stream.video_provider import VideoProvider +from dimos.perception.semantic_seg import SemanticSegmentationStream +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.stream.video_operators import VideoOperators as MyVideoOps, Operators as MyOps +from dimos.stream.frame_processor import FrameProcessor +from reactivex import operators as RxOps + + +def main(): + # Create a queue for thread communication (limit to prevent memory issues) + frame_queue = queue.Queue(maxsize=5) + stop_event = threading.Event() + + # Unitree Go2 camera parameters at 1080p + camera_params = { + "resolution": (1920, 1080), # 1080p resolution + "focal_length": 3.2, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) + } + + # Initialize video provider and segmentation stream + # video_provider = VideoProvider("test_camera", video_source=0) + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + ) + + seg_stream = SemanticSegmentationStream( + enable_mono_depth=False, camera_params=camera_params, gt_depth_scale=512.0 + ) + + # Create streams + video_stream = robot.get_ros_video_stream(fps=5) + segmentation_stream = seg_stream.create_stream(video_stream) + + # Define callbacks for the segmentation stream + def on_next(segmentation): + if stop_event.is_set(): + return + # Get the frame and visualize + vis_frame = segmentation.metadata["viz_frame"] + depth_viz = segmentation.metadata["depth_viz"] + # Get the image dimensions + height, width = vis_frame.shape[:2] + depth_height, depth_width = depth_viz.shape[:2] + + # Resize depth visualization to match segmentation height + # (maintaining aspect ratio if needed) + depth_resized = cv2.resize(depth_viz, (int(depth_width * height / depth_height), height)) + + # Create a combined frame for side-by-side display + combined_viz = np.hstack((vis_frame, depth_resized)) + + # Add labels + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(combined_viz, "Semantic Segmentation", (10, 30), font, 0.8, (255, 255, 255), 2) + cv2.putText( + combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2 + ) + + # Put frame in queue for main thread to display (non-blocking) + try: + frame_queue.put_nowait(combined_viz) + except queue.Full: + # Skip frame if queue is full + pass + + def on_error(error): + print(f"Error: {error}") + stop_event.set() + + def on_completed(): + print("Stream completed") + stop_event.set() + + # Start the subscription + subscription = None + + try: + # Subscribe to start processing in background thread + print_emission_args = { + "enabled": True, + "dev_name": "SemanticSegmentation", + "counts": {}, + } + + frame_processor = FrameProcessor(delete_on_init=True) + subscription = segmentation_stream.pipe( + MyOps.print_emission(id="A", **print_emission_args), + RxOps.share(), + MyOps.print_emission(id="B", **print_emission_args), + RxOps.map(lambda x: x.metadata["viz_frame"] if x is not None else None), + MyOps.print_emission(id="C", **print_emission_args), + RxOps.filter(lambda x: x is not None), + MyOps.print_emission(id="D", **print_emission_args), + # MyVideoOps.with_jpeg_export(frame_processor=frame_processor, suffix="_frame_"), + MyOps.print_emission(id="E", **print_emission_args), + ) + + print("Semantic segmentation visualization started. Press 'q' to exit.") + + streams = { + "segmentation_stream": subscription, + } + fast_api_server = RobotWebInterface(port=5555, **streams) + fast_api_server.run() + + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + seg_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_semantic_seg_robot_agent.py b/build/lib/tests/test_semantic_seg_robot_agent.py new file mode 100644 index 0000000000..8007e700a0 --- /dev/null +++ b/build/lib/tests/test_semantic_seg_robot_agent.py @@ -0,0 +1,141 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import os +import sys + +from dimos.stream.video_provider import VideoProvider +from dimos.perception.semantic_seg import SemanticSegmentationStream +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.stream.video_operators import VideoOperators as MyVideoOps, Operators as MyOps +from dimos.stream.frame_processor import FrameProcessor +from reactivex import Subject, operators as RxOps +from dimos.agents.agent import OpenAIAgent +from dimos.utils.threadpool import get_scheduler + + +def main(): + # Unitree Go2 camera parameters at 1080p + camera_params = { + "resolution": (1920, 1080), # 1080p resolution + "focal_length": 3.2, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) + } + + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() + ) + + seg_stream = SemanticSegmentationStream( + enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0 + ) + + # Create streams + video_stream = robot.get_ros_video_stream(fps=5) + segmentation_stream = seg_stream.create_stream( + video_stream.pipe(MyVideoOps.with_fps_sampling(fps=0.5)) + ) + # Throttling to slowdown SegmentationAgent calls + # TODO: add Agent parameter to handle this called api_call_interval + + frame_processor = FrameProcessor(delete_on_init=True) + seg_stream = segmentation_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x.metadata["viz_frame"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + # MyVideoOps.with_jpeg_export(frame_processor=frame_processor, suffix="_frame_"), # debugging + ) + + depth_stream = segmentation_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x.metadata["depth_viz"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + + object_stream = segmentation_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x.metadata["objects"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + RxOps.map( + lambda objects: "\n".join( + f"Object {obj['object_id']}: {obj['label']} (confidence: {obj['prob']:.2f})" + + (f", depth: {obj['depth']:.2f}m" if "depth" in obj else "") + for obj in objects + ) + if objects + else "No objects detected." + ), + ) + + text_query_stream = Subject() + + # Combine text query with latest object data when a new text query arrives + enriched_query_stream = text_query_stream.pipe( + RxOps.with_latest_from(object_stream), + RxOps.map( + lambda combined: { + "query": combined[0], + "objects": combined[1] if len(combined) > 1 else "No object data available", + } + ), + RxOps.map(lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}"), + RxOps.do_action( + lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m") + or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]] + ), + ) + + segmentation_agent = OpenAIAgent( + dev_name="SemanticSegmentationAgent", + model_name="gpt-4o", + system_query="You are a helpful assistant that can control a virtual robot with semantic segmentation / distnace data as a guide. Only output skill calls, no other text", + input_query_stream=enriched_query_stream, + process_all_inputs=False, + pool_scheduler=get_scheduler(), + skills=robot.get_skills(), + ) + agent_response_stream = segmentation_agent.get_response_observable() + + print("Semantic segmentation visualization started. Press 'q' to exit.") + + streams = { + "raw_stream": video_stream, + "depth_stream": depth_stream, + "seg_stream": seg_stream, + } + text_streams = { + "object_stream": object_stream, + "enriched_query_stream": enriched_query_stream, + "agent_response_stream": agent_response_stream, + } + + try: + fast_api_server = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + fast_api_server.query_stream.subscribe(lambda x: text_query_stream.on_next(x)) + fast_api_server.run() + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + seg_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_semantic_seg_webcam.py b/build/lib/tests/test_semantic_seg_webcam.py new file mode 100644 index 0000000000..083d1a0090 --- /dev/null +++ b/build/lib/tests/test_semantic_seg_webcam.py @@ -0,0 +1,140 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import os +import sys +import queue +import threading + +# Add the parent directory to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.stream.video_provider import VideoProvider +from dimos.perception.semantic_seg import SemanticSegmentationStream + + +def main(): + # Create a queue for thread communication (limit to prevent memory issues) + frame_queue = queue.Queue(maxsize=5) + stop_event = threading.Event() + + # Logitech C920e camera parameters at 480p + camera_params = { + "resolution": (640, 480), # 480p resolution + "focal_length": 3.67, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) + } + + # Initialize video provider and segmentation stream + video_provider = VideoProvider("test_camera", video_source=0) + seg_stream = SemanticSegmentationStream( + enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0 + ) + + # Create streams + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=5) + segmentation_stream = seg_stream.create_stream(video_stream) + + # Define callbacks for the segmentation stream + def on_next(segmentation): + if stop_event.is_set(): + return + + # Get the frame and visualize + vis_frame = segmentation.metadata["viz_frame"] + depth_viz = segmentation.metadata["depth_viz"] + # Get the image dimensions + height, width = vis_frame.shape[:2] + depth_height, depth_width = depth_viz.shape[:2] + + # Resize depth visualization to match segmentation height + # (maintaining aspect ratio if needed) + depth_resized = cv2.resize(depth_viz, (int(depth_width * height / depth_height), height)) + + # Create a combined frame for side-by-side display + combined_viz = np.hstack((vis_frame, depth_resized)) + + # Add labels + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(combined_viz, "Semantic Segmentation", (10, 30), font, 0.8, (255, 255, 255), 2) + cv2.putText( + combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2 + ) + + # Put frame in queue for main thread to display (non-blocking) + try: + frame_queue.put_nowait(combined_viz) + except queue.Full: + # Skip frame if queue is full + pass + + def on_error(error): + print(f"Error: {error}") + stop_event.set() + + def on_completed(): + print("Stream completed") + stop_event.set() + + # Start the subscription + subscription = None + + try: + # Subscribe to start processing in background thread + subscription = segmentation_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + print("Semantic segmentation visualization started. Press 'q' to exit.") + + # Main thread loop for displaying frames + while not stop_event.is_set(): + try: + # Get frame with timeout (allows checking stop_event periodically) + combined_viz = frame_queue.get(timeout=1.0) + + # Display the frame in main thread + cv2.imshow("Semantic Segmentation", combined_viz) + # Check for exit key + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + + except queue.Empty: + # No frame available, check if we should continue + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + continue + + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + video_provider.dispose_all() + seg_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_skills.py b/build/lib/tests/test_skills.py new file mode 100644 index 0000000000..0d4b7f2ff8 --- /dev/null +++ b/build/lib/tests/test_skills.py @@ -0,0 +1,185 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the skills module in the dimos package.""" + +import unittest +from unittest import mock + +import tests.test_header + +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.robot.robot import MockRobot +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.types.constants import Colors +from dimos.agents.agent import OpenAIAgent + + +class TestSkill(AbstractSkill): + """A test skill that tracks its execution for testing purposes.""" + + _called: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._called = False + + def __call__(self): + self._called = True + return "TestSkill executed successfully" + + +class SkillLibraryTest(unittest.TestCase): + """Tests for the SkillLibrary functionality.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.robot = MockRobot() + self.skill_library = MyUnitreeSkills(robot=self.robot) + self.skill_library.initialize_skills() + + def test_skill_iteration(self): + """Test that skills can be properly iterated in the skill library.""" + skills_count = 0 + for skill in self.skill_library: + skills_count += 1 + self.assertTrue(hasattr(skill, "__name__")) + self.assertTrue(issubclass(skill, AbstractSkill)) + + self.assertGreater(skills_count, 0, "Skill library should contain at least one skill") + + def test_skill_registration(self): + """Test that skills can be properly registered in the skill library.""" + # Clear existing skills for isolated test + self.skill_library = MyUnitreeSkills(robot=self.robot) + original_count = len(list(self.skill_library)) + + # Add a custom test skill + test_skill = TestSkill + self.skill_library.add(test_skill) + + # Verify the skill was added + new_count = len(list(self.skill_library)) + self.assertEqual(new_count, original_count + 1) + + # Check if the skill can be found by name + found = False + for skill in self.skill_library: + if skill.__name__ == "TestSkill": + found = True + break + self.assertTrue(found, "Added skill should be found in skill library") + + def test_skill_direct_execution(self): + """Test that a skill can be executed directly.""" + test_skill = TestSkill() + self.assertFalse(test_skill._called) + result = test_skill() + self.assertTrue(test_skill._called) + self.assertEqual(result, "TestSkill executed successfully") + + def test_skill_library_execution(self): + """Test that a skill can be executed through the skill library.""" + # Add our test skill to the library + test_skill = TestSkill + self.skill_library.add(test_skill) + + # Create an instance to confirm it was executed + with mock.patch.object(TestSkill, "__call__", return_value="Success") as mock_call: + result = self.skill_library.call("TestSkill") + mock_call.assert_called_once() + self.assertEqual(result, "Success") + + def test_skill_not_found(self): + """Test that calling a non-existent skill raises an appropriate error.""" + with self.assertRaises(ValueError): + self.skill_library.call("NonExistentSkill") + + +class SkillWithAgentTest(unittest.TestCase): + """Tests for skills used with an agent.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.robot = MockRobot() + self.skill_library = MyUnitreeSkills(robot=self.robot) + self.skill_library.initialize_skills() + + # Add a test skill + self.skill_library.add(TestSkill) + + # Create the agent + self.agent = OpenAIAgent( + dev_name="SkillTestAgent", + system_query="You are a skill testing agent. When prompted to perform an action, use the appropriate skill.", + skills=self.skill_library, + ) + + @mock.patch("dimos.agents.agent.OpenAIAgent.run_observable_query") + def test_agent_skill_identification(self, mock_query): + """Test that the agent can identify skills based on natural language.""" + # Mock the agent response + mock_response = mock.MagicMock() + mock_response.run.return_value = "I found the TestSkill and executed it." + mock_query.return_value = mock_response + + # Run the test + response = self.agent.run_observable_query("Please run the test skill").run() + + # Assertions + mock_query.assert_called_once_with("Please run the test skill") + self.assertEqual(response, "I found the TestSkill and executed it.") + + @mock.patch.object(TestSkill, "__call__") + @mock.patch("dimos.agents.agent.OpenAIAgent.run_observable_query") + def test_agent_skill_execution(self, mock_query, mock_skill_call): + """Test that the agent can execute skills properly.""" + # Mock the agent and skill call + mock_skill_call.return_value = "TestSkill executed successfully" + mock_response = mock.MagicMock() + mock_response.run.return_value = "Executed TestSkill successfully." + mock_query.return_value = mock_response + + # Run the test + response = self.agent.run_observable_query("Execute the TestSkill skill").run() + + # We can't directly verify the skill was called since our mocking setup + # doesn't capture the internal skill execution of the agent, but we can + # verify the agent was properly called + mock_query.assert_called_once_with("Execute the TestSkill skill") + self.assertEqual(response, "Executed TestSkill successfully.") + + def test_agent_multi_skill_registration(self): + """Test that multiple skills can be registered with an agent.""" + + # Create a new skill + class AnotherTestSkill(AbstractSkill): + def __call__(self): + return "Another test skill executed" + + # Register the new skill + initial_count = len(list(self.skill_library)) + self.skill_library.add(AnotherTestSkill) + + # Verify two distinct skills now exist + self.assertEqual(len(list(self.skill_library)), initial_count + 1) + + # Verify both skills are found by name + skill_names = [skill.__name__ for skill in self.skill_library] + self.assertIn("TestSkill", skill_names) + self.assertIn("AnotherTestSkill", skill_names) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/tests/test_skills_rest.py b/build/lib/tests/test_skills_rest.py new file mode 100644 index 0000000000..70a15fcfd5 --- /dev/null +++ b/build/lib/tests/test_skills_rest.py @@ -0,0 +1,73 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +from textwrap import dedent +from dimos.skills.skills import SkillLibrary + +from dotenv import load_dotenv +from dimos.agents.claude_agent import ClaudeAgent +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.skills.rest.rest import GenericRestSkill +import reactivex as rx +import reactivex.operators as ops + +# Load API key from environment +load_dotenv() + +# Create a skill library and add the GenericRestSkill +skills = SkillLibrary() +skills.add(GenericRestSkill) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) + +# Create a text stream for agent responses in the web interface +text_streams = { + "agent_responses": agent_response_stream, +} +web_interface = RobotWebInterface(port=5555, text_streams=text_streams) + +# Create a ClaudeAgent instance +agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query=dedent( + """ + You are a virtual agent. When given a query, respond by using + the appropriate tool calls if needed to execute commands on the robot. + + IMPORTANT: + Only return the response directly asked of the user. E.G. if the user asks for the time, + only return the time. If the user asks for the weather, only return the weather. + """ + ), + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=2000, +) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +# Start the web interface +web_interface.run() + +# Run this query in the web interface: +# +# Make a web request to nist to get the current time. +# You should use http://worldclockapi.com/api/json/utc/now +# diff --git a/build/lib/tests/test_spatial_memory.py b/build/lib/tests/test_spatial_memory.py new file mode 100644 index 0000000000..b400749cb4 --- /dev/null +++ b/build/lib/tests/test_spatial_memory.py @@ -0,0 +1,297 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import pickle +import numpy as np +import cv2 +import matplotlib.pyplot as plt +from matplotlib.patches import Circle +import reactivex +from reactivex import operators as ops +import chromadb + +from dimos.agents.memory.visual_memory import VisualMemory + +import tests.test_header + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.perception.spatial_perception import SpatialMemory + + +def extract_position(transform): + """Extract position coordinates from a transform message""" + if transform is None: + return (0, 0, 0) + + pos = transform.transform.translation + return (pos.x, pos.y, pos.z) + + +def setup_persistent_chroma_db(db_path="chromadb_data"): + """ + Set up a persistent ChromaDB database at the specified path. + + Args: + db_path: Path to store the ChromaDB database + + Returns: + The ChromaDB client instance + """ + # Create a persistent ChromaDB client + full_db_path = os.path.join("/home/stash/dimensional/dimos/assets/test_spatial_memory", db_path) + print(f"Setting up persistent ChromaDB at: {full_db_path}") + + # Ensure the directory exists + os.makedirs(full_db_path, exist_ok=True) + + return chromadb.PersistentClient(path=full_db_path) + + +def main(): + print("Starting spatial memory test...") + + # Initialize ROS control and robot + ros_control = UnitreeROSControl(node_name="spatial_memory_test", mock_connection=False) + + robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP")) + + # Create counters for tracking + frame_count = 0 + transform_count = 0 + stored_count = 0 + + print("Setting up video stream...") + video_stream = robot.get_ros_video_stream() + + # Create transform stream at 1 Hz + print("Setting up transform stream...") + transform_stream = ros_control.get_transform_stream( + child_frame="map", + parent_frame="base_link", + rate_hz=1.0, # 1 transform per second + ) + + # Setup output directory for visual memory + visual_memory_dir = "/home/stash/dimensional/dimos/assets/test_spatial_memory" + os.makedirs(visual_memory_dir, exist_ok=True) + + # Setup persistent storage path for visual memory + visual_memory_path = os.path.join(visual_memory_dir, "visual_memory.pkl") + + # Try to load existing visual memory if it exists + if os.path.exists(visual_memory_path): + try: + print(f"Loading existing visual memory from {visual_memory_path}...") + visual_memory = VisualMemory.load(visual_memory_path, output_dir=visual_memory_dir) + print(f"Loaded {visual_memory.count()} images from previous runs") + except Exception as e: + print(f"Error loading visual memory: {e}") + visual_memory = VisualMemory(output_dir=visual_memory_dir) + else: + print("No existing visual memory found. Starting with empty visual memory.") + visual_memory = VisualMemory(output_dir=visual_memory_dir) + + # Setup a persistent database for ChromaDB + db_client = setup_persistent_chroma_db() + + # Create spatial perception instance with persistent storage + print("Creating SpatialMemory with persistent vector database...") + spatial_memory = SpatialMemory( + collection_name="test_spatial_memory", + min_distance_threshold=1, # Store frames every 1 meter + min_time_threshold=1, # Store frames at least every 1 second + chroma_client=db_client, # Use the persistent client + visual_memory=visual_memory, # Use the visual memory we loaded or created + ) + + # Combine streams using combine_latest + # This will pair up items properly without buffering + combined_stream = reactivex.combine_latest(video_stream, transform_stream).pipe( + ops.map( + lambda pair: { + "frame": pair[0], # First element is the frame + "position": extract_position(pair[1]), # Second element is the transform + } + ) + ) + + # Process with spatial memory + result_stream = spatial_memory.process_stream(combined_stream) + + # Simple callback to track stored frames and save them to the assets directory + def on_stored_frame(result): + nonlocal stored_count + # Only count actually stored frames (not debug frames) + if not result.get("stored", True) == False: + stored_count += 1 + pos = result["position"] + print(f"\nStored frame #{stored_count} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})") + + # Save the frame to the assets directory + if "frame" in result: + frame_filename = f"/home/stash/dimensional/dimos/assets/test_spatial_memory/frame_{stored_count:03d}.jpg" + cv2.imwrite(frame_filename, result["frame"]) + print(f"Saved frame to {frame_filename}") + + # Subscribe to results + print("Subscribing to spatial perception results...") + result_subscription = result_stream.subscribe(on_stored_frame) + + print("\nRunning until interrupted...") + try: + while True: + time.sleep(1.0) + print(f"Running: {stored_count} frames stored so far", end="\r") + except KeyboardInterrupt: + print("\nTest interrupted by user") + finally: + # Clean up resources + print("\nCleaning up...") + if "result_subscription" in locals(): + result_subscription.dispose() + + # Visualize spatial memory with multiple object queries + visualize_spatial_memory_with_objects( + spatial_memory, + objects=[ + "kitchen", + "conference room", + "vacuum", + "office", + "bathroom", + "boxes", + "telephone booth", + ], + output_filename="spatial_memory_map.png", + ) + + # Save visual memory to disk for later use + saved_path = spatial_memory.vector_db.visual_memory.save("visual_memory.pkl") + print(f"Saved {spatial_memory.vector_db.visual_memory.count()} images to disk at {saved_path}") + + +def visualize_spatial_memory_with_objects( + spatial_memory, objects, output_filename="spatial_memory_map.png" +): + """ + Visualize a spatial memory map with multiple labeled objects. + + Args: + spatial_memory: SpatialMemory instance + objects: List of object names to query and visualize (e.g. ["kitchen", "office"]) + output_filename: Filename to save the visualization + """ + # Define colors for different objects - will cycle through these + colors = ["red", "green", "orange", "purple", "brown", "cyan", "magenta", "yellow"] + + # Get all stored locations for background + locations = spatial_memory.vector_db.get_all_locations() + if not locations: + print("No locations stored in spatial memory.") + return + + # Extract coordinates from all stored locations + if len(locations[0]) >= 3: + x_coords = [loc[0] for loc in locations] + y_coords = [loc[1] for loc in locations] + else: + x_coords, y_coords = zip(*locations) + + # Create figure + plt.figure(figsize=(12, 10)) + + # Plot all points in blue + plt.scatter(x_coords, y_coords, c="blue", s=50, alpha=0.5, label="All Frames") + + # Container for all object coordinates + object_coords = {} + + # Query for each object and store the result + for i, obj in enumerate(objects): + color = colors[i % len(colors)] # Cycle through colors + print(f"\nProcessing {obj} query for visualization...") + + # Get best match for this object + results = spatial_memory.query_by_text(obj, limit=1) + if not results: + print(f"No results found for '{obj}'") + continue + + # Get the first (best) result + result = results[0] + metadata = result["metadata"] + + # Extract coordinates from the first metadata item + if isinstance(metadata, list) and metadata: + metadata = metadata[0] + + if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: + x = metadata.get("x", 0) + y = metadata.get("y", 0) + + # Store coordinates for this object + object_coords[obj] = (x, y) + + # Plot this object's position + plt.scatter([x], [y], c=color, s=100, alpha=0.8, label=obj.title()) + + # Add annotation + obj_abbrev = obj[0].upper() if len(obj) > 0 else "X" + plt.annotate( + f"{obj_abbrev}", (x, y), textcoords="offset points", xytext=(0, 10), ha="center" + ) + + # Save the image to a file using the object name + if "image" in result and result["image"] is not None: + # Clean the object name to make it suitable for a filename + clean_name = obj.replace(" ", "_").lower() + output_img_filename = f"{clean_name}_result.jpg" + cv2.imwrite(output_img_filename, result["image"]) + print(f"Saved {obj} image to {output_img_filename}") + + # Finalize the plot + plt.title("Spatial Memory Map with Query Results") + plt.xlabel("X Position (m)") + plt.ylabel("Y Position (m)") + plt.grid(True) + plt.axis("equal") + plt.legend() + + # Add origin circle + plt.gca().add_patch(Circle((0, 0), 1.0, fill=False, color="blue", linestyle="--")) + + # Save the visualization + plt.savefig(output_filename, dpi=300) + print(f"Saved enhanced map visualization to {output_filename}") + + return object_coords + + # Final cleanup + print("Performing final cleanup...") + spatial_memory.cleanup() + + try: + robot.cleanup() + except Exception as e: + print(f"Error during robot cleanup: {e}") + + print("Test completed successfully") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_spatial_memory_query.py b/build/lib/tests/test_spatial_memory_query.py new file mode 100644 index 0000000000..a0e77e9444 --- /dev/null +++ b/build/lib/tests/test_spatial_memory_query.py @@ -0,0 +1,297 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test script for querying an existing spatial memory database + +Usage: + python test_spatial_memory_query.py --query "kitchen table" --limit 5 --threshold 0.7 --save-all + python test_spatial_memory_query.py --query "robot" --limit 3 --save-one +""" + +import os +import sys +import argparse +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import chromadb +from datetime import datetime + +import tests.test_header +from dimos.perception.spatial_perception import SpatialMemory +from dimos.agents.memory.visual_memory import VisualMemory + + +def setup_persistent_chroma_db(db_path): + """Set up a persistent ChromaDB client at the specified path.""" + print(f"Setting up persistent ChromaDB at: {db_path}") + os.makedirs(db_path, exist_ok=True) + return chromadb.PersistentClient(path=db_path) + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Query spatial memory database.") + parser.add_argument( + "--query", type=str, default=None, help="Text query to search for (e.g., 'kitchen table')" + ) + parser.add_argument("--limit", type=int, default=3, help="Maximum number of results to return") + parser.add_argument( + "--threshold", + type=float, + default=None, + help="Similarity threshold (0.0-1.0). Only return results above this threshold.", + ) + parser.add_argument("--save-all", action="store_true", help="Save all result images") + parser.add_argument("--save-one", action="store_true", help="Save only the best matching image") + parser.add_argument( + "--visualize", + action="store_true", + help="Create a visualization of all stored memory locations", + ) + parser.add_argument( + "--db-path", + type=str, + default="/home/stash/dimensional/dimos/assets/test_spatial_memory/chromadb_data", + help="Path to ChromaDB database", + ) + parser.add_argument( + "--visual-memory-path", + type=str, + default="/home/stash/dimensional/dimos/assets/test_spatial_memory/visual_memory.pkl", + help="Path to visual memory file", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + print("Loading existing spatial memory database for querying...") + + # Setup the persistent ChromaDB client + db_client = setup_persistent_chroma_db(args.db_path) + + # Setup output directory for any saved results + output_dir = os.path.dirname(args.visual_memory_path) + + # Load the visual memory + print(f"Loading visual memory from {args.visual_memory_path}...") + if os.path.exists(args.visual_memory_path): + visual_memory = VisualMemory.load(args.visual_memory_path, output_dir=output_dir) + print(f"Loaded {visual_memory.count()} images from visual memory") + else: + visual_memory = VisualMemory(output_dir=output_dir) + print("No existing visual memory found. Query results won't include images.") + + # Create SpatialMemory with the existing database and visual memory + spatial_memory = SpatialMemory( + collection_name="test_spatial_memory", chroma_client=db_client, visual_memory=visual_memory + ) + + # Create a visualization if requested + if args.visualize: + print("\nCreating visualization of spatial memory...") + common_objects = [ + "kitchen", + "conference room", + "vacuum", + "office", + "bathroom", + "boxes", + "telephone booth", + ] + visualize_spatial_memory_with_objects( + spatial_memory, objects=common_objects, output_filename="spatial_memory_map.png" + ) + + # Handle query if provided + if args.query: + query = args.query + limit = args.limit + print(f"\nQuerying for: '{query}' (limit: {limit})...") + + # Run the query + results = spatial_memory.query_by_text(query, limit=limit) + + if not results: + print(f"No results found for query: '{query}'") + return + + # Filter by threshold if specified + if args.threshold is not None: + print(f"Filtering results with similarity threshold: {args.threshold}") + filtered_results = [] + for result in results: + # Distance is inverse of similarity (0 is perfect match) + # Convert to similarity score (1.0 is perfect match) + similarity = 1.0 - ( + result.get("distance", 0) if result.get("distance") is not None else 0 + ) + if similarity >= args.threshold: + filtered_results.append((result, similarity)) + + # Sort by similarity (highest first) + filtered_results.sort(key=lambda x: x[1], reverse=True) + + if not filtered_results: + print(f"No results met the similarity threshold of {args.threshold}") + return + + print(f"Found {len(filtered_results)} results above threshold") + results_with_scores = filtered_results + else: + # Add similarity scores for all results + results_with_scores = [] + for result in results: + similarity = 1.0 - ( + result.get("distance", 0) if result.get("distance") is not None else 0 + ) + results_with_scores.append((result, similarity)) + + # Process and display results + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + for i, (result, similarity) in enumerate(results_with_scores): + metadata = result.get("metadata", {}) + if isinstance(metadata, list) and metadata: + metadata = metadata[0] + + # Display result information + print(f"\nResult {i + 1} for '{query}':") + print(f"Similarity: {similarity:.4f} (distance: {1.0 - similarity:.4f})") + + # Extract and display position information + if isinstance(metadata, dict): + x = metadata.get("x", 0) + y = metadata.get("y", 0) + z = metadata.get("z", 0) + print(f"Position: ({x:.2f}, {y:.2f}, {z:.2f})") + if "timestamp" in metadata: + print(f"Timestamp: {metadata['timestamp']}") + if "frame_id" in metadata: + print(f"Frame ID: {metadata['frame_id']}") + + # Save image if requested and available + if "image" in result and result["image"] is not None: + # Only save first image, or all images based on flags + if args.save_one and i > 0: + continue + if not (args.save_all or args.save_one): + continue + + # Create a descriptive filename + clean_query = query.replace(" ", "_").replace("/", "_").lower() + output_filename = f"{clean_query}_result_{i + 1}_{timestamp}.jpg" + + # Save the image + cv2.imwrite(output_filename, result["image"]) + print(f"Saved image to {output_filename}") + elif "image" in result and result["image"] is None: + print("Image data not available for this result") + else: + print('No query specified. Use --query "text to search for" to run a query.') + print("Use --help to see all available options.") + + print("\nQuery completed successfully!") + + +def visualize_spatial_memory_with_objects( + spatial_memory, objects, output_filename="spatial_memory_map.png" +): + """Visualize spatial memory with labeled objects.""" + # Define colors for different objects + colors = ["red", "green", "orange", "purple", "brown", "cyan", "magenta", "yellow"] + + # Get all stored locations for background + locations = spatial_memory.vector_db.get_all_locations() + if not locations: + print("No locations stored in spatial memory.") + return + + # Extract coordinates + if len(locations[0]) >= 3: + x_coords = [loc[0] for loc in locations] + y_coords = [loc[1] for loc in locations] + else: + x_coords, y_coords = zip(*locations) + + # Create figure + plt.figure(figsize=(12, 10)) + plt.scatter(x_coords, y_coords, c="blue", s=50, alpha=0.5, label="All Frames") + + # Container for object coordinates + object_coords = {} + + # Query for each object + for i, obj in enumerate(objects): + color = colors[i % len(colors)] + print(f"Processing {obj} query for visualization...") + + # Get best match + results = spatial_memory.query_by_text(obj, limit=1) + if not results: + print(f"No results found for '{obj}'") + continue + + # Process result + result = results[0] + metadata = result["metadata"] + + if isinstance(metadata, list) and metadata: + metadata = metadata[0] + + if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: + x = metadata.get("x", 0) + y = metadata.get("y", 0) + + # Store coordinates + object_coords[obj] = (x, y) + + # Plot position + plt.scatter([x], [y], c=color, s=100, alpha=0.8, label=obj.title()) + + # Add annotation + obj_abbrev = obj[0].upper() if len(obj) > 0 else "X" + plt.annotate( + f"{obj_abbrev}", (x, y), textcoords="offset points", xytext=(0, 10), ha="center" + ) + + # Save image if available + if "image" in result and result["image"] is not None: + clean_name = obj.replace(" ", "_").lower() + output_img_filename = f"{clean_name}_result.jpg" + cv2.imwrite(output_img_filename, result["image"]) + print(f"Saved {obj} image to {output_img_filename}") + + # Finalize plot + plt.title("Spatial Memory Map with Query Results") + plt.xlabel("X Position (m)") + plt.ylabel("Y Position (m)") + plt.grid(True) + plt.axis("equal") + plt.legend() + + # Add origin marker + plt.gca().add_patch(plt.Circle((0, 0), 1.0, fill=False, color="blue", linestyle="--")) + + # Save visualization + plt.savefig(output_filename, dpi=300) + print(f"Saved visualization to {output_filename}") + + return object_coords + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_standalone_chromadb.py b/build/lib/tests/test_standalone_chromadb.py new file mode 100644 index 0000000000..a5dc0e9b73 --- /dev/null +++ b/build/lib/tests/test_standalone_chromadb.py @@ -0,0 +1,87 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +# ----- + +import chromadb +from langchain_openai import OpenAIEmbeddings +from langchain_chroma import Chroma + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +if not OPENAI_API_KEY: + raise Exception("OpenAI key not specified.") + +collection_name = "my_collection" + +embeddings = OpenAIEmbeddings( + model="text-embedding-3-large", + dimensions=1024, + api_key=OPENAI_API_KEY, +) + +db_connection = Chroma( + collection_name=collection_name, + embedding_function=embeddings, +) + + +def add_vector(vector_id, vector_data): + """Add a vector to the ChromaDB collection.""" + if not db_connection: + raise Exception("Collection not initialized. Call connect() first.") + db_connection.add_texts( + ids=[vector_id], + texts=[vector_data], + metadatas=[{"name": vector_id}], + ) + + +add_vector("id0", "Food") +add_vector("id1", "Cat") +add_vector("id2", "Mouse") +add_vector("id3", "Bike") +add_vector("id4", "Dog") +add_vector("id5", "Tricycle") +add_vector("id6", "Car") +add_vector("id7", "Horse") +add_vector("id8", "Vehicle") +add_vector("id6", "Red") +add_vector("id7", "Orange") +add_vector("id8", "Yellow") + + +def get_vector(vector_id): + """Retrieve a vector from the ChromaDB by its identifier.""" + result = db_connection.get(include=["embeddings"], ids=[vector_id]) + return result + + +print(get_vector("id1")) +# print(get_vector("id3")) +# print(get_vector("id0")) +# print(get_vector("id2")) + + +def query(query_texts, n_results=2): + """Query the collection with a specific text and return up to n results.""" + if not db_connection: + raise Exception("Collection not initialized. Call connect() first.") + return db_connection.similarity_search(query=query_texts, k=n_results) + + +results = query("Colors") +print(results) diff --git a/build/lib/tests/test_standalone_fastapi.py b/build/lib/tests/test_standalone_fastapi.py new file mode 100644 index 0000000000..6fac013546 --- /dev/null +++ b/build/lib/tests/test_standalone_fastapi.py @@ -0,0 +1,81 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +import logging + +logging.basicConfig(level=logging.DEBUG) + +from fastapi import FastAPI, Response +import cv2 +import uvicorn +from starlette.responses import StreamingResponse + +app = FastAPI() + +# Note: Chrome does not allow for loading more than 6 simultaneous +# video streams. Use Safari or another browser for utilizing +# multiple simultaneous streams. Possibly build out functionality +# that will stop live streams. + + +@app.get("/") +async def root(): + pid = os.getpid() # Get the current process ID + return {"message": f"Video Streaming Server, PID: {pid}"} + + +def video_stream_generator(): + pid = os.getpid() + print(f"Stream initiated by worker with PID: {pid}") # Log the PID when the generator is called + + # Use the correct path for your video source + cap = cv2.VideoCapture( + f"{os.getcwd()}/assets/trimmed_video_480p.mov" + ) # Change 0 to a filepath for video files + + if not cap.isOpened(): + yield (b"--frame\r\nContent-Type: text/plain\r\n\r\n" + b"Could not open video source\r\n") + return + + try: + while True: + ret, frame = cap.read() + # If frame is read correctly ret is True + if not ret: + print(f"Reached the end of the video, restarting... PID: {pid}") + cap.set( + cv2.CAP_PROP_POS_FRAMES, 0 + ) # Set the position of the next video frame to 0 (the beginning) + continue + _, buffer = cv2.imencode(".jpg", frame) + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + buffer.tobytes() + b"\r\n") + finally: + cap.release() + + +@app.get("/video") +async def video_endpoint(): + logging.debug("Attempting to open video stream.") + response = StreamingResponse( + video_stream_generator(), media_type="multipart/x-mixed-replace; boundary=frame" + ) + logging.debug("Streaming response set up.") + return response + + +if __name__ == "__main__": + uvicorn.run("__main__:app", host="0.0.0.0", port=5555, workers=20) diff --git a/build/lib/tests/test_standalone_hugging_face.py b/build/lib/tests/test_standalone_hugging_face.py new file mode 100644 index 0000000000..d0b2e68e61 --- /dev/null +++ b/build/lib/tests/test_standalone_hugging_face.py @@ -0,0 +1,147 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +# from transformers import AutoModelForCausalLM, AutoTokenizer + +# model_name = "Qwen/QwQ-32B" + +# model = AutoModelForCausalLM.from_pretrained( +# model_name, +# torch_dtype="auto", +# device_map="auto" +# ) +# tokenizer = AutoTokenizer.from_pretrained(model_name) + +# prompt = "How many r's are in the word \"strawberry\"" +# messages = [ +# {"role": "user", "content": prompt} +# ] +# text = tokenizer.apply_chat_template( +# messages, +# tokenize=False, +# add_generation_prompt=True +# ) + +# model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + +# generated_ids = model.generate( +# **model_inputs, +# max_new_tokens=32768 +# ) +# generated_ids = [ +# output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) +# ] + +# response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] +# print(response) + +# ----------------------------------------------------------------------------- + +# import requests +# import json + +# API_URL = "https://api-inference.huggingface.co/models/Qwen/QwQ-32B" +# api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') + +# HEADERS = {"Authorization": f"Bearer {api_key}"} + +# prompt = "How many r's are in the word \"strawberry\"" +# messages = [ +# {"role": "user", "content": prompt} +# ] + +# # Format the prompt in the desired chat format +# chat_template = ( +# f"{messages[0]['content']}\n" +# "Assistant:" +# ) + +# payload = { +# "inputs": chat_template, +# "parameters": { +# "max_new_tokens": 32768, +# "temperature": 0.7 +# } +# } + +# # API request +# response = requests.post(API_URL, headers=HEADERS, json=payload) + +# # Handle response +# if response.status_code == 200: +# output = response.json()[0]['generated_text'] +# print(output.strip()) +# else: +# print(f"Error {response.status_code}: {response.text}") + +# ----------------------------------------------------------------------------- + +# import os +# import requests +# import time + +# API_URL = "https://api-inference.huggingface.co/models/Qwen/QwQ-32B" +# api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') + +# HEADERS = {"Authorization": f"Bearer {api_key}"} + +# def query_with_retries(payload, max_retries=5, delay=15): +# for attempt in range(max_retries): +# response = requests.post(API_URL, headers=HEADERS, json=payload) +# if response.status_code == 200: +# return response.json()[0]['generated_text'] +# elif response.status_code == 500: # Service unavailable +# print(f"Attempt {attempt + 1}/{max_retries}: Model busy. Retrying in {delay} seconds...") +# time.sleep(delay) +# else: +# print(f"Error {response.status_code}: {response.text}") +# break +# return "Failed after multiple retries." + +# prompt = "How many r's are in the word \"strawberry\"" +# messages = [{"role": "user", "content": prompt}] +# chat_template = f"{messages[0]['content']}\nAssistant:" + +# payload = { +# "inputs": chat_template, +# "parameters": {"max_new_tokens": 32768, "temperature": 0.7} +# } + +# output = query_with_retries(payload) +# print(output.strip()) + +# ----------------------------------------------------------------------------- + +import os +from huggingface_hub import InferenceClient + +# Use environment variable for API key +api_key = os.getenv("HUGGINGFACE_ACCESS_TOKEN") + +client = InferenceClient( + provider="hf-inference", + api_key=api_key, +) + +messages = [{"role": "user", "content": 'How many r\'s are in the word "strawberry"'}] + +completion = client.chat.completions.create( + model="Qwen/QwQ-32B", + messages=messages, + max_tokens=150, +) + +print(completion.choices[0].message) diff --git a/build/lib/tests/test_standalone_openai_json.py b/build/lib/tests/test_standalone_openai_json.py new file mode 100644 index 0000000000..ef839ae85b --- /dev/null +++ b/build/lib/tests/test_standalone_openai_json.py @@ -0,0 +1,108 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +# ----- + +import dotenv + +dotenv.load_dotenv() + +import json +from textwrap import dedent +from openai import OpenAI +from pydantic import BaseModel + +MODEL = "gpt-4o-2024-08-06" + +math_tutor_prompt = """ + You are a helpful math tutor. You will be provided with a math problem, + and your goal will be to output a step by step solution, along with a final answer. + For each step, just provide the output as an equation use the explanation field to detail the reasoning. +""" + +bad_prompt = """ + Follow the instructions. +""" + +client = OpenAI() + + +class MathReasoning(BaseModel): + class Step(BaseModel): + explanation: str + output: str + + steps: list[Step] + final_answer: str + + +def get_math_solution(question: str): + completion = client.beta.chat.completions.parse( + model=MODEL, + messages=[ + {"role": "system", "content": dedent(bad_prompt)}, + {"role": "user", "content": question}, + ], + response_format=MathReasoning, + ) + return completion.choices[0].message + + +# Web Server +import http.server +import socketserver +import urllib.parse + +PORT = 5555 + + +class CustomHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): + # Parse query parameters from the URL + parsed_path = urllib.parse.urlparse(self.path) + query_params = urllib.parse.parse_qs(parsed_path.query) + + # Check for a specific query parameter, e.g., 'problem' + problem = query_params.get("problem", [""])[ + 0 + ] # Default to an empty string if 'problem' isn't provided + + if problem: + print(f"Problem: {problem}") + solution = get_math_solution(problem) + + if solution.refusal: + print(f"Refusal: {solution.refusal}") + + print(f"Solution: {solution}") + self.send_response(200) + else: + solution = json.dumps( + {"error": "Please provide a math problem using the 'problem' query parameter."} + ) + self.send_response(400) + + self.send_header("Content-type", "application/json; charset=utf-8") + self.end_headers() + + # Write the message content + self.wfile.write(str(solution).encode()) + + +with socketserver.TCPServer(("", PORT), CustomHandler) as httpd: + print(f"Serving at port {PORT}") + httpd.serve_forever() diff --git a/build/lib/tests/test_standalone_openai_json_struct.py b/build/lib/tests/test_standalone_openai_json_struct.py new file mode 100644 index 0000000000..1b49aed8a7 --- /dev/null +++ b/build/lib/tests/test_standalone_openai_json_struct.py @@ -0,0 +1,92 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +# ----- + +from typing import List, Union, Dict + +import dotenv + +dotenv.load_dotenv() + +from textwrap import dedent +from openai import OpenAI +from pydantic import BaseModel + +MODEL = "gpt-4o-2024-08-06" + +math_tutor_prompt = """ + You are a helpful math tutor. You will be provided with a math problem, + and your goal will be to output a step by step solution, along with a final answer. + For each step, just provide the output as an equation use the explanation field to detail the reasoning. +""" + +general_prompt = """ + Follow the instructions. Output a step by step solution, along with a final answer. Use the explanation field to detail the reasoning. +""" + +client = OpenAI() + + +class MathReasoning(BaseModel): + class Step(BaseModel): + explanation: str + output: str + + steps: list[Step] + final_answer: str + + +def get_math_solution(question: str): + prompt = general_prompt + completion = client.beta.chat.completions.parse( + model=MODEL, + messages=[ + {"role": "system", "content": dedent(prompt)}, + {"role": "user", "content": question}, + ], + response_format=MathReasoning, + ) + return completion.choices[0].message + + +# Define Problem +problem = "What is the derivative of 3x^2" +print(f"Problem: {problem}") + +# Query for result +solution = get_math_solution(problem) + +# If the query was refused +if solution.refusal: + print(f"Refusal: {solution.refusal}") + exit() + +# If we were able to successfully parse the response back +parsed_solution = solution.parsed +if not parsed_solution: + print(f"Unable to Parse Solution") + exit() + +# Print solution from class definitions +print(f"Parsed: {parsed_solution}") + +steps = parsed_solution.steps +print(f"Steps: {steps}") + +final_answer = parsed_solution.final_answer +print(f"Final Answer: {final_answer}") diff --git a/build/lib/tests/test_standalone_openai_json_struct_func.py b/build/lib/tests/test_standalone_openai_json_struct_func.py new file mode 100644 index 0000000000..dcea40ffff --- /dev/null +++ b/build/lib/tests/test_standalone_openai_json_struct_func.py @@ -0,0 +1,177 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +# ----- + +from typing import List, Union, Dict + +import dotenv + +dotenv.load_dotenv() + +import json +import requests +from textwrap import dedent +from openai import OpenAI, pydantic_function_tool +from pydantic import BaseModel, Field + +MODEL = "gpt-4o-2024-08-06" + +math_tutor_prompt = """ + You are a helpful math tutor. You will be provided with a math problem, + and your goal will be to output a step by step solution, along with a final answer. + For each step, just provide the output as an equation use the explanation field to detail the reasoning. +""" + +general_prompt = """ + Follow the instructions. Output a step by step solution, along with a final answer. Use the explanation field to detail the reasoning. +""" + +client = OpenAI() + + +class MathReasoning(BaseModel): + class Step(BaseModel): + explanation: str + output: str + + steps: list[Step] + final_answer: str + + +# region Function Calling +class GetWeather(BaseModel): + latitude: str = Field(..., description="latitude e.g. Bogotá, Colombia") + longitude: str = Field(..., description="longitude e.g. Bogotá, Colombia") + + +def get_weather(latitude, longitude): + response = requests.get( + f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit" + ) + data = response.json() + return data["current"]["temperature_2m"] + + +def get_tools(): + return [pydantic_function_tool(GetWeather)] + + +tools = get_tools() + + +def call_function(name, args): + if name == "get_weather": + print(f"Running function: {name}") + print(f"Arguments are: {args}") + return get_weather(**args) + elif name == "GetWeather": + print(f"Running function: {name}") + print(f"Arguments are: {args}") + return get_weather(**args) + else: + return f"Local function not found: {name}" + + +def callback(message, messages, response_message, tool_calls): + if message is None or message.tool_calls is None: + print("No message or tools were called.") + return + + has_called_tools = False + for tool_call in message.tool_calls: + messages.append(response_message) + + has_called_tools = True + name = tool_call.function.name + args = json.loads(tool_call.function.arguments) + + result = call_function(name, args) + print(f"Function Call Results: {result}") + + messages.append( + {"role": "tool", "tool_call_id": tool_call.id, "content": str(result), "name": name} + ) + + # Complete the second call, after the functions have completed. + if has_called_tools: + print("Sending Second Query.") + completion_2 = client.beta.chat.completions.parse( + model=MODEL, + messages=messages, + response_format=MathReasoning, + tools=tools, + ) + print(f"Message: {completion_2.choices[0].message}") + return completion_2.choices[0].message + else: + print("No Need for Second Query.") + return None + + +# endregion Function Calling + + +def get_math_solution(question: str): + prompt = general_prompt + messages = [ + {"role": "system", "content": dedent(prompt)}, + {"role": "user", "content": question}, + ] + response = client.beta.chat.completions.parse( + model=MODEL, messages=messages, response_format=MathReasoning, tools=tools + ) + + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + + new_response = callback(response.choices[0].message, messages, response_message, tool_calls) + + return new_response or response.choices[0].message + + +# Define Problem +problems = ["What is the derivative of 3x^2", "What's the weather like in San Fran today?"] +problem = problems[0] + +for problem in problems: + print("================") + print(f"Problem: {problem}") + + # Query for result + solution = get_math_solution(problem) + + # If the query was refused + if solution.refusal: + print(f"Refusal: {solution.refusal}") + break + + # If we were able to successfully parse the response back + parsed_solution = solution.parsed + if not parsed_solution: + print(f"Unable to Parse Solution") + print(f"Solution: {solution}") + break + + # Print solution from class definitions + print(f"Parsed: {parsed_solution}") + + steps = parsed_solution.steps + print(f"Steps: {steps}") + + final_answer = parsed_solution.final_answer + print(f"Final Answer: {final_answer}") diff --git a/build/lib/tests/test_standalone_openai_json_struct_func_playground.py b/build/lib/tests/test_standalone_openai_json_struct_func_playground.py new file mode 100644 index 0000000000..f4554de6be --- /dev/null +++ b/build/lib/tests/test_standalone_openai_json_struct_func_playground.py @@ -0,0 +1,222 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +# ----- +# # Milestone 1 + + +# from typing import List, Dict, Optional +# import requests +# import json +# from pydantic import BaseModel, Field +# from openai import OpenAI, pydantic_function_tool + +# # Environment setup +# import dotenv +# dotenv.load_dotenv() + +# # Constants and prompts +# MODEL = "gpt-4o-2024-08-06" +# GENERAL_PROMPT = ''' +# Follow the instructions. Output a step by step solution, along with a final answer. +# Use the explanation field to detail the reasoning. +# ''' + +# # Initialize OpenAI client +# client = OpenAI() + +# # Models and functions +# class Step(BaseModel): +# explanation: str +# output: str + +# class MathReasoning(BaseModel): +# steps: List[Step] +# final_answer: str + +# class GetWeather(BaseModel): +# latitude: str = Field(..., description="Latitude e.g., Bogotá, Colombia") +# longitude: str = Field(..., description="Longitude e.g., Bogotá, Colombia") + +# def fetch_weather(latitude: str, longitude: str) -> Dict: +# url = f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit" +# response = requests.get(url) +# return response.json().get('current', {}) + +# # Tool management +# def get_tools() -> List[BaseModel]: +# return [pydantic_function_tool(GetWeather)] + +# def handle_function_call(tool_call: Dict) -> Optional[str]: +# if tool_call['name'] == "get_weather": +# result = fetch_weather(**tool_call['args']) +# return f"Temperature is {result['temperature_2m']}°F" +# return None + +# # Communication and processing with OpenAI +# def process_message_with_openai(question: str) -> MathReasoning: +# messages = [ +# {"role": "system", "content": GENERAL_PROMPT.strip()}, +# {"role": "user", "content": question} +# ] +# response = client.beta.chat.completions.parse( +# model=MODEL, +# messages=messages, +# response_format=MathReasoning, +# tools=get_tools() +# ) +# return response.choices[0].message + +# def get_math_solution(question: str) -> MathReasoning: +# solution = process_message_with_openai(question) +# return solution + +# # Example usage +# def main(): +# problems = [ +# "What is the derivative of 3x^2", +# "What's the weather like in San Francisco today?" +# ] +# problem = problems[1] +# print(f"Problem: {problem}") + +# solution = get_math_solution(problem) +# if not solution: +# print("Failed to get a solution.") +# return + +# if not solution.parsed: +# print("Failed to get a parsed solution.") +# print(f"Solution: {solution}") +# return + +# print(f"Steps: {solution.parsed.steps}") +# print(f"Final Answer: {solution.parsed.final_answer}") + +# if __name__ == "__main__": +# main() + + +# # Milestone 1 + +# Milestone 2 +import json +import os +import requests + +from dotenv import load_dotenv + +load_dotenv() + +from openai import OpenAI + +client = OpenAI() + + +def get_current_weather(latitude, longitude): + """Get the current weather in a given latitude and longitude using the 7Timer API""" + base = "http://www.7timer.info/bin/api.pl" + request_url = f"{base}?lon={longitude}&lat={latitude}&product=civillight&output=json" + response = requests.get(request_url) + + # Parse response to extract the main weather data + weather_data = response.json() + current_data = weather_data.get("dataseries", [{}])[0] + + result = { + "latitude": latitude, + "longitude": longitude, + "temp": current_data.get("temp2m", {"max": "Unknown", "min": "Unknown"}), + "humidity": "Unknown", + } + + # Convert the dictionary to JSON string to match the given structure + return json.dumps(result) + + +def run_conversation(content): + messages = [{"role": "user", "content": content}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given latitude and longitude", + "parameters": { + "type": "object", + "properties": { + "latitude": { + "type": "string", + "description": "The latitude of a place", + }, + "longitude": { + "type": "string", + "description": "The longitude of a place", + }, + }, + "required": ["latitude", "longitude"], + }, + }, + } + ] + response = client.chat.completions.create( + model="gpt-3.5-turbo-0125", + messages=messages, + tools=tools, + tool_choice="auto", + ) + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + + if tool_calls: + messages.append(response_message) + + available_functions = { + "get_current_weather": get_current_weather, + } + for tool_call in tool_calls: + print(f"Function: {tool_call.function.name}") + print(f"Params:{tool_call.function.arguments}") + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call( + latitude=function_args.get("latitude"), + longitude=function_args.get("longitude"), + ) + print(f"API: {function_response}") + messages.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": function_response, + } + ) + + second_response = client.chat.completions.create( + model="gpt-3.5-turbo-0125", messages=messages, stream=True + ) + return second_response + + +if __name__ == "__main__": + question = "What's the weather like in Paris and San Francisco?" + response = run_conversation(question) + for chunk in response: + print(chunk.choices[0].delta.content or "", end="", flush=True) +# Milestone 2 diff --git a/build/lib/tests/test_standalone_project_out.py b/build/lib/tests/test_standalone_project_out.py new file mode 100644 index 0000000000..22aec63bae --- /dev/null +++ b/build/lib/tests/test_standalone_project_out.py @@ -0,0 +1,141 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import sys +import os + +# ----- + +import ast +import inspect +import types +import sys + + +def extract_function_info(filename): + with open(filename, "r") as f: + source = f.read() + tree = ast.parse(source, filename=filename) + + function_info = [] + + # Use a dictionary to track functions + module_globals = {} + + # Add the source to the locals (useful if you use local functions) + exec(source, module_globals) + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + docstring = ast.get_docstring(node) or "" + + # Attempt to get the callable object from the globals + try: + if node.name in module_globals: + func_obj = module_globals[node.name] + signature = inspect.signature(func_obj) + function_info.append( + {"name": node.name, "signature": str(signature), "docstring": docstring} + ) + else: + function_info.append( + { + "name": node.name, + "signature": "Could not get signature", + "docstring": docstring, + } + ) + except TypeError as e: + print( + f"Could not get function signature for {node.name} in {filename}: {e}", + file=sys.stderr, + ) + function_info.append( + { + "name": node.name, + "signature": "Could not get signature", + "docstring": docstring, + } + ) + + class_info = [] + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + docstring = ast.get_docstring(node) or "" + methods = [] + for method in node.body: + if isinstance(method, (ast.FunctionDef, ast.AsyncFunctionDef)): + method_docstring = ast.get_docstring(method) or "" + try: + if node.name in module_globals: + class_obj = module_globals[node.name] + method_obj = getattr(class_obj, method.name) + signature = inspect.signature(method_obj) + methods.append( + { + "name": method.name, + "signature": str(signature), + "docstring": method_docstring, + } + ) + else: + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) + except AttributeError as e: + print( + f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", + file=sys.stderr, + ) + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) + except TypeError as e: + print( + f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", + file=sys.stderr, + ) + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) + class_info.append({"name": node.name, "docstring": docstring, "methods": methods}) + + return {"function_info": function_info, "class_info": class_info} + + +# Usage: +file_path = "./dimos/agents/memory/base.py" +extracted_info = extract_function_info(file_path) +print(extracted_info) + +file_path = "./dimos/agents/memory/chroma_impl.py" +extracted_info = extract_function_info(file_path) +print(extracted_info) + +file_path = "./dimos/agents/agent.py" +extracted_info = extract_function_info(file_path) +print(extracted_info) diff --git a/build/lib/tests/test_standalone_rxpy_01.py b/build/lib/tests/test_standalone_rxpy_01.py new file mode 100644 index 0000000000..733930d430 --- /dev/null +++ b/build/lib/tests/test_standalone_rxpy_01.py @@ -0,0 +1,133 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +# ----- + +import reactivex +from reactivex import operators as ops +from reactivex.scheduler import ThreadPoolScheduler +import multiprocessing +from threading import Event + +which_test = 2 +if which_test == 1: + """ + Test 1: Periodic Emission Test + + This test creates a ThreadPoolScheduler that leverages as many threads as there are CPU + cores available, optimizing the execution across multiple threads. The core functionality + revolves around an observable, secondly_emission, which emits a value every second. + Each emission is an incrementing integer, which is then mapped to a message indicating + the number of seconds since the test began. The sequence is limited to 30 emissions, + each logged as it occurs, and accompanied by an additional message via the + emission_process function to indicate the value's emission. The test subscribes to the + observable to print each emitted value, handle any potential errors, and confirm + completion of the emissions after 30 seconds. + + Key Components: + • ThreadPoolScheduler: Manages concurrency with multiple threads. + • Observable Sequence: Emits every second, indicating progression with a specific + message format. + • Subscription: Monitors and logs emissions, errors, and the completion event. + """ + + # Create a scheduler that uses as many threads as there are CPUs available + optimal_thread_count = multiprocessing.cpu_count() + pool_scheduler = ThreadPoolScheduler(optimal_thread_count) + + def emission_process(value): + print(f"Emitting: {value}") + + # Create an observable that emits every second + secondly_emission = reactivex.interval(1.0, scheduler=pool_scheduler).pipe( + ops.map(lambda x: f"Value {x} emitted after {x + 1} second(s)"), + ops.do_action(emission_process), + ops.take(30), # Limit the emission to 30 times + ) + + # Subscribe to the observable to start emitting + secondly_emission.subscribe( + on_next=lambda x: print(x), + on_error=lambda e: print(e), + on_completed=lambda: print("Emission completed."), + scheduler=pool_scheduler, + ) + +elif which_test == 2: + """ + Test 2: Combined Emission Test + + In this test, a similar ThreadPoolScheduler setup is used to handle tasks across multiple + CPU cores efficiently. This setup includes two observables. The first, secondly_emission, + emits an incrementing integer every second, indicating the passage of time. The second + observable, immediate_emission, emits a predefined sequence of characters (['a', 'b', + 'c', 'd', 'e']) repeatedly and immediately. These two streams are combined using the zip + operator, which synchronizes their emissions into pairs. Each combined pair is formatted + and logged, indicating both the time elapsed and the immediate value emitted at that + second. + + A synchronization mechanism via an Event (completed_event) ensures that the main program + thread waits until all planned emissions are completed before exiting. This test not only + checks the functionality of zipping different rhythmic emissions but also demonstrates + handling of asynchronous task completion in Python using event-driven programming. + + Key Components: + • Combined Observable Emissions: Synchronizes periodic and immediate emissions into + a single stream. + • Event Synchronization: Uses a threading event to manage program lifecycle and + ensure that all emissions are processed before shutdown. + • Complex Subscription Management: Handles errors and completion, including + setting an event to signal the end of task processing. + """ + + # Create a scheduler with optimal threads + optimal_thread_count = multiprocessing.cpu_count() + pool_scheduler = ThreadPoolScheduler(optimal_thread_count) + + # Define an event to wait for the observable to complete + completed_event = Event() + + def emission_process(value): + print(f"Emitting: {value}") + + # Observable that emits every second + secondly_emission = reactivex.interval(1.0, scheduler=pool_scheduler).pipe( + ops.map(lambda x: f"Second {x + 1}"), ops.take(30) + ) + + # Observable that emits values immediately and repeatedly + immediate_emission = reactivex.from_(["a", "b", "c", "d", "e"]).pipe(ops.repeat()) + + # Combine emissions using zip + combined_emissions = reactivex.zip(secondly_emission, immediate_emission).pipe( + ops.map(lambda combined: f"{combined[0]} - Value: {combined[1]}"), + ops.do_action(lambda s: print(f"Combined emission: {s}")), + ) + + # Subscribe to the combined emissions + combined_emissions.subscribe( + on_next=lambda x: print(x), + on_error=lambda e: print(f"Error: {e}"), + on_completed=lambda: { + print("Combined emission completed."), + completed_event.set(), # Set the event to signal completion + }, + scheduler=pool_scheduler, + ) + + # Wait for the observable to complete + completed_event.wait() diff --git a/build/lib/tests/test_unitree_agent.py b/build/lib/tests/test_unitree_agent.py new file mode 100644 index 0000000000..34c5aa335d --- /dev/null +++ b/build/lib/tests/test_unitree_agent.py @@ -0,0 +1,318 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os +import time + +from dimos.web.fastapi_server import FastAPIServer + +print(f"Current working directory: {os.getcwd()}") + +# ----- + +from dimos.agents.agent import OpenAIAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.data_provider import QueryDataProvider + +MOCK_CONNECTION = True + + +class UnitreeAgentDemo: + def __init__(self): + self.robot_ip = None + self.connection_method = None + self.serial_number = None + self.output_dir = None + self._fetch_env_vars() + + def _fetch_env_vars(self): + print("Fetching environment variables") + + def get_env_var(var_name, default=None, required=False): + """Get environment variable with validation.""" + value = os.getenv(var_name, default) + if required and not value: + raise ValueError(f"{var_name} environment variable is required") + return value + + self.robot_ip = get_env_var("ROBOT_IP", required=True) + self.connection_method = get_env_var("CONN_TYPE") + self.serial_number = get_env_var("SERIAL_NUMBER") + self.output_dir = get_env_var( + "ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros") + ) + + def _initialize_robot(self, with_video_stream=True): + print( + f"Initializing Unitree Robot {'with' if with_video_stream else 'without'} Video Stream" + ) + self.robot = UnitreeGo2( + ip=self.robot_ip, + connection_method=self.connection_method, + serial_number=self.serial_number, + output_dir=self.output_dir, + disable_video_stream=(not with_video_stream), + mock_connection=MOCK_CONNECTION, + ) + print(f"Robot initialized: {self.robot}") + + # ----- + + def run_with_queries(self): + # Initialize robot + self._initialize_robot(with_video_stream=False) + + # Initialize query stream + query_provider = QueryDataProvider() + + # Create the skills available to the agent. + # By default, this will create all skills in this class and make them available. + skills_instance = MyUnitreeSkills(robot=self.robot) + + print("Starting Unitree Perception Agent") + self.UnitreePerceptionAgent = OpenAIAgent( + dev_name="UnitreePerceptionAgent", + agent_type="Perception", + input_query_stream=query_provider.data_stream, + output_dir=self.output_dir, + skills=skills_instance, + # frame_processor=frame_processor, + ) + + # Start the query stream. + # Queries will be pushed every 1 second, in a count from 100 to 5000. + # This will cause listening agents to consume the queries and respond + # to them via skill execution and provide 1-shot responses. + query_provider.start_query_stream( + query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + frequency=0.01, + start_count=1, + end_count=10000, + step=1, + ) + + def run_with_test_video(self): + # Initialize robot + self._initialize_robot(with_video_stream=False) + + # Initialize test video stream + from dimos.stream.video_provider import VideoProvider + + self.video_stream = VideoProvider( + dev_name="UnitreeGo2", video_source=f"{os.getcwd()}/assets/framecount.mp4" + ).capture_video_as_observable(realtime=False, fps=1) + + # Get Skills + # By default, this will create all skills in this class and make them available to the agent. + skills_instance = MyUnitreeSkills(robot=self.robot) + + print("Starting Unitree Perception Agent (Test Video)") + self.UnitreePerceptionAgent = OpenAIAgent( + dev_name="UnitreePerceptionAgent", + agent_type="Perception", + input_video_stream=self.video_stream, + output_dir=self.output_dir, + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + image_detail="high", + skills=skills_instance, + # frame_processor=frame_processor, + ) + + def run_with_ros_video(self): + # Initialize robot + self._initialize_robot() + + # Initialize ROS video stream + print("Starting Unitree Perception Stream") + self.video_stream = self.robot.get_ros_video_stream() + + # Get Skills + # By default, this will create all skills in this class and make them available to the agent. + skills_instance = MyUnitreeSkills(robot=self.robot) + + # Run recovery stand + print("Running recovery stand") + self.robot.webrtc_req(api_id=1006) + + # Wait for 1 second + time.sleep(1) + + # Switch to sport mode + print("Switching to sport mode") + self.robot.webrtc_req(api_id=1011, parameter='{"gait_type": "sport"}') + + # Wait for 1 second + time.sleep(1) + + print("Starting Unitree Perception Agent (ROS Video)") + self.UnitreePerceptionAgent = OpenAIAgent( + dev_name="UnitreePerceptionAgent", + agent_type="Perception", + input_video_stream=self.video_stream, + output_dir=self.output_dir, + query="Based on the image, execute the command seen in the image AND ONLY THE COMMAND IN THE IMAGE. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + # WORKING MOVEMENT DEMO VVV + # query="Move() 5 meters foward. Then spin 360 degrees to the right, and then Reverse() 5 meters, and then Move forward 3 meters", + image_detail="high", + skills=skills_instance, + # frame_processor=frame_processor, + ) + + def run_with_multiple_query_and_test_video_agents(self): + # Initialize robot + self._initialize_robot(with_video_stream=False) + + # Initialize query stream + query_provider = QueryDataProvider() + + # Initialize test video stream + from dimos.stream.video_provider import VideoProvider + + self.video_stream = VideoProvider( + dev_name="UnitreeGo2", video_source=f"{os.getcwd()}/assets/framecount.mp4" + ).capture_video_as_observable(realtime=False, fps=1) + + # Create the skills available to the agent. + # By default, this will create all skills in this class and make them available. + skills_instance = MyUnitreeSkills(robot=self.robot) + + print("Starting Unitree Perception Agent") + self.UnitreeQueryPerceptionAgent = OpenAIAgent( + dev_name="UnitreeQueryPerceptionAgent", + agent_type="Perception", + input_query_stream=query_provider.data_stream, + output_dir=self.output_dir, + skills=skills_instance, + # frame_processor=frame_processor, + ) + + print("Starting Unitree Perception Agent Two") + self.UnitreeQueryPerceptionAgentTwo = OpenAIAgent( + dev_name="UnitreeQueryPerceptionAgentTwo", + agent_type="Perception", + input_query_stream=query_provider.data_stream, + output_dir=self.output_dir, + skills=skills_instance, + # frame_processor=frame_processor, + ) + + print("Starting Unitree Perception Agent (Test Video)") + self.UnitreeVideoPerceptionAgent = OpenAIAgent( + dev_name="UnitreeVideoPerceptionAgent", + agent_type="Perception", + input_video_stream=self.video_stream, + output_dir=self.output_dir, + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + image_detail="high", + skills=skills_instance, + # frame_processor=frame_processor, + ) + + print("Starting Unitree Perception Agent Two (Test Video)") + self.UnitreeVideoPerceptionAgentTwo = OpenAIAgent( + dev_name="UnitreeVideoPerceptionAgentTwo", + agent_type="Perception", + input_video_stream=self.video_stream, + output_dir=self.output_dir, + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + image_detail="high", + skills=skills_instance, + # frame_processor=frame_processor, + ) + + # Start the query stream. + # Queries will be pushed every 1 second, in a count from 100 to 5000. + # This will cause listening agents to consume the queries and respond + # to them via skill execution and provide 1-shot responses. + query_provider.start_query_stream( + query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + frequency=0.01, + start_count=1, + end_count=10000000, + step=1, + ) + + def run_with_queries_and_fast_api(self): + # Initialize robot + self._initialize_robot(with_video_stream=True) + + # Initialize ROS video stream + print("Starting Unitree Perception Stream") + self.video_stream = self.robot.get_ros_video_stream() + + # Initialize test video stream + # from dimos.stream.video_provider import VideoProvider + # self.video_stream = VideoProvider( + # dev_name="UnitreeGo2", + # video_source=f"{os.getcwd()}/assets/framecount.mp4" + # ).capture_video_as_observable(realtime=False, fps=1) + + # Will be visible at http://[host]:[port]/video_feed/[key] + streams = { + "unitree_video": self.video_stream, + } + fast_api_server = FastAPIServer(port=5555, **streams) + + # Create the skills available to the agent. + skills_instance = MyUnitreeSkills(robot=self.robot) + + print("Starting Unitree Perception Agent") + self.UnitreeQueryPerceptionAgent = OpenAIAgent( + dev_name="UnitreeQueryPerceptionAgent", + agent_type="Perception", + input_query_stream=fast_api_server.query_stream, + output_dir=self.output_dir, + skills=skills_instance, + ) + + # Run the FastAPI server (this will block) + fast_api_server.run() + + # ----- + + def stop(self): + print("Stopping Unitree Agent") + self.robot.cleanup() + + +if __name__ == "__main__": + myUnitreeAgentDemo = UnitreeAgentDemo() + + test_to_run = 4 + + if test_to_run == 0: + myUnitreeAgentDemo.run_with_queries() + elif test_to_run == 1: + myUnitreeAgentDemo.run_with_test_video() + elif test_to_run == 2: + myUnitreeAgentDemo.run_with_ros_video() + elif test_to_run == 3: + myUnitreeAgentDemo.run_with_multiple_query_and_test_video_agents() + elif test_to_run == 4: + myUnitreeAgentDemo.run_with_queries_and_fast_api() + elif test_to_run < 0 or test_to_run >= 5: + assert False, f"Invalid test number: {test_to_run}" + + # Keep the program running to allow the Unitree Agent Demo to operate continuously + try: + print("\nRunning Unitree Agent Demo (Press Ctrl+C to stop)...") + while True: + time.sleep(0.1) + except KeyboardInterrupt: + print("\nStopping Unitree Agent Demo") + myUnitreeAgentDemo.stop() + except Exception as e: + print(f"Error in main loop: {e}") diff --git a/build/lib/tests/test_unitree_agent_queries_fastapi.py b/build/lib/tests/test_unitree_agent_queries_fastapi.py new file mode 100644 index 0000000000..be95ea5de6 --- /dev/null +++ b/build/lib/tests/test_unitree_agent_queries_fastapi.py @@ -0,0 +1,105 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unitree Go2 robot agent demo with FastAPI server integration. + +Connects a Unitree Go2 robot to an OpenAI agent with a web interface. + +Environment Variables: + OPENAI_API_KEY: Required. OpenAI API key. + ROBOT_IP: Required. IP address of the Unitree robot. + CONN_TYPE: Required. Connection method to the robot. + ROS_OUTPUT_DIR: Optional. Directory for ROS output files. +""" + +import tests.test_header +import os +import sys +import reactivex as rx +import reactivex.operators as ops + +# Local application imports +from dimos.agents.agent import OpenAIAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.utils.logging_config import logger +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.fastapi_server import FastAPIServer + + +def main(): + # Get environment variables + robot_ip = os.getenv("ROBOT_IP") + if not robot_ip: + raise ValueError("ROBOT_IP environment variable is required") + connection_method = os.getenv("CONN_TYPE") or "webrtc" + output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) + + try: + # Initialize robot + logger.info("Initializing Unitree Robot") + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir, + skills=MyUnitreeSkills(), + ) + + # Set up video stream + logger.info("Starting video stream") + video_stream = robot.get_ros_video_stream() + + # Create FastAPI server with video stream and text streams + logger.info("Initializing FastAPI server") + streams = {"unitree_video": video_stream} + + # Create a subject for agent responses + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + text_streams = { + "agent_responses": agent_response_stream, + } + + web_interface = FastAPIServer(port=5555, text_streams=text_streams, **streams) + + logger.info("Starting action primitive execution agent") + agent = OpenAIAgent( + dev_name="UnitreeQueryExecutionAgent", + input_query_stream=web_interface.query_stream, + output_dir=output_dir, + skills=robot.get_skills(), + ) + + # Subscribe to agent responses and send them to the subject + agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + + # Start server (blocking call) + logger.info("Starting FastAPI server") + web_interface.run() + + except KeyboardInterrupt: + print("Stopping demo...") + except Exception as e: + logger.error(f"Error: {e}") + return 1 + finally: + if robot: + robot.cleanup() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/build/lib/tests/test_unitree_ros_v0.0.4.py b/build/lib/tests/test_unitree_ros_v0.0.4.py new file mode 100644 index 0000000000..e4086074cc --- /dev/null +++ b/build/lib/tests/test_unitree_ros_v0.0.4.py @@ -0,0 +1,198 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header +import os + +import time +from dotenv import load_dotenv +from dimos.agents.claude_agent import ClaudeAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal +from dimos.skills.visual_navigation_skills import FollowHuman +import reactivex as rx +import reactivex.operators as ops +from dimos.stream.audio.pipelines import tts, stt +import threading +import json +from dimos.types.vector import Vector +from dimos.skills.speak import Speak +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.utils.reactive import backpressure + +# Load API key from environment +load_dotenv() + +# Allow command line arguments to control spatial memory parameters +import argparse + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Run the robot with optional spatial memory parameters" + ) + parser.add_argument( + "--voice", + action="store_true", + help="Use voice input from microphone instead of web interface", + ) + return parser.parse_args() + + +args = parse_arguments() + +# Initialize robot with spatial memory parameters +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + skills=MyUnitreeSkills(), + mock_connection=False, + new_memory=True, +) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) +local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) + +# Initialize object detection stream +min_confidence = 0.6 +class_filter = None # No class filtering +detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) + +# Create video stream from robot's camera +video_stream = backpressure(robot.get_ros_video_stream()) + +# Initialize ObjectDetectionStream with robot +object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + transform_to_map=robot.ros_control.transform_pose, + detector=detector, + video_stream=video_stream, +) + +# Create visualization stream for web interface +viz_stream = backpressure(object_detector.get_stream()).pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), +) + +# Get the formatted detection stream +formatted_detection_stream = object_detector.get_formatted_stream().pipe( + ops.filter(lambda x: x is not None) +) + + +# Create a direct mapping that combines detection data with locations +def combine_with_locations(object_detections): + # Get locations from spatial memory + try: + locations = robot.get_spatial_memory().get_robot_locations() + + # Format the locations section + locations_text = "\n\nSaved Robot Locations:\n" + if locations: + for loc in locations: + locations_text += f"- {loc.name}: Position ({loc.position[0]:.2f}, {loc.position[1]:.2f}, {loc.position[2]:.2f}), " + locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" + else: + locations_text += "None\n" + + # Simply concatenate the strings + return object_detections + locations_text + except Exception as e: + print(f"Error adding locations: {e}") + return object_detections + + +# Create the combined stream with a simple pipe operation +enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) + +streams = { + "unitree_video": robot.get_ros_video_stream(), + "local_planner_viz": local_planner_viz_stream, + "object_detection": viz_stream, +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + +stt_node = stt() + +# Read system query from prompt.txt file +with open( + os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt"), "r" +) as f: + system_query = f.read() + +# Create a ClaudeAgent instance with either voice input or web interface input based on flag +input_stream = stt_node.emit_text() if args.voice else web_interface.query_stream +print(f"Using {'voice input' if args.voice else 'web interface input'} for queries") + +agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=input_stream, + input_data_stream=enhanced_data_stream, # Add the enhanced data stream + skills=robot.get_skills(), + system_query=system_query, + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=0, +) + +# Initialize TTS node only if voice flag is set +tts_node = None +if args.voice: + print("Voice mode: Enabling TTS for speech output") + tts_node = tts() + tts_node.consume_text(agent.get_response_observable()) +else: + print("Web interface mode: Disabling TTS to avoid audio issues") + +robot_skills = robot.get_skills() +robot_skills.add(ObserveStream) +robot_skills.add(KillSkill) +robot_skills.add(NavigateWithText) +robot_skills.add(FollowHuman) +robot_skills.add(GetPose) +# Add Speak skill only if voice flag is set +if args.voice: + robot_skills.add(Speak) +# robot_skills.add(NavigateToGoal) +robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +robot_skills.create_instance("NavigateWithText", robot=robot) +robot_skills.create_instance("FollowHuman", robot=robot) +robot_skills.create_instance("GetPose", robot=robot) +# robot_skills.create_instance("NavigateToGoal", robot=robot) +# Create Speak skill instance only if voice flag is set +if args.voice: + robot_skills.create_instance("Speak", tts_node=tts_node) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +print("ObserveStream and Kill skills registered and ready for use") +print("Created memory.txt file") + +web_interface.run() diff --git a/build/lib/tests/test_webrtc_queue.py b/build/lib/tests/test_webrtc_queue.py new file mode 100644 index 0000000000..11408df145 --- /dev/null +++ b/build/lib/tests/test_webrtc_queue.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tests.test_header + +import time +from dimos.robot.unitree.unitree_go2 import UnitreeGo2, WebRTCConnectionMethod +import os +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl + + +def main(): + """Test WebRTC request queue with a sequence of 20 back-to-back commands""" + + print("Initializing UnitreeGo2...") + + # Get configuration from environment variables + + robot_ip = os.getenv("ROBOT_IP") + connection_method = getattr(WebRTCConnectionMethod, os.getenv("CONNECTION_METHOD", "LocalSTA")) + + # Initialize ROS control + ros_control = UnitreeROSControl(node_name="unitree_go2_test", use_raw=True) + + # Initialize robot + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + ros_control=ros_control, + use_ros=True, + use_webrtc=False, # Using queue instead of direct WebRTC + ) + + # Wait for initialization + print("Waiting for robot to initialize...") + time.sleep(5) + + # First put the robot in a good starting state + print("Running recovery stand...") + robot.webrtc_req(api_id=1006) # RecoveryStand + + # Queue 20 WebRTC requests back-to-back + print("\n🤖 QUEUEING 20 COMMANDS BACK-TO-BACK 🤖\n") + + # Dance 1 + robot.webrtc_req(api_id=1022) # Dance1 + print("Queued: Dance1 (1022)") + + # Wiggle Hips + robot.webrtc_req(api_id=1033) # WiggleHips + print("Queued: WiggleHips (1033)") + + # Stretch + robot.webrtc_req(api_id=1017) # Stretch + print("Queued: Stretch (1017)") + + # Hello + robot.webrtc_req(api_id=1016) # Hello + print("Queued: Hello (1016)") + + # Dance 2 + robot.webrtc_req(api_id=1023) # Dance2 + print("Queued: Dance2 (1023)") + + # Wallow + robot.webrtc_req(api_id=1021) # Wallow + print("Queued: Wallow (1021)") + + # Scrape + robot.webrtc_req(api_id=1029) # Scrape + print("Queued: Scrape (1029)") + + # Finger Heart + robot.webrtc_req(api_id=1036) # FingerHeart + print("Queued: FingerHeart (1036)") + + # Recovery Stand (base position) + robot.webrtc_req(api_id=1006) # RecoveryStand + print("Queued: RecoveryStand (1006)") + + # Hello again + robot.webrtc_req(api_id=1016) # Hello + print("Queued: Hello (1016)") + + # Wiggle Hips again + robot.webrtc_req(api_id=1033) # WiggleHips + print("Queued: WiggleHips (1033)") + + # Front Pounce + robot.webrtc_req(api_id=1032) # FrontPounce + print("Queued: FrontPounce (1032)") + + # Dance 1 again + robot.webrtc_req(api_id=1022) # Dance1 + print("Queued: Dance1 (1022)") + + # Stretch again + robot.webrtc_req(api_id=1017) # Stretch + print("Queued: Stretch (1017)") + + # Front Jump + robot.webrtc_req(api_id=1031) # FrontJump + print("Queued: FrontJump (1031)") + + # Finger Heart again + robot.webrtc_req(api_id=1036) # FingerHeart + print("Queued: FingerHeart (1036)") + + # Scrape again + robot.webrtc_req(api_id=1029) # Scrape + print("Queued: Scrape (1029)") + + # Hello one more time + robot.webrtc_req(api_id=1016) # Hello + print("Queued: Hello (1016)") + + # Dance 2 again + robot.webrtc_req(api_id=1023) # Dance2 + print("Queued: Dance2 (1023)") + + # Finish with recovery stand + robot.webrtc_req(api_id=1006) # RecoveryStand + print("Queued: RecoveryStand (1006)") + + print("\nAll 20 commands queued successfully! Watch the robot perform them in sequence.") + print("The WebRTC queue manager will process them one by one when the robot is ready.") + print("Press Ctrl+C to stop the program when you've seen enough.\n") + + try: + # Keep the program running so the queue can be processed + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nStopping the test...") + finally: + # Cleanup + print("Cleaning up resources...") + robot.cleanup() + print("Test completed.") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_websocketvis.py b/build/lib/tests/test_websocketvis.py new file mode 100644 index 0000000000..a400bd9d14 --- /dev/null +++ b/build/lib/tests/test_websocketvis.py @@ -0,0 +1,152 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import time +import threading +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.web.websocket_vis.server import WebsocketVis +from dimos.web.websocket_vis.helpers import vector_stream +from dimos.robot.global_planner.planner import AstarPlanner +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector +from reactivex import operators as ops +import argparse +import pickle +import reactivex as rx +from dimos.web.robot_web_interface import RobotWebInterface + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple test for vis.") + parser.add_argument( + "--live", + action="store_true", + ) + parser.add_argument( + "--port", type=int, default=5555, help="Port for web visualization interface" + ) + return parser.parse_args() + + +def setup_web_interface(robot, port=5555): + """Set up web interface with robot video and local planner visualization""" + print(f"Setting up web interface on port {port}") + + # Get video stream from robot + video_stream = robot.video_stream_ros.pipe( + ops.share(), + ops.map(lambda frame: frame), + ops.filter(lambda frame: frame is not None), + ) + + # Get local planner visualization stream + local_planner_stream = robot.local_planner_viz_stream.pipe( + ops.share(), + ops.map(lambda frame: frame), + ops.filter(lambda frame: frame is not None), + ) + + # Create web interface with streams + web_interface = RobotWebInterface( + port=port, robot_video=video_stream, local_planner=local_planner_stream + ) + + return web_interface + + +def main(): + args = parse_args() + + websocket_vis = WebsocketVis() + websocket_vis.start() + + web_interface = None + + if args.live: + ros_control = UnitreeROSControl(node_name="web_nav_test", mock_connection=False) + robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP")) + planner = robot.global_planner + + websocket_vis.connect( + vector_stream("robot", lambda: robot.ros_control.transform_euler_pos("base_link")) + ) + websocket_vis.connect( + robot.ros_control.topic("map", Costmap).pipe(ops.map(lambda x: ["costmap", x])) + ) + + # Also set up the web interface with both streams + if hasattr(robot, "video_stream_ros") and hasattr(robot, "local_planner_viz_stream"): + web_interface = setup_web_interface(robot, port=args.port) + + # Start web interface in a separate thread + viz_thread = threading.Thread(target=web_interface.run, daemon=True) + viz_thread.start() + print(f"Web interface available at http://localhost:{args.port}") + + else: + pickle_path = f"{__file__.rsplit('/', 1)[0]}/mockdata/vegas.pickle" + print(f"Loading costmap from {pickle_path}") + planner = AstarPlanner( + get_costmap=lambda: pickle.load(open(pickle_path, "rb")), + get_robot_pos=lambda: Vector(5.0, 5.0), + set_local_nav=lambda x: time.sleep(1) and True, + ) + + def msg_handler(msgtype, data): + if msgtype == "click": + target = Vector(data["position"]) + try: + planner.set_goal(target) + except Exception as e: + print(f"Error setting goal: {e}") + return + + def threaded_msg_handler(msgtype, data): + thread = threading.Thread(target=msg_handler, args=(msgtype, data)) + thread.daemon = True + thread.start() + + websocket_vis.connect(planner.vis_stream()) + websocket_vis.msg_handler = threaded_msg_handler + + print(f"WebSocket server started on port {websocket_vis.port}") + print(planner.get_costmap()) + + planner.plan(Vector(-4.8, -1.0)) # plan a path to the origin + + def fakepos(): + # Simulate a fake vector position change (to test realtime rendering) + vec = Vector(math.sin(time.time()) * 2, math.cos(time.time()) * 2, 0) + print(vec) + return vec + + # if not args.live: + # websocket_vis.connect(rx.interval(0.05).pipe(ops.map(lambda _: ["fakepos", fakepos()]))) + + try: + # Keep the server running + while True: + time.sleep(0.1) + pass + except KeyboardInterrupt: + print("Stopping WebSocket server...") + websocket_vis.stop() + print("WebSocket server stopped") + + +if __name__ == "__main__": + main() diff --git a/build/lib/tests/test_zed_setup.py b/build/lib/tests/test_zed_setup.py new file mode 100644 index 0000000000..ca50bb63fb --- /dev/null +++ b/build/lib/tests/test_zed_setup.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple test script to verify ZED camera setup and basic functionality. +""" + +import sys +from pathlib import Path + + +def test_imports(): + """Test that all required modules can be imported.""" + print("Testing imports...") + + try: + import numpy as np + + print("✓ NumPy imported successfully") + except ImportError as e: + print(f"✗ NumPy import failed: {e}") + return False + + try: + import cv2 + + print("✓ OpenCV imported successfully") + except ImportError as e: + print(f"✗ OpenCV import failed: {e}") + return False + + try: + from PIL import Image, ImageDraw, ImageFont + + print("✓ PIL imported successfully") + except ImportError as e: + print(f"✗ PIL import failed: {e}") + return False + + try: + import pyzed.sl as sl + + print("✓ ZED SDK (pyzed) imported successfully") + # Note: SDK version method varies between versions + except ImportError as e: + print(f"✗ ZED SDK import failed: {e}") + print(" Please install ZED SDK and pyzed package") + return False + + try: + from dimos.hardware.zed_camera import ZEDCamera + + print("✓ ZEDCamera class imported successfully") + except ImportError as e: + print(f"✗ ZEDCamera import failed: {e}") + return False + + try: + from dimos.perception.zed_visualizer import ZEDVisualizer + + print("✓ ZEDVisualizer class imported successfully") + except ImportError as e: + print(f"✗ ZEDVisualizer import failed: {e}") + return False + + return True + + +def test_camera_detection(): + """Test if ZED cameras are detected.""" + print("\nTesting camera detection...") + + try: + import pyzed.sl as sl + + # List available cameras + cameras = sl.Camera.get_device_list() + print(f"Found {len(cameras)} ZED camera(s):") + + for i, camera_info in enumerate(cameras): + print(f" Camera {i}:") + print(f" Model: {camera_info.camera_model}") + print(f" Serial: {camera_info.serial_number}") + print(f" State: {camera_info.camera_state}") + + return len(cameras) > 0 + + except Exception as e: + print(f"Error detecting cameras: {e}") + return False + + +def test_basic_functionality(): + """Test basic ZED camera functionality without actually opening the camera.""" + print("\nTesting basic functionality...") + + try: + import pyzed.sl as sl + from dimos.hardware.zed_camera import ZEDCamera + from dimos.perception.zed_visualizer import ZEDVisualizer + + # Test camera initialization (without opening) + camera = ZEDCamera( + camera_id=0, + resolution=sl.RESOLUTION.HD720, + depth_mode=sl.DEPTH_MODE.NEURAL, + ) + print("✓ ZEDCamera instance created successfully") + + # Test visualizer initialization + visualizer = ZEDVisualizer(max_depth=10.0) + print("✓ ZEDVisualizer instance created successfully") + + # Test creating a dummy visualization + dummy_rgb = np.zeros((480, 640, 3), dtype=np.uint8) + dummy_depth = np.ones((480, 640), dtype=np.float32) * 2.0 + + vis = visualizer.create_side_by_side_image(dummy_rgb, dummy_depth) + print("✓ Dummy visualization created successfully") + + return True + + except Exception as e: + print(f"✗ Basic functionality test failed: {e}") + return False + + +def main(): + """Run all tests.""" + print("ZED Camera Setup Test") + print("=" * 50) + + # Test imports + if not test_imports(): + print("\n❌ Import tests failed. Please install missing dependencies.") + return False + + # Test camera detection + cameras_found = test_camera_detection() + if not cameras_found: + print( + "\n⚠️ No ZED cameras detected. Please connect a ZED camera to test capture functionality." + ) + + # Test basic functionality + if not test_basic_functionality(): + print("\n❌ Basic functionality tests failed.") + return False + + print("\n" + "=" * 50) + if cameras_found: + print("✅ All tests passed! You can now run the ZED demo:") + print(" python examples/zed_neural_depth_demo.py --display-time 10") + else: + print("✅ Setup is ready, but no camera detected.") + print(" Connect a ZED camera and run:") + print(" python examples/zed_neural_depth_demo.py --display-time 10") + + return True + + +if __name__ == "__main__": + # Add the project root to Python path + sys.path.append(str(Path(__file__).parent)) + + # Import numpy after path setup + import numpy as np + + success = main() + sys.exit(0 if success else 1) diff --git a/build/lib/tests/visualization_script.py b/build/lib/tests/visualization_script.py new file mode 100644 index 0000000000..d0c4c6af84 --- /dev/null +++ b/build/lib/tests/visualization_script.py @@ -0,0 +1,1041 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visualize pickled manipulation pipeline results.""" + +import os +import sys +import pickle +import numpy as np +import json +import matplotlib + +# Try to use TkAgg backend for live display, fallback to Agg if not available +try: + matplotlib.use("TkAgg") +except: + try: + matplotlib.use("Qt5Agg") + except: + matplotlib.use("Agg") # Fallback to non-interactive +import matplotlib.pyplot as plt +import open3d as o3d + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid +from dimos.perception.grasp_generation.utils import visualize_grasps_3d +from dimos.perception.pointcloud.utils import visualize_pcd +from dimos.utils.logging_config import setup_logger +import trimesh + +import tf_lcm_py +import cv2 +from contextlib import contextmanager +import lcm_msgs +from lcm_msgs.sensor_msgs import JointState, PointCloud2, CameraInfo, PointCloud2, PointField +from lcm_msgs.std_msgs import Header +from typing import List, Tuple, Optional +import atexit +from datetime import datetime +import time + +from pydrake.all import ( + AddMultibodyPlantSceneGraph, + CoulombFriction, + Diagram, + DiagramBuilder, + InverseKinematics, + MeshcatVisualizer, + MeshcatVisualizerParams, + MultibodyPlant, + Parser, + RigidTransform, + RollPitchYaw, + RotationMatrix, + JointIndex, + Solve, + StartMeshcat, +) +from pydrake.geometry import ( + CollisionFilterDeclaration, + Mesh, + ProximityProperties, + InMemoryMesh, + Box, + Cylinder, +) +from pydrake.math import RigidTransform as DrakeRigidTransform +from pydrake.common import MemoryFile + +from pydrake.all import ( + MinimumDistanceLowerBoundConstraint, + MultibodyPlant, + Parser, + DiagramBuilder, + AddMultibodyPlantSceneGraph, + MeshcatVisualizer, + StartMeshcat, + RigidTransform, + Role, + RollPitchYaw, + RotationMatrix, + Solve, + InverseKinematics, + MeshcatVisualizerParams, + MinimumDistanceLowerBoundConstraint, + DoDifferentialInverseKinematics, + DifferentialInverseKinematicsStatus, + DifferentialInverseKinematicsParameters, + DepthImageToPointCloud, +) +from manipulation.scenarios import AddMultibodyTriad +from manipulation.meshcat_utils import ( # TODO(russt): switch to pydrake version + _MeshcatPoseSliders, +) +from manipulation.scenarios import AddIiwa, AddShape, AddWsg + +logger = setup_logger("visualization_script") + + +def create_point_cloud(color_img, depth_img, intrinsics): + """Create Open3D point cloud from RGB and depth images.""" + fx, fy, cx, cy = intrinsics + height, width = depth_img.shape + + o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) + color_o3d = o3d.geometry.Image(color_img) + depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False + ) + + return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) + + +def deserialize_point_cloud(data): + """Reconstruct Open3D PointCloud from serialized data.""" + if data is None: + return None + + pcd = o3d.geometry.PointCloud() + if "points" in data and data["points"]: + pcd.points = o3d.utility.Vector3dVector(np.array(data["points"])) + if "colors" in data and data["colors"]: + pcd.colors = o3d.utility.Vector3dVector(np.array(data["colors"])) + return pcd + + +def deserialize_voxel_grid(data): + """Reconstruct Open3D VoxelGrid from serialized data.""" + if data is None: + return None + + # Create a point cloud to convert to voxel grid + pcd = o3d.geometry.PointCloud() + voxel_size = data["voxel_size"] + origin = np.array(data["origin"]) + + # Create points from voxel indices + points = [] + colors = [] + for voxel in data["voxels"]: + # Each voxel is (i, j, k, r, g, b) + i, j, k, r, g, b = voxel + # Convert voxel grid index to 3D point + point = origin + np.array([i, j, k]) * voxel_size + points.append(point) + colors.append([r, g, b]) + + if points: + pcd.points = o3d.utility.Vector3dVector(np.array(points)) + pcd.colors = o3d.utility.Vector3dVector(np.array(colors)) + + # Convert to voxel grid + voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size) + return voxel_grid + + +def visualize_results(pickle_path="manipulation_results.pkl"): + """Load pickled results and visualize them.""" + print(f"Loading results from {pickle_path}...") + try: + with open(pickle_path, "rb") as f: + data = pickle.load(f) + + results = data["results"] + color_img = data["color_img"] + depth_img = data["depth_img"] + intrinsics = data["intrinsics"] + + print(f"Loaded results with keys: {list(results.keys())}") + + except FileNotFoundError: + print(f"Error: Pickle file {pickle_path} not found.") + print("Make sure to run test_manipulation_pipeline_single_frame_lcm.py first.") + return + except Exception as e: + print(f"Error loading pickle file: {e}") + return + + # Determine number of subplots based on what results we have + num_plots = 0 + plot_configs = [] + + if "detection_viz" in results and results["detection_viz"] is not None: + plot_configs.append(("detection_viz", "Object Detection")) + num_plots += 1 + + if "segmentation_viz" in results and results["segmentation_viz"] is not None: + plot_configs.append(("segmentation_viz", "Semantic Segmentation")) + num_plots += 1 + + if "pointcloud_viz" in results and results["pointcloud_viz"] is not None: + plot_configs.append(("pointcloud_viz", "All Objects Point Cloud")) + num_plots += 1 + + if "detected_pointcloud_viz" in results and results["detected_pointcloud_viz"] is not None: + plot_configs.append(("detected_pointcloud_viz", "Detection Objects Point Cloud")) + num_plots += 1 + + if "misc_pointcloud_viz" in results and results["misc_pointcloud_viz"] is not None: + plot_configs.append(("misc_pointcloud_viz", "Misc/Background Points")) + num_plots += 1 + + if "grasp_overlay" in results and results["grasp_overlay"] is not None: + plot_configs.append(("grasp_overlay", "Grasp Overlay")) + num_plots += 1 + + if num_plots == 0: + print("No visualization results to display") + return + + # Create subplot layout + if num_plots <= 3: + fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5)) + else: + rows = 2 + cols = (num_plots + 1) // 2 + fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) + + # Ensure axes is always a list for consistent indexing + if num_plots == 1: + axes = [axes] + elif num_plots > 2: + axes = axes.flatten() + + # Plot each result + for i, (key, title) in enumerate(plot_configs): + axes[i].imshow(results[key]) + axes[i].set_title(title) + axes[i].axis("off") + + # Hide unused subplots if any + if num_plots > 3: + for i in range(num_plots, len(axes)): + axes[i].axis("off") + + plt.tight_layout() + + # Save and show the plot + output_path = "visualization_results.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"Results visualization saved to: {output_path}") + + # Show plot live as well + plt.show(block=True) + plt.close() + + # Deserialize and reconstruct 3D objects from the pickle file + print("\nReconstructing 3D visualization objects from serialized data...") + + # Reconstruct full point cloud if available + full_pcd = None + if "full_pointcloud" in results and results["full_pointcloud"] is not None: + full_pcd = deserialize_point_cloud(results["full_pointcloud"]) + print(f"Reconstructed full point cloud with {len(np.asarray(full_pcd.points))} points") + + # Visualize reconstructed full point cloud + try: + visualize_pcd( + full_pcd, + window_name="Reconstructed Full Scene Point Cloud", + point_size=2.0, + show_coordinate_frame=True, + ) + except (KeyboardInterrupt, EOFError): + print("\nSkipping full point cloud visualization") + except Exception as e: + print(f"Error in point cloud visualization: {e}") + else: + print("No full point cloud available for visualization") + + # Reconstruct misc clusters if available + if "misc_clusters" in results and results["misc_clusters"]: + misc_clusters = [deserialize_point_cloud(cluster) for cluster in results["misc_clusters"]] + cluster_count = len(misc_clusters) + total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in misc_clusters) + print(f"Reconstructed {cluster_count} misc clusters with {total_misc_points} total points") + + # Visualize reconstructed misc clusters + try: + visualize_clustered_point_clouds( + misc_clusters, + window_name="Reconstructed Misc/Background Clusters (DBSCAN)", + point_size=3.0, + show_coordinate_frame=True, + ) + except (KeyboardInterrupt, EOFError): + print("\nSkipping misc clusters visualization") + except Exception as e: + print(f"Error in misc clusters visualization: {e}") + else: + print("No misc clusters available for visualization") + + # Reconstruct voxel grid if available + if "misc_voxel_grid" in results and results["misc_voxel_grid"] is not None: + misc_voxel_grid = deserialize_voxel_grid(results["misc_voxel_grid"]) + if misc_voxel_grid: + voxel_count = len(misc_voxel_grid.get_voxels()) + print(f"Reconstructed voxel grid with {voxel_count} voxels") + + # Visualize reconstructed voxel grid + try: + visualize_voxel_grid( + misc_voxel_grid, + window_name="Reconstructed Misc/Background Voxel Grid", + show_coordinate_frame=True, + ) + except (KeyboardInterrupt, EOFError): + print("\nSkipping voxel grid visualization") + except Exception as e: + print(f"Error in voxel grid visualization: {e}") + else: + print("Failed to reconstruct voxel grid") + else: + print("No voxel grid available for visualization") + + +class DrakeKinematicsEnv: + def __init__( + self, + urdf_path: str, + kinematic_chain_joints: List[str], + links_to_ignore: Optional[List[str]] = None, + ): + self._resources_to_cleanup = [] + + # Register cleanup at exit + atexit.register(self.cleanup_resources) + + # Initialize tf resources once and reuse them + self.buffer = tf_lcm_py.Buffer(30.0) + self._resources_to_cleanup.append(self.buffer) + with self.safe_lcm_instance() as lcm_instance: + self.tf_lcm_instance = lcm_instance + self._resources_to_cleanup.append(self.tf_lcm_instance) + # Create TransformListener with our LCM instance and buffer + self.listener = tf_lcm_py.TransformListener(self.tf_lcm_instance, self.buffer) + self._resources_to_cleanup.append(self.listener) + + # Check if URDF file exists + if not os.path.exists(urdf_path): + raise FileNotFoundError(f"URDF file not found: {urdf_path}") + + # Drake utils initialization + self.meshcat = StartMeshcat() + print(f"Meshcat started at: {self.meshcat.web_url()}") + + self.urdf_path = urdf_path + self.builder = DiagramBuilder() + + self.plant, self.scene_graph = AddMultibodyPlantSceneGraph(self.builder, time_step=0.01) + self.parser = Parser(self.plant) + + # Load the robot URDF + print(f"Loading URDF from: {self.urdf_path}") + self.model_instances = self.parser.AddModelsFromUrl(f"file://{self.urdf_path}") + self.kinematic_chain_joints = kinematic_chain_joints + self.model_instance = self.model_instances[0] if self.model_instances else None + + if not self.model_instances: + raise RuntimeError("Failed to load any model instances from URDF") + + print(f"Loaded {len(self.model_instances)} model instances") + + # Set up collision filtering + if links_to_ignore: + bodies = [] + for link_name in links_to_ignore: + try: + body = self.plant.GetBodyByName(link_name) + if body is not None: + bodies.extend(self.plant.GetBodiesWeldedTo(body)) + except RuntimeError: + print(f"Warning: Link '{link_name}' not found in URDF") + + if bodies: + arm_geoms = self.plant.CollectRegisteredGeometries(bodies) + decl = CollisionFilterDeclaration().ExcludeWithin(arm_geoms) + manager = self.scene_graph.collision_filter_manager() + manager.Apply(decl) + + # Load and process point cloud data + self._load_and_process_point_clouds() + + # Finalize the plant before adding visualizer + self.plant.Finalize() + + # Print some debug info about the plant + print(f"Plant has {self.plant.num_bodies()} bodies") + print(f"Plant has {self.plant.num_joints()} joints") + for i in range(self.plant.num_joints()): + joint = self.plant.get_joint(JointIndex(i)) + print(f" Joint {i}: {joint.name()} (type: {joint.type_name()})") + + # Add visualizer + self.visualizer = MeshcatVisualizer.AddToBuilder( + self.builder, self.scene_graph, self.meshcat, params=MeshcatVisualizerParams() + ) + + # Build the diagram + self.diagram = self.builder.Build() + self.diagram_context = self.diagram.CreateDefaultContext() + self.plant_context = self.plant.GetMyContextFromRoot(self.diagram_context) + + # Set up joint indices + self.joint_indices = [] + for joint_name in self.kinematic_chain_joints: + try: + joint = self.plant.GetJointByName(joint_name) + if joint.num_positions() > 0: + start_index = joint.position_start() + for i in range(joint.num_positions()): + self.joint_indices.append(start_index + i) + print( + f"Added joint '{joint_name}' at indices {start_index} to {start_index + joint.num_positions() - 1}" + ) + except RuntimeError: + print(f"Warning: Joint '{joint_name}' not found in URDF.") + + # Get important frames/bodies + try: + self.end_effector_link = self.plant.GetBodyByName("link6") + self.end_effector_frame = self.plant.GetFrameByName("link6") + print("Found end effector link6") + except RuntimeError: + print("Warning: link6 not found") + self.end_effector_link = None + self.end_effector_frame = None + + try: + self.camera_link = self.plant.GetBodyByName("camera_center_link") + print("Found camera_center_link") + except RuntimeError: + print("Warning: camera_center_link not found") + self.camera_link = None + + # Set robot to a reasonable initial configuration + self._set_initial_configuration() + + # Force initial visualization update + self._update_visualization() + + print("Drake environment initialization complete!") + print(f"Visit {self.meshcat.web_url()} to see the visualization") + + def _load_and_process_point_clouds(self): + """Load point cloud data from pickle file and add to scene""" + pickle_path = "manipulation_results.pkl" + try: + with open(pickle_path, "rb") as f: + data = pickle.load(f) + + results = data["results"] + print(f"Loaded results with keys: {list(results.keys())}") + + except FileNotFoundError: + print(f"Warning: Pickle file {pickle_path} not found.") + print("Skipping point cloud loading.") + return + except Exception as e: + print(f"Warning: Error loading pickle file: {e}") + return + + full_detected_pcd = o3d.geometry.PointCloud() + for obj in results["detected_objects"]: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(obj["point_cloud_numpy"]) + full_detected_pcd += pcd + + self.process_and_add_object_class("all_objects", results) + self.process_and_add_object_class("misc_clusters", results) + misc_clusters = results["misc_clusters"] + print(type(misc_clusters[0]["points"])) + print(np.asarray(misc_clusters[0]["points"]).shape) + + def process_and_add_object_class(self, object_key: str, results: dict): + # Process detected objects + if object_key in results: + detected_objects = results[object_key] + if detected_objects: + print(f"Processing {len(detected_objects)} {object_key}") + all_decomposed_meshes = [] + + transform = self.get_transform("world", "camera_center_link") + for i in range(len(detected_objects)): + try: + if object_key == "misc_clusters": + points = np.asarray(detected_objects[i]["points"]) + elif "point_cloud_numpy" in detected_objects[i]: + points = detected_objects[i]["point_cloud_numpy"] + elif ( + "point_cloud" in detected_objects[i] + and detected_objects[i]["point_cloud"] + ): + # Handle serialized point cloud + points = np.array(detected_objects[i]["point_cloud"]["points"]) + else: + print(f"Warning: No point cloud data found for object {i}") + continue + + if len(points) < 10: # Need more points for mesh reconstruction + print( + f"Warning: Object {i} has too few points ({len(points)}) for mesh reconstruction" + ) + continue + + # Swap y-z axes since this is a common problem + points = np.column_stack((points[:, 0], points[:, 2], -points[:, 1])) + # Transform points to world frame + points = self.transform_point_cloud_with_open3d(points, transform) + + # Use fast DBSCAN clustering + convex hulls approach + clustered_hulls = self._create_clustered_convex_hulls(points, i) + all_decomposed_meshes.extend(clustered_hulls) + + print( + f"Created {len(clustered_hulls)} clustered convex hulls for object {i}" + ) + + except Exception as e: + print(f"Warning: Failed to process object {i}: {e}") + + if all_decomposed_meshes: + self.register_convex_hulls_as_collision(all_decomposed_meshes, object_key) + print(f"Registered {len(all_decomposed_meshes)} total clustered convex hulls") + else: + print("Warning: No valid clustered convex hulls created from detected objects") + else: + print("No detected objects found") + + def _create_clustered_convex_hulls( + self, points: np.ndarray, object_id: int + ) -> List[o3d.geometry.TriangleMesh]: + """ + Create convex hulls from DBSCAN clusters of point cloud data. + Fast approach: cluster points, then convex hull each cluster. + + Args: + points: Nx3 numpy array of 3D points + object_id: ID for debugging/logging + + Returns: + List of Open3D triangle meshes (convex hulls of clusters) + """ + try: + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + + # Quick outlier removal (optional, can skip for speed) + if len(points) > 50: # Only for larger point clouds + pcd, _ = pcd.remove_statistical_outlier(nb_neighbors=10, std_ratio=2.0) + points = np.asarray(pcd.points) + + if len(points) < 4: + print(f"Warning: Too few points after filtering for object {object_id}") + return [] + + # Try multiple DBSCAN parameter combinations to find clusters + clusters = [] + labels = None + + # Calculate some basic statistics for parameter estimation + if len(points) > 10: + # Compute nearest neighbor distances for better eps estimation + distances = pcd.compute_nearest_neighbor_distance() + avg_nn_distance = np.mean(distances) + std_nn_distance = np.std(distances) + + print( + f"Object {object_id}: {len(points)} points, avg_nn_dist={avg_nn_distance:.4f}" + ) + + for i in range(20): + try: + eps = avg_nn_distance * (2.0 + (i * 0.1)) + min_samples = 20 + labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_samples)) + unique_labels = np.unique(labels) + clusters = unique_labels[unique_labels >= 0] # Remove noise label (-1) + + noise_points = np.sum(labels == -1) + clustered_points = len(points) - noise_points + + print( + f" Try {i + 1}: eps={eps:.4f}, min_samples={min_samples} → {len(clusters)} clusters, {clustered_points}/{len(points)} points clustered" + ) + + # Accept if we found clusters and most points are clustered + if ( + len(clusters) > 0 and clustered_points >= len(points) * 0.95 + ): # At least 30% of points clustered + print(f" ✓ Accepted parameter set {i + 1}") + break + + except Exception as e: + print( + f" Try {i + 1}: Failed with eps={eps:.4f}, min_samples={min_samples}: {e}" + ) + continue + + if len(clusters) == 0 or labels is None: + print( + f"No clusters found for object {object_id} after all attempts, using entire point cloud" + ) + # Fallback: use entire point cloud as single convex hull + hull_mesh, _ = pcd.compute_convex_hull() + hull_mesh.compute_vertex_normals() + return [hull_mesh] + + print( + f"Found {len(clusters)} clusters for object {object_id} (eps={eps:.3f}, min_samples={min_samples})" + ) + + # Create convex hull for each cluster + convex_hulls = [] + for cluster_id in clusters: + try: + # Get points for this cluster + cluster_mask = labels == cluster_id + cluster_points = points[cluster_mask] + + if len(cluster_points) < 4: + print( + f"Skipping cluster {cluster_id} with only {len(cluster_points)} points" + ) + continue + + # Create point cloud for this cluster + cluster_pcd = o3d.geometry.PointCloud() + cluster_pcd.points = o3d.utility.Vector3dVector(cluster_points) + + # Compute convex hull + hull_mesh, _ = cluster_pcd.compute_convex_hull() + hull_mesh.compute_vertex_normals() + + # Validate hull + if ( + len(np.asarray(hull_mesh.vertices)) >= 4 + and len(np.asarray(hull_mesh.triangles)) >= 4 + ): + convex_hulls.append(hull_mesh) + print( + f" Cluster {cluster_id}: {len(cluster_points)} points → convex hull with {len(np.asarray(hull_mesh.vertices))} vertices" + ) + else: + print(f" Skipping degenerate hull for cluster {cluster_id}") + + except Exception as e: + print(f"Error processing cluster {cluster_id} for object {object_id}: {e}") + + if not convex_hulls: + print( + f"No valid convex hulls created for object {object_id}, using entire point cloud" + ) + # Fallback: use entire point cloud as single convex hull + hull_mesh, _ = pcd.compute_convex_hull() + hull_mesh.compute_vertex_normals() + return [hull_mesh] + + return convex_hulls + + except Exception as e: + print(f"Error in DBSCAN clustering for object {object_id}: {e}") + # Final fallback: single convex hull + try: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + hull_mesh, _ = pcd.compute_convex_hull() + hull_mesh.compute_vertex_normals() + return [hull_mesh] + except: + return [] + + def _set_initial_configuration(self): + """Set the robot to a reasonable initial joint configuration""" + # Set all joints to zero initially + if self.joint_indices: + q = np.zeros(len(self.joint_indices)) + + # You can customize these values for a better initial pose + # For example, if you know good default joint angles: + if len(q) >= 6: # Assuming at least 6 DOF arm + q[1] = 0.0 # joint1 + q[2] = 0.0 # joint2 + q[3] = 0.0 # joint3 + q[4] = 0.0 # joint4 + q[5] = 0.0 # joint5 + q[6] = 0.0 # joint6 + + # Set the joint positions in the plant context + positions = self.plant.GetPositions(self.plant_context) + for i, joint_idx in enumerate(self.joint_indices): + if joint_idx < len(positions): + positions[joint_idx] = q[i] + + self.plant.SetPositions(self.plant_context, positions) + print(f"Set initial joint configuration: {q}") + else: + print("Warning: No joint indices found, using default configuration") + + def _update_visualization(self): + """Force update the visualization""" + try: + # Get the visualizer's context from the diagram context + visualizer_context = self.visualizer.GetMyContextFromRoot(self.diagram_context) + self.visualizer.ForcedPublish(visualizer_context) + print("Visualization updated successfully") + except Exception as e: + print(f"Error updating visualization: {e}") + + def set_joint_positions(self, joint_positions): + """Set specific joint positions and update visualization""" + if len(joint_positions) != len(self.joint_indices): + raise ValueError( + f"Expected {len(self.joint_indices)} joint positions, got {len(joint_positions)}" + ) + + positions = self.plant.GetPositions(self.plant_context) + for i, joint_idx in enumerate(self.joint_indices): + if joint_idx < len(positions): + positions[joint_idx] = joint_positions[i] + + self.plant.SetPositions(self.plant_context, positions) + self._update_visualization() + print(f"Updated joint positions: {joint_positions}") + + def register_convex_hulls_as_collision( + self, meshes: List[o3d.geometry.TriangleMesh], hull_type: str + ): + """Register convex hulls as collision and visual geometry""" + if not meshes: + print("No meshes to register") + return + + world = self.plant.world_body() + proximity = ProximityProperties() + + for i, mesh in enumerate(meshes): + try: + # Convert Open3D → numpy arrays → trimesh.Trimesh + vertices = np.asarray(mesh.vertices) + faces = np.asarray(mesh.triangles) + + if len(vertices) == 0 or len(faces) == 0: + print(f"Warning: Mesh {i} is empty, skipping") + continue + + tmesh = trimesh.Trimesh(vertices=vertices, faces=faces) + + # Export to OBJ in memory + tmesh_obj_blob = tmesh.export(file_type="obj") + mem_file = MemoryFile( + contents=tmesh_obj_blob, extension=".obj", filename_hint=f"convex_hull_{i}.obj" + ) + in_memory_mesh = InMemoryMesh() + in_memory_mesh.mesh_file = mem_file + drake_mesh = Mesh(in_memory_mesh, scale=1.0) + + pos = np.array([0.0, 0.0, 0.0]) + rpy = RollPitchYaw(0.0, 0.0, 0.0) + X_WG = DrakeRigidTransform(RotationMatrix(rpy), pos) + + # Register collision and visual geometry + self.plant.RegisterCollisionGeometry( + body=world, + X_BG=X_WG, + shape=drake_mesh, + name=f"convex_hull_collision_{i}_{hull_type}", + properties=proximity, + ) + self.plant.RegisterVisualGeometry( + body=world, + X_BG=X_WG, + shape=drake_mesh, + name=f"convex_hull_visual_{i}_{hull_type}", + diffuse_color=np.array([0.7, 0.5, 0.3, 0.8]), # Orange-ish color + ) + + print( + f"Registered convex hull {i} with {len(vertices)} vertices and {len(faces)} faces" + ) + + except Exception as e: + print(f"Warning: Failed to register mesh {i}: {e}") + + # Add a simple table for reference + try: + table_shape = Box(1.0, 1.0, 0.1) # Thinner table + table_pose = RigidTransform(p=[0.5, 0.0, -0.05]) # In front of robot + self.plant.RegisterCollisionGeometry( + world, table_pose, table_shape, "table_collision", proximity + ) + self.plant.RegisterVisualGeometry( + world, table_pose, table_shape, "table_visual", [0.8, 0.6, 0.4, 1.0] + ) + print("Added reference table") + except Exception as e: + print(f"Warning: Failed to add table: {e}") + + def get_seeded_random_rgba(self, id: int): + np.random.seed(id) + return np.random.rand(4) + + @contextmanager + def safe_lcm_instance(self): + """Context manager for safely managing LCM instance lifecycle""" + lcm_instance = tf_lcm_py.LCM() + try: + yield lcm_instance + finally: + pass + + def cleanup_resources(self): + """Clean up resources before exiting""" + # Only clean up once when exiting + print("Cleaning up resources...") + # Force cleanup of resources in reverse order (last created first) + for resource in reversed(self._resources_to_cleanup): + try: + # For objects like TransformListener that might have a close or shutdown method + if hasattr(resource, "close"): + resource.close() + elif hasattr(resource, "shutdown"): + resource.shutdown() + + # Explicitly delete the resource + del resource + except Exception as e: + print(f"Error during cleanup: {e}") + + # Clear the resources list + self._resources_to_cleanup = [] + + def get_transform(self, target_frame, source_frame): + print("Getting transform from", source_frame, "to", target_frame) + attempts = 0 + max_attempts = 20 # Reduced from 120 to avoid long blocking + + while attempts < max_attempts: + try: + # Process LCM messages with error handling + if not self.tf_lcm_instance.handle_timeout(100): # 100ms timeout + # If handle_timeout returns false, we might need to re-check if LCM is still good + if not self.tf_lcm_instance.good(): + print("WARNING: LCM instance is no longer in a good state") + + # Get the most recent timestamp from the buffer instead of using current time + try: + timestamp = self.buffer.get_most_recent_timestamp() + if attempts % 10 == 0: + print(f"Using timestamp from buffer: {timestamp}") + except Exception as e: + # Fall back to current time if get_most_recent_timestamp fails + timestamp = datetime.now() + if not hasattr(timestamp, "timestamp"): + timestamp.timestamp = ( + lambda: time.mktime(timestamp.timetuple()) + timestamp.microsecond / 1e6 + ) + if attempts % 10 == 0: + print(f"Falling back to current time: {timestamp}") + + # Check if we can find the transform + if self.buffer.can_transform(target_frame, source_frame, timestamp): + # print(f"Found transform between '{target_frame}' and '{source_frame}'!") + + # Look up the transform with the timestamp from the buffer + transform = self.buffer.lookup_transform( + target_frame, + source_frame, + timestamp, + timeout=10.0, + time_tolerance=0.1, + lcm_module=lcm_msgs, + ) + + return transform + + # Increment counter and report status every 10 attempts + attempts += 1 + if attempts % 10 == 0: + print(f"Still waiting... (attempt {attempts}/{max_attempts})") + frames = self.buffer.get_all_frame_names() + if frames: + print(f"Frames received so far ({len(frames)} total):") + for frame in sorted(frames): + print(f" {frame}") + else: + print("No frames received yet") + + # Brief pause + time.sleep(0.5) + + except Exception as e: + print(f"Error during transform lookup: {e}") + attempts += 1 + time.sleep(1) # Longer pause after an error + + print(f"\nERROR: No transform found after {max_attempts} attempts") + return None + + def transform_point_cloud_with_open3d(self, points_np: np.ndarray, transform) -> np.ndarray: + """ + Transforms a point cloud using Open3D given a transform. + + Args: + points_np (np.ndarray): Nx3 array of 3D points. + transform: Transform from tf_lcm_py. + + Returns: + np.ndarray: Nx3 array of transformed 3D points. + """ + if points_np.shape[1] != 3: + print("Input point cloud must have shape Nx3.") + return points_np + + # Convert transform to 4x4 numpy matrix + tf_matrix = np.eye(4) + + # Extract rotation quaternion components + qw = transform.transform.rotation.w + qx = transform.transform.rotation.x + qy = transform.transform.rotation.y + qz = transform.transform.rotation.z + + # Convert quaternion to rotation matrix + # Formula from: https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Quaternion-derived_rotation_matrix + tf_matrix[0, 0] = 1 - 2 * qy * qy - 2 * qz * qz + tf_matrix[0, 1] = 2 * qx * qy - 2 * qz * qw + tf_matrix[0, 2] = 2 * qx * qz + 2 * qy * qw + + tf_matrix[1, 0] = 2 * qx * qy + 2 * qz * qw + tf_matrix[1, 1] = 1 - 2 * qx * qx - 2 * qz * qz + tf_matrix[1, 2] = 2 * qy * qz - 2 * qx * qw + + tf_matrix[2, 0] = 2 * qx * qz - 2 * qy * qw + tf_matrix[2, 1] = 2 * qy * qz + 2 * qx * qw + tf_matrix[2, 2] = 1 - 2 * qx * qx - 2 * qy * qy + + # Set translation + tf_matrix[0, 3] = transform.transform.translation.x + tf_matrix[1, 3] = transform.transform.translation.y + tf_matrix[2, 3] = transform.transform.translation.z + + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points_np) + + # Apply transformation + pcd.transform(tf_matrix) + + # Return as NumPy array + return np.asarray(pcd.points) + + +# Updated main function +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Visualize manipulation results") + parser.add_argument("--visualize-only", action="store_true", help="Only visualize results") + args = parser.parse_args() + + if args.visualize_only: + visualize_results() + exit(0) + + try: + # Then set up Drake environment + kinematic_chain_joints = [ + "pillar_platform_joint", + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + ] + + links_to_ignore = [ + "devkit_base_link", + "pillar_platform", + "piper_angled_mount", + "pan_tilt_base", + "pan_tilt_head", + "pan_tilt_pan", + "base_link", + "link1", + "link2", + "link3", + "link4", + "link5", + "link6", + ] + + urdf_path = "./assets/devkit_base_descr.urdf" + urdf_path = os.path.abspath(urdf_path) + + print(f"Attempting to load URDF from: {urdf_path}") + + env = DrakeKinematicsEnv(urdf_path, kinematic_chain_joints, links_to_ignore) + env.set_joint_positions([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + transform = env.get_transform("world", "camera_center_link") + print( + transform.transform.translation.x, + transform.transform.translation.y, + transform.transform.translation.z, + ) + print( + transform.transform.rotation.w, + transform.transform.rotation.x, + transform.transform.rotation.y, + transform.transform.rotation.z, + ) + + # Keep the visualization alive + print("\nVisualization is running. Press Ctrl+C to exit.") + while True: + time.sleep(1) + + except KeyboardInterrupt: + print("\nExiting...") + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() diff --git a/build/lib/tests/zed_neural_depth_demo.py b/build/lib/tests/zed_neural_depth_demo.py new file mode 100644 index 0000000000..5edce9633f --- /dev/null +++ b/build/lib/tests/zed_neural_depth_demo.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ZED Camera Neural Depth Demo - OpenCV Live Visualization with Data Saving + +This script demonstrates live visualization of ZED camera RGB and depth data using OpenCV. +Press SPACE to save RGB and depth images to rgbd_data2 folder. +Press ESC or 'q' to quit. +""" + +import os +import sys +import time +import argparse +import logging +from pathlib import Path +import numpy as np +import cv2 +import yaml +from datetime import datetime +import open3d as o3d + +# Add the project root to Python path +sys.path.append(str(Path(__file__).parent.parent)) + +try: + import pyzed.sl as sl +except ImportError: + print("ERROR: ZED SDK not found. Please install the ZED SDK and pyzed Python package.") + print("Download from: https://www.stereolabs.com/developers/release/") + sys.exit(1) + +from dimos.hardware.zed_camera import ZEDCamera +from dimos.perception.pointcloud.utils import visualize_pcd, visualize_clustered_point_clouds + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class ZEDLiveVisualizer: + """Live OpenCV visualization for ZED camera data with saving functionality.""" + + def __init__(self, camera, max_depth=10.0, output_dir="assets/rgbd_data2"): + self.camera = camera + self.max_depth = max_depth + self.output_dir = Path(output_dir) + self.save_counter = 0 + + # Store captured pointclouds for later visualization + self.captured_pointclouds = [] + + # Display settings for 480p + self.display_width = 640 + self.display_height = 480 + + # Create output directory structure + self.setup_output_directory() + + # Get camera info for saving + self.camera_info = camera.get_camera_info() + + # Save camera info files once + self.save_camera_info() + + # OpenCV window name (single window) + self.window_name = "ZED Camera - RGB + Depth" + + # Create window + cv2.namedWindow(self.window_name, cv2.WINDOW_AUTOSIZE) + + def setup_output_directory(self): + """Create the output directory structure.""" + self.output_dir.mkdir(exist_ok=True) + (self.output_dir / "color").mkdir(exist_ok=True) + (self.output_dir / "depth").mkdir(exist_ok=True) + (self.output_dir / "pointclouds").mkdir(exist_ok=True) + logger.info(f"Created output directory: {self.output_dir}") + + def save_camera_info(self): + """Save camera info YAML files with ZED camera parameters.""" + # Get current timestamp + now = datetime.now() + timestamp_sec = int(now.timestamp()) + timestamp_nanosec = int((now.timestamp() % 1) * 1e9) + + # Get camera resolution + resolution = self.camera_info.get("resolution", {}) + width = int(resolution.get("width", 1280)) + height = int(resolution.get("height", 720)) + + # Extract left camera parameters (for RGB) from already available camera_info + left_cam = self.camera_info.get("left_cam", {}) + # Convert numpy values to Python floats + fx = float(left_cam.get("fx", 749.341552734375)) + fy = float(left_cam.get("fy", 748.5587768554688)) + cx = float(left_cam.get("cx", 639.4312744140625)) + cy = float(left_cam.get("cy", 357.2478942871094)) + + # Build distortion coefficients from ZED format + # ZED provides k1, k2, p1, p2, k3 - convert to rational_polynomial format + k1 = float(left_cam.get("k1", 0.0)) + k2 = float(left_cam.get("k2", 0.0)) + p1 = float(left_cam.get("p1", 0.0)) + p2 = float(left_cam.get("p2", 0.0)) + k3 = float(left_cam.get("k3", 0.0)) + distortion = [k1, k2, p1, p2, k3, 0.0, 0.0, 0.0] + + # Create camera info structure with plain Python types + camera_info = { + "D": distortion, + "K": [fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0], + "P": [fx, 0.0, cx, 0.0, 0.0, fy, cy, 0.0, 0.0, 0.0, 1.0, 0.0], + "R": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + "binning_x": 0, + "binning_y": 0, + "distortion_model": "rational_polynomial", + "header": { + "frame_id": "camera_color_optical_frame", + "stamp": {"nanosec": timestamp_nanosec, "sec": timestamp_sec}, + }, + "height": height, + "roi": {"do_rectify": False, "height": 0, "width": 0, "x_offset": 0, "y_offset": 0}, + "width": width, + } + + # Save color camera info + color_info_path = self.output_dir / "color_camera_info.yaml" + with open(color_info_path, "w") as f: + yaml.dump(camera_info, f, default_flow_style=False) + + # Save depth camera info (same as color for ZED) + depth_info_path = self.output_dir / "depth_camera_info.yaml" + with open(depth_info_path, "w") as f: + yaml.dump(camera_info, f, default_flow_style=False) + + logger.info(f"Saved camera info files to {self.output_dir}") + + def normalize_depth_for_display(self, depth_map): + """Normalize depth map for OpenCV visualization.""" + # Handle invalid values + valid_mask = (depth_map > 0) & np.isfinite(depth_map) + + if not np.any(valid_mask): + return np.zeros_like(depth_map, dtype=np.uint8) + + # Normalize to 0-255 for display + depth_norm = np.zeros_like(depth_map, dtype=np.float32) + depth_clipped = np.clip(depth_map[valid_mask], 0, self.max_depth) + depth_norm[valid_mask] = depth_clipped / self.max_depth + + # Convert to 8-bit and apply colormap + depth_8bit = (depth_norm * 255).astype(np.uint8) + depth_colored = cv2.applyColorMap(depth_8bit, cv2.COLORMAP_JET) + + return depth_colored + + def save_frame(self, rgb_img, depth_map): + """Save RGB, depth images, and pointcloud with proper naming convention.""" + # Generate filename with 5-digit zero-padding + filename = f"{self.save_counter:05d}.png" + pcd_filename = f"{self.save_counter:05d}.ply" + + # Save RGB image + rgb_path = self.output_dir / "color" / filename + cv2.imwrite(str(rgb_path), rgb_img) + + # Save depth image (convert to 16-bit for proper depth storage) + depth_path = self.output_dir / "depth" / filename + # Convert meters to millimeters and save as 16-bit + depth_mm = (depth_map * 1000).astype(np.uint16) + cv2.imwrite(str(depth_path), depth_mm) + + # Capture and save pointcloud + pcd = self.camera.capture_pointcloud() + if pcd is not None and len(np.asarray(pcd.points)) > 0: + pcd_path = self.output_dir / "pointclouds" / pcd_filename + o3d.io.write_point_cloud(str(pcd_path), pcd) + + # Store pointcloud for later visualization + self.captured_pointclouds.append(pcd) + + logger.info( + f"Saved frame {self.save_counter}: {rgb_path}, {depth_path}, and {pcd_path}" + ) + else: + logger.warning(f"Failed to capture pointcloud for frame {self.save_counter}") + logger.info(f"Saved frame {self.save_counter}: {rgb_path} and {depth_path}") + + self.save_counter += 1 + + def visualize_captured_pointclouds(self): + """Visualize all captured pointclouds using Open3D, one by one.""" + if not self.captured_pointclouds: + logger.info("No pointclouds captured to visualize") + return + + logger.info( + f"Visualizing {len(self.captured_pointclouds)} captured pointclouds one by one..." + ) + logger.info("Close each pointcloud window to proceed to the next one") + + for i, pcd in enumerate(self.captured_pointclouds): + if len(np.asarray(pcd.points)) > 0: + logger.info(f"Displaying pointcloud {i + 1}/{len(self.captured_pointclouds)}") + visualize_pcd(pcd, window_name=f"ZED Pointcloud {i + 1:05d}", point_size=2.0) + else: + logger.warning(f"Pointcloud {i + 1} is empty, skipping...") + + logger.info("Finished displaying all pointclouds") + + def update_display(self): + """Update the live display with new frames.""" + # Capture frame + left_img, right_img, depth_map = self.camera.capture_frame() + + if left_img is None or depth_map is None: + return False, None, None + + # Resize RGB to 480p + rgb_resized = cv2.resize(left_img, (self.display_width, self.display_height)) + + # Create depth visualization + depth_colored = self.normalize_depth_for_display(depth_map) + + # Resize depth to 480p + depth_resized = cv2.resize(depth_colored, (self.display_width, self.display_height)) + + # Add text overlays + text_color = (255, 255, 255) + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 2 + + # Add title and instructions to RGB + cv2.putText( + rgb_resized, "RGB Camera Feed", (10, 25), font, font_scale, text_color, thickness + ) + cv2.putText( + rgb_resized, + "SPACE: Save | ESC/Q: Quit", + (10, 50), + font, + font_scale - 0.1, + text_color, + thickness, + ) + + # Add title and stats to depth + cv2.putText( + depth_resized, + f"Depth Map (0-{self.max_depth}m)", + (10, 25), + font, + font_scale, + text_color, + thickness, + ) + cv2.putText( + depth_resized, + f"Saved: {self.save_counter} frames", + (10, 50), + font, + font_scale - 0.1, + text_color, + thickness, + ) + + # Stack images horizontally + combined_display = np.hstack((rgb_resized, depth_resized)) + + # Display combined image + cv2.imshow(self.window_name, combined_display) + + return True, left_img, depth_map + + def handle_key_events(self, rgb_img, depth_map): + """Handle keyboard input.""" + key = cv2.waitKey(1) & 0xFF + + if key == ord(" "): # Space key - save frame + if rgb_img is not None and depth_map is not None: + self.save_frame(rgb_img, depth_map) + return "save" + elif key == 27 or key == ord("q"): # ESC or 'q' - quit + return "quit" + + return "continue" + + def cleanup(self): + """Clean up OpenCV windows.""" + cv2.destroyAllWindows() + + +def main(): + parser = argparse.ArgumentParser( + description="ZED Camera Neural Depth Demo - OpenCV with Data Saving" + ) + parser.add_argument("--camera-id", type=int, default=0, help="ZED camera ID (default: 0)") + parser.add_argument( + "--resolution", + type=str, + default="HD1080", + choices=["HD2K", "HD1080", "HD720", "VGA"], + help="Camera resolution (default: HD1080)", + ) + parser.add_argument( + "--max-depth", + type=float, + default=10.0, + help="Maximum depth for visualization in meters (default: 10.0)", + ) + parser.add_argument( + "--camera-fps", type=int, default=15, help="Camera capture FPS (default: 30)" + ) + parser.add_argument( + "--depth-mode", + type=str, + default="NEURAL", + choices=["NEURAL", "NEURAL_PLUS"], + help="Depth mode (NEURAL=faster, NEURAL_PLUS=more accurate)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="assets/rgbd_data2", + help="Output directory for saved data (default: rgbd_data2)", + ) + + args = parser.parse_args() + + # Map resolution string to ZED enum + resolution_map = { + "HD2K": sl.RESOLUTION.HD2K, + "HD1080": sl.RESOLUTION.HD1080, + "HD720": sl.RESOLUTION.HD720, + "VGA": sl.RESOLUTION.VGA, + } + + depth_mode_map = {"NEURAL": sl.DEPTH_MODE.NEURAL, "NEURAL_PLUS": sl.DEPTH_MODE.NEURAL_PLUS} + + try: + # Initialize ZED camera with neural depth + logger.info( + f"Initializing ZED camera with {args.depth_mode} depth processing at {args.camera_fps} FPS..." + ) + camera = ZEDCamera( + camera_id=args.camera_id, + resolution=resolution_map[args.resolution], + depth_mode=depth_mode_map[args.depth_mode], + fps=args.camera_fps, + ) + + # Open camera + with camera: + # Get camera information + info = camera.get_camera_info() + logger.info(f"Camera Model: {info.get('model', 'Unknown')}") + logger.info(f"Serial Number: {info.get('serial_number', 'Unknown')}") + logger.info(f"Firmware: {info.get('firmware', 'Unknown')}") + logger.info(f"Resolution: {info.get('resolution', {})}") + logger.info(f"Baseline: {info.get('baseline', 0):.3f}m") + + # Initialize visualizer + visualizer = ZEDLiveVisualizer( + camera, max_depth=args.max_depth, output_dir=args.output_dir + ) + + logger.info("Starting live visualization...") + logger.info("Controls:") + logger.info(" SPACE - Save current RGB and depth frame") + logger.info(" ESC/Q - Quit") + + frame_count = 0 + start_time = time.time() + + try: + while True: + loop_start = time.time() + + # Update display + success, rgb_img, depth_map = visualizer.update_display() + + if success: + frame_count += 1 + + # Handle keyboard events + action = visualizer.handle_key_events(rgb_img, depth_map) + + if action == "quit": + break + elif action == "save": + # Frame was saved, no additional action needed + pass + + # Print performance stats every 60 frames + if frame_count % 60 == 0: + elapsed = time.time() - start_time + fps = frame_count / elapsed + logger.info( + f"Frame {frame_count} | FPS: {fps:.1f} | Saved: {visualizer.save_counter}" + ) + + # Small delay to prevent CPU overload + elapsed = time.time() - loop_start + min_frame_time = 1.0 / 60.0 # Cap at 60 FPS + if elapsed < min_frame_time: + time.sleep(min_frame_time - elapsed) + + except KeyboardInterrupt: + logger.info("Stopped by user") + + # Final stats + total_time = time.time() - start_time + if total_time > 0: + avg_fps = frame_count / total_time + logger.info( + f"Final stats: {frame_count} frames in {total_time:.1f}s (avg {avg_fps:.1f} FPS)" + ) + logger.info(f"Total saved frames: {visualizer.save_counter}") + + # Visualize captured pointclouds + visualizer.visualize_captured_pointclouds() + + except Exception as e: + logger.error(f"Error during execution: {e}") + raise + finally: + if "visualizer" in locals(): + visualizer.cleanup() + logger.info("Demo completed") + + +if __name__ == "__main__": + main() From 31cb9014fcfaeb88803ddae9cfb6c088356b71f4 Mon Sep 17 00:00:00 2001 From: mustafab0 <39084056+mustafab0@users.noreply.github.com> Date: Wed, 16 Jul 2025 22:55:52 +0000 Subject: [PATCH 59/89] CI code cleanup --- build/lib/dimos/msgs/geometry_msgs/Twist.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/build/lib/dimos/msgs/geometry_msgs/Twist.py b/build/lib/dimos/msgs/geometry_msgs/Twist.py index b9d9630716..581c1d2e5f 100644 --- a/build/lib/dimos/msgs/geometry_msgs/Twist.py +++ b/build/lib/dimos/msgs/geometry_msgs/Twist.py @@ -1,3 +1,17 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """LCM type definitions This file automatically generated by lcm. DO NOT MODIFY BY HAND!!!! From 39ba624f00460e29cc2f8f993229c510461d21ea Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 16 Jul 2025 18:56:09 -0700 Subject: [PATCH 60/89] visual servoing fully working --- dimos/hardware/piper_arm.py | 54 +- .../{ibvs => visual_servoing}/detection3d.py | 170 ++-- dimos/manipulation/visual_servoing/pbvs.py | 783 ++++++++++++++++++ .../{ibvs => visual_servoing}/utils.py | 4 - dimos/utils/transform_utils.py | 147 +++- tests/test_ibvs.py | 316 ++++--- 6 files changed, 1219 insertions(+), 255 deletions(-) rename dimos/manipulation/{ibvs => visual_servoing}/detection3d.py (65%) create mode 100644 dimos/manipulation/visual_servoing/pbvs.py rename dimos/manipulation/{ibvs => visual_servoing}/utils.py (97%) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 19bb7f866e..943f35c4c3 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -26,6 +26,9 @@ import termios import tty import select +from scipy.spatial.transform import Rotation as R +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion +from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler class PiperArm: @@ -72,7 +75,7 @@ def gotoZero(self): print(X, Y, Z, RX, RY, RZ) self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) - self.arm.GripperCtrl(abs(joint_6), 1000, 0x01, 0) + self.arm.GripperCtrl(0, 1000, 0x01, 0) def softStop(self): self.gotoZero() @@ -81,7 +84,7 @@ def softStop(self): self.arm.MotionCtrl_1(0x01, 0, 0) time.sleep(5) - def cmd_EE_pose(self, x, y, z, r, p, y_): + def cmd_ee_pose_values(self, x, y, z, r, p, y_): """Command end-effector to target pose in space (position + Euler angles)""" factor = 1000 pose = [x * factor, y * factor, z * factor, r * factor, p * factor, y_ * factor] @@ -90,26 +93,44 @@ def cmd_EE_pose(self, x, y, z, r, p, y_): int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5]) ) - def get_EE_pose(self): - """Return the current end-effector pose as (x, y, z, r, p, y)""" + def cmd_ee_pose(self, pose: Pose): + """Command end-effector to target pose using Pose message""" + # Convert quaternion to euler angles + euler = quaternion_to_euler(pose.orientation, degrees=True) + + # Command the pose + self.cmd_ee_pose_values( + pose.position.x, pose.position.y, pose.position.z, + euler[0], euler[1], euler[2] + ) + + def get_ee_pose(self): + """Return the current end-effector pose as Pose message with position in meters and quaternion orientation""" pose = self.arm.GetArmEndPoseMsgs() + factor = 1000.0 # Extract individual pose values and convert to base units - # Position values are divided by 1000 to convert from SDK units to mm - # Rotation values are divided by 1000 to convert from SDK units to degrees - x = pose.end_pose.X_axis / 1000.0 - y = pose.end_pose.Y_axis / 1000.0 - z = pose.end_pose.Z_axis / 1000.0 - r = pose.end_pose.RX_axis / 1000.0 - p = pose.end_pose.RY_axis / 1000.0 - y_rot = pose.end_pose.RZ_axis / 1000.0 + # Position values are divided by 1000 to convert from SDK units to meters + # Rotation values are divided by 1000 to convert from SDK units to radians + x = pose.end_pose.X_axis / factor / factor # Convert mm to m + y = pose.end_pose.Y_axis / factor / factor # Convert mm to m + z = pose.end_pose.Z_axis / factor / factor # Convert mm to m + rx = pose.end_pose.RX_axis / factor + ry = pose.end_pose.RY_axis / factor + rz = pose.end_pose.RZ_axis / factor + + # Create position vector (already in meters) + position = Vector3(x, y, z) + + orientation = euler_to_quaternion(Vector3(rx, ry, rz), degrees=True) - return (x, y, z, r, p, y_rot) + return Pose(position, orientation) def cmd_gripper_ctrl(self, position): """Command end-effector gripper""" - position = position * 1000 + factor = 1000 + position = position * factor * factor - self.arm.GripperCtrl(abs(round(position)), 1000, 0x01, 0) + self.arm.GripperCtrl(abs(round(position)), factor, 0x01, 0) print(f"[PiperArm] Commanding gripper position: {position}") def resetArm(self): @@ -212,9 +233,6 @@ def disable(self): if __name__ == "__main__": arm = PiperArm() - print("get_EE_pose") - arm.get_EE_pose() - def get_key(timeout=0.1): """Non-blocking key reader for arrow keys.""" fd = sys.stdin.fileno() diff --git a/dimos/manipulation/ibvs/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py similarity index 65% rename from dimos/manipulation/ibvs/detection3d.py rename to dimos/manipulation/visual_servoing/detection3d.py index caf693c78e..2b6e7e518b 100644 --- a/dimos/manipulation/ibvs/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Any import numpy as np import cv2 -from scipy.spatial.transform import Rotation as R from dimos.utils.logging_config import setup_logger from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter @@ -28,13 +27,14 @@ from dimos.perception.detection2d.utils import plot_results, calculate_object_size_from_bbox from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion -from dimos.types.vector import Vector from dimos.types.manipulation import ObjectData -from dimos.manipulation.ibvs.utils import estimate_object_depth +from dimos.manipulation.visual_servoing.utils import estimate_object_depth from dimos.utils.transform_utils import ( optical_to_robot_frame, pose_to_matrix, matrix_to_pose, + euler_to_quaternion, + compose_transforms, ) logger = setup_logger("dimos.perception.detection3d") @@ -84,22 +84,19 @@ def __init__( ) def process_frame( - self, rgb_image: np.ndarray, depth_image: np.ndarray, camera_pose: Optional[Any] = None - ) -> Dict[str, Any]: + self, rgb_image: np.ndarray, depth_image: np.ndarray, transform: Optional[np.ndarray] = None + ) -> List[ObjectData]: """ Process a single RGB-D frame to extract 3D object detections. Args: rgb_image: RGB image (H, W, 3) depth_image: Depth image (H, W) in meters - camera_pose: Optional camera pose in world frame (Pose object in ZED coordinates) + transform: Optional 4x4 transformation matrix to transform objects from camera frame to desired frame Returns: - Dictionary containing: - - detections: List of ObjectData objects with 3D pose information - - processing_time: Total processing time in seconds + List of ObjectData objects with 3D pose information """ - start_time = time.time() # Convert RGB to BGR for Sam (OpenCV format) bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) @@ -109,7 +106,7 @@ def process_frame( # Early exit if no detections if not masks or len(masks) == 0: - return {"detections": [], "processing_time": time.time() - start_time} + return [] # Convert CUDA tensors to numpy arrays if needed numpy_masks = [] @@ -179,69 +176,68 @@ def process_frame( else: obj_data["color"] = np.array([128, 128, 128], dtype=np.uint8) - # Transform to world frame if camera pose is available - if camera_pose is not None: + # Transform to desired frame if transform matrix is provided + if transform is not None: # Get orientation as euler angles, default to no rotation if not available obj_cam_orientation = pose.get( "rotation", np.array([0.0, 0.0, 0.0]) ) # Default to no rotation - world_pose = self._transform_to_world( - obj_cam_pos, obj_cam_orientation, camera_pose + transformed_pose = self._transform_object_pose( + obj_cam_pos, obj_cam_orientation, transform ) - obj_data["world_position"] = world_pose.position - obj_data["position"] = world_pose.position # Use world position - obj_data["rotation"] = world_pose.orientation # Use world rotation + obj_data["position"] = transformed_pose.position + obj_data["rotation"] = transformed_pose.orientation else: - # If no camera pose, use camera coordinates + # If no transform, use camera coordinates obj_data["position"] = Vector3(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]) detections.append(obj_data) - return {"detections": detections, "processing_time": time.time() - start_time} + return detections - def _transform_to_world( - self, obj_pos: np.ndarray, obj_orientation: np.ndarray, camera_pose: Pose + def _transform_object_pose( + self, obj_pos: np.ndarray, obj_orientation: np.ndarray, transform_matrix: np.ndarray ) -> Pose: """ - Transform object pose from optical frame to world frame. + Transform object pose from optical frame to desired frame using transformation matrix. Args: obj_pos: Object position in optical frame [x, y, z] obj_orientation: Object orientation in optical frame [roll, pitch, yaw] in radians - camera_pose: Camera pose in world frame (x forward, y left, z up) + transform_matrix: 4x4 transformation matrix from camera frame to desired frame Returns: - Object pose in world frame as Pose + Object pose in desired frame as Pose """ # Create object pose in optical frame - # Convert euler angles to quaternion - quat = R.from_euler("xyz", obj_orientation).as_quat() # [x, y, z, w] - obj_orientation_quat = Quaternion(quat[0], quat[1], quat[2], quat[3]) - - obj_pose_optical = Pose(Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), obj_orientation_quat) - - # Transform object pose from optical frame to world frame convention - obj_pose_world_frame = optical_to_robot_frame(obj_pose_optical) + # Convert euler angles to quaternion using utility function + euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2]) + obj_orientation_quat = euler_to_quaternion(euler_vector) + + obj_pose_optical = Pose( + Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), + obj_orientation_quat + ) - # Create transformation matrix from camera pose - T_world_camera = pose_to_matrix(camera_pose) + # Transform object pose from optical frame to robot frame convention first + obj_pose_robot_frame = optical_to_robot_frame(obj_pose_optical) # Create transformation matrix from object pose (relative to camera) - T_camera_object = pose_to_matrix(obj_pose_world_frame) + T_camera_object = pose_to_matrix(obj_pose_robot_frame) - # Combine transformations: T_world_object = T_world_camera * T_camera_object - T_world_object = T_world_camera @ T_camera_object + # Use compose_transforms to combine transformations + T_desired_object = compose_transforms(transform_matrix, T_camera_object) # Convert back to pose - world_pose = matrix_to_pose(T_world_object) + desired_pose = matrix_to_pose(T_desired_object) - return world_pose + return desired_pose def visualize_detections( self, rgb_image: np.ndarray, detections: List[ObjectData], - pbvs_controller: Optional[Any] = None, + show_coordinates: bool = True, ) -> np.ndarray: """ Visualize detections with 3D position overlay next to bounding boxes. @@ -249,7 +245,7 @@ def visualize_detections( Args: rgb_image: Original RGB image detections: List of ObjectData objects - pbvs_controller: Optional PBVS controller to get robot frame coordinates + show_coordinates: Whether to show 3D coordinates next to bounding boxes Returns: Visualization image @@ -267,61 +263,45 @@ def visualize_detections( # Use plot_results for basic visualization viz = plot_results(rgb_image, bboxes, track_ids, class_ids, confidences, names) - # Add 3D position overlay next to bounding boxes - fx, fy, cx, cy = self.camera_intrinsics - - for det in detections: - if "position" in det and "bbox" in det: - # Get position to display (robot frame if available, otherwise world frame) - world_position = det["position"] - display_position = world_position - frame_label = "" - - # Check if we should display robot frame coordinates - if pbvs_controller and pbvs_controller.manipulator_origin is not None: - robot_frame_data = pbvs_controller.get_object_pose_robot_frame(world_position) - if robot_frame_data: - display_position, _ = robot_frame_data - frame_label = "[R]" # Robot frame indicator - - bbox = det["bbox"] - - if isinstance(display_position, Vector3): - display_xyz = np.array( - [display_position.x, display_position.y, display_position.z] + # Add 3D position coordinates if requested + if show_coordinates: + for det in detections: + if "position" in det and "bbox" in det: + position = det["position"] + bbox = det["bbox"] + + if isinstance(position, Vector3): + pos_xyz = np.array([position.x, position.y, position.z]) + else: + pos_xyz = np.array([position["x"], position["y"], position["z"]]) + + # Get bounding box coordinates + x1, y1, x2, y2 = map(int, bbox) + + # Add position text next to bounding box (top-right corner) + pos_text = f"({pos_xyz[0]:.2f}, {pos_xyz[1]:.2f}, {pos_xyz[2]:.2f})" + text_x = x2 + 5 # Right edge of bbox + small offset + text_y = y1 + 15 # Top edge of bbox + small offset + + # Add background rectangle for better readability + text_size = cv2.getTextSize(pos_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0] + cv2.rectangle( + viz, + (text_x - 2, text_y - text_size[1] - 2), + (text_x + text_size[0] + 2, text_y + 2), + (0, 0, 0), + -1, ) - else: - display_xyz = np.array( - [display_position["x"], display_position["y"], display_position["z"]] - ) - - # Get bounding box coordinates - x1, y1, x2, y2 = map(int, bbox) - # Add position text next to bounding box (top-right corner) - pos_text = f"{frame_label}({display_xyz[0]:.2f}, {display_xyz[1]:.2f}, {display_xyz[2]:.2f})" - text_x = x2 + 5 # Right edge of bbox + small offset - text_y = y1 + 15 # Top edge of bbox + small offset - - # Add background rectangle for better readability - text_size = cv2.getTextSize(pos_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0] - cv2.rectangle( - viz, - (text_x - 2, text_y - text_size[1] - 2), - (text_x + text_size[0] + 2, text_y + 2), - (0, 0, 0), - -1, - ) - - cv2.putText( - viz, - pos_text, - (text_x, text_y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) + cv2.putText( + viz, + pos_text, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) return viz diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py new file mode 100644 index 0000000000..e3099ca7bc --- /dev/null +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -0,0 +1,783 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Position-Based Visual Servoing (PBVS) system for robotic manipulation. +Supports both eye-in-hand and eye-to-hand configurations. +""" + +import numpy as np +from typing import Optional, Tuple, Dict, Any, List +import cv2 + +from scipy.spatial.transform import Rotation as R +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion +from dimos.types.manipulation import ObjectData +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import ( + yaw_towards_point, + euler_to_quaternion, +) + +logger = setup_logger("dimos.manipulation.pbvs") + + +class PBVS: + """ + High-level Position-Based Visual Servoing orchestrator. + + Handles: + - Object tracking and target management + - Pregrasp distance computation + - Grasp pose generation + - Coordination with low-level controller + + Note: This class is agnostic to camera mounting (eye-in-hand vs eye-to-hand). + The caller is responsible for providing appropriate camera and EE poses. + """ + + def __init__( + self, + position_gain: float = 0.5, + rotation_gain: float = 0.3, + max_velocity: float = 0.1, # m/s + max_angular_velocity: float = 0.5, # rad/s + target_tolerance: float = 0.01, # 1cm + tracking_distance_threshold: float = 0.05, # 5cm for target tracking + pregrasp_distance: float = 0.15, # 15cm pregrasp distance + direct_ee_control: bool = False, # If True, output target poses instead of velocities + ): + """ + Initialize PBVS system. + + Args: + position_gain: Proportional gain for position control + rotation_gain: Proportional gain for rotation control + max_velocity: Maximum linear velocity command magnitude (m/s) + max_angular_velocity: Maximum angular velocity command magnitude (rad/s) + target_tolerance: Distance threshold for considering target reached (m) + tracking_distance_threshold: Max distance for target association (m) + pregrasp_distance: Distance to maintain before grasping (m) + direct_ee_control: If True, output target poses instead of velocity commands + """ + # Initialize low-level controller only if not in direct control mode + if not direct_ee_control: + self.controller = PBVSController( + position_gain=position_gain, + rotation_gain=rotation_gain, + max_velocity=max_velocity, + max_angular_velocity=max_angular_velocity, + target_tolerance=target_tolerance, + ) + else: + self.controller = None + + # Store parameters for direct mode error computation + self.target_tolerance = target_tolerance + + # Target tracking parameters + self.tracking_distance_threshold = tracking_distance_threshold + self.pregrasp_distance = pregrasp_distance + self.direct_ee_control = direct_ee_control + + # Target state + self.current_target = None + self.target_grasp_pose = None + + # For direct control mode visualization + self.last_position_error = None + self.last_target_reached = False + + logger.info( + f"Initialized PBVS system with controller gains: pos={position_gain}, rot={rotation_gain}, " + f"pregrasp_distance={pregrasp_distance}m" + ) + + def set_target(self, target_object: Dict[str, Any]) -> bool: + """ + Set a new target object for servoing. + + Args: + target_object: Object dict with at least 'position' field + + Returns: + True if target was set successfully + """ + if target_object and "position" in target_object: + self.current_target = target_object + self.target_grasp_pose = None # Will be computed when needed + logger.info(f"New target set: ID {target_object.get('object_id', 'unknown')}") + return True + return False + + def clear_target(self): + """Clear the current target.""" + self.current_target = None + self.target_grasp_pose = None + self.last_position_error = None + self.last_target_reached = False + if self.controller: + self.controller.clear_state() + logger.info("Target cleared") + + def get_current_target(self): + """ + Get the current target object. + + Returns: + Current target ObjectData or None if no target selected + """ + return self.current_target + + def is_target_reached(self, ee_pose: Pose) -> bool: + """ + Check if the current target has been reached. + + Args: + ee_pose: Current end-effector pose + + Returns: + True if target is reached, False otherwise + """ + if not self.target_grasp_pose: + return False + + # Calculate position error + error_x = self.target_grasp_pose.position.x - ee_pose.position.x + error_y = self.target_grasp_pose.position.y - ee_pose.position.y + error_z = self.target_grasp_pose.position.z - ee_pose.position.z + + error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) + return error_magnitude < self.target_tolerance + + def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: + """ + Update target by matching to closest object in new detections. + If tracking is lost, keeps the old target pose. + + Args: + new_detections: List of newly detected objects + + Returns: + True if target was successfully tracked, False if lost (but target is kept) + """ + if not self.current_target or "position" not in self.current_target: + return False + + if not new_detections: + logger.debug("No detections for target tracking - using last known pose") + return False + + # Get current target position + target_pos = self.current_target["position"] + if isinstance(target_pos, Vector3): + target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) + else: + target_xyz = np.array([target_pos["x"], target_pos["y"], target_pos["z"]]) + + best_match = None + min_distance = float("inf") + + for detection in new_detections: + if "position" not in detection: + continue + + det_pos = detection["position"] + if isinstance(det_pos, Vector3): + det_xyz = np.array([det_pos.x, det_pos.y, det_pos.z]) + else: + det_xyz = np.array([det_pos["x"], det_pos["y"], det_pos["z"]]) + + distance = np.linalg.norm(target_xyz - det_xyz) + + if distance < self.tracking_distance_threshold: + best_match = detection + + if distance < min_distance: + min_distance = distance + + if best_match: + self.current_target = best_match + self.target_grasp_pose = None # Recompute grasp pose + return True + logger.info(f"Target tracking lost: closest target distance={min_distance:.3f}m") + return False + + def _update_target_grasp_pose(self, ee_pose: Pose): + """ + Update target grasp pose based on current target and EE pose. + + Args: + ee_pose: Current end-effector pose + """ + if not self.current_target or "position" not in self.current_target: + return + + # Get target position + target_pos = self.current_target["position"] + + # Calculate orientation pointing from target towards EE + yaw_to_ee = yaw_towards_point( + Vector3(target_pos.x, target_pos.y, target_pos.z), + ee_pose.position + ) + + # Create target pose with proper orientation + # Convert euler angles to quaternion using utility function + euler = Vector3(0.0, 1.65, yaw_to_ee) # roll=0, pitch=90deg, yaw=calculated + target_orientation = euler_to_quaternion(euler) + + target_pose = Pose(target_pos, target_orientation) + + # Apply pregrasp distance + self.target_grasp_pose = self._apply_pregrasp_distance(target_pose, ee_pose) + + def _apply_pregrasp_distance(self, target_pose: Pose, ee_pose: Pose) -> Pose: + """ + Apply pregrasp distance to target pose by moving back towards EE. + + Args: + target_pose: Target pose + ee_pose: Current end-effector pose + + Returns: + Modified target pose with pregrasp distance applied + """ + # Get approach vector (from target position towards EE) + target_pos = np.array([target_pose.position.x, target_pose.position.y, target_pose.position.z]) + ee_pos = np.array([ee_pose.position.x, ee_pose.position.y, ee_pose.position.z]) + approach_vector = ee_pos - target_pos # Vector pointing towards EE + + # Normalize approach vector + approach_magnitude = np.linalg.norm(approach_vector) + if approach_magnitude > 1e-6: # Avoid division by zero + norm_approach_vector = approach_vector / approach_magnitude + else: + norm_approach_vector = np.array([0.0, 0.0, 0.0]) + + # Move back by pregrasp distance towards EE + offset_vector = self.pregrasp_distance * norm_approach_vector + + # Apply offset to target position + new_position = Vector3( + target_pose.position.x + offset_vector[0], + target_pose.position.y + offset_vector[1], + target_pose.position.z + offset_vector[2] + ) + + return Pose(new_position, target_pose.orientation) + + def compute_control( + self, ee_pose: Pose, new_detections: Optional[List[ObjectData]] = None + ) -> Tuple[Optional[Vector3], Optional[Vector3], bool, bool, Optional[Pose]]: + """ + Compute PBVS control with position and orientation servoing. + + Args: + ee_pose: Current end-effector pose + new_detections: Optional new detections for target tracking + + Returns: + Tuple of (velocity_command, angular_velocity_command, target_reached, has_target, target_pose) + - velocity_command: Linear velocity vector or None if no target (None in direct_ee_control mode) + - angular_velocity_command: Angular velocity vector or None if no target (None in direct_ee_control mode) + - target_reached: True if within target tolerance + - has_target: True if currently tracking a target + - target_pose: Target EE pose (only in direct_ee_control mode, otherwise None) + """ + # Check if we have a target + if not self.current_target or "position" not in self.current_target: + return None, None, False, False, None + + # Try to update target tracking if new detections provided + # Continue with last known pose even if tracking is lost + target_tracked = False + if new_detections is not None: + if self.update_target_tracking(new_detections): + target_tracked = True + else: + target_tracked = False + + # Update target grasp pose + self._update_target_grasp_pose(ee_pose) + + if self.target_grasp_pose is None: + logger.warning("Failed to compute grasp pose") + return None, None, False, False, None + + # Check if target reached using our separate function + target_reached = self.is_target_reached(ee_pose) + + # Return appropriate values based on control mode + if self.direct_ee_control: + # Direct control mode - compute errors for visualization only + self.last_position_error = Vector3( + self.target_grasp_pose.position.x - ee_pose.position.x, + self.target_grasp_pose.position.y - ee_pose.position.y, + self.target_grasp_pose.position.z - ee_pose.position.z + ) + self.last_target_reached = target_reached + return None, None, target_reached, target_tracked, self.target_grasp_pose + else: + # Velocity control mode - use controller + velocity_cmd, angular_velocity_cmd, controller_reached = self.controller.compute_control( + ee_pose, self.target_grasp_pose + ) + return velocity_cmd, angular_velocity_cmd, target_reached, target_tracked, None + + def get_object_pose_camera_frame( + self, object_pos: Vector3, camera_pose: Pose + ) -> Tuple[Vector3, Quaternion]: + """ + Get object pose in camera frame coordinates with orientation. + + Args: + object_pos: Object position in camera frame + camera_pose: Current camera pose + + Returns: + Tuple of (position, rotation) in camera frame + """ + # Calculate orientation pointing at camera + yaw_to_camera = yaw_towards_point(Vector3(object_pos.x, object_pos.y, object_pos.z)) + + # Convert euler angles to quaternion using utility function + euler = Vector3(0.0, 0.0, yaw_to_camera) # Level grasp + orientation = euler_to_quaternion(euler) + + return object_pos, orientation + + def create_status_overlay( + self, image: np.ndarray, + ) -> np.ndarray: + """ + Create PBVS status overlay on image. + + Args: + image: Input image + camera_intrinsics: Optional [fx, fy, cx, cy] (not used) + + Returns: + Image with PBVS status overlay + """ + if self.direct_ee_control: + # Use our own error data for direct control mode + return self._create_direct_status_overlay(image, self.current_target) + else: + # Use controller's overlay for velocity mode + return self.controller.create_status_overlay( + image, + self.current_target, + self.direct_ee_control, + ) + + def _create_direct_status_overlay(self, image: np.ndarray, current_target: Optional[ObjectData] = None) -> np.ndarray: + """ + Create status overlay for direct control mode. + + Args: + image: Input image + current_target: Current target object + + Returns: + Image with status overlay + """ + viz_img = image.copy() + height, width = image.shape[:2] + + # Status panel + if current_target is not None: + panel_height = 160 # Adjusted panel for target, grasp pose, and pregrasp distance info + panel_y = height - panel_height + overlay = viz_img.copy() + cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) + viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) + + # Status text + y = panel_y + 20 + cv2.putText( + viz_img, "PBVS Status (Direct EE)", (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2 + ) + + # Add frame info + cv2.putText( + viz_img, "Frame: Camera", (250, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 + ) + + if self.last_position_error: + error_mag = np.linalg.norm( + [self.last_position_error.x, self.last_position_error.y, self.last_position_error.z] + ) + color = (0, 255, 0) if self.last_target_reached else (0, 255, 255) + + cv2.putText( + viz_img, + f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", + (10, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color, + 1, + ) + + cv2.putText( + viz_img, + f"XYZ: ({self.last_position_error.x:.3f}, {self.last_position_error.y:.3f}, {self.last_position_error.z:.3f})", + (10, y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + # Show target and grasp poses + if current_target and "position" in current_target: + target_pos = current_target["position"] + cv2.putText( + viz_img, + f"Target: ({target_pos.x:.3f}, {target_pos.y:.3f}, {target_pos.z:.3f})", + (10, y + 65), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 0), + 1, + ) + + if self.target_grasp_pose: + grasp_pos = self.target_grasp_pose.position + cv2.putText( + viz_img, + f"Grasp: ({grasp_pos.x:.3f}, {grasp_pos.y:.3f}, {grasp_pos.z:.3f})", + (10, y + 80), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (0, 255, 255), + 1, + ) + + # Show pregrasp distance if we have both poses + if current_target and "position" in current_target: + target_pos = current_target["position"] + distance = np.sqrt( + (grasp_pos.x - target_pos.x)**2 + + (grasp_pos.y - target_pos.y)**2 + + (grasp_pos.z - target_pos.z)**2 + ) + cv2.putText( + viz_img, + f"Pregrasp: {distance*1000:.1f}mm", + (10, y + 95), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 200, 0), + 1, + ) + + if self.last_target_reached: + cv2.putText( + viz_img, + "TARGET REACHED", + (width - 150, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + return viz_img + + +class PBVSController: + """ + Low-level Position-Based Visual Servoing controller. + Pure control logic that computes velocity commands from poses. + + Handles: + - Position and orientation error computation + - Velocity command generation with gain control + - Target reached detection + """ + + def __init__( + self, + position_gain: float = 0.5, + rotation_gain: float = 0.3, + max_velocity: float = 0.1, # m/s + max_angular_velocity: float = 0.5, # rad/s + target_tolerance: float = 0.01, # 1cm + ): + """ + Initialize PBVS controller. + + Args: + position_gain: Proportional gain for position control + rotation_gain: Proportional gain for rotation control + max_velocity: Maximum linear velocity command magnitude (m/s) + max_angular_velocity: Maximum angular velocity command magnitude (rad/s) + target_tolerance: Distance threshold for considering target reached (m) + """ + self.position_gain = position_gain + self.rotation_gain = rotation_gain + self.max_velocity = max_velocity + self.max_angular_velocity = max_angular_velocity + self.target_tolerance = target_tolerance + + # State variables for visualization + self.last_position_error = None + self.last_rotation_error = None + self.last_velocity_cmd = None + self.last_angular_velocity_cmd = None + self.last_target_reached = False + + logger.info( + f"Initialized PBVS controller: pos_gain={position_gain}, rot_gain={rotation_gain}, " + f"max_vel={max_velocity}m/s, max_ang_vel={max_angular_velocity}rad/s, " + f"target_tolerance={target_tolerance}m" + ) + + def clear_state(self): + """Clear controller state.""" + self.last_position_error = None + self.last_rotation_error = None + self.last_velocity_cmd = None + self.last_angular_velocity_cmd = None + self.last_target_reached = False + + def compute_control( + self, ee_pose: Pose, grasp_pose: Pose + ) -> Tuple[Optional[Vector3], Optional[Vector3], bool]: + """ + Compute PBVS control with position and orientation servoing. + + Args: + ee_pose: Current end-effector pose + grasp_pose: Target grasp pose + + Returns: + Tuple of (velocity_command, angular_velocity_command, target_reached) + - velocity_command: Linear velocity vector + - angular_velocity_command: Angular velocity vector + - target_reached: True if within target tolerance + """ + # Calculate position error (target - EE position) + error = Vector3( + grasp_pose.position.x - ee_pose.position.x, + grasp_pose.position.y - ee_pose.position.y, + grasp_pose.position.z - ee_pose.position.z + ) + self.last_position_error = error + + # Compute velocity command with proportional control + velocity_cmd = Vector3( + error.x * self.position_gain, + error.y * self.position_gain, + error.z * self.position_gain, + ) + + # Limit velocity magnitude + vel_magnitude = np.linalg.norm([velocity_cmd.x, velocity_cmd.y, velocity_cmd.z]) + if vel_magnitude > self.max_velocity: + scale = self.max_velocity / vel_magnitude + velocity_cmd = Vector3( + float(velocity_cmd.x * scale), + float(velocity_cmd.y * scale), + float(velocity_cmd.z * scale), + ) + + self.last_velocity_cmd = velocity_cmd + + # Compute angular velocity for orientation control + angular_velocity_cmd = self._compute_angular_velocity(grasp_pose.orientation, ee_pose) + + # Check if target reached + error_magnitude = np.linalg.norm([error.x, error.y, error.z]) + target_reached = bool(error_magnitude < self.target_tolerance) + self.last_target_reached = target_reached + + return velocity_cmd, angular_velocity_cmd, target_reached + + def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) -> Vector3: + """ + Compute angular velocity commands for orientation control. + Uses quaternion error computation for better numerical stability. + + Args: + target_rot: Target orientation (quaternion) + current_pose: Current EE pose + + Returns: + Angular velocity command as Vector3 + """ + # Use quaternion error for better numerical stability + + # Convert to scipy Rotation objects + target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w]) + current_rot_scipy = R.from_quat([ + current_pose.orientation.x, + current_pose.orientation.y, + current_pose.orientation.z, + current_pose.orientation.w + ]) + + # Compute rotation error: error = target * current^(-1) + error_rot = target_rot_scipy * current_rot_scipy.inv() + + # Convert to axis-angle representation for control + error_axis_angle = error_rot.as_rotvec() + + # Use axis-angle directly as angular velocity error (small angle approximation) + roll_error = error_axis_angle[0] + pitch_error = error_axis_angle[1] + yaw_error = error_axis_angle[2] + + self.last_rotation_error = Vector3(roll_error, pitch_error, yaw_error) + + # Apply proportional control + angular_velocity = Vector3( + roll_error * self.rotation_gain, + pitch_error * self.rotation_gain, + yaw_error * self.rotation_gain, + ) + + # Limit angular velocity magnitude + ang_vel_magnitude = np.sqrt( + angular_velocity.x**2 + angular_velocity.y**2 + angular_velocity.z**2 + ) + if ang_vel_magnitude > self.max_angular_velocity: + scale = self.max_angular_velocity / ang_vel_magnitude + angular_velocity = Vector3( + angular_velocity.x * scale, + angular_velocity.y * scale, + angular_velocity.z * scale + ) + + self.last_angular_velocity_cmd = angular_velocity + + return angular_velocity + + def create_status_overlay( + self, + image: np.ndarray, + current_target: Optional[Dict[str, Any]] = None, + direct_ee_control: bool = False, + ) -> np.ndarray: + """ + Create PBVS status overlay on image. + + Args: + image: Input image + current_target: Current target object (for display) + direct_ee_control: Whether in direct EE control mode + + Returns: + Image with PBVS status overlay + """ + viz_img = image.copy() + height, width = image.shape[:2] + + # Status panel + if current_target is not None: + panel_height = 160 # Adjusted panel height + panel_y = height - panel_height + overlay = viz_img.copy() + cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) + viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) + + # Status text + y = panel_y + 20 + mode_text = "Direct EE" if direct_ee_control else "Velocity" + cv2.putText( + viz_img, f"PBVS Status ({mode_text})", (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2 + ) + + # Add frame info + cv2.putText( + viz_img, "Frame: Camera", (250, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 + ) + + if self.last_position_error: + error_mag = np.linalg.norm( + [ + self.last_position_error.x, + self.last_position_error.y, + self.last_position_error.z, + ] + ) + color = (0, 255, 0) if self.last_target_reached else (0, 255, 255) + + cv2.putText( + viz_img, + f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", + (10, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color, + 1, + ) + + cv2.putText( + viz_img, + f"XYZ: ({self.last_position_error.x:.3f}, {self.last_position_error.y:.3f}, {self.last_position_error.z:.3f})", + (10, y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + if self.last_velocity_cmd and not direct_ee_control: + cv2.putText( + viz_img, + f"Lin Vel: ({self.last_velocity_cmd.x:.2f}, {self.last_velocity_cmd.y:.2f}, {self.last_velocity_cmd.z:.2f})m/s", + (10, y + 65), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 200, 0), + 1, + ) + + if self.last_rotation_error: + cv2.putText( + viz_img, + f"Rot Error: ({self.last_rotation_error.x:.2f}, {self.last_rotation_error.y:.2f}, {self.last_rotation_error.z:.2f})rad", + (10, y + 85), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + if self.last_angular_velocity_cmd and not direct_ee_control: + cv2.putText( + viz_img, + f"Ang Vel: ({self.last_angular_velocity_cmd.x:.2f}, {self.last_angular_velocity_cmd.y:.2f}, {self.last_angular_velocity_cmd.z:.2f})rad/s", + (10, y + 105), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 200, 0), + 1, + ) + + if self.last_target_reached: + cv2.putText( + viz_img, + "TARGET REACHED", + (width - 150, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + return viz_img diff --git a/dimos/manipulation/ibvs/utils.py b/dimos/manipulation/visual_servoing/utils.py similarity index 97% rename from dimos/manipulation/ibvs/utils.py rename to dimos/manipulation/visual_servoing/utils.py index 581d34dc8c..b35cf1b0c1 100644 --- a/dimos/manipulation/ibvs/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -61,10 +61,6 @@ def estimate_object_depth( """ x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) - # Quick bounds check - if x2 <= x1 or y2 <= y1: - return 0.05 - # Extract depth ROI once roi_depth = depth_image[y1:y2, x1:x2] diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 46237ce0be..e57a17336b 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -15,9 +15,8 @@ import numpy as np from typing import Tuple, Dict, Any import logging -from scipy.spatial.transform import Rotation +from scipy.spatial.transform import Rotation as R -from dimos.types.vector import Vector from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion logger = logging.getLogger(__name__) @@ -48,12 +47,12 @@ def pose_to_matrix(pose: Pose) -> np.ndarray: # Create rotation matrix from quaternion using scipy quat = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] - rotation = Rotation.from_quat(quat) - R = rotation.as_matrix() + rotation = R.from_quat(quat) + Rot = rotation.as_matrix() # Create 4x4 transform T = np.eye(4) - T[:3, :3] = R + T[:3, :3] = Rot T[:3, 3] = [tx, ty, tz] return T @@ -73,8 +72,8 @@ def matrix_to_pose(T: np.ndarray) -> Pose: pos = Vector3(T[0, 3], T[1, 3], T[2, 3]) # Extract rotation matrix and convert to quaternion - R = T[:3, :3] - rotation = Rotation.from_matrix(R) + Rot = T[:3, :3] + rotation = R.from_matrix(Rot) quat = rotation.as_quat() # Returns [x, y, z, w] orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) @@ -131,7 +130,7 @@ def optical_to_robot_frame(pose: Pose) -> Pose: # Rotation transformation using quaternions # First convert quaternion to rotation matrix quat_optical = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] - R_optical = Rotation.from_quat(quat_optical).as_matrix() + R_optical = R.from_quat(quat_optical).as_matrix() # Coordinate frame transformation matrix from optical to robot # X_robot = Z_optical, Y_robot = -X_optical, Z_robot = -Y_optical @@ -147,7 +146,7 @@ def optical_to_robot_frame(pose: Pose) -> Pose: R_robot = T_frame @ R_optical @ T_frame.T # Convert back to quaternion - quat_robot = Rotation.from_matrix(R_robot).as_quat() # [x, y, z, w] + quat_robot = R.from_matrix(R_robot).as_quat() # [x, y, z, w] return Pose( Vector3(robot_x, robot_y, robot_z), @@ -173,7 +172,7 @@ def robot_to_optical_frame(pose: Pose) -> Pose: # Rotation transformation using quaternions quat_robot = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] - R_robot = Rotation.from_quat(quat_robot).as_matrix() + R_robot = R.from_quat(quat_robot).as_matrix() # Coordinate frame transformation matrix from Robot to optical (inverse of optical to Robot) # This is the transpose of the forward transformation @@ -189,7 +188,7 @@ def robot_to_optical_frame(pose: Pose) -> Pose: R_optical = T_frame_inv @ R_robot @ T_frame_inv.T # Convert back to quaternion - quat_optical = Rotation.from_matrix(R_optical).as_quat() # [x, y, z, w] + quat_optical = R.from_matrix(R_optical).as_quat() # [x, y, z, w] return Pose( Vector3(optical_x, optical_y, optical_z), @@ -197,7 +196,7 @@ def robot_to_optical_frame(pose: Pose) -> Pose: ) -def yaw_towards_point(position: Vector, target_point: Vector = Vector(0.0, 0.0, 0.0)) -> float: +def yaw_towards_point(position: Vector3, target_point: Vector3 = Vector3(0.0, 0.0, 0.0)) -> float: """ Calculate yaw angle from target point to position (away from target). This is commonly used for object orientation in grasping applications. @@ -210,29 +209,30 @@ def yaw_towards_point(position: Vector, target_point: Vector = Vector(0.0, 0.0, Returns: Yaw angle in radians pointing from target_point to position """ - direction = position - target_point - return np.arctan2(direction.y, direction.x) + direction_x = position.x - target_point.x + direction_y = position.y - target_point.y + return np.arctan2(direction_y, direction_x) def transform_robot_to_map( - robot_position: Vector, robot_rotation: Vector, position: Vector, rotation: Vector -) -> Tuple[Vector, Vector]: + robot_position: Vector3, robot_rotation: Vector3, position: Vector3, rotation: Vector3 +) -> Tuple[Vector3, Vector3]: """Transform position and rotation from robot frame to map frame. Args: robot_position: Current robot position in map frame robot_rotation: Current robot rotation in map frame - position: Position in robot frame as Vector (x, y, z) - rotation: Rotation in robot frame as Vector (roll, pitch, yaw) in radians + position: Position in robot frame as Vector3 (x, y, z) + rotation: Rotation in robot frame as Vector3 (roll, pitch, yaw) in radians Returns: Tuple of (transformed_position, transformed_rotation) where: - - transformed_position: Vector (x, y, z) in map frame - - transformed_rotation: Vector (roll, pitch, yaw) in map frame + - transformed_position: Vector3 (x, y, z) in map frame + - transformed_rotation: Vector3 (roll, pitch, yaw) in map frame Example: - obj_pos_robot = Vector(1.0, 0.5, 0.0) # 1m forward, 0.5m left of robot - obj_rot_robot = Vector(0.0, 0.0, 0.0) # No rotation relative to robot + obj_pos_robot = Vector3(1.0, 0.5, 0.0) # 1m forward, 0.5m left of robot + obj_rot_robot = Vector3(0.0, 0.0, 0.0) # No rotation relative to robot map_pos, map_rot = transform_robot_to_map(robot_position, robot_rotation, obj_pos_robot, obj_rot_robot) """ @@ -262,7 +262,106 @@ def transform_robot_to_map( map_pitch = robot_rot.y + rot_pitch # Add robot's pitch map_yaw_rot = normalize_angle(robot_yaw + rot_yaw) # Add robot's yaw and normalize - transformed_position = Vector(map_x, map_y, map_z) - transformed_rotation = Vector(map_roll, map_pitch, map_yaw_rot) + transformed_position = Vector3(map_x, map_y, map_z) + transformed_rotation = Vector3(map_roll, map_pitch, map_yaw_rot) return transformed_position, transformed_rotation + + +def create_transform_from_6dof(translation: Vector3, euler_angles: Vector3) -> np.ndarray: + """ + Create a 4x4 transformation matrix from 6DOF parameters. + + Args: + translation: Translation vector [x, y, z] in meters + euler_angles: Euler angles [rx, ry, rz] in radians (XYZ convention) + + Returns: + 4x4 transformation matrix + """ + # Create transformation matrix + T = np.eye(4) + + # Set translation + T[0:3, 3] = [translation.x, translation.y, translation.z] + + # Set rotation using scipy + if np.linalg.norm([euler_angles.x, euler_angles.y, euler_angles.z]) > 1e-6: + rotation = R.from_euler('xyz', [euler_angles.x, euler_angles.y, euler_angles.z]) + T[0:3, 0:3] = rotation.as_matrix() + + return T + + +def invert_transform(T: np.ndarray) -> np.ndarray: + """ + Invert a 4x4 transformation matrix efficiently. + + Args: + T: 4x4 transformation matrix + + Returns: + Inverted 4x4 transformation matrix + """ + # For homogeneous transform matrices, we can use the special structure: + # [R t]^-1 = [R^T -R^T*t] + # [0 1] [0 1 ] + + Rot = T[:3, :3] + t = T[:3, 3] + + T_inv = np.eye(4) + T_inv[:3, :3] = Rot.T + T_inv[:3, 3] = -Rot.T @ t + + return T_inv + + +def compose_transforms(*transforms: np.ndarray) -> np.ndarray: + """ + Compose multiple transformation matrices. + + Args: + *transforms: Variable number of 4x4 transformation matrices + + Returns: + Composed 4x4 transformation matrix (T1 @ T2 @ ... @ Tn) + """ + result = np.eye(4) + for T in transforms: + result = result @ T + return result + + +def euler_to_quaternion(euler_angles: Vector3, degrees: bool = False) -> Quaternion: + """ + Convert euler angles to quaternion. + + Args: + euler_angles: Euler angles as Vector3 [roll, pitch, yaw] in radians (XYZ convention) + + Returns: + Quaternion object [x, y, z, w] + """ + rotation = R.from_euler('xyz', [euler_angles.x, euler_angles.y, euler_angles.z], degrees=degrees) + quat = rotation.as_quat() # Returns [x, y, z, w] + return Quaternion(quat[0], quat[1], quat[2], quat[3]) + + +def quaternion_to_euler(quaternion: Quaternion, degrees: bool = False) -> Vector3: + """ + Convert quaternion to euler angles. + + Args: + quaternion: Quaternion object [x, y, z, w] + + Returns: + Euler angles as Vector3 [roll, pitch, yaw] in radians (XYZ convention) + """ + quat = [quaternion.x, quaternion.y, quaternion.z, quaternion.w] + rotation = R.from_quat(quat) + euler = rotation.as_euler('xyz', degrees=degrees) # Returns [roll, pitch, yaw] + if not degrees: + return Vector3(normalize_angle(euler[0]), normalize_angle(euler[1]), normalize_angle(euler[2])) + else: + return Vector3(euler[0], euler[1], euler[2]) diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index 86b5b2b563..d4d90de16f 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -16,24 +16,31 @@ # Copyright 2025 Dimensional Inc. """ -Test script for PBVS with ZED camera supporting robot arm frame. -Click on objects to select targets (requires origin to be set first). -Press 'o' to set manipulator origin at current camera pose. +Test script for PBVS with eye-in-hand configuration. +Uses EE pose as odometry for camera pose estimation. +Click on objects to select targets. """ import cv2 import numpy as np import sys import os -import time -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import tests.test_header from dimos.hardware.zed_camera import ZEDCamera -from dimos.manipulation.ibvs.detection3d import Detection3DProcessor -from dimos.manipulation.ibvs.utils import parse_zed_pose +from dimos.hardware.piper_arm import PiperArm +from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor from dimos.perception.common.utils import find_clicked_object -from dimos.manipulation.ibvs.pbvs import PBVSController +from dimos.manipulation.visual_servoing.pbvs import PBVS +from dimos.utils.transform_utils import ( + pose_to_matrix, + matrix_to_pose, + create_transform_from_6dof, + compose_transforms, + quaternion_to_euler, +) +from dimos.msgs.geometry_msgs import Vector3 try: import pyzed.sl as sl @@ -44,8 +51,6 @@ # Global for mouse events mouse_click = None -warning_message = None -warning_time = None def mouse_callback(event, x, y, flags, param): @@ -54,20 +59,91 @@ def mouse_callback(event, x, y, flags, param): mouse_click = (x, y) -def main(): - global mouse_click, warning_message, warning_time +def execute_grasp(arm, target_object, grasp_width_offset: float = 0.02) -> bool: + """ + Execute grasping by opening gripper to accommodate target object. + + Args: + arm: Robot arm interface with gripper control + target_object: ObjectData with size information + safety_margin: Multiplier for gripper opening (default 1.5x object size) + + Returns: + True if grasp was executed, False if no target or no size data + """ + if not target_object: + print("❌ No target object provided for grasping") + return False + + if "size" not in target_object: + print("❌ Target has no size information for grasping") + return False + + # Get object size from detection3d data (already in meters) + object_size = target_object["size"] + object_width = object_size["width"] + object_height = object_size["height"] + object_depth = object_size["depth"] + + # Use the larger dimension (width or height) for gripper opening + # Depth is not relevant for gripper opening (that's approach direction) + + # Calculate gripper opening with safety margin + gripper_opening = object_width + grasp_width_offset + + # Clamp gripper opening to reasonable limits (0.5cm to 10cm) + gripper_opening = max(0.005, min(gripper_opening, 0.1)) # 0.5cm to 10cm + + print(f"🤏 Executing grasp: object size w={object_width*1000:.1f}mm h={object_height*1000:.1f}mm d={object_depth*1000:.1f}mm, " + f"offset={grasp_width_offset*1000:.1f}mm, opening gripper to {gripper_opening*1000:.1f}mm") + + # Command gripper to open + arm.cmd_gripper_ctrl(gripper_opening) + + return True - print("=== PBVS Test with Robot Frame Support ===") - print("IMPORTANT: Press 'o' to set manipulator origin FIRST") - print("Then click objects to select targets | 'r' - reset | 'q' - quit") - # Initialize camera +def main(): + global mouse_click + + # Control mode flag + DIRECT_EE_CONTROL = True # Set to True for direct EE pose control, False for velocity control + + print("=== PBVS Eye-in-Hand Test ===") + print("Using EE pose as odometry for camera pose") + print(f"Control mode: {'Direct EE Pose' if DIRECT_EE_CONTROL else 'Velocity Commands'}") + print("Click objects to select targets | 'r' - reset | 'q' - quit") + if DIRECT_EE_CONTROL: + print("SAFETY CONTROLS:") + print(" 's' - SOFT STOP (emergency stop)") + print(" 'h' - GO HOME (return to safe position)") + print(" 'SPACE' - EXECUTE target pose (only moves when pressed)") + print(" 'g' - EXECUTE GRASP (open gripper for target object)") + + # Initialize hardware zed = ZEDCamera(resolution=sl.RESOLUTION.HD720, depth_mode=sl.DEPTH_MODE.NEURAL) - if not zed.open() or not zed.enable_positional_tracking(): + if not zed.open(): print("Camera initialization failed!") return - # Get intrinsics + # Initialize robot arm + try: + arm = PiperArm() + print("Initialized Piper arm") + except Exception as e: + print(f"Failed to initialize Piper arm: {e}") + return + + # Define EE to camera transform (adjust these values for your setup) + # Format: [x, y, z, rx, ry, rz] in meters and radians + ee_to_camera_6dof = [-0.06, 0.03, -0.05, 0.0, -1.57, 0.0] + + # Create transform matrices + pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) + rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) + T_ee_to_camera = create_transform_from_6dof(pos, rot) + + # Get camera intrinsics cam_info = zed.get_camera_info() intrinsics = [ cam_info["left_cam"]["fx"], @@ -78,129 +154,115 @@ def main(): # Initialize processors detector = Detection3DProcessor(intrinsics) - pbvs = PBVSController(position_gain=0.3, rotation_gain=0.2, target_tolerance=0.025) + pbvs = PBVS(position_gain=0.3, rotation_gain=0.2, target_tolerance=0.05, pregrasp_distance=0.2, direct_ee_control=DIRECT_EE_CONTROL) # Setup window cv2.namedWindow("PBVS") cv2.setMouseCallback("PBVS", mouse_callback) + # Control state for direct EE mode + execute_target = False # Only move when space is pressed + last_valid_target = None + try: while True: # Capture - bgr, _, depth, pose_data = zed.capture_frame_with_pose() + bgr, _, depth, _ = zed.capture_frame_with_pose() if bgr is None or depth is None: continue # Process rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - camera_pose = parse_zed_pose(pose_data) if pose_data else None - results = detector.process_frame(rgb, depth, camera_pose) - detections = results["detections"] + + # Get EE pose from robot (this serves as our odometry) + ee_pose = arm.get_ee_pose() + + # Transform EE pose to camera pose + ee_transform = pose_to_matrix(ee_pose) + camera_transform = compose_transforms(ee_transform, T_ee_to_camera) + camera_pose = matrix_to_pose(camera_transform) + + # Process detections using camera transform + detections = detector.process_frame(rgb, depth, camera_transform) # Handle click if mouse_click: clicked = find_clicked_object(mouse_click, detections) if clicked: - # Try to set target (will fail if no origin) - if not pbvs.set_target(clicked): - warning_message = "SET ORIGIN FIRST! Press 'o'" - warning_time = time.time() + pbvs.set_target(clicked) mouse_click = None - # Create visualization with position overlays (robot frame if available) - viz = detector.visualize_detections(rgb, detections, pbvs_controller=pbvs) + # Create visualization with position overlays + viz = detector.visualize_detections(rgb, detections) # PBVS control - if camera_pose: - vel_cmd, ang_vel_cmd, reached, has_target = pbvs.compute_control( - camera_pose, detections + vel_cmd, ang_vel_cmd, reached, target_tracked, target_pose = pbvs.compute_control( + ee_pose, detections + ) + + # Apply commands to robot based on control mode + if DIRECT_EE_CONTROL and target_pose and execute_target: + # Direct EE pose control - only when space is pressed + print(f"🎯 EXECUTING target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f})") + last_valid_target = pbvs.get_current_target() + arm.cmd_ee_pose(target_pose) + execute_target = False # Reset flag after execution + elif not DIRECT_EE_CONTROL and vel_cmd and ang_vel_cmd: + # Velocity control + arm.cmd_vel_ee( + vel_cmd.x, vel_cmd.y, vel_cmd.z, + ang_vel_cmd.x, ang_vel_cmd.y, ang_vel_cmd.z ) - # Apply PBVS overlay - viz = pbvs.create_status_overlay(viz, intrinsics) + # Apply PBVS overlay + viz = pbvs.create_status_overlay(viz) - # Highlight target - if has_target and pbvs.current_target and "bbox" in pbvs.current_target: - x1, y1, x2, y2 = map(int, pbvs.current_target["bbox"]) - cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) - cv2.putText( - viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 - ) + # Highlight target + current_target = pbvs.get_current_target() + if target_tracked and current_target and "bbox" in current_target: + x1, y1, x2, y2 = map(int, current_target["bbox"]) + cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) + cv2.putText( + viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 + ) # Convert back to BGR for OpenCV display viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) - # Add camera pose info - if camera_pose: - # Show camera pose in appropriate frame - if pbvs.manipulator_origin is not None: - cam_robot = pbvs.get_camera_pose_robot_frame(camera_pose) - if cam_robot: - pose_text = f"Camera [Robot]: ({cam_robot.position.x:.2f}, {cam_robot.position.y:.2f}, {cam_robot.position.z:.2f})m" - else: - pose_text = f"Camera [ZED]: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" + # Add pose info + mode_text = "Direct EE" if DIRECT_EE_CONTROL else "Velocity" + cv2.putText( + viz_bgr, f"Eye-in-Hand ({mode_text})", (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1 + ) + + camera_text = f"Camera: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" + cv2.putText( + viz_bgr, camera_text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 + ) + + ee_text = f"EE: ({ee_pose.position.x:.2f}, {ee_pose.position.y:.2f}, {ee_pose.position.z:.2f})m" + cv2.putText( + viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1 + ) + + # Add direct EE control status + if DIRECT_EE_CONTROL: + if target_pose: + status_text = "Target Ready - Press SPACE to execute" + status_color = (0, 255, 255) # Yellow else: - pose_text = f"Camera [ZED]: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" - + status_text = "No target selected" + status_color = (100, 100, 100) # Gray + cv2.putText( - viz_bgr, pose_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 + viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 + ) + + cv2.putText( + viz_bgr, "s=STOP | h=HOME | SPACE=EXECUTE | g=GRASP", (10, 110), + cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1 ) - - # Show origin status - if pbvs.manipulator_origin is not None: - cv2.putText( - viz_bgr, - "Manipulator Origin SET", - (10, 50), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (0, 255, 0), - 1, - ) - else: - cv2.putText( - viz_bgr, - "Press 'o' to set manipulator origin", - (10, 50), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 0, 0), - 1, - ) - - # Display warning message if active - if warning_message and warning_time: - # Show warning for 3 seconds - if time.time() - warning_time < 3.0: - # Draw warning box - height, width = viz_bgr.shape[:2] - box_height = 80 - box_y = height // 2 - box_height // 2 - - # Semi-transparent red background - overlay = viz_bgr.copy() - cv2.rectangle( - overlay, (50, box_y), (width - 50, box_y + box_height), (0, 0, 255), -1 - ) - viz_bgr = cv2.addWeighted(viz_bgr, 0.7, overlay, 0.3, 0) - - # Warning text - text_size = cv2.getTextSize(warning_message, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)[0] - text_x = (width - text_size[0]) // 2 - text_y = box_y + box_height // 2 + text_size[1] // 2 - - cv2.putText( - viz_bgr, - warning_message, - (text_x, text_y), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (255, 255, 255), - 2, - ) - else: - warning_message = None - warning_time = None # Display cv2.imshow("PBVS", viz_bgr) @@ -211,11 +273,36 @@ def main(): break elif key == ord("r"): pbvs.clear_target() - elif key == ord("o") and camera_pose: - pbvs.set_manipulator_origin(camera_pose) - print( - f"Set manipulator origin at: ({camera_pose.position.x:.3f}, {camera_pose.position.y:.3f}, {camera_pose.position.z:.3f})" - ) + elif key == ord("s"): + # SOFT STOP - Emergency stop + print("🛑 SOFT STOP - Emergency stopping robot!") + arm.softStop() + elif key == ord("h"): + # GO HOME - Return to safe position + print("🏠 GO HOME - Returning to safe position...") + arm.gotoZero() + elif key == ord(" "): + # SPACE - Execute target pose (only in direct EE mode) + if DIRECT_EE_CONTROL and target_pose: + execute_target = True + target_euler = quaternion_to_euler(target_pose.orientation, degrees=True) + print("⚡ SPACE pressed - Target will execute on next frame!") + print(f"📍 Target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f}) " + f"rot=({target_euler.x:.1f}°, {target_euler.y:.1f}°, {target_euler.z:.1f}°)") + elif key == ord("g"): + # G - Execute grasp (open gripper for target object) + current_target = pbvs.get_current_target() + if current_target: + last_valid_target = current_target + if last_valid_target: + print("🤏 GRASP - Opening gripper for target object...") + success = execute_grasp(arm, last_valid_target, grasp_width_offset=0.03) + if success: + print("✅ Gripper opened successfully") + else: + print("❌ Failed to execute grasp") + else: + print("❌ No target selected for grasping") except KeyboardInterrupt: pass @@ -223,6 +310,7 @@ def main(): cv2.destroyAllWindows() detector.cleanup() zed.close() + arm.disable() if __name__ == "__main__": From cc0bf7b6e3172ca227976c94b0e705ac36b424d8 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Thu, 17 Jul 2025 01:59:59 +0000 Subject: [PATCH 61/89] CI code cleanup --- .../visual_servoing/detection3d.py | 7 +- dimos/manipulation/visual_servoing/pbvs.py | 176 ++++++++++-------- dimos/utils/transform_utils.py | 44 +++-- tests/test_ibvs.py | 91 +++++---- 4 files changed, 179 insertions(+), 139 deletions(-) diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index 2b6e7e518b..c885f4af50 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -213,11 +213,8 @@ def _transform_object_pose( # Convert euler angles to quaternion using utility function euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2]) obj_orientation_quat = euler_to_quaternion(euler_vector) - - obj_pose_optical = Pose( - Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), - obj_orientation_quat - ) + + obj_pose_optical = Pose(Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), obj_orientation_quat) # Transform object pose from optical frame to robot frame convention first obj_pose_robot_frame = optical_to_robot_frame(obj_pose_optical) diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index e3099ca7bc..3477f1a4b3 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -36,17 +36,17 @@ class PBVS: """ High-level Position-Based Visual Servoing orchestrator. - + Handles: - Object tracking and target management - Pregrasp distance computation - Grasp pose generation - Coordination with low-level controller - + Note: This class is agnostic to camera mounting (eye-in-hand vs eye-to-hand). The caller is responsible for providing appropriate camera and EE poses. """ - + def __init__( self, position_gain: float = 0.5, @@ -82,28 +82,28 @@ def __init__( ) else: self.controller = None - + # Store parameters for direct mode error computation self.target_tolerance = target_tolerance - + # Target tracking parameters self.tracking_distance_threshold = tracking_distance_threshold self.pregrasp_distance = pregrasp_distance self.direct_ee_control = direct_ee_control - + # Target state self.current_target = None self.target_grasp_pose = None - + # For direct control mode visualization self.last_position_error = None self.last_target_reached = False - + logger.info( f"Initialized PBVS system with controller gains: pos={position_gain}, rot={rotation_gain}, " f"pregrasp_distance={pregrasp_distance}m" ) - + def set_target(self, target_object: Dict[str, Any]) -> bool: """ Set a new target object for servoing. @@ -120,7 +120,7 @@ def set_target(self, target_object: Dict[str, Any]) -> bool: logger.info(f"New target set: ID {target_object.get('object_id', 'unknown')}") return True return False - + def clear_target(self): """Clear the current target.""" self.current_target = None @@ -130,37 +130,37 @@ def clear_target(self): if self.controller: self.controller.clear_state() logger.info("Target cleared") - + def get_current_target(self): """ Get the current target object. - + Returns: Current target ObjectData or None if no target selected """ return self.current_target - + def is_target_reached(self, ee_pose: Pose) -> bool: """ Check if the current target has been reached. - + Args: ee_pose: Current end-effector pose - + Returns: True if target is reached, False otherwise """ if not self.target_grasp_pose: return False - + # Calculate position error error_x = self.target_grasp_pose.position.x - ee_pose.position.x error_y = self.target_grasp_pose.position.y - ee_pose.position.y error_z = self.target_grasp_pose.position.z - ee_pose.position.z - + error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) return error_magnitude < self.target_tolerance - + def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: """ Update target by matching to closest object in new detections. @@ -203,7 +203,7 @@ def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: if distance < self.tracking_distance_threshold: best_match = detection - + if distance < min_distance: min_distance = distance @@ -213,36 +213,35 @@ def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: return True logger.info(f"Target tracking lost: closest target distance={min_distance:.3f}m") return False - + def _update_target_grasp_pose(self, ee_pose: Pose): """ Update target grasp pose based on current target and EE pose. - + Args: ee_pose: Current end-effector pose """ if not self.current_target or "position" not in self.current_target: return - + # Get target position target_pos = self.current_target["position"] - + # Calculate orientation pointing from target towards EE yaw_to_ee = yaw_towards_point( - Vector3(target_pos.x, target_pos.y, target_pos.z), - ee_pose.position + Vector3(target_pos.x, target_pos.y, target_pos.z), ee_pose.position ) - + # Create target pose with proper orientation # Convert euler angles to quaternion using utility function euler = Vector3(0.0, 1.65, yaw_to_ee) # roll=0, pitch=90deg, yaw=calculated target_orientation = euler_to_quaternion(euler) - + target_pose = Pose(target_pos, target_orientation) - + # Apply pregrasp distance self.target_grasp_pose = self._apply_pregrasp_distance(target_pose, ee_pose) - + def _apply_pregrasp_distance(self, target_pose: Pose, ee_pose: Pose) -> Pose: """ Apply pregrasp distance to target pose by moving back towards EE. @@ -255,7 +254,9 @@ def _apply_pregrasp_distance(self, target_pose: Pose, ee_pose: Pose) -> Pose: Modified target pose with pregrasp distance applied """ # Get approach vector (from target position towards EE) - target_pos = np.array([target_pose.position.x, target_pose.position.y, target_pose.position.z]) + target_pos = np.array( + [target_pose.position.x, target_pose.position.y, target_pose.position.z] + ) ee_pos = np.array([ee_pose.position.x, ee_pose.position.y, ee_pose.position.z]) approach_vector = ee_pos - target_pos # Vector pointing towards EE @@ -273,11 +274,11 @@ def _apply_pregrasp_distance(self, target_pose: Pose, ee_pose: Pose) -> Pose: new_position = Vector3( target_pose.position.x + offset_vector[0], target_pose.position.y + offset_vector[1], - target_pose.position.z + offset_vector[2] + target_pose.position.z + offset_vector[2], ) return Pose(new_position, target_pose.orientation) - + def compute_control( self, ee_pose: Pose, new_detections: Optional[List[ObjectData]] = None ) -> Tuple[Optional[Vector3], Optional[Vector3], bool, bool, Optional[Pose]]: @@ -299,7 +300,7 @@ def compute_control( # Check if we have a target if not self.current_target or "position" not in self.current_target: return None, None, False, False, None - + # Try to update target tracking if new detections provided # Continue with last known pose even if tracking is lost target_tracked = False @@ -308,34 +309,34 @@ def compute_control( target_tracked = True else: target_tracked = False - + # Update target grasp pose self._update_target_grasp_pose(ee_pose) - + if self.target_grasp_pose is None: logger.warning("Failed to compute grasp pose") return None, None, False, False, None - + # Check if target reached using our separate function target_reached = self.is_target_reached(ee_pose) - + # Return appropriate values based on control mode if self.direct_ee_control: # Direct control mode - compute errors for visualization only self.last_position_error = Vector3( self.target_grasp_pose.position.x - ee_pose.position.x, self.target_grasp_pose.position.y - ee_pose.position.y, - self.target_grasp_pose.position.z - ee_pose.position.z + self.target_grasp_pose.position.z - ee_pose.position.z, ) self.last_target_reached = target_reached return None, None, target_reached, target_tracked, self.target_grasp_pose else: # Velocity control mode - use controller - velocity_cmd, angular_velocity_cmd, controller_reached = self.controller.compute_control( - ee_pose, self.target_grasp_pose + velocity_cmd, angular_velocity_cmd, controller_reached = ( + self.controller.compute_control(ee_pose, self.target_grasp_pose) ) return velocity_cmd, angular_velocity_cmd, target_reached, target_tracked, None - + def get_object_pose_camera_frame( self, object_pos: Vector3, camera_pose: Pose ) -> Tuple[Vector3, Quaternion]: @@ -351,15 +352,16 @@ def get_object_pose_camera_frame( """ # Calculate orientation pointing at camera yaw_to_camera = yaw_towards_point(Vector3(object_pos.x, object_pos.y, object_pos.z)) - + # Convert euler angles to quaternion using utility function euler = Vector3(0.0, 0.0, yaw_to_camera) # Level grasp orientation = euler_to_quaternion(euler) return object_pos, orientation - + def create_status_overlay( - self, image: np.ndarray, + self, + image: np.ndarray, ) -> np.ndarray: """ Create PBVS status overlay on image. @@ -377,19 +379,21 @@ def create_status_overlay( else: # Use controller's overlay for velocity mode return self.controller.create_status_overlay( - image, + image, self.current_target, self.direct_ee_control, ) - - def _create_direct_status_overlay(self, image: np.ndarray, current_target: Optional[ObjectData] = None) -> np.ndarray: + + def _create_direct_status_overlay( + self, image: np.ndarray, current_target: Optional[ObjectData] = None + ) -> np.ndarray: """ Create status overlay for direct control mode. - + Args: image: Input image current_target: Current target object - + Returns: Image with status overlay """ @@ -407,7 +411,13 @@ def _create_direct_status_overlay(self, image: np.ndarray, current_target: Optio # Status text y = panel_y + 20 cv2.putText( - viz_img, "PBVS Status (Direct EE)", (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2 + viz_img, + "PBVS Status (Direct EE)", + (10, y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, ) # Add frame info @@ -417,7 +427,11 @@ def _create_direct_status_overlay(self, image: np.ndarray, current_target: Optio if self.last_position_error: error_mag = np.linalg.norm( - [self.last_position_error.x, self.last_position_error.y, self.last_position_error.z] + [ + self.last_position_error.x, + self.last_position_error.y, + self.last_position_error.z, + ] ) color = (0, 255, 0) if self.last_target_reached else (0, 255, 255) @@ -453,7 +467,7 @@ def _create_direct_status_overlay(self, image: np.ndarray, current_target: Optio (255, 255, 0), 1, ) - + if self.target_grasp_pose: grasp_pos = self.target_grasp_pose.position cv2.putText( @@ -465,18 +479,18 @@ def _create_direct_status_overlay(self, image: np.ndarray, current_target: Optio (0, 255, 255), 1, ) - + # Show pregrasp distance if we have both poses if current_target and "position" in current_target: target_pos = current_target["position"] distance = np.sqrt( - (grasp_pos.x - target_pos.x)**2 + - (grasp_pos.y - target_pos.y)**2 + - (grasp_pos.z - target_pos.z)**2 + (grasp_pos.x - target_pos.x) ** 2 + + (grasp_pos.y - target_pos.y) ** 2 + + (grasp_pos.z - target_pos.z) ** 2 ) cv2.putText( viz_img, - f"Pregrasp: {distance*1000:.1f}mm", + f"Pregrasp: {distance * 1000:.1f}mm", (10, y + 95), cv2.FONT_HERSHEY_SIMPLEX, 0.4, @@ -502,7 +516,7 @@ class PBVSController: """ Low-level Position-Based Visual Servoing controller. Pure control logic that computes velocity commands from poses. - + Handles: - Position and orientation error computation - Velocity command generation with gain control @@ -545,7 +559,7 @@ def __init__( f"max_vel={max_velocity}m/s, max_ang_vel={max_angular_velocity}rad/s, " f"target_tolerance={target_tolerance}m" ) - + def clear_state(self): """Clear controller state.""" self.last_position_error = None @@ -561,7 +575,7 @@ def compute_control( Compute PBVS control with position and orientation servoing. Args: - ee_pose: Current end-effector pose + ee_pose: Current end-effector pose grasp_pose: Target grasp pose Returns: @@ -574,7 +588,7 @@ def compute_control( error = Vector3( grasp_pose.position.x - ee_pose.position.x, grasp_pose.position.y - ee_pose.position.y, - grasp_pose.position.z - ee_pose.position.z + grasp_pose.position.z - ee_pose.position.z, ) self.last_position_error = error @@ -620,25 +634,27 @@ def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) Angular velocity command as Vector3 """ # Use quaternion error for better numerical stability - + # Convert to scipy Rotation objects target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w]) - current_rot_scipy = R.from_quat([ - current_pose.orientation.x, - current_pose.orientation.y, - current_pose.orientation.z, - current_pose.orientation.w - ]) - + current_rot_scipy = R.from_quat( + [ + current_pose.orientation.x, + current_pose.orientation.y, + current_pose.orientation.z, + current_pose.orientation.w, + ] + ) + # Compute rotation error: error = target * current^(-1) error_rot = target_rot_scipy * current_rot_scipy.inv() - + # Convert to axis-angle representation for control error_axis_angle = error_rot.as_rotvec() - + # Use axis-angle directly as angular velocity error (small angle approximation) roll_error = error_axis_angle[0] - pitch_error = error_axis_angle[1] + pitch_error = error_axis_angle[1] yaw_error = error_axis_angle[2] self.last_rotation_error = Vector3(roll_error, pitch_error, yaw_error) @@ -657,9 +673,7 @@ def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) if ang_vel_magnitude > self.max_angular_velocity: scale = self.max_angular_velocity / ang_vel_magnitude angular_velocity = Vector3( - angular_velocity.x * scale, - angular_velocity.y * scale, - angular_velocity.z * scale + angular_velocity.x * scale, angular_velocity.y * scale, angular_velocity.z * scale ) self.last_angular_velocity_cmd = angular_velocity @@ -667,8 +681,8 @@ def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) return angular_velocity def create_status_overlay( - self, - image: np.ndarray, + self, + image: np.ndarray, current_target: Optional[Dict[str, Any]] = None, direct_ee_control: bool = False, ) -> np.ndarray: @@ -698,7 +712,13 @@ def create_status_overlay( y = panel_y + 20 mode_text = "Direct EE" if direct_ee_control else "Velocity" cv2.putText( - viz_img, f"PBVS Status ({mode_text})", (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2 + viz_img, + f"PBVS Status ({mode_text})", + (10, y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, ) # Add frame info diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index e57a17336b..689091bc3b 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -271,59 +271,59 @@ def transform_robot_to_map( def create_transform_from_6dof(translation: Vector3, euler_angles: Vector3) -> np.ndarray: """ Create a 4x4 transformation matrix from 6DOF parameters. - + Args: translation: Translation vector [x, y, z] in meters euler_angles: Euler angles [rx, ry, rz] in radians (XYZ convention) - + Returns: 4x4 transformation matrix """ # Create transformation matrix T = np.eye(4) - + # Set translation T[0:3, 3] = [translation.x, translation.y, translation.z] - + # Set rotation using scipy if np.linalg.norm([euler_angles.x, euler_angles.y, euler_angles.z]) > 1e-6: - rotation = R.from_euler('xyz', [euler_angles.x, euler_angles.y, euler_angles.z]) + rotation = R.from_euler("xyz", [euler_angles.x, euler_angles.y, euler_angles.z]) T[0:3, 0:3] = rotation.as_matrix() - + return T def invert_transform(T: np.ndarray) -> np.ndarray: """ Invert a 4x4 transformation matrix efficiently. - + Args: T: 4x4 transformation matrix - + Returns: Inverted 4x4 transformation matrix """ # For homogeneous transform matrices, we can use the special structure: # [R t]^-1 = [R^T -R^T*t] # [0 1] [0 1 ] - + Rot = T[:3, :3] t = T[:3, 3] - + T_inv = np.eye(4) T_inv[:3, :3] = Rot.T T_inv[:3, 3] = -Rot.T @ t - + return T_inv def compose_transforms(*transforms: np.ndarray) -> np.ndarray: """ Compose multiple transformation matrices. - + Args: *transforms: Variable number of 4x4 transformation matrices - + Returns: Composed 4x4 transformation matrix (T1 @ T2 @ ... @ Tn) """ @@ -336,14 +336,16 @@ def compose_transforms(*transforms: np.ndarray) -> np.ndarray: def euler_to_quaternion(euler_angles: Vector3, degrees: bool = False) -> Quaternion: """ Convert euler angles to quaternion. - + Args: euler_angles: Euler angles as Vector3 [roll, pitch, yaw] in radians (XYZ convention) - + Returns: Quaternion object [x, y, z, w] """ - rotation = R.from_euler('xyz', [euler_angles.x, euler_angles.y, euler_angles.z], degrees=degrees) + rotation = R.from_euler( + "xyz", [euler_angles.x, euler_angles.y, euler_angles.z], degrees=degrees + ) quat = rotation.as_quat() # Returns [x, y, z, w] return Quaternion(quat[0], quat[1], quat[2], quat[3]) @@ -351,17 +353,19 @@ def euler_to_quaternion(euler_angles: Vector3, degrees: bool = False) -> Quatern def quaternion_to_euler(quaternion: Quaternion, degrees: bool = False) -> Vector3: """ Convert quaternion to euler angles. - + Args: quaternion: Quaternion object [x, y, z, w] - + Returns: Euler angles as Vector3 [roll, pitch, yaw] in radians (XYZ convention) """ quat = [quaternion.x, quaternion.y, quaternion.z, quaternion.w] rotation = R.from_quat(quat) - euler = rotation.as_euler('xyz', degrees=degrees) # Returns [roll, pitch, yaw] + euler = rotation.as_euler("xyz", degrees=degrees) # Returns [roll, pitch, yaw] if not degrees: - return Vector3(normalize_angle(euler[0]), normalize_angle(euler[1]), normalize_angle(euler[2])) + return Vector3( + normalize_angle(euler[0]), normalize_angle(euler[1]), normalize_angle(euler[2]) + ) else: return Vector3(euler[0], euler[1], euler[2]) diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index d4d90de16f..8ed21fabdc 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -62,44 +62,46 @@ def mouse_callback(event, x, y, flags, param): def execute_grasp(arm, target_object, grasp_width_offset: float = 0.02) -> bool: """ Execute grasping by opening gripper to accommodate target object. - + Args: arm: Robot arm interface with gripper control target_object: ObjectData with size information safety_margin: Multiplier for gripper opening (default 1.5x object size) - + Returns: True if grasp was executed, False if no target or no size data """ if not target_object: print("❌ No target object provided for grasping") return False - + if "size" not in target_object: print("❌ Target has no size information for grasping") return False - + # Get object size from detection3d data (already in meters) object_size = target_object["size"] object_width = object_size["width"] - object_height = object_size["height"] + object_height = object_size["height"] object_depth = object_size["depth"] - + # Use the larger dimension (width or height) for gripper opening # Depth is not relevant for gripper opening (that's approach direction) - + # Calculate gripper opening with safety margin gripper_opening = object_width + grasp_width_offset - + # Clamp gripper opening to reasonable limits (0.5cm to 10cm) gripper_opening = max(0.005, min(gripper_opening, 0.1)) # 0.5cm to 10cm - - print(f"🤏 Executing grasp: object size w={object_width*1000:.1f}mm h={object_height*1000:.1f}mm d={object_depth*1000:.1f}mm, " - f"offset={grasp_width_offset*1000:.1f}mm, opening gripper to {gripper_opening*1000:.1f}mm") - + + print( + f"🤏 Executing grasp: object size w={object_width * 1000:.1f}mm h={object_height * 1000:.1f}mm d={object_depth * 1000:.1f}mm, " + f"offset={grasp_width_offset * 1000:.1f}mm, opening gripper to {gripper_opening * 1000:.1f}mm" + ) + # Command gripper to open arm.cmd_gripper_ctrl(gripper_opening) - + return True @@ -108,7 +110,7 @@ def main(): # Control mode flag DIRECT_EE_CONTROL = True # Set to True for direct EE pose control, False for velocity control - + print("=== PBVS Eye-in-Hand Test ===") print("Using EE pose as odometry for camera pose") print(f"Control mode: {'Direct EE Pose' if DIRECT_EE_CONTROL else 'Velocity Commands'}") @@ -137,7 +139,7 @@ def main(): # Define EE to camera transform (adjust these values for your setup) # Format: [x, y, z, rx, ry, rz] in meters and radians ee_to_camera_6dof = [-0.06, 0.03, -0.05, 0.0, -1.57, 0.0] - + # Create transform matrices pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) @@ -154,7 +156,13 @@ def main(): # Initialize processors detector = Detection3DProcessor(intrinsics) - pbvs = PBVS(position_gain=0.3, rotation_gain=0.2, target_tolerance=0.05, pregrasp_distance=0.2, direct_ee_control=DIRECT_EE_CONTROL) + pbvs = PBVS( + position_gain=0.3, + rotation_gain=0.2, + target_tolerance=0.05, + pregrasp_distance=0.2, + direct_ee_control=DIRECT_EE_CONTROL, + ) # Setup window cv2.namedWindow("PBVS") @@ -163,7 +171,7 @@ def main(): # Control state for direct EE mode execute_target = False # Only move when space is pressed last_valid_target = None - + try: while True: # Capture @@ -173,10 +181,10 @@ def main(): # Process rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - + # Get EE pose from robot (this serves as our odometry) ee_pose = arm.get_ee_pose() - + # Transform EE pose to camera pose ee_transform = pose_to_matrix(ee_pose) camera_transform = compose_transforms(ee_transform, T_ee_to_camera) @@ -203,15 +211,16 @@ def main(): # Apply commands to robot based on control mode if DIRECT_EE_CONTROL and target_pose and execute_target: # Direct EE pose control - only when space is pressed - print(f"🎯 EXECUTING target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f})") + print( + f"🎯 EXECUTING target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f})" + ) last_valid_target = pbvs.get_current_target() arm.cmd_ee_pose(target_pose) execute_target = False # Reset flag after execution elif not DIRECT_EE_CONTROL and vel_cmd and ang_vel_cmd: # Velocity control arm.cmd_vel_ee( - vel_cmd.x, vel_cmd.y, vel_cmd.z, - ang_vel_cmd.x, ang_vel_cmd.y, ang_vel_cmd.z + vel_cmd.x, vel_cmd.y, vel_cmd.z, ang_vel_cmd.x, ang_vel_cmd.y, ang_vel_cmd.z ) # Apply PBVS overlay @@ -232,20 +241,23 @@ def main(): # Add pose info mode_text = "Direct EE" if DIRECT_EE_CONTROL else "Velocity" cv2.putText( - viz_bgr, f"Eye-in-Hand ({mode_text})", (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1 + viz_bgr, + f"Eye-in-Hand ({mode_text})", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 255), + 1, ) - + camera_text = f"Camera: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" cv2.putText( viz_bgr, camera_text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 ) - + ee_text = f"EE: ({ee_pose.position.x:.2f}, {ee_pose.position.y:.2f}, {ee_pose.position.z:.2f})m" - cv2.putText( - viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1 - ) - + cv2.putText(viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + # Add direct EE control status if DIRECT_EE_CONTROL: if target_pose: @@ -254,14 +266,19 @@ def main(): else: status_text = "No target selected" status_color = (100, 100, 100) # Gray - + cv2.putText( viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 ) - + cv2.putText( - viz_bgr, "s=STOP | h=HOME | SPACE=EXECUTE | g=GRASP", (10, 110), - cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1 + viz_bgr, + "s=STOP | h=HOME | SPACE=EXECUTE | g=GRASP", + (10, 110), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, ) # Display @@ -287,8 +304,10 @@ def main(): execute_target = True target_euler = quaternion_to_euler(target_pose.orientation, degrees=True) print("⚡ SPACE pressed - Target will execute on next frame!") - print(f"📍 Target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f}) " - f"rot=({target_euler.x:.1f}°, {target_euler.y:.1f}°, {target_euler.z:.1f}°)") + print( + f"📍 Target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f}) " + f"rot=({target_euler.x:.1f}°, {target_euler.y:.1f}°, {target_euler.z:.1f}°)" + ) elif key == ord("g"): # G - Execute grasp (open gripper for target object) current_target = pbvs.get_current_target() From cebc3832cd045648aefc46f6860efefcf75599f4 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 17 Jul 2025 00:46:24 -0700 Subject: [PATCH 62/89] grasp working --- dimos/hardware/piper_arm.py | 2 +- .../visual_servoing/detection3d.py | 13 +- dimos/manipulation/visual_servoing/pbvs.py | 225 +++++++++++------- dimos/manipulation/visual_servoing/utils.py | 146 +++++++++++- dimos/perception/pointcloud/utils.py | 22 +- dimos/perception/segmentation/sam_2d_seg.py | 29 ++- tests/test_ibvs.py | 6 +- 7 files changed, 325 insertions(+), 118 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 943f35c4c3..50c97b7abf 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -130,7 +130,7 @@ def cmd_gripper_ctrl(self, position): factor = 1000 position = position * factor * factor - self.arm.GripperCtrl(abs(round(position)), factor, 0x01, 0) + self.arm.GripperCtrl(abs(round(position)), 250, 0x01, 0) print(f"[PiperArm] Commanding gripper position: {position}") def resetArm(self): diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index c885f4af50..fde02d5c05 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -54,6 +54,7 @@ def __init__( min_confidence: float = 0.6, min_points: int = 30, max_depth: float = 1.0, + max_object_size: float = 0.2, ): """ Initialize the real-time 3D detection processor. @@ -67,11 +68,13 @@ def __init__( self.camera_intrinsics = camera_intrinsics self.min_points = min_points self.max_depth = max_depth + self.max_object_size = max_object_size # Initialize Sam segmenter with tracking enabled but analysis disabled self.detector = Sam2DSegmenter( use_tracker=False, use_analyzer=False, + use_filtering=False, device="cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu", ) @@ -80,7 +83,7 @@ def __init__( logger.info( f"Initialized Detection3DProcessor with Sam segmenter, confidence={min_confidence}, " - f"min_points={min_points}, max_depth={max_depth}m" + f"min_points={min_points}, max_depth={max_depth}m, max_object_size={max_object_size}m" ) def process_frame( @@ -122,8 +125,6 @@ def process_frame( depth_image=depth_image, masks=numpy_masks, camera_intrinsics=self.camera_intrinsics, - min_points=self.min_points, - max_depth=self.max_depth, ) # Build detection results @@ -149,6 +150,9 @@ def process_frame( # Set depth and position in camera frame obj_data["depth"] = float(obj_cam_pos[2]) + if obj_cam_pos[2] > self.max_depth: + continue + obj_data["rotation"] = None # Calculate object size from bbox and depth @@ -167,6 +171,9 @@ def process_frame( "depth": max(depth_m, 0.01), # Minimum 1cm depth } + if min(obj_data["size"]["width"], obj_data["size"]["height"], obj_data["size"]["depth"]) > self.max_object_size: + continue + # Extract average color from the region x1, y1, x2, y2 = map(int, bbox) roi = rgb_image[y1:y2, x1:x2] diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index 3477f1a4b3..05c6664405 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -20,6 +20,7 @@ import numpy as np from typing import Optional, Tuple, Dict, Any, List import cv2 +from enum import Enum from scipy.spatial.transform import Rotation as R from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion @@ -29,14 +30,21 @@ yaw_towards_point, euler_to_quaternion, ) +from dimos.manipulation.visual_servoing.utils import find_best_object_match logger = setup_logger("dimos.manipulation.pbvs") +class GraspStage(Enum): + """Enum for different grasp stages.""" + PRE_GRASP = "pre_grasp" + GRASP = "grasp" + + class PBVS: """ High-level Position-Based Visual Servoing orchestrator. - + Handles: - Object tracking and target management - Pregrasp distance computation @@ -54,8 +62,10 @@ def __init__( max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s target_tolerance: float = 0.01, # 1cm - tracking_distance_threshold: float = 0.05, # 5cm for target tracking + max_tracking_distance_threshold: float = 0.2, # Max distance for target tracking (m) + min_size_similarity: float = 0.7, # Min size similarity threshold (0.0-1.0) pregrasp_distance: float = 0.15, # 15cm pregrasp distance + grasp_distance: float = 0.05, # 5cm grasp distance (final approach) direct_ee_control: bool = False, # If True, output target poses instead of velocities ): """ @@ -67,8 +77,10 @@ def __init__( max_velocity: Maximum linear velocity command magnitude (m/s) max_angular_velocity: Maximum angular velocity command magnitude (rad/s) target_tolerance: Distance threshold for considering target reached (m) - tracking_distance_threshold: Max distance for target association (m) + max_tracking_distance: Maximum distance for valid target tracking (m) + min_size_similarity: Minimum size similarity for valid target tracking (0.0-1.0) pregrasp_distance: Distance to maintain before grasping (m) + grasp_distance: Distance for final grasp approach (m) direct_ee_control: If True, output target poses instead of velocity commands """ # Initialize low-level controller only if not in direct control mode @@ -87,21 +99,25 @@ def __init__( self.target_tolerance = target_tolerance # Target tracking parameters - self.tracking_distance_threshold = tracking_distance_threshold + self.max_tracking_distance_threshold = max_tracking_distance_threshold + self.min_size_similarity = min_size_similarity self.pregrasp_distance = pregrasp_distance + self.grasp_distance = grasp_distance self.direct_ee_control = direct_ee_control - # Target state + # Target state self.current_target = None self.target_grasp_pose = None - + self.grasp_stage = GraspStage.PRE_GRASP + # For direct control mode visualization self.last_position_error = None self.last_target_reached = False logger.info( f"Initialized PBVS system with controller gains: pos={position_gain}, rot={rotation_gain}, " - f"pregrasp_distance={pregrasp_distance}m" + f"pregrasp_distance={pregrasp_distance}m, grasp_distance={grasp_distance}m, " + f"tracking_thresholds: distance={max_tracking_distance_threshold}m, size={min_size_similarity:.2f}" ) def set_target(self, target_object: Dict[str, Any]) -> bool: @@ -117,6 +133,7 @@ def set_target(self, target_object: Dict[str, Any]) -> bool: if target_object and "position" in target_object: self.current_target = target_object self.target_grasp_pose = None # Will be computed when needed + self.grasp_stage = GraspStage.PRE_GRASP # Reset to pre-grasp stage logger.info(f"New target set: ID {target_object.get('object_id', 'unknown')}") return True return False @@ -125,6 +142,7 @@ def clear_target(self): """Clear the current target.""" self.current_target = None self.target_grasp_pose = None + self.grasp_stage = GraspStage.PRE_GRASP self.last_position_error = None self.last_target_reached = False if self.controller: @@ -139,27 +157,49 @@ def get_current_target(self): Current target ObjectData or None if no target selected """ return self.current_target + + def set_grasp_stage(self, stage: GraspStage): + """ + Set the grasp stage. + + Args: + stage: The new grasp stage + """ + self.grasp_stage = stage + + def is_target_reached(self, ee_pose: Pose) -> bool: """ - Check if the current target has been reached. - + Check if the current target stage has been reached. + Args: ee_pose: Current end-effector pose - + Returns: - True if target is reached, False otherwise + True if current stage target is reached, False otherwise """ if not self.target_grasp_pose: return False - + # Calculate position error error_x = self.target_grasp_pose.position.x - ee_pose.position.x error_y = self.target_grasp_pose.position.y - ee_pose.position.y error_z = self.target_grasp_pose.position.z - ee_pose.position.z - + error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) - return error_magnitude < self.target_tolerance + stage_reached = error_magnitude < self.target_tolerance + + # Handle stage transitions + if stage_reached and self.grasp_stage == GraspStage.PRE_GRASP: + return True # Signal that pre-grasp target was reached + elif stage_reached and self.grasp_stage == GraspStage.GRASP: + # Grasp reached, clear target + logger.info("Grasp position reached, clearing target") + self.clear_target() + return True + + return False def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: """ @@ -179,39 +219,28 @@ def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: logger.debug("No detections for target tracking - using last known pose") return False - # Get current target position - target_pos = self.current_target["position"] - if isinstance(target_pos, Vector3): - target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) - else: - target_xyz = np.array([target_pos["x"], target_pos["y"], target_pos["z"]]) - - best_match = None - min_distance = float("inf") - - for detection in new_detections: - if "position" not in detection: - continue - - det_pos = detection["position"] - if isinstance(det_pos, Vector3): - det_xyz = np.array([det_pos.x, det_pos.y, det_pos.z]) - else: - det_xyz = np.array([det_pos["x"], det_pos["y"], det_pos["z"]]) - - distance = np.linalg.norm(target_xyz - det_xyz) - - if distance < self.tracking_distance_threshold: - best_match = detection - - if distance < min_distance: - min_distance = distance + # Use stage-dependent distance threshold + max_distance = self.max_tracking_distance_threshold + + # Find best match using standardized utility function + match_result = find_best_object_match( + target_obj=self.current_target, + candidates=new_detections, + max_distance=max_distance, + min_size_similarity=self.min_size_similarity + ) - if best_match: - self.current_target = best_match + if match_result.is_valid_match: + self.current_target = match_result.matched_object self.target_grasp_pose = None # Recompute grasp pose + logger.debug(f"Target tracking successful: distance={match_result.distance:.3f}m, " + f"size_similarity={match_result.size_similarity:.2f}, " + f"confidence={match_result.confidence:.2f}") return True - logger.info(f"Target tracking lost: closest target distance={min_distance:.3f}m") + + logger.debug(f"Target tracking lost: distance={match_result.distance:.3f}m, " + f"size_similarity={match_result.size_similarity:.2f}, " + f"thresholds: distance={max_distance:.3f}m, size={self.min_size_similarity:.2f}") return False def _update_target_grasp_pose(self, ee_pose: Pose): @@ -221,7 +250,7 @@ def _update_target_grasp_pose(self, ee_pose: Pose): Args: ee_pose: Current end-effector pose """ - if not self.current_target or "position" not in self.current_target: + if not self.current_target: return # Get target position @@ -234,24 +263,25 @@ def _update_target_grasp_pose(self, ee_pose: Pose): # Create target pose with proper orientation # Convert euler angles to quaternion using utility function - euler = Vector3(0.0, 1.65, yaw_to_ee) # roll=0, pitch=90deg, yaw=calculated + euler = Vector3(0.0, 1.57, yaw_to_ee) # roll=0, pitch=90deg, yaw=calculated target_orientation = euler_to_quaternion(euler) target_pose = Pose(target_pos, target_orientation) - # Apply pregrasp distance - self.target_grasp_pose = self._apply_pregrasp_distance(target_pose, ee_pose) + # Apply grasp distance + distance = self.pregrasp_distance if self.grasp_stage == GraspStage.PRE_GRASP else self.grasp_distance + self.target_grasp_pose = self._apply_grasp_distance(target_pose, ee_pose, distance) - def _apply_pregrasp_distance(self, target_pose: Pose, ee_pose: Pose) -> Pose: + def _apply_grasp_distance(self, target_pose: Pose, ee_pose: Pose, distance: float) -> Pose: """ - Apply pregrasp distance to target pose by moving back towards EE. + Apply appropriate grasp distance to target pose based on current stage. Args: target_pose: Target pose ee_pose: Current end-effector pose Returns: - Modified target pose with pregrasp distance applied + Modified target pose with appropriate distance applied """ # Get approach vector (from target position towards EE) target_pos = np.array( @@ -267,8 +297,8 @@ def _apply_pregrasp_distance(self, target_pose: Pose, ee_pose: Pose) -> Pose: else: norm_approach_vector = np.array([0.0, 0.0, 0.0]) - # Move back by pregrasp distance towards EE - offset_vector = self.pregrasp_distance * norm_approach_vector + # Move back by appropriate distance towards EE based on stage + offset_vector = distance * norm_approach_vector # Apply offset to target position new_position = Vector3( @@ -311,25 +341,38 @@ def compute_control( target_tracked = False # Update target grasp pose + if not self.current_target: + logger.info("No current target") + self._update_target_grasp_pose(ee_pose) if self.target_grasp_pose is None: logger.warning("Failed to compute grasp pose") return None, None, False, False, None - # Check if target reached using our separate function - target_reached = self.is_target_reached(ee_pose) - - # Return appropriate values based on control mode - if self.direct_ee_control: - # Direct control mode - compute errors for visualization only + # Compute errors for visualization before checking if reached (in case pose gets cleared) + if self.direct_ee_control and self.target_grasp_pose: self.last_position_error = Vector3( self.target_grasp_pose.position.x - ee_pose.position.x, self.target_grasp_pose.position.y - ee_pose.position.y, self.target_grasp_pose.position.z - ee_pose.position.z, ) - self.last_target_reached = target_reached - return None, None, target_reached, target_tracked, self.target_grasp_pose + + # Check if target reached using our separate function + target_reached = self.is_target_reached(ee_pose) + + # If stage transitioned, recompute target grasp pose + if target_reached and self.grasp_stage == GraspStage.GRASP and self.target_grasp_pose is None: + self._update_target_grasp_pose(ee_pose) + + # Return appropriate values based on control mode + if self.direct_ee_control: + # Direct control mode + if self.target_grasp_pose: + self.last_target_reached = target_reached + return None, None, target_reached, target_tracked, self.target_grasp_pose + else: + return None, None, False, target_tracked, None else: # Velocity control mode - use controller velocity_cmd, angular_velocity_cmd, controller_reached = ( @@ -402,7 +445,7 @@ def _create_direct_status_overlay( # Status panel if current_target is not None: - panel_height = 160 # Adjusted panel for target, grasp pose, and pregrasp distance info + panel_height = 175 # Adjusted panel for target, grasp pose, stage, and distance info panel_y = height - panel_height overlay = viz_img.copy() cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) @@ -456,7 +499,7 @@ def _create_direct_status_overlay( ) # Show target and grasp poses - if current_target and "position" in current_target: + if current_target: target_pos = current_target["position"] cv2.putText( viz_img, @@ -481,19 +524,33 @@ def _create_direct_status_overlay( ) # Show pregrasp distance if we have both poses - if current_target and "position" in current_target: + if current_target: target_pos = current_target["position"] distance = np.sqrt( (grasp_pos.x - target_pos.x) ** 2 + (grasp_pos.y - target_pos.y) ** 2 + (grasp_pos.z - target_pos.z) ** 2 ) + + # Show current stage and distance + stage_text = f"Stage: {self.grasp_stage.value}" cv2.putText( viz_img, - f"Pregrasp: {distance * 1000:.1f}mm", + stage_text, (10, y + 95), cv2.FONT_HERSHEY_SIMPLEX, 0.4, + (255, 150, 255), + 1, + ) + + distance_text = f"Distance: {distance * 1000:.1f}mm" + cv2.putText( + viz_img, + distance_text, + (10, y + 110), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, (255, 200, 0), 1, ) @@ -594,9 +651,9 @@ def compute_control( # Compute velocity command with proportional control velocity_cmd = Vector3( - error.x * self.position_gain, - error.y * self.position_gain, - error.z * self.position_gain, + error.x * self.position_gain, + error.y * self.position_gain, + error.z * self.position_gain, ) # Limit velocity magnitude @@ -604,9 +661,9 @@ def compute_control( if vel_magnitude > self.max_velocity: scale = self.max_velocity / vel_magnitude velocity_cmd = Vector3( - float(velocity_cmd.x * scale), - float(velocity_cmd.y * scale), - float(velocity_cmd.z * scale), + float(velocity_cmd.x * scale), + float(velocity_cmd.y * scale), + float(velocity_cmd.z * scale), ) self.last_velocity_cmd = velocity_cmd @@ -634,36 +691,36 @@ def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) Angular velocity command as Vector3 """ # Use quaternion error for better numerical stability - + # Convert to scipy Rotation objects target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w]) current_rot_scipy = R.from_quat( [ - current_pose.orientation.x, - current_pose.orientation.y, - current_pose.orientation.z, + current_pose.orientation.x, + current_pose.orientation.y, + current_pose.orientation.z, current_pose.orientation.w, ] ) - + # Compute rotation error: error = target * current^(-1) error_rot = target_rot_scipy * current_rot_scipy.inv() - + # Convert to axis-angle representation for control error_axis_angle = error_rot.as_rotvec() - + # Use axis-angle directly as angular velocity error (small angle approximation) roll_error = error_axis_angle[0] - pitch_error = error_axis_angle[1] + pitch_error = error_axis_angle[1] yaw_error = error_axis_angle[2] self.last_rotation_error = Vector3(roll_error, pitch_error, yaw_error) # Apply proportional control angular_velocity = Vector3( - roll_error * self.rotation_gain, - pitch_error * self.rotation_gain, - yaw_error * self.rotation_gain, + roll_error * self.rotation_gain, + pitch_error * self.rotation_gain, + yaw_error * self.rotation_gain, ) # Limit angular velocity magnitude @@ -786,8 +843,8 @@ def create_status_overlay( cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 200, 0), - 1, - ) + 1, + ) if self.last_target_reached: cv2.putText( diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index b35cf1b0c1..a34e25a439 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -13,11 +13,155 @@ # limitations under the License. import numpy as np -from typing import Dict, Any, Optional, List +from typing import Dict, Any, Optional, List, Tuple, Union +from dataclasses import dataclass from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion +@dataclass +class ObjectMatchResult: + """Result of object matching with confidence metrics.""" + matched_object: Optional[Dict[str, Any]] + confidence: float + distance: float + size_similarity: float + is_valid_match: bool + + +def calculate_object_similarity( + target_obj: Dict[str, Any], + candidate_obj: Dict[str, Any], + distance_weight: float = 0.6, + size_weight: float = 0.4 +) -> Tuple[float, float, float]: + """ + Calculate comprehensive similarity between two objects. + + Args: + target_obj: Target object with 'position' and optionally 'size' + candidate_obj: Candidate object with 'position' and optionally 'size' + distance_weight: Weight for distance component (0-1) + size_weight: Weight for size component (0-1) + + Returns: + Tuple of (total_similarity, distance_m, size_similarity) + """ + # Extract positions + target_pos = target_obj.get("position", {}) + candidate_pos = candidate_obj.get("position", {}) + + if isinstance(target_pos, Vector3): + target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) + else: + target_xyz = np.array([target_pos.get("x", 0), target_pos.get("y", 0), target_pos.get("z", 0)]) + + if isinstance(candidate_pos, Vector3): + candidate_xyz = np.array([candidate_pos.x, candidate_pos.y, candidate_pos.z]) + else: + candidate_xyz = np.array([candidate_pos.get("x", 0), candidate_pos.get("y", 0), candidate_pos.get("z", 0)]) + + # Calculate Euclidean distance + distance = np.linalg.norm(target_xyz - candidate_xyz) + distance_similarity = 1.0 / (1.0 + distance) # Exponential decay + + # Calculate size similarity by comparing each dimension individually + size_similarity = 1.0 # Default if no size info + target_size = target_obj.get("size", {}) + candidate_size = candidate_obj.get("size", {}) + + if target_size and candidate_size: + # Extract dimensions with defaults + target_dims = [ + target_size.get("width", 0.0), + target_size.get("height", 0.0), + target_size.get("depth", 0.0) + ] + candidate_dims = [ + candidate_size.get("width", 0.0), + candidate_size.get("height", 0.0), + candidate_size.get("depth", 0.0) + ] + + # Calculate similarity for each dimension pair + dim_similarities = [] + for target_dim, candidate_dim in zip(target_dims, candidate_dims): + if target_dim == 0.0 and candidate_dim == 0.0: + dim_similarities.append(1.0) # Both dimensions are zero + elif target_dim == 0.0 or candidate_dim == 0.0: + dim_similarities.append(0.0) # One dimension is zero, other is not + else: + # Calculate similarity as min/max ratio + max_dim = max(target_dim, candidate_dim) + min_dim = min(target_dim, candidate_dim) + dim_similarity = min_dim / max_dim if max_dim > 0 else 0.0 + dim_similarities.append(dim_similarity) + + # Return average similarity across all dimensions + size_similarity = np.mean(dim_similarities) if dim_similarities else 0.0 + + # Weighted combination + total_similarity = distance_weight * distance_similarity + size_weight * size_similarity + + return total_similarity, distance, size_similarity + + +def find_best_object_match( + target_obj: Dict[str, Any], + candidates: List[Dict[str, Any]], + max_distance: float = 0.1, + min_size_similarity: float = 0.4, + distance_weight: float = 0.7, + size_weight: float = 0.3 +) -> ObjectMatchResult: + """ + Find the best matching object from candidates using distance and size criteria. + + Args: + target_obj: Target object to match against + candidates: List of candidate objects + max_distance: Maximum allowed distance for valid match (meters) + min_size_similarity: Minimum size similarity for valid match (0-1) + distance_weight: Weight for distance in similarity calculation + size_weight: Weight for size in similarity calculation + + Returns: + ObjectMatchResult with best match and confidence metrics + """ + if not candidates or not target_obj.get("position"): + return ObjectMatchResult(None, 0.0, float('inf'), 0.0, False) + + best_match = None + best_confidence = 0.0 + best_distance = float('inf') + best_size_sim = 0.0 + + for candidate in candidates: + if not candidate.get("position"): + continue + + similarity, distance, size_sim = calculate_object_similarity( + target_obj, candidate, distance_weight, size_weight + ) + + # Check validity constraints + is_valid = distance <= max_distance and size_sim >= min_size_similarity + + if is_valid and similarity > best_confidence: + best_match = candidate + best_confidence = similarity + best_distance = distance + best_size_sim = size_sim + + return ObjectMatchResult( + matched_object=best_match, + confidence=best_confidence, + distance=best_distance, + size_similarity=best_size_sim, + is_valid_match=best_match is not None + ) + + def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: """ Parse ZED pose data dictionary into a Pose object. diff --git a/dimos/perception/pointcloud/utils.py b/dimos/perception/pointcloud/utils.py index 3ee1ea3923..be65635393 100644 --- a/dimos/perception/pointcloud/utils.py +++ b/dimos/perception/pointcloud/utils.py @@ -1087,8 +1087,6 @@ def extract_centroids_from_masks( depth_image: np.ndarray, masks: List[np.ndarray], camera_intrinsics: Union[List[float], np.ndarray], - min_points: int = 10, - max_depth: float = 10.0, ) -> List[Dict[str, Any]]: """ Extract 3D centroids and orientations from segmentation masks. @@ -1098,8 +1096,6 @@ def extract_centroids_from_masks( depth_image: Depth image (H, W) in meters masks: List of boolean masks (H, W) camera_intrinsics: Camera parameters as [fx, fy, cx, cy] or 3x3 matrix - min_points: Minimum number of valid 3D points required for a detection - max_depth: Maximum valid depth in meters Returns: List of dictionaries containing: @@ -1129,20 +1125,10 @@ def extract_centroids_from_masks( # Get depth values at mask locations depths = depth_image[y_coords, x_coords] - # Filter valid depths - valid_mask = (depths > 0) & (depths < max_depth) & np.isfinite(depths) - if valid_mask.sum() < min_points: - continue - - # Get valid coordinates and depths - valid_x = x_coords[valid_mask] - valid_y = y_coords[valid_mask] - valid_z = depths[valid_mask] - # Convert to 3D points in camera frame - X = (valid_x - cx) * valid_z / fx - Y = (valid_y - cy) * valid_z / fy - Z = valid_z + X = (x_coords - cx) * depths / fx + Y = (y_coords - cy) * depths / fy + Z = depths # Calculate centroid centroid_x = np.mean(X) @@ -1158,7 +1144,7 @@ def extract_centroids_from_masks( { "centroid": centroid, "orientation": orientation, - "num_points": int(valid_mask.sum()), + "num_points": int(mask.sum()), "mask_idx": mask_idx, } ) diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index 1b81dce07b..1f7c170cb2 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -47,6 +47,7 @@ def __init__( use_tracker=True, use_analyzer=True, use_rich_labeling=False, + use_filtering=True, ): self.device = device if is_cuda_available(): @@ -62,6 +63,7 @@ def __init__( self.use_tracker = use_tracker self.use_analyzer = use_analyzer self.use_rich_labeling = use_rich_labeling + self.use_filtering = use_filtering module_dir = os.path.dirname(__file__) self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") @@ -98,7 +100,7 @@ def process_image(self, image): source=image, device=self.device, retina_masks=True, - conf=0.5, + conf=0.3, iou=0.9, persist=True, verbose=False, @@ -112,14 +114,23 @@ def process_image(self, image): ) # Filter results - ( - filtered_masks, - filtered_bboxes, - filtered_track_ids, - filtered_probs, - filtered_names, - filtered_texture_values, - ) = filter_segmentation_results(image, masks, bboxes, track_ids, probs, names, areas) + if self.use_filtering: + ( + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + filtered_texture_values, + ) = filter_segmentation_results(image, masks, bboxes, track_ids, probs, names, areas) + else: + # Use original results without filtering + filtered_masks = masks + filtered_bboxes = bboxes + filtered_track_ids = track_ids + filtered_probs = probs + filtered_names = names + filtered_texture_values = [] if self.use_tracker: # Update tracker with filtered results diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index 8ed21fabdc..424e46f4ef 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -32,7 +32,7 @@ from dimos.hardware.piper_arm import PiperArm from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor from dimos.perception.common.utils import find_clicked_object -from dimos.manipulation.visual_servoing.pbvs import PBVS +from dimos.manipulation.visual_servoing.pbvs import PBVS, GraspStage from dimos.utils.transform_utils import ( pose_to_matrix, matrix_to_pose, @@ -160,7 +160,8 @@ def main(): position_gain=0.3, rotation_gain=0.2, target_tolerance=0.05, - pregrasp_distance=0.2, + pregrasp_distance=0.25, + grasp_distance=0.01, direct_ee_control=DIRECT_EE_CONTROL, ) @@ -315,6 +316,7 @@ def main(): last_valid_target = current_target if last_valid_target: print("🤏 GRASP - Opening gripper for target object...") + pbvs.set_grasp_stage(GraspStage.GRASP) success = execute_grasp(arm, last_valid_target, grasp_width_offset=0.03) if success: print("✅ Gripper opened successfully") From 2b9a02e72682333734c38530e88b1ec1d32371e0 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Thu, 17 Jul 2025 07:47:15 +0000 Subject: [PATCH 63/89] CI code cleanup --- .../visual_servoing/detection3d.py | 9 +- dimos/manipulation/visual_servoing/pbvs.py | 107 ++++++++++-------- dimos/manipulation/visual_servoing/utils.py | 71 ++++++------ dimos/perception/segmentation/sam_2d_seg.py | 4 +- 4 files changed, 108 insertions(+), 83 deletions(-) diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index fde02d5c05..bccedea020 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -171,7 +171,14 @@ def process_frame( "depth": max(depth_m, 0.01), # Minimum 1cm depth } - if min(obj_data["size"]["width"], obj_data["size"]["height"], obj_data["size"]["depth"]) > self.max_object_size: + if ( + min( + obj_data["size"]["width"], + obj_data["size"]["height"], + obj_data["size"]["depth"], + ) + > self.max_object_size + ): continue # Extract average color from the region diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index 05c6664405..a3e6e9b4c6 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -37,6 +37,7 @@ class GraspStage(Enum): """Enum for different grasp stages.""" + PRE_GRASP = "pre_grasp" GRASP = "grasp" @@ -44,7 +45,7 @@ class GraspStage(Enum): class PBVS: """ High-level Position-Based Visual Servoing orchestrator. - + Handles: - Object tracking and target management - Pregrasp distance computation @@ -105,11 +106,11 @@ def __init__( self.grasp_distance = grasp_distance self.direct_ee_control = direct_ee_control - # Target state + # Target state self.current_target = None self.target_grasp_pose = None self.grasp_stage = GraspStage.PRE_GRASP - + # For direct control mode visualization self.last_position_error = None self.last_target_reached = False @@ -157,39 +158,37 @@ def get_current_target(self): Current target ObjectData or None if no target selected """ return self.current_target - + def set_grasp_stage(self, stage: GraspStage): """ Set the grasp stage. - + Args: stage: The new grasp stage """ self.grasp_stage = stage - - def is_target_reached(self, ee_pose: Pose) -> bool: """ Check if the current target stage has been reached. - + Args: ee_pose: Current end-effector pose - + Returns: True if current stage target is reached, False otherwise """ if not self.target_grasp_pose: return False - + # Calculate position error error_x = self.target_grasp_pose.position.x - ee_pose.position.x error_y = self.target_grasp_pose.position.y - ee_pose.position.y error_z = self.target_grasp_pose.position.z - ee_pose.position.z - + error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) stage_reached = error_magnitude < self.target_tolerance - + # Handle stage transitions if stage_reached and self.grasp_stage == GraspStage.PRE_GRASP: return True # Signal that pre-grasp target was reached @@ -198,7 +197,7 @@ def is_target_reached(self, ee_pose: Pose) -> bool: logger.info("Grasp position reached, clearing target") self.clear_target() return True - + return False def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: @@ -221,26 +220,30 @@ def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: # Use stage-dependent distance threshold max_distance = self.max_tracking_distance_threshold - + # Find best match using standardized utility function match_result = find_best_object_match( target_obj=self.current_target, candidates=new_detections, max_distance=max_distance, - min_size_similarity=self.min_size_similarity + min_size_similarity=self.min_size_similarity, ) if match_result.is_valid_match: self.current_target = match_result.matched_object self.target_grasp_pose = None # Recompute grasp pose - logger.debug(f"Target tracking successful: distance={match_result.distance:.3f}m, " - f"size_similarity={match_result.size_similarity:.2f}, " - f"confidence={match_result.confidence:.2f}") + logger.debug( + f"Target tracking successful: distance={match_result.distance:.3f}m, " + f"size_similarity={match_result.size_similarity:.2f}, " + f"confidence={match_result.confidence:.2f}" + ) return True - - logger.debug(f"Target tracking lost: distance={match_result.distance:.3f}m, " - f"size_similarity={match_result.size_similarity:.2f}, " - f"thresholds: distance={max_distance:.3f}m, size={self.min_size_similarity:.2f}") + + logger.debug( + f"Target tracking lost: distance={match_result.distance:.3f}m, " + f"size_similarity={match_result.size_similarity:.2f}, " + f"thresholds: distance={max_distance:.3f}m, size={self.min_size_similarity:.2f}" + ) return False def _update_target_grasp_pose(self, ee_pose: Pose): @@ -269,7 +272,11 @@ def _update_target_grasp_pose(self, ee_pose: Pose): target_pose = Pose(target_pos, target_orientation) # Apply grasp distance - distance = self.pregrasp_distance if self.grasp_stage == GraspStage.PRE_GRASP else self.grasp_distance + distance = ( + self.pregrasp_distance + if self.grasp_stage == GraspStage.PRE_GRASP + else self.grasp_distance + ) self.target_grasp_pose = self._apply_grasp_distance(target_pose, ee_pose, distance) def _apply_grasp_distance(self, target_pose: Pose, ee_pose: Pose, distance: float) -> Pose: @@ -357,14 +364,18 @@ def compute_control( self.target_grasp_pose.position.y - ee_pose.position.y, self.target_grasp_pose.position.z - ee_pose.position.z, ) - + # Check if target reached using our separate function target_reached = self.is_target_reached(ee_pose) - + # If stage transitioned, recompute target grasp pose - if target_reached and self.grasp_stage == GraspStage.GRASP and self.target_grasp_pose is None: + if ( + target_reached + and self.grasp_stage == GraspStage.GRASP + and self.target_grasp_pose is None + ): self._update_target_grasp_pose(ee_pose) - + # Return appropriate values based on control mode if self.direct_ee_control: # Direct control mode @@ -531,7 +542,7 @@ def _create_direct_status_overlay( + (grasp_pos.y - target_pos.y) ** 2 + (grasp_pos.z - target_pos.z) ** 2 ) - + # Show current stage and distance stage_text = f"Stage: {self.grasp_stage.value}" cv2.putText( @@ -543,7 +554,7 @@ def _create_direct_status_overlay( (255, 150, 255), 1, ) - + distance_text = f"Distance: {distance * 1000:.1f}mm" cv2.putText( viz_img, @@ -651,9 +662,9 @@ def compute_control( # Compute velocity command with proportional control velocity_cmd = Vector3( - error.x * self.position_gain, - error.y * self.position_gain, - error.z * self.position_gain, + error.x * self.position_gain, + error.y * self.position_gain, + error.z * self.position_gain, ) # Limit velocity magnitude @@ -661,9 +672,9 @@ def compute_control( if vel_magnitude > self.max_velocity: scale = self.max_velocity / vel_magnitude velocity_cmd = Vector3( - float(velocity_cmd.x * scale), - float(velocity_cmd.y * scale), - float(velocity_cmd.z * scale), + float(velocity_cmd.x * scale), + float(velocity_cmd.y * scale), + float(velocity_cmd.z * scale), ) self.last_velocity_cmd = velocity_cmd @@ -691,36 +702,36 @@ def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) Angular velocity command as Vector3 """ # Use quaternion error for better numerical stability - + # Convert to scipy Rotation objects target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w]) current_rot_scipy = R.from_quat( [ - current_pose.orientation.x, - current_pose.orientation.y, - current_pose.orientation.z, + current_pose.orientation.x, + current_pose.orientation.y, + current_pose.orientation.z, current_pose.orientation.w, ] ) - + # Compute rotation error: error = target * current^(-1) error_rot = target_rot_scipy * current_rot_scipy.inv() - + # Convert to axis-angle representation for control error_axis_angle = error_rot.as_rotvec() - + # Use axis-angle directly as angular velocity error (small angle approximation) roll_error = error_axis_angle[0] - pitch_error = error_axis_angle[1] + pitch_error = error_axis_angle[1] yaw_error = error_axis_angle[2] self.last_rotation_error = Vector3(roll_error, pitch_error, yaw_error) # Apply proportional control angular_velocity = Vector3( - roll_error * self.rotation_gain, - pitch_error * self.rotation_gain, - yaw_error * self.rotation_gain, + roll_error * self.rotation_gain, + pitch_error * self.rotation_gain, + yaw_error * self.rotation_gain, ) # Limit angular velocity magnitude @@ -843,8 +854,8 @@ def create_status_overlay( cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 200, 0), - 1, - ) + 1, + ) if self.last_target_reached: cv2.putText( diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index a34e25a439..0000bda999 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -22,6 +22,7 @@ @dataclass class ObjectMatchResult: """Result of object matching with confidence metrics.""" + matched_object: Optional[Dict[str, Any]] confidence: float distance: float @@ -30,59 +31,63 @@ class ObjectMatchResult: def calculate_object_similarity( - target_obj: Dict[str, Any], + target_obj: Dict[str, Any], candidate_obj: Dict[str, Any], distance_weight: float = 0.6, - size_weight: float = 0.4 + size_weight: float = 0.4, ) -> Tuple[float, float, float]: """ Calculate comprehensive similarity between two objects. - + Args: target_obj: Target object with 'position' and optionally 'size' candidate_obj: Candidate object with 'position' and optionally 'size' distance_weight: Weight for distance component (0-1) size_weight: Weight for size component (0-1) - + Returns: Tuple of (total_similarity, distance_m, size_similarity) """ # Extract positions target_pos = target_obj.get("position", {}) candidate_pos = candidate_obj.get("position", {}) - + if isinstance(target_pos, Vector3): target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) else: - target_xyz = np.array([target_pos.get("x", 0), target_pos.get("y", 0), target_pos.get("z", 0)]) - + target_xyz = np.array( + [target_pos.get("x", 0), target_pos.get("y", 0), target_pos.get("z", 0)] + ) + if isinstance(candidate_pos, Vector3): candidate_xyz = np.array([candidate_pos.x, candidate_pos.y, candidate_pos.z]) else: - candidate_xyz = np.array([candidate_pos.get("x", 0), candidate_pos.get("y", 0), candidate_pos.get("z", 0)]) - + candidate_xyz = np.array( + [candidate_pos.get("x", 0), candidate_pos.get("y", 0), candidate_pos.get("z", 0)] + ) + # Calculate Euclidean distance distance = np.linalg.norm(target_xyz - candidate_xyz) distance_similarity = 1.0 / (1.0 + distance) # Exponential decay - + # Calculate size similarity by comparing each dimension individually size_similarity = 1.0 # Default if no size info target_size = target_obj.get("size", {}) candidate_size = candidate_obj.get("size", {}) - + if target_size and candidate_size: # Extract dimensions with defaults target_dims = [ - target_size.get("width", 0.0), - target_size.get("height", 0.0), - target_size.get("depth", 0.0) + target_size.get("width", 0.0), + target_size.get("height", 0.0), + target_size.get("depth", 0.0), ] candidate_dims = [ - candidate_size.get("width", 0.0), - candidate_size.get("height", 0.0), - candidate_size.get("depth", 0.0) + candidate_size.get("width", 0.0), + candidate_size.get("height", 0.0), + candidate_size.get("depth", 0.0), ] - + # Calculate similarity for each dimension pair dim_similarities = [] for target_dim, candidate_dim in zip(target_dims, candidate_dims): @@ -96,13 +101,13 @@ def calculate_object_similarity( min_dim = min(target_dim, candidate_dim) dim_similarity = min_dim / max_dim if max_dim > 0 else 0.0 dim_similarities.append(dim_similarity) - + # Return average similarity across all dimensions size_similarity = np.mean(dim_similarities) if dim_similarities else 0.0 - + # Weighted combination total_similarity = distance_weight * distance_similarity + size_weight * size_similarity - + return total_similarity, distance, size_similarity @@ -112,11 +117,11 @@ def find_best_object_match( max_distance: float = 0.1, min_size_similarity: float = 0.4, distance_weight: float = 0.7, - size_weight: float = 0.3 + size_weight: float = 0.3, ) -> ObjectMatchResult: """ Find the best matching object from candidates using distance and size criteria. - + Args: target_obj: Target object to match against candidates: List of candidate objects @@ -124,41 +129,41 @@ def find_best_object_match( min_size_similarity: Minimum size similarity for valid match (0-1) distance_weight: Weight for distance in similarity calculation size_weight: Weight for size in similarity calculation - + Returns: ObjectMatchResult with best match and confidence metrics """ if not candidates or not target_obj.get("position"): - return ObjectMatchResult(None, 0.0, float('inf'), 0.0, False) - + return ObjectMatchResult(None, 0.0, float("inf"), 0.0, False) + best_match = None best_confidence = 0.0 - best_distance = float('inf') + best_distance = float("inf") best_size_sim = 0.0 - + for candidate in candidates: if not candidate.get("position"): continue - + similarity, distance, size_sim = calculate_object_similarity( target_obj, candidate, distance_weight, size_weight ) - + # Check validity constraints is_valid = distance <= max_distance and size_sim >= min_size_similarity - + if is_valid and similarity > best_confidence: best_match = candidate best_confidence = similarity best_distance = distance best_size_sim = size_sim - + return ObjectMatchResult( matched_object=best_match, confidence=best_confidence, distance=best_distance, size_similarity=best_size_sim, - is_valid_match=best_match is not None + is_valid_match=best_match is not None, ) diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index 1f7c170cb2..972c67e905 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -122,7 +122,9 @@ def process_image(self, image): filtered_probs, filtered_names, filtered_texture_values, - ) = filter_segmentation_results(image, masks, bboxes, track_ids, probs, names, areas) + ) = filter_segmentation_results( + image, masks, bboxes, track_ids, probs, names, areas + ) else: # Use original results without filtering filtered_masks = masks From eb733fc48ef9e300641444dfe955310b5334bcac Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 17 Jul 2025 14:46:01 -0700 Subject: [PATCH 64/89] works really well, keeping a checkpoint --- .../visual_servoing/detection3d.py | 2 +- dimos/manipulation/visual_servoing/pbvs.py | 80 +++++++++++-------- dimos/perception/segmentation/sam_2d_seg.py | 3 +- tests/test_ibvs.py | 73 +++++++++++------ 4 files changed, 98 insertions(+), 60 deletions(-) diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index bccedea020..68d90f2c2b 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -74,7 +74,7 @@ def __init__( self.detector = Sam2DSegmenter( use_tracker=False, use_analyzer=False, - use_filtering=False, + use_filtering=True, device="cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu", ) diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index a3e6e9b4c6..0ce25519ff 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -28,6 +28,7 @@ from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( yaw_towards_point, + pose_to_matrix, euler_to_quaternion, ) from dimos.manipulation.visual_servoing.utils import find_best_object_match @@ -105,6 +106,7 @@ def __init__( self.pregrasp_distance = pregrasp_distance self.grasp_distance = grasp_distance self.direct_ee_control = direct_ee_control + self.grasp_pitch_degrees = 45.0 # Default grasp pitch in degrees (45° between level and top-down) # Target state self.current_target = None @@ -168,6 +170,21 @@ def set_grasp_stage(self, stage: GraspStage): """ self.grasp_stage = stage + def set_grasp_pitch(self, pitch_degrees: float): + """ + Set the grasp pitch angle in degrees. + + Args: + pitch_degrees: Grasp pitch angle in degrees (0-90) + 0° = level grasp (horizontal) + 90° = top-down grasp (vertical) + """ + # Clamp to valid range + pitch_degrees = max(0.0, min(90.0, pitch_degrees)) + self.grasp_pitch_degrees = pitch_degrees + # Reset target grasp pose to recompute with new pitch + self.target_grasp_pose = None + def is_target_reached(self, ee_pose: Pose) -> bool: """ Check if the current target stage has been reached. @@ -265,8 +282,12 @@ def _update_target_grasp_pose(self, ee_pose: Pose): ) # Create target pose with proper orientation + # Convert grasp pitch from degrees to radians with mapping: + # 0° (level) -> π/2 (1.57 rad), 90° (top-down) -> π (3.14 rad) + pitch_radians = 1.57 + (self.grasp_pitch_degrees * np.pi / 180.0 / 2.0) + # Convert euler angles to quaternion using utility function - euler = Vector3(0.0, 1.57, yaw_to_ee) # roll=0, pitch=90deg, yaw=calculated + euler = Vector3(0.0, pitch_radians, yaw_to_ee) # roll=0, pitch=mapped, yaw=calculated target_orientation = euler_to_quaternion(euler) target_pose = Pose(target_pos, target_orientation) @@ -277,44 +298,39 @@ def _update_target_grasp_pose(self, ee_pose: Pose): if self.grasp_stage == GraspStage.PRE_GRASP else self.grasp_distance ) - self.target_grasp_pose = self._apply_grasp_distance(target_pose, ee_pose, distance) + self.target_grasp_pose = self._apply_grasp_distance(target_pose, distance) - def _apply_grasp_distance(self, target_pose: Pose, ee_pose: Pose, distance: float) -> Pose: + def _apply_grasp_distance(self, target_pose: Pose, distance: float) -> Pose: """ - Apply appropriate grasp distance to target pose based on current stage. + Apply grasp distance offset to target pose along its approach direction. Args: - target_pose: Target pose - ee_pose: Current end-effector pose + target_pose: Target grasp pose + distance: Distance to offset along the approach direction (meters) Returns: - Modified target pose with appropriate distance applied - """ - # Get approach vector (from target position towards EE) - target_pos = np.array( - [target_pose.position.x, target_pose.position.y, target_pose.position.z] + Target pose offset by the specified distance along its approach direction + """ + # Convert pose to transformation matrix to extract rotation + T_target = pose_to_matrix(target_pose) + rotation_matrix = T_target[:3, :3] + + # Define the approach vector based on the target pose orientation + # Assuming the gripper approaches along its local -z axis (common for downward grasps) + # You can change this to [1, 0, 0] for x-axis or [0, 1, 0] for y-axis based on your gripper + approach_vector_local = np.array([0, 0, -1]) + + # Transform approach vector to world coordinates + approach_vector_world = rotation_matrix @ approach_vector_local + + # Apply offset along the approach direction + offset_position = Vector3( + target_pose.position.x + distance * approach_vector_world[0], + target_pose.position.y + distance * approach_vector_world[1], + target_pose.position.z + distance * approach_vector_world[2], ) - ee_pos = np.array([ee_pose.position.x, ee_pose.position.y, ee_pose.position.z]) - approach_vector = ee_pos - target_pos # Vector pointing towards EE - - # Normalize approach vector - approach_magnitude = np.linalg.norm(approach_vector) - if approach_magnitude > 1e-6: # Avoid division by zero - norm_approach_vector = approach_vector / approach_magnitude - else: - norm_approach_vector = np.array([0.0, 0.0, 0.0]) - - # Move back by appropriate distance towards EE based on stage - offset_vector = distance * norm_approach_vector - - # Apply offset to target position - new_position = Vector3( - target_pose.position.x + offset_vector[0], - target_pose.position.y + offset_vector[1], - target_pose.position.z + offset_vector[2], - ) - - return Pose(new_position, target_pose.orientation) + + return Pose(offset_position, target_pose.orientation) def compute_control( self, ee_pose: Pose, new_detections: Optional[List[ObjectData]] = None diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index 972c67e905..462342872b 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -101,10 +101,9 @@ def process_image(self, image): device=self.device, retina_masks=True, conf=0.3, - iou=0.9, + iou=0.5, persist=True, verbose=False, - tracker=self.tracker_config, ) if len(results) > 0: diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index 424e46f4ef..b5bfe6afc5 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -25,6 +25,7 @@ import numpy as np import sys import os +import time import tests.test_header @@ -59,7 +60,7 @@ def mouse_callback(event, x, y, flags, param): mouse_click = (x, y) -def execute_grasp(arm, target_object, grasp_width_offset: float = 0.02) -> bool: +def execute_grasp(arm, target_object, target_pose, grasp_width_offset: float = 0.02) -> bool: """ Execute grasping by opening gripper to accommodate target object. @@ -102,6 +103,8 @@ def execute_grasp(arm, target_object, grasp_width_offset: float = 0.02) -> bool: # Command gripper to open arm.cmd_gripper_ctrl(gripper_opening) + arm.cmd_ee_pose(target_pose, line_mode=True) + return True @@ -120,7 +123,10 @@ def main(): print(" 's' - SOFT STOP (emergency stop)") print(" 'h' - GO HOME (return to safe position)") print(" 'SPACE' - EXECUTE target pose (only moves when pressed)") - print(" 'g' - EXECUTE GRASP (open gripper for target object)") + print(" 'g' - RELEASE GRIPPER (open gripper to 100mm)") + print("GRASP PITCH CONTROLS:") + print(" '↑' - Increase grasp pitch by 15° (towards top-down)") + print(" '↓' - Decrease grasp pitch by 15° (towards level)") # Initialize hardware zed = ZEDCamera(resolution=sl.RESOLUTION.HD720, depth_mode=sl.DEPTH_MODE.NEURAL) @@ -164,6 +170,10 @@ def main(): grasp_distance=0.01, direct_ee_control=DIRECT_EE_CONTROL, ) + + # Set custom grasp pitch (60 degrees - between level and top-down) + GRASP_PITCH_DEGREES = 0 # 0° = level grasp, 90° = top-down grasp + pbvs.set_grasp_pitch(GRASP_PITCH_DEGREES) # Setup window cv2.namedWindow("PBVS") @@ -172,6 +182,10 @@ def main(): # Control state for direct EE mode execute_target = False # Only move when space is pressed last_valid_target = None + + # Rate limiting for pose execution + MIN_EXECUTION_PERIOD = 1.0 # Minimum seconds between pose executions + last_execution_time = 0 try: while True: @@ -210,14 +224,22 @@ def main(): ) # Apply commands to robot based on control mode - if DIRECT_EE_CONTROL and target_pose and execute_target: - # Direct EE pose control - only when space is pressed - print( - f"🎯 EXECUTING target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f})" - ) - last_valid_target = pbvs.get_current_target() - arm.cmd_ee_pose(target_pose) - execute_target = False # Reset flag after execution + if DIRECT_EE_CONTROL and target_pose: + # Check if enough time has passed since last execution + current_time = time.time() + if current_time - last_execution_time >= MIN_EXECUTION_PERIOD: + # Direct EE pose control + print( + f"🎯 EXECUTING target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f})" + ) + last_valid_target = pbvs.get_current_target() + if pbvs.grasp_stage == GraspStage.PRE_GRASP: + arm.cmd_ee_pose(target_pose) + last_execution_time = current_time + elif pbvs.grasp_stage == GraspStage.GRASP and execute_target: + execute_grasp(arm, last_valid_target, target_pose, grasp_width_offset=0.03) + last_execution_time = current_time + execute_target = False # Reset flag after execution elif not DIRECT_EE_CONTROL and vel_cmd and ang_vel_cmd: # Velocity control arm.cmd_vel_ee( @@ -274,7 +296,7 @@ def main(): cv2.putText( viz_bgr, - "s=STOP | h=HOME | SPACE=EXECUTE | g=GRASP", + "s=STOP | h=HOME | SPACE=EXECUTE | g=RELEASE", (10, 110), cv2.FONT_HERSHEY_SIMPLEX, 0.4, @@ -304,26 +326,27 @@ def main(): if DIRECT_EE_CONTROL and target_pose: execute_target = True target_euler = quaternion_to_euler(target_pose.orientation, degrees=True) + if pbvs.grasp_stage == GraspStage.PRE_GRASP: + pbvs.set_grasp_stage(GraspStage.GRASP) print("⚡ SPACE pressed - Target will execute on next frame!") print( f"📍 Target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f}) " f"rot=({target_euler.x:.1f}°, {target_euler.y:.1f}°, {target_euler.z:.1f}°)" ) + elif key == 82: # Up arrow key (increase pitch) + current_pitch = pbvs.grasp_pitch_degrees + new_pitch = min(90.0, current_pitch + 15.0) + pbvs.set_grasp_pitch(new_pitch) + print(f"↑ Grasp pitch increased to {new_pitch:.0f}° (0°=level, 90°=top-down)") + elif key == 84: # Down arrow key (decrease pitch) + current_pitch = pbvs.grasp_pitch_degrees + new_pitch = max(0.0, current_pitch - 15.0) + pbvs.set_grasp_pitch(new_pitch) + print(f"↓ Grasp pitch decreased to {new_pitch:.0f}° (0°=level, 90°=top-down)") elif key == ord("g"): - # G - Execute grasp (open gripper for target object) - current_target = pbvs.get_current_target() - if current_target: - last_valid_target = current_target - if last_valid_target: - print("🤏 GRASP - Opening gripper for target object...") - pbvs.set_grasp_stage(GraspStage.GRASP) - success = execute_grasp(arm, last_valid_target, grasp_width_offset=0.03) - if success: - print("✅ Gripper opened successfully") - else: - print("❌ Failed to execute grasp") - else: - print("❌ No target selected for grasping") + # G - Release gripper (open to 100mm) + print("🖐️ RELEASE - Opening gripper to 100mm...") + arm.release_gripper() except KeyboardInterrupt: pass From cb9fc5f9d6bee7bd86b4aa3d84f75e2947aa07a7 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 17 Jul 2025 22:16:08 -0700 Subject: [PATCH 65/89] changed all visual servoing to use LCM types --- dimos/hardware/piper_arm.py | 2 +- dimos/manipulation/ibvs/pbvs.py | 679 ------------------ .../visual_servoing/detection3d.py | 290 ++++---- dimos/manipulation/visual_servoing/pbvs.py | 316 ++------ dimos/manipulation/visual_servoing/utils.py | 503 +++++++++++-- tests/test_ibvs.py | 144 ++-- 6 files changed, 706 insertions(+), 1228 deletions(-) delete mode 100644 dimos/manipulation/ibvs/pbvs.py diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 50c97b7abf..ee528792d1 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -82,7 +82,7 @@ def softStop(self): time.sleep(1) self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) self.arm.MotionCtrl_1(0x01, 0, 0) - time.sleep(5) + time.sleep(3) def cmd_ee_pose_values(self, x, y, z, r, p, y_): """Command end-effector to target pose in space (position + Euler angles)""" diff --git a/dimos/manipulation/ibvs/pbvs.py b/dimos/manipulation/ibvs/pbvs.py deleted file mode 100644 index c34d84d86b..0000000000 --- a/dimos/manipulation/ibvs/pbvs.py +++ /dev/null @@ -1,679 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Position-Based Visual Servoing (PBVS) controller for eye-in-hand configuration. -Works with manipulator frame origin and proper robot arm conventions. -""" - -import numpy as np -from typing import Optional, Tuple, Dict, Any, List -import cv2 - -from scipy.spatial.transform import Rotation as R -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import ( - pose_to_matrix, - apply_transform, - optical_to_robot_frame, - yaw_towards_point, -) - -logger = setup_logger("dimos.manipulation.pbvs") - - -class PBVSController: - """ - Position-Based Visual Servoing controller for eye-in-hand cameras. - Supports manipulator frame origin and robot arm conventions. - - Handles: - - Position and orientation error computation - - Velocity command generation with gain control - - Automatic target tracking across frames - - Frame transformations from ZED to robot conventions - - Pregrasp distance functionality - - 6DOF EE to camera transform handling - """ - - def __init__( - self, - position_gain: float = 0.5, - rotation_gain: float = 0.3, - max_velocity: float = 0.1, # m/s - max_angular_velocity: float = 0.5, # rad/s - target_tolerance: float = 0.01, # 5cm - tracking_distance_threshold: float = 0.05, # 5cm for target tracking - pregrasp_distance: float = 0.15, # 15cm pregrasp distance - ee_to_camera_transform: Vector = Vector( - [0.0, 0.0, -0.06, 0.0, -1.57, 0.0] - ), # 6DOF: [x,y,z,rx,ry,rz] - ): - """ - Initialize PBVS controller. - - Args: - position_gain: Proportional gain for position control - rotation_gain: Proportional gain for rotation control - max_velocity: Maximum linear velocity command magnitude (m/s) - max_angular_velocity: Maximum angular velocity command magnitude (rad/s) - target_tolerance: Distance threshold for considering target reached (m) - tracking_distance_threshold: Max distance for target association (m) - pregrasp_distance: Distance to maintain before grasping (m) - ee_to_camera_transform: 6DOF transform from EE to camera [x,y,z,rx,ry,rz] - """ - self.position_gain = position_gain - self.rotation_gain = rotation_gain - self.max_velocity = max_velocity - self.max_angular_velocity = max_angular_velocity - self.target_tolerance = target_tolerance - self.tracking_distance_threshold = tracking_distance_threshold - self.pregrasp_distance = pregrasp_distance - self.ee_to_camera_transform_vec = ee_to_camera_transform - - # State variables - self.current_target = None - self.last_position_error = None - self.last_rotation_error = None - self.last_velocity_cmd = None - self.last_angular_velocity_cmd = None - self.last_target_reached = False - - # Manipulator frame origin - self.manipulator_origin = None # Transform matrix from world to manipulator frame - self.manipulator_origin_pose = None # Original pose for reference - - # Create 6DOF EE to camera transform matrix - self.ee_to_camera_transform = self._create_ee_to_camera_transform() - - logger.info( - f"Initialized PBVS controller: pos_gain={position_gain}, rot_gain={rotation_gain}, " - f"max_vel={max_velocity}m/s, max_ang_vel={max_angular_velocity}rad/s, " - f"target_tolerance={target_tolerance}m, pregrasp_distance={pregrasp_distance}m, " - f"ee_to_camera_transform={ee_to_camera_transform.to_list()}" - ) - - def _create_ee_to_camera_transform(self) -> np.ndarray: - """ - Create 6DOF transform matrix from EE to camera frame. - - Returns: - 4x4 transformation matrix from EE to camera - """ - # Extract position and rotation from 6DOF vector - pos = self.ee_to_camera_transform_vec.to_list()[:3] - rot = self.ee_to_camera_transform_vec.to_list()[3:6] # euler angles: [rx, ry, rz] - - # Create transformation matrix - T_ee_to_cam = np.eye(4) - T_ee_to_cam[0:3, 3] = pos - - # Apply rotation using scipy (treating as euler angles) - if np.linalg.norm(rot) > 1e-6: - rotation = R.from_euler("xyz", rot) - T_ee_to_cam[0:3, 0:3] = rotation.as_matrix() - - return T_ee_to_cam - - def set_manipulator_origin(self, camera_pose: Pose): - """ - Set the manipulator frame origin based on current camera pose. - This establishes the robot arm coordinate frame. - - Args: - camera_pose: Current camera pose in world frame - """ - self.manipulator_origin_pose = camera_pose - - # Create transform matrix from ZED world to manipulator origin - # This is the inverse of the camera pose at origin - T_world_to_origin = pose_to_matrix(camera_pose) - self.manipulator_origin = np.linalg.inv(T_world_to_origin) - - logger.info( - f"Set manipulator origin at pose: pos=({camera_pose.position.x:.3f}, " - f"{camera_pose.position.y:.3f}, {camera_pose.position.z:.3f})" - ) - - def _apply_pregrasp_distance(self, target_pose: Pose) -> Pose: - """ - Apply pregrasp distance to target pose by moving back towards robot origin. - - Args: - target_pose: Target pose in robot frame - - Returns: - Modified target pose with pregrasp distance applied - """ - # Get approach vector (from target position towards robot origin) - target_pos = np.array( - [target_pose.position.x, target_pose.position.y, target_pose.position.z] - ) - robot_origin = np.array([0.0, 0.0, 0.0]) # Robot origin in robot frame - approach_vector = robot_origin - target_pos # Vector pointing towards robot - - # Normalize approach vector - approach_magnitude = np.linalg.norm(approach_vector) - if approach_magnitude > 1e-6: # Avoid division by zero - norm_approach_vector = approach_vector / approach_magnitude - else: - norm_approach_vector = np.array([0.0, 0.0, 0.0]) - - # Move back by pregrasp distance towards robot - offset_vector = self.pregrasp_distance * norm_approach_vector - - # Apply offset to target position - new_position = Vector3( - target_pose.position.x + offset_vector[0], - target_pose.position.y + offset_vector[1], - target_pose.position.z + offset_vector[2], - ) - - return Pose(new_position, target_pose.orientation) - - def _update_target_robot_frame(self): - """Update current target with robot frame coordinates.""" - if not self.current_target or "position" not in self.current_target: - return - - # Get target position in ZED world frame - target_pos = self.current_target["position"] - target_pose_zed = Pose(target_pos, Quaternion()) # Identity quaternion - - # Transform to manipulator frame - target_pose_manip = apply_transform(target_pose_zed, self.manipulator_origin) - - # Calculate orientation pointing at origin (in robot frame) - yaw_to_origin = yaw_towards_point( - Vector( - target_pose_manip.position.x, - target_pose_manip.position.y, - target_pose_manip.position.z, - ) - ) - - # Create target pose with proper orientation - # Convert euler angles to quaternion using scipy - euler = [0.0, 1.57, yaw_to_origin] # roll=0, pitch=90deg, yaw=calculated - quat = R.from_euler("xyz", euler).as_quat() # [x, y, z, w] - target_orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) - - target_pose_robot = Pose(target_pose_manip.position, target_orientation) - - # Apply pregrasp distance - target_pose_pregrasp = self._apply_pregrasp_distance(target_pose_robot) - - # Update target with robot frame pose - self.current_target["robot_position"] = target_pose_pregrasp.position - self.current_target["robot_rotation"] = target_pose_pregrasp.orientation - - def set_target(self, target_object: Dict[str, Any]) -> bool: - """ - Set a new target object for servoing. - Requires manipulator origin to be set. - - Args: - target_object: Object dict with at least 'position' field - - Returns: - True if target was set successfully, False if no origin set - """ - # Require origin to be set - if self.manipulator_origin is None: - logger.warning("Cannot set target: No manipulator origin set") - return False - - if target_object and "position" in target_object: - self.current_target = target_object - - # Update to robot frame - self._update_target_robot_frame() - - logger.info(f"New target set: ID {target_object.get('object_id', 'unknown')}") - return True - return False - - def clear_target(self): - """Clear the current target.""" - self.current_target = None - self.last_position_error = None - self.last_rotation_error = None - self.last_velocity_cmd = None - self.last_angular_velocity_cmd = None - self.last_target_reached = False - logger.info("Target cleared") - - def update_target_tracking(self, new_detections: List[Dict[str, Any]]) -> bool: - """ - Update target by matching to closest object in new detections. - - Args: - new_detections: List of newly detected objects - - Returns: - True if target was successfully tracked, False if lost - """ - if not self.current_target or "position" not in self.current_target: - return False - - if not new_detections: - logger.debug("No detections for target tracking") - return False - - # Get current target position (in ZED world frame for matching) - target_pos = self.current_target["position"] - if isinstance(target_pos, (Vector, Vector3)): - target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) - else: - target_xyz = np.array([target_pos["x"], target_pos["y"], target_pos["z"]]) - - # Find closest match - min_distance = float("inf") - best_match = None - - for detection in new_detections: - if "position" not in detection: - continue - - det_pos = detection["position"] - if isinstance(det_pos, (Vector, Vector3)): - det_xyz = np.array([det_pos.x, det_pos.y, det_pos.z]) - else: - det_xyz = np.array([det_pos["x"], det_pos["y"], det_pos["z"]]) - - distance = np.linalg.norm(target_xyz - det_xyz) - - if distance < min_distance and distance < self.tracking_distance_threshold: - min_distance = distance - best_match = detection - - if best_match: - self.current_target = best_match - # Update to robot frame - self._update_target_robot_frame() - return True - return False - - def _get_ee_pose_from_camera(self, camera_pose: Pose) -> Pose: - """ - Get end-effector pose from camera pose using 6DOF EE to camera transform. - - Args: - camera_pose: Current camera pose in robot frame - - Returns: - End-effector pose in robot frame - """ - # Transform camera pose to EE frame - camera_transform = pose_to_matrix(camera_pose) - ee_transform = camera_transform @ np.linalg.inv(self.ee_to_camera_transform) - - # Extract position and rotation - ee_pos = Vector3(ee_transform[0:3, 3]) - ee_rot_matrix = ee_transform[0:3, 0:3] - - # Convert rotation matrix to quaternion - - # Ensure the rotation matrix is valid (orthogonal with det=1) - try: - rotation = R.from_matrix(ee_rot_matrix) - quat = rotation.as_quat() # [x, y, z, w] - ee_orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) - except ValueError as e: - logger.warning(f"Invalid rotation matrix in EE pose calculation: {e}") - # Fallback to identity quaternion - ee_orientation = Quaternion(0.0, 0.0, 0.0, 1.0) - - return Pose(ee_pos, ee_orientation) - - def compute_control( - self, camera_pose: Pose, new_detections: Optional[List[Dict[str, Any]]] = None - ) -> Tuple[Optional[Vector], Optional[Vector], bool, bool]: - """ - Compute PBVS control with position and orientation servoing. - - Args: - camera_pose: Current camera pose in ZED world frame - new_detections: Optional new detections for target tracking - - Returns: - Tuple of (velocity_command, angular_velocity_command, target_reached, has_target) - - velocity_command: Linear velocity vector or None if no target - - angular_velocity_command: Angular velocity vector or None if no target - - target_reached: True if within target tolerance - - has_target: True if currently tracking a target - """ - # Check if we have a target and origin - if not self.current_target or "position" not in self.current_target: - return None, None, False, False - - if self.manipulator_origin is None: - logger.warning("Cannot compute control: No manipulator origin set") - return None, None, False, False - - # Try to update target tracking if new detections provided - if new_detections is not None: - self.update_target_tracking(new_detections) - - # Transform camera pose to robot frame - camera_pose_robot = apply_transform(camera_pose, self.manipulator_origin) - - # Get EE pose from camera pose - ee_pose_robot = self._get_ee_pose_from_camera(camera_pose_robot) - - # Get target in robot frame - target_pos = self.current_target.get("robot_position") - target_rot = self.current_target.get("robot_rotation") - - if target_pos is None or target_rot is None: - logger.warning("Target position or rotation not available") - return None, None, False, False - - # Calculate position error (target - EE position) - error = Vector3( - target_pos.x - ee_pose_robot.position.x, - target_pos.y - ee_pose_robot.position.y, - target_pos.z - ee_pose_robot.position.z, - ) - self.last_position_error = error - - # Compute velocity command with proportional control - velocity_cmd = Vector( - [ - error.x * self.position_gain, - error.y * self.position_gain, - error.z * self.position_gain, - ] - ) - - # Limit velocity magnitude - vel_magnitude = np.linalg.norm([velocity_cmd.x, velocity_cmd.y, velocity_cmd.z]) - if vel_magnitude > self.max_velocity: - scale = self.max_velocity / vel_magnitude - velocity_cmd = Vector( - [ - float(velocity_cmd.x * scale), - float(velocity_cmd.y * scale), - float(velocity_cmd.z * scale), - ] - ) - - self.last_velocity_cmd = velocity_cmd - - # Compute angular velocity for orientation control - angular_velocity_cmd = self._compute_angular_velocity(target_rot, ee_pose_robot) - - # Check if target reached - error_magnitude = np.linalg.norm([error.x, error.y, error.z]) - target_reached = bool(error_magnitude < self.target_tolerance) - self.last_target_reached = target_reached - - # Clear target only if it's reached - if target_reached: - logger.info( - f"Target reached! Clearing target ID {self.current_target.get('object_id', 'unknown')}" - ) - self.clear_target() - - return velocity_cmd, angular_velocity_cmd, target_reached, True - - def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) -> Vector: - """ - Compute angular velocity commands for orientation control. - Uses quaternion error computation for better numerical stability. - - Args: - target_rot: Target orientation (quaternion) - current_pose: Current EE pose - - Returns: - Angular velocity command as Vector - """ - # Use quaternion error for better numerical stability - - # Convert to scipy Rotation objects - target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w]) - current_rot_scipy = R.from_quat( - [ - current_pose.orientation.x, - current_pose.orientation.y, - current_pose.orientation.z, - current_pose.orientation.w, - ] - ) - - # Compute rotation error: error = target * current^(-1) - error_rot = target_rot_scipy * current_rot_scipy.inv() - - # Convert to axis-angle representation for control - error_axis_angle = error_rot.as_rotvec() - - # Use axis-angle directly as angular velocity error (small angle approximation) - roll_error = error_axis_angle[0] - pitch_error = error_axis_angle[1] - yaw_error = error_axis_angle[2] - - self.last_rotation_error = Vector([roll_error, pitch_error, yaw_error]) - - # Apply proportional control - angular_velocity = Vector( - [ - roll_error * self.rotation_gain, - pitch_error * self.rotation_gain, - yaw_error * self.rotation_gain, - ] - ) - - # Limit angular velocity magnitude - ang_vel_magnitude = np.sqrt( - angular_velocity.x**2 + angular_velocity.y**2 + angular_velocity.z**2 - ) - if ang_vel_magnitude > self.max_angular_velocity: - scale = self.max_angular_velocity / ang_vel_magnitude - angular_velocity = angular_velocity * scale - - self.last_angular_velocity_cmd = angular_velocity - - return angular_velocity - - def get_camera_pose_robot_frame(self, camera_pose_zed: Pose) -> Optional[Pose]: - """ - Get camera pose in robot frame coordinates. - - Args: - camera_pose_zed: Camera pose in ZED world frame - - Returns: - Camera pose in robot frame or None if no origin set - """ - if self.manipulator_origin is None: - return None - - camera_pose_manip = apply_transform(camera_pose_zed, self.manipulator_origin) - return camera_pose_manip - - def get_ee_pose_robot_frame(self, camera_pose_zed: Pose) -> Optional[Pose]: - """ - Get end-effector pose in robot frame coordinates. - - Args: - camera_pose_zed: Camera pose in ZED world frame - - Returns: - End-effector pose in robot frame or None if no origin set - """ - if self.manipulator_origin is None: - return None - - camera_pose_robot = apply_transform(camera_pose_zed, self.manipulator_origin) - return self._get_ee_pose_from_camera(camera_pose_robot) - - def get_object_pose_robot_frame( - self, object_pos_zed: Vector - ) -> Optional[Tuple[Vector, Vector]]: - """ - Get object pose in robot frame coordinates with orientation. - - Args: - object_pos_zed: Object position in ZED world frame - - Returns: - Tuple of (position, rotation) in robot frame or None if no origin set - """ - if self.manipulator_origin is None: - return None - - # Transform position - obj_pose_zed = Pose(object_pos_zed, Quaternion()) # Identity quaternion - obj_pose_manip = apply_transform(obj_pose_zed, self.manipulator_origin) - - # Calculate orientation pointing at origin - yaw_to_origin = yaw_towards_point( - Vector(obj_pose_manip.position.x, obj_pose_manip.position.y, obj_pose_manip.position.z) - ) - - # Convert euler angles to quaternion - euler = [0.0, 0.0, yaw_to_origin] # Level grasp - quat = R.from_euler("xyz", euler).as_quat() # [x, y, z, w] - orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) - - return obj_pose_manip.position, orientation - - def create_status_overlay( - self, image: np.ndarray, camera_intrinsics: Optional[list] = None - ) -> np.ndarray: - """ - Create PBVS status overlay on image. - - Args: - image: Input image - camera_intrinsics: Optional [fx, fy, cx, cy] (not used) - - Returns: - Image with PBVS status overlay - """ - viz_img = image.copy() - height, width = image.shape[:2] - - # Status panel - if self.current_target: - panel_height = 140 # Adjusted panel height - panel_y = height - panel_height - overlay = viz_img.copy() - cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) - viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) - - # Status text - y = panel_y + 20 - cv2.putText( - viz_img, "PBVS Status", (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2 - ) - - # Add frame info - frame_text = ( - "Frame: Robot" if self.manipulator_origin is not None else "Frame: ZED World" - ) - cv2.putText( - viz_img, frame_text, (200, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 - ) - - if self.last_position_error: - error_mag = np.linalg.norm( - [ - self.last_position_error.x, - self.last_position_error.y, - self.last_position_error.z, - ] - ) - color = (0, 255, 0) if self.last_target_reached else (0, 255, 255) - - cv2.putText( - viz_img, - f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", - (10, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - color, - 1, - ) - - cv2.putText( - viz_img, - f"XYZ: ({self.last_position_error.x:.3f}, {self.last_position_error.y:.3f}, {self.last_position_error.z:.3f})", - (10, y + 45), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (200, 200, 200), - 1, - ) - - if self.last_velocity_cmd: - cv2.putText( - viz_img, - f"Lin Vel: ({self.last_velocity_cmd.x:.2f}, {self.last_velocity_cmd.y:.2f}, {self.last_velocity_cmd.z:.2f})m/s", - (10, y + 65), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 200, 0), - 1, - ) - - if self.last_rotation_error: - cv2.putText( - viz_img, - f"Rot Error: ({self.last_rotation_error.x:.2f}, {self.last_rotation_error.y:.2f}, {self.last_rotation_error.z:.2f})rad", - (10, y + 85), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (200, 200, 200), - 1, - ) - - if self.last_angular_velocity_cmd: - cv2.putText( - viz_img, - f"Ang Vel: ({self.last_angular_velocity_cmd.x:.2f}, {self.last_angular_velocity_cmd.y:.2f}, {self.last_angular_velocity_cmd.z:.2f})rad/s", - (10, y + 105), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 200, 0), - 1, - ) - - # Add config info - ee_transform = self.ee_to_camera_transform_vec.to_list() - cv2.putText( - viz_img, - f"Pregrasp: {self.pregrasp_distance:.3f}m | EE Transform: [{ee_transform[0]:.2f},{ee_transform[1]:.2f},{ee_transform[2]:.2f}]", - (10, y + 125), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) - - if self.last_target_reached: - cv2.putText( - viz_img, - "TARGET REACHED", - (width - 150, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 0), - 2, - ) - - return viz_img diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index 68d90f2c2b..52d9e524b0 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -16,19 +16,19 @@ Real-time 3D object detection processor that extracts object poses from RGB-D data. """ -import time -from typing import Dict, List, Optional, Any +from typing import List, Optional, Tuple import numpy as np import cv2 from dimos.utils.logging_config import setup_logger from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter from dimos.perception.pointcloud.utils import extract_centroids_from_masks -from dimos.perception.detection2d.utils import plot_results, calculate_object_size_from_bbox +from dimos.perception.detection2d.utils import calculate_object_size_from_bbox -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion -from dimos.types.manipulation import ObjectData -from dimos.manipulation.visual_servoing.utils import estimate_object_depth +from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point +from dimos_lcm.vision_msgs import Detection3D, Detection3DArray, BoundingBox3D, ObjectHypothesisWithPose, ObjectHypothesis, Detection2D, Detection2DArray, BoundingBox2D, Pose2D, Point2D +from dimos_lcm.std_msgs import Header +from dimos.manipulation.visual_servoing.utils import estimate_object_depth, visualize_detections_3d from dimos.utils.transform_utils import ( optical_to_robot_frame, pose_to_matrix, @@ -88,7 +88,7 @@ def __init__( def process_frame( self, rgb_image: np.ndarray, depth_image: np.ndarray, transform: Optional[np.ndarray] = None - ) -> List[ObjectData]: + ) -> Tuple[Detection3DArray, Detection2DArray]: """ Process a single RGB-D frame to extract 3D object detections. @@ -98,7 +98,7 @@ def process_frame( transform: Optional 4x4 transformation matrix to transform objects from camera frame to desired frame Returns: - List of ObjectData objects with 3D pose information + Tuple of (Detection3DArray, Detection2DArray) with 3D and 2D information """ # Convert RGB to BGR for Sam (OpenCV format) @@ -109,7 +109,7 @@ def process_frame( # Early exit if no detections if not masks or len(masks) == 0: - return [] + return Detection3DArray(detections_length=0, header=Header(), detections=[]), Detection2DArray(detections_length=0, header=Header(), detections=[]) # Convert CUDA tensors to numpy arrays if needed numpy_masks = [] @@ -128,86 +128,116 @@ def process_frame( ) # Build detection results - detections = [] + detections_3d = [] + detections_2d = [] pose_dict = {p["mask_idx"]: p for p in poses if p["centroid"][2] < self.max_depth} for i, (bbox, name, prob, track_id) in enumerate(zip(bboxes, names, probs, track_ids)): - # Create ObjectData object - obj_data: ObjectData = { - "object_id": track_id, - "bbox": bbox.tolist() if isinstance(bbox, np.ndarray) else bbox, - "confidence": float(prob), - "label": name, - "movement_tolerance": 1.0, # Default to freely movable - "segmentation_mask": numpy_masks[i] if i < len(numpy_masks) else np.array([]), - } - - # Add 3D pose if available - if i in pose_dict: - pose = pose_dict[i] - obj_cam_pos = pose["centroid"] - - # Set depth and position in camera frame - obj_data["depth"] = float(obj_cam_pos[2]) - - if obj_cam_pos[2] > self.max_depth: - continue - - obj_data["rotation"] = None - - # Calculate object size from bbox and depth - width_m, height_m = calculate_object_size_from_bbox( - bbox, obj_cam_pos[2], self.camera_intrinsics - ) - # Calculate depth dimension using segmentation mask - depth_m = estimate_object_depth( - depth_image, numpy_masks[i] if i < len(numpy_masks) else None, bbox + # Skip if no 3D pose data + if i not in pose_dict: + continue + + pose = pose_dict[i] + obj_cam_pos = pose["centroid"] + + if obj_cam_pos[2] > self.max_depth: + continue + + # Calculate object size from bbox and depth + width_m, height_m = calculate_object_size_from_bbox( + bbox, obj_cam_pos[2], self.camera_intrinsics + ) + + # Calculate depth dimension using segmentation mask + depth_m = estimate_object_depth( + depth_image, numpy_masks[i] if i < len(numpy_masks) else None, bbox + ) + + size_x = max(width_m, 0.01) # Minimum 1cm width + size_y = max(height_m, 0.01) # Minimum 1cm height + size_z = max(depth_m, 0.01) # Minimum 1cm depth + + if min(size_x, size_y, size_z) > self.max_object_size: + continue + + # Transform to desired frame if transform matrix is provided + if transform is not None: + # Get orientation as euler angles, default to no rotation if not available + obj_cam_orientation = pose.get( + "rotation", np.array([0.0, 0.0, 0.0]) + ) # Default to no rotation + transformed_pose = self._transform_object_pose( + obj_cam_pos, obj_cam_orientation, transform + ) + center_pose = transformed_pose + else: + # If no transform, use camera coordinates + center_pose = Pose( + Point(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]), + Quaternion(0.0, 0.0, 0.0, 1.0) # Default orientation ) - obj_data["size"] = { - "width": max(width_m, 0.01), # Minimum 1cm width - "height": max(height_m, 0.01), # Minimum 1cm height - "depth": max(depth_m, 0.01), # Minimum 1cm depth - } - - if ( - min( - obj_data["size"]["width"], - obj_data["size"]["height"], - obj_data["size"]["depth"], + # Create Detection3D object + detection = Detection3D( + results_length=1, + header=Header(), # Empty header + results=[ObjectHypothesisWithPose( + hypothesis=ObjectHypothesis( + class_id=name, + score=float(prob) ) - > self.max_object_size - ): - continue - - # Extract average color from the region - x1, y1, x2, y2 = map(int, bbox) - roi = rgb_image[y1:y2, x1:x2] - if roi.size > 0: - avg_color = np.mean(roi.reshape(-1, 3), axis=0) - obj_data["color"] = avg_color.astype(np.uint8) - else: - obj_data["color"] = np.array([128, 128, 128], dtype=np.uint8) - - # Transform to desired frame if transform matrix is provided - if transform is not None: - # Get orientation as euler angles, default to no rotation if not available - obj_cam_orientation = pose.get( - "rotation", np.array([0.0, 0.0, 0.0]) - ) # Default to no rotation - transformed_pose = self._transform_object_pose( - obj_cam_pos, obj_cam_orientation, transform + )], + bbox=BoundingBox3D( + center=center_pose, + size=Vector3(size_x, size_y, size_z) + ), + id=str(track_id) + ) + + detections_3d.append(detection) + + # Create corresponding Detection2D + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + width = x2 - x1 + height = y2 - y1 + + detection_2d = Detection2D( + results_length=1, + header=Header(), + results=[ObjectHypothesisWithPose( + hypothesis=ObjectHypothesis( + class_id=name, + score=float(prob) ) - obj_data["position"] = transformed_pose.position - obj_data["rotation"] = transformed_pose.orientation - else: - # If no transform, use camera coordinates - obj_data["position"] = Vector3(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]) - - detections.append(obj_data) - - return detections + )], + bbox=BoundingBox2D( + center=Pose2D( + position=Point2D(center_x, center_y), + theta=0.0 + ), + size_x=float(width), + size_y=float(height) + ), + id=str(track_id) + ) + detections_2d.append(detection_2d) + + # Create and return both arrays + return ( + Detection3DArray( + detections_length=len(detections_3d), + header=Header(), + detections=detections_3d + ), + Detection2DArray( + detections_length=len(detections_2d), + header=Header(), + detections=detections_2d + ) + ) def _transform_object_pose( self, obj_pos: np.ndarray, obj_orientation: np.ndarray, transform_matrix: np.ndarray @@ -228,7 +258,7 @@ def _transform_object_pose( euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2]) obj_orientation_quat = euler_to_quaternion(euler_vector) - obj_pose_optical = Pose(Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), obj_orientation_quat) + obj_pose_optical = Pose(Point(obj_pos[0], obj_pos[1], obj_pos[2]), obj_orientation_quat) # Transform object pose from optical frame to robot frame convention first obj_pose_robot_frame = optical_to_robot_frame(obj_pose_optical) @@ -247,7 +277,8 @@ def _transform_object_pose( def visualize_detections( self, rgb_image: np.ndarray, - detections: List[ObjectData], + detections_3d: List[Detection3D], + detections_2d: List[Detection2D], show_coordinates: bool = True, ) -> np.ndarray: """ @@ -255,95 +286,50 @@ def visualize_detections( Args: rgb_image: Original RGB image - detections: List of ObjectData objects - show_coordinates: Whether to show 3D coordinates next to bounding boxes + detections_3d: List of Detection3D objects + detections_2d: List of Detection2D objects (must be 1:1 correspondence) + show_coordinates: Whether to show 3D coordinates Returns: Visualization image """ - if not detections: - return rgb_image.copy() - - # Extract data for plot_results function - bboxes = [det["bbox"] for det in detections] - track_ids = [det.get("object_id", i) for i, det in enumerate(detections)] - class_ids = [i for i in range(len(detections))] - confidences = [det["confidence"] for det in detections] - names = [det["label"] for det in detections] - - # Use plot_results for basic visualization - viz = plot_results(rgb_image, bboxes, track_ids, class_ids, confidences, names) - - # Add 3D position coordinates if requested - if show_coordinates: - for det in detections: - if "position" in det and "bbox" in det: - position = det["position"] - bbox = det["bbox"] - - if isinstance(position, Vector3): - pos_xyz = np.array([position.x, position.y, position.z]) - else: - pos_xyz = np.array([position["x"], position["y"], position["z"]]) - - # Get bounding box coordinates - x1, y1, x2, y2 = map(int, bbox) - - # Add position text next to bounding box (top-right corner) - pos_text = f"({pos_xyz[0]:.2f}, {pos_xyz[1]:.2f}, {pos_xyz[2]:.2f})" - text_x = x2 + 5 # Right edge of bbox + small offset - text_y = y1 + 15 # Top edge of bbox + small offset - - # Add background rectangle for better readability - text_size = cv2.getTextSize(pos_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0] - cv2.rectangle( - viz, - (text_x - 2, text_y - text_size[1] - 2), - (text_x + text_size[0] + 2, text_y + 2), - (0, 0, 0), - -1, - ) - - cv2.putText( - viz, - pos_text, - (text_x, text_y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) - - return viz + # Extract 2D bboxes from Detection2D objects + from dimos.manipulation.visual_servoing.utils import bbox2d_to_corners + bboxes_2d = [] + for det_2d in detections_2d: + if det_2d.bbox: + x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) + bboxes_2d.append([x1, y1, x2, y2]) + + return visualize_detections_3d(rgb_image, detections_3d, show_coordinates, bboxes_2d) def get_closest_detection( - self, detections: List[ObjectData], class_filter: Optional[str] = None - ) -> Optional[ObjectData]: + self, detections: List[Detection3D], class_filter: Optional[str] = None + ) -> Optional[Detection3D]: """ Get the closest detection with valid 3D data. Args: - detections: List of ObjectData objects + detections: List of Detection3D objects class_filter: Optional class name to filter by Returns: - Closest ObjectData or None + Closest Detection3D or None """ - valid_detections = [ - d - for d in detections - if "position" in d and (class_filter is None or d["label"] == class_filter) - ] + valid_detections = [] + for d in detections: + # Check if has valid bbox center position + if d.bbox and d.bbox.center and d.bbox.center.position: + # Check class filter if specified + if class_filter is None or (d.results_length > 0 and d.results[0].hypothesis.class_id == class_filter): + valid_detections.append(d) if not valid_detections: return None # Sort by depth (Z coordinate) def get_z_coord(d): - pos = d["position"] - if isinstance(pos, Vector3): - return abs(pos.z) - return abs(pos["z"]) + return abs(d.bbox.center.position.z) return min(valid_detections, key=get_z_coord) diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index 0ce25519ff..682a165042 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -18,20 +18,23 @@ """ import numpy as np -from typing import Optional, Tuple, Dict, Any, List -import cv2 +from typing import Optional, Tuple from enum import Enum from scipy.spatial.transform import Rotation as R -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion -from dimos.types.manipulation import ObjectData +from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point +from dimos_lcm.vision_msgs import Detection3D, Detection3DArray from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( yaw_towards_point, pose_to_matrix, euler_to_quaternion, ) -from dimos.manipulation.visual_servoing.utils import find_best_object_match +from dimos.manipulation.visual_servoing.utils import ( + find_best_object_match, + create_pbvs_status_overlay, + create_pbvs_controller_overlay, +) logger = setup_logger("dimos.manipulation.pbvs") @@ -64,7 +67,7 @@ def __init__( max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s target_tolerance: float = 0.01, # 1cm - max_tracking_distance_threshold: float = 0.2, # Max distance for target tracking (m) + max_tracking_distance_threshold: float = 0.1, # Max distance for target tracking (m) min_size_similarity: float = 0.7, # Min size similarity threshold (0.0-1.0) pregrasp_distance: float = 0.15, # 15cm pregrasp distance grasp_distance: float = 0.05, # 5cm grasp distance (final approach) @@ -123,21 +126,21 @@ def __init__( f"tracking_thresholds: distance={max_tracking_distance_threshold}m, size={min_size_similarity:.2f}" ) - def set_target(self, target_object: Dict[str, Any]) -> bool: + def set_target(self, target_object: Detection3D) -> bool: """ Set a new target object for servoing. Args: - target_object: Object dict with at least 'position' field + target_object: Detection3D object Returns: True if target was set successfully """ - if target_object and "position" in target_object: + if target_object and target_object.bbox and target_object.bbox.center: self.current_target = target_object self.target_grasp_pose = None # Will be computed when needed self.grasp_stage = GraspStage.PRE_GRASP # Reset to pre-grasp stage - logger.info(f"New target set: ID {target_object.get('object_id', 'unknown')}") + logger.info(f"New target set: ID {target_object.id}") return True return False @@ -152,12 +155,12 @@ def clear_target(self): self.controller.clear_state() logger.info("Target cleared") - def get_current_target(self): + def get_current_target(self) -> Optional[Detection3D]: """ Get the current target object. Returns: - Current target ObjectData or None if no target selected + Current target Detection3D or None if no target selected """ return self.current_target @@ -217,7 +220,7 @@ def is_target_reached(self, ee_pose: Pose) -> bool: return False - def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: + def update_target_tracking(self, new_detections: Detection3DArray) -> bool: """ Update target by matching to closest object in new detections. If tracking is lost, keeps the old target pose. @@ -228,10 +231,10 @@ def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: Returns: True if target was successfully tracked, False if lost (but target is kept) """ - if not self.current_target or "position" not in self.current_target: + if not self.current_target or not self.current_target.bbox or not self.current_target.bbox.center: return False - if not new_detections: + if not new_detections or new_detections.detections_length == 0: logger.debug("No detections for target tracking - using last known pose") return False @@ -241,7 +244,7 @@ def update_target_tracking(self, new_detections: List[ObjectData]) -> bool: # Find best match using standardized utility function match_result = find_best_object_match( target_obj=self.current_target, - candidates=new_detections, + candidates=new_detections.detections, max_distance=max_distance, min_size_similarity=self.min_size_similarity, ) @@ -270,11 +273,11 @@ def _update_target_grasp_pose(self, ee_pose: Pose): Args: ee_pose: Current end-effector pose """ - if not self.current_target: + if not self.current_target or not self.current_target.bbox or not self.current_target.bbox.center: return # Get target position - target_pos = self.current_target["position"] + target_pos = self.current_target.bbox.center.position # Calculate orientation pointing from target towards EE yaw_to_ee = yaw_towards_point( @@ -324,7 +327,7 @@ def _apply_grasp_distance(self, target_pose: Pose, distance: float) -> Pose: approach_vector_world = rotation_matrix @ approach_vector_local # Apply offset along the approach direction - offset_position = Vector3( + offset_position = Point( target_pose.position.x + distance * approach_vector_world[0], target_pose.position.y + distance * approach_vector_world[1], target_pose.position.z + distance * approach_vector_world[2], @@ -333,7 +336,7 @@ def _apply_grasp_distance(self, target_pose: Pose, distance: float) -> Pose: return Pose(offset_position, target_pose.orientation) def compute_control( - self, ee_pose: Pose, new_detections: Optional[List[ObjectData]] = None + self, ee_pose: Pose, new_detections: Optional[Detection3DArray] = None ) -> Tuple[Optional[Vector3], Optional[Vector3], bool, bool, Optional[Pose]]: """ Compute PBVS control with position and orientation servoing. @@ -351,7 +354,7 @@ def compute_control( - target_pose: Target EE pose (only in direct_ee_control mode, otherwise None) """ # Check if we have a target - if not self.current_target or "position" not in self.current_target: + if not self.current_target or not self.current_target.bbox or not self.current_target.bbox.center: return None, None, False, False, None # Try to update target tracking if new detections provided @@ -438,14 +441,21 @@ def create_status_overlay( Args: image: Input image - camera_intrinsics: Optional [fx, fy, cx, cy] (not used) Returns: Image with PBVS status overlay """ if self.direct_ee_control: - # Use our own error data for direct control mode - return self._create_direct_status_overlay(image, self.current_target) + # Use direct control overlay + return create_pbvs_status_overlay( + image, + self.current_target, + self.last_position_error, + self.last_target_reached, + self.target_grasp_pose, + self.grasp_stage.value, + is_direct_control=True, + ) else: # Use controller's overlay for velocity mode return self.controller.create_status_overlay( @@ -454,146 +464,6 @@ def create_status_overlay( self.direct_ee_control, ) - def _create_direct_status_overlay( - self, image: np.ndarray, current_target: Optional[ObjectData] = None - ) -> np.ndarray: - """ - Create status overlay for direct control mode. - - Args: - image: Input image - current_target: Current target object - - Returns: - Image with status overlay - """ - viz_img = image.copy() - height, width = image.shape[:2] - - # Status panel - if current_target is not None: - panel_height = 175 # Adjusted panel for target, grasp pose, stage, and distance info - panel_y = height - panel_height - overlay = viz_img.copy() - cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) - viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) - - # Status text - y = panel_y + 20 - cv2.putText( - viz_img, - "PBVS Status (Direct EE)", - (10, y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 255), - 2, - ) - - # Add frame info - cv2.putText( - viz_img, "Frame: Camera", (250, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 - ) - - if self.last_position_error: - error_mag = np.linalg.norm( - [ - self.last_position_error.x, - self.last_position_error.y, - self.last_position_error.z, - ] - ) - color = (0, 255, 0) if self.last_target_reached else (0, 255, 255) - - cv2.putText( - viz_img, - f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", - (10, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - color, - 1, - ) - - cv2.putText( - viz_img, - f"XYZ: ({self.last_position_error.x:.3f}, {self.last_position_error.y:.3f}, {self.last_position_error.z:.3f})", - (10, y + 45), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (200, 200, 200), - 1, - ) - - # Show target and grasp poses - if current_target: - target_pos = current_target["position"] - cv2.putText( - viz_img, - f"Target: ({target_pos.x:.3f}, {target_pos.y:.3f}, {target_pos.z:.3f})", - (10, y + 65), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 0), - 1, - ) - - if self.target_grasp_pose: - grasp_pos = self.target_grasp_pose.position - cv2.putText( - viz_img, - f"Grasp: ({grasp_pos.x:.3f}, {grasp_pos.y:.3f}, {grasp_pos.z:.3f})", - (10, y + 80), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (0, 255, 255), - 1, - ) - - # Show pregrasp distance if we have both poses - if current_target: - target_pos = current_target["position"] - distance = np.sqrt( - (grasp_pos.x - target_pos.x) ** 2 - + (grasp_pos.y - target_pos.y) ** 2 - + (grasp_pos.z - target_pos.z) ** 2 - ) - - # Show current stage and distance - stage_text = f"Stage: {self.grasp_stage.value}" - cv2.putText( - viz_img, - stage_text, - (10, y + 95), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 150, 255), - 1, - ) - - distance_text = f"Distance: {distance * 1000:.1f}mm" - cv2.putText( - viz_img, - distance_text, - (10, y + 110), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 200, 0), - 1, - ) - - if self.last_target_reached: - cv2.putText( - viz_img, - "TARGET REACHED", - (width - 150, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 0), - 2, - ) - - return viz_img class PBVSController: @@ -767,7 +637,7 @@ def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) def create_status_overlay( self, image: np.ndarray, - current_target: Optional[Dict[str, Any]] = None, + current_target: Optional[Detection3D] = None, direct_ee_control: bool = False, ) -> np.ndarray: """ @@ -775,113 +645,19 @@ def create_status_overlay( Args: image: Input image - current_target: Current target object (for display) + current_target: Current target object Detection3D (for display) direct_ee_control: Whether in direct EE control mode Returns: Image with PBVS status overlay """ - viz_img = image.copy() - height, width = image.shape[:2] - - # Status panel - if current_target is not None: - panel_height = 160 # Adjusted panel height - panel_y = height - panel_height - overlay = viz_img.copy() - cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) - viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) - - # Status text - y = panel_y + 20 - mode_text = "Direct EE" if direct_ee_control else "Velocity" - cv2.putText( - viz_img, - f"PBVS Status ({mode_text})", - (10, y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 255), - 2, - ) - - # Add frame info - cv2.putText( - viz_img, "Frame: Camera", (250, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 - ) - - if self.last_position_error: - error_mag = np.linalg.norm( - [ - self.last_position_error.x, - self.last_position_error.y, - self.last_position_error.z, - ] - ) - color = (0, 255, 0) if self.last_target_reached else (0, 255, 255) - - cv2.putText( - viz_img, - f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", - (10, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - color, - 1, - ) - - cv2.putText( - viz_img, - f"XYZ: ({self.last_position_error.x:.3f}, {self.last_position_error.y:.3f}, {self.last_position_error.z:.3f})", - (10, y + 45), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (200, 200, 200), - 1, - ) - - if self.last_velocity_cmd and not direct_ee_control: - cv2.putText( - viz_img, - f"Lin Vel: ({self.last_velocity_cmd.x:.2f}, {self.last_velocity_cmd.y:.2f}, {self.last_velocity_cmd.z:.2f})m/s", - (10, y + 65), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 200, 0), - 1, - ) - - if self.last_rotation_error: - cv2.putText( - viz_img, - f"Rot Error: ({self.last_rotation_error.x:.2f}, {self.last_rotation_error.y:.2f}, {self.last_rotation_error.z:.2f})rad", - (10, y + 85), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (200, 200, 200), - 1, - ) - - if self.last_angular_velocity_cmd and not direct_ee_control: - cv2.putText( - viz_img, - f"Ang Vel: ({self.last_angular_velocity_cmd.x:.2f}, {self.last_angular_velocity_cmd.y:.2f}, {self.last_angular_velocity_cmd.z:.2f})rad/s", - (10, y + 105), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 200, 0), - 1, - ) - - if self.last_target_reached: - cv2.putText( - viz_img, - "TARGET REACHED", - (width - 150, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 0), - 2, - ) - - return viz_img + return create_pbvs_controller_overlay( + image, + current_target, + self.last_position_error, + self.last_rotation_error, + self.last_velocity_cmd, + self.last_angular_velocity_cmd, + self.last_target_reached, + direct_ee_control, + ) diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index 0000bda999..098c49e7ab 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -13,17 +13,20 @@ # limitations under the License. import numpy as np -from typing import Dict, Any, Optional, List, Tuple, Union +from typing import Dict, Any, Optional, List, Tuple from dataclasses import dataclass -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion +from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point +from dimos_lcm.vision_msgs import Detection3D, Detection2D, BoundingBox2D +import cv2 +from dimos.perception.detection2d.utils import plot_results @dataclass class ObjectMatchResult: """Result of object matching with confidence metrics.""" - matched_object: Optional[Dict[str, Any]] + matched_object: Optional[Detection3D] confidence: float distance: float size_similarity: float @@ -31,8 +34,8 @@ class ObjectMatchResult: def calculate_object_similarity( - target_obj: Dict[str, Any], - candidate_obj: Dict[str, Any], + target_obj: Detection3D, + candidate_obj: Detection3D, distance_weight: float = 0.6, size_weight: float = 0.4, ) -> Tuple[float, float, float]: @@ -40,8 +43,8 @@ def calculate_object_similarity( Calculate comprehensive similarity between two objects. Args: - target_obj: Target object with 'position' and optionally 'size' - candidate_obj: Candidate object with 'position' and optionally 'size' + target_obj: Target Detection3D object + candidate_obj: Candidate Detection3D object distance_weight: Weight for distance component (0-1) size_weight: Weight for size component (0-1) @@ -49,22 +52,11 @@ def calculate_object_similarity( Tuple of (total_similarity, distance_m, size_similarity) """ # Extract positions - target_pos = target_obj.get("position", {}) - candidate_pos = candidate_obj.get("position", {}) + target_pos = target_obj.bbox.center.position + candidate_pos = candidate_obj.bbox.center.position - if isinstance(target_pos, Vector3): - target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) - else: - target_xyz = np.array( - [target_pos.get("x", 0), target_pos.get("y", 0), target_pos.get("z", 0)] - ) - - if isinstance(candidate_pos, Vector3): - candidate_xyz = np.array([candidate_pos.x, candidate_pos.y, candidate_pos.z]) - else: - candidate_xyz = np.array( - [candidate_pos.get("x", 0), candidate_pos.get("y", 0), candidate_pos.get("z", 0)] - ) + target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) + candidate_xyz = np.array([candidate_pos.x, candidate_pos.y, candidate_pos.z]) # Calculate Euclidean distance distance = np.linalg.norm(target_xyz - candidate_xyz) @@ -72,21 +64,13 @@ def calculate_object_similarity( # Calculate size similarity by comparing each dimension individually size_similarity = 1.0 # Default if no size info - target_size = target_obj.get("size", {}) - candidate_size = candidate_obj.get("size", {}) + target_size = target_obj.bbox.size + candidate_size = candidate_obj.bbox.size if target_size and candidate_size: - # Extract dimensions with defaults - target_dims = [ - target_size.get("width", 0.0), - target_size.get("height", 0.0), - target_size.get("depth", 0.0), - ] - candidate_dims = [ - candidate_size.get("width", 0.0), - candidate_size.get("height", 0.0), - candidate_size.get("depth", 0.0), - ] + # Extract dimensions + target_dims = [target_size.x, target_size.y, target_size.z] + candidate_dims = [candidate_size.x, candidate_size.y, candidate_size.z] # Calculate similarity for each dimension pair dim_similarities = [] @@ -112,8 +96,8 @@ def calculate_object_similarity( def find_best_object_match( - target_obj: Dict[str, Any], - candidates: List[Dict[str, Any]], + target_obj: Detection3D, + candidates: List[Detection3D], max_distance: float = 0.1, min_size_similarity: float = 0.4, distance_weight: float = 0.7, @@ -123,8 +107,8 @@ def find_best_object_match( Find the best matching object from candidates using distance and size criteria. Args: - target_obj: Target object to match against - candidates: List of candidate objects + target_obj: Target Detection3D to match against + candidates: List of candidate Detection3D objects max_distance: Maximum allowed distance for valid match (meters) min_size_similarity: Minimum size similarity for valid match (0-1) distance_weight: Weight for distance in similarity calculation @@ -133,7 +117,7 @@ def find_best_object_match( Returns: ObjectMatchResult with best match and confidence metrics """ - if not candidates or not target_obj.get("position"): + if not candidates or not target_obj.bbox or not target_obj.bbox.center: return ObjectMatchResult(None, 0.0, float("inf"), 0.0, False) best_match = None @@ -142,7 +126,7 @@ def find_best_object_match( best_size_sim = 0.0 for candidate in candidates: - if not candidate.get("position"): + if not candidate.bbox or not candidate.bbox.center: continue similarity, distance, size_sim = calculate_object_similarity( @@ -186,7 +170,7 @@ def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: # Extract position position = zed_pose_data.get("position", [0, 0, 0]) - pos_vector = Vector3(position[0], position[1], position[2]) + pos_vector = Point(position[0], position[1], position[2]) quat = zed_pose_data["rotation"] orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) @@ -244,3 +228,438 @@ def estimate_object_depth( return 0.10 else: return 0.05 + +# ============= Visualization Functions ============= + +def visualize_detections_3d( + rgb_image: np.ndarray, + detections: List[Detection3D], + show_coordinates: bool = True, + bboxes_2d: Optional[List[List[float]]] = None, +) -> np.ndarray: + """ + Visualize detections with 3D position overlay next to bounding boxes. + + Args: + rgb_image: Original RGB image + detections: List of Detection3D objects + show_coordinates: Whether to show 3D coordinates next to bounding boxes + bboxes_2d: Optional list of 2D bounding boxes corresponding to detections + + Returns: + Visualization image + """ + if not detections: + return rgb_image.copy() + + # If no 2D bboxes provided, skip visualization + if bboxes_2d is None: + return rgb_image.copy() + + # Extract data for plot_results function + bboxes = bboxes_2d + track_ids = [int(det.id) if det.id.isdigit() else i for i, det in enumerate(detections)] + class_ids = [i for i in range(len(detections))] + confidences = [det.results[0].hypothesis.score if det.results_length > 0 else 0.0 for det in detections] + names = [det.results[0].hypothesis.class_id if det.results_length > 0 else "unknown" for det in detections] + + # Use plot_results for basic visualization + viz = plot_results(rgb_image, bboxes, track_ids, class_ids, confidences, names) + + # Add 3D position coordinates if requested + if show_coordinates and bboxes_2d is not None: + for i, det in enumerate(detections): + if det.bbox and det.bbox.center and i < len(bboxes_2d): + position = det.bbox.center.position + bbox = bboxes_2d[i] + + pos_xyz = np.array([position.x, position.y, position.z]) + + # Get bounding box coordinates + x1, y1, x2, y2 = map(int, bbox) + + # Add position text next to bounding box (top-right corner) + pos_text = f"({pos_xyz[0]:.2f}, {pos_xyz[1]:.2f}, {pos_xyz[2]:.2f})" + text_x = x2 + 5 # Right edge of bbox + small offset + text_y = y1 + 15 # Top edge of bbox + small offset + + # Add background rectangle for better readability + text_size = cv2.getTextSize(pos_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0] + cv2.rectangle( + viz, + (text_x - 2, text_y - text_size[1] - 2), + (text_x + text_size[0] + 2, text_y + 2), + (0, 0, 0), + -1, + ) + + cv2.putText( + viz, + pos_text, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) + + return viz + + +def create_pbvs_status_overlay( + image: np.ndarray, + current_target: Optional[Detection3D], + position_error: Optional[Vector3], + target_reached: bool, + target_grasp_pose: Optional[Pose], + grasp_stage: str, + is_direct_control: bool = False, +) -> np.ndarray: + """ + Create PBVS status overlay for direct control mode. + + Args: + image: Input image + current_target: Current target Detection3D + position_error: Position error vector + target_reached: Whether target is reached + target_grasp_pose: Target grasp pose + grasp_stage: Current grasp stage + is_direct_control: Whether in direct control mode + + Returns: + Image with status overlay + """ + viz_img = image.copy() + height, width = image.shape[:2] + + # Status panel + if current_target is not None: + panel_height = 175 # Adjusted panel for target, grasp pose, stage, and distance info + panel_y = height - panel_height + overlay = viz_img.copy() + cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) + viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) + + # Status text + y = panel_y + 20 + mode_text = "Direct EE" if is_direct_control else "Velocity" + cv2.putText( + viz_img, + f"PBVS Status ({mode_text})", + (10, y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, + ) + + # Add frame info + cv2.putText( + viz_img, "Frame: Camera", (250, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 + ) + + if position_error: + error_mag = np.linalg.norm( + [ + position_error.x, + position_error.y, + position_error.z, + ] + ) + color = (0, 255, 0) if target_reached else (0, 255, 255) + + cv2.putText( + viz_img, + f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", + (10, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color, + 1, + ) + + cv2.putText( + viz_img, + f"XYZ: ({position_error.x:.3f}, {position_error.y:.3f}, {position_error.z:.3f})", + (10, y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + # Show target and grasp poses + if current_target and current_target.bbox and current_target.bbox.center: + target_pos = current_target.bbox.center.position + cv2.putText( + viz_img, + f"Target: ({target_pos.x:.3f}, {target_pos.y:.3f}, {target_pos.z:.3f})", + (10, y + 65), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 0), + 1, + ) + + if target_grasp_pose: + grasp_pos = target_grasp_pose.position + cv2.putText( + viz_img, + f"Grasp: ({grasp_pos.x:.3f}, {grasp_pos.y:.3f}, {grasp_pos.z:.3f})", + (10, y + 80), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (0, 255, 255), + 1, + ) + + # Show pregrasp distance if we have both poses + if current_target and current_target.bbox and current_target.bbox.center: + target_pos = current_target.bbox.center.position + distance = np.sqrt( + (grasp_pos.x - target_pos.x) ** 2 + + (grasp_pos.y - target_pos.y) ** 2 + + (grasp_pos.z - target_pos.z) ** 2 + ) + + # Show current stage and distance + stage_text = f"Stage: {grasp_stage}" + cv2.putText( + viz_img, + stage_text, + (10, y + 95), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 150, 255), + 1, + ) + + distance_text = f"Distance: {distance * 1000:.1f}mm" + cv2.putText( + viz_img, + distance_text, + (10, y + 110), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 200, 0), + 1, + ) + + if target_reached: + cv2.putText( + viz_img, + "TARGET REACHED", + (width - 150, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + return viz_img + + +def create_pbvs_controller_overlay( + image: np.ndarray, + current_target: Optional[Detection3D], + position_error: Optional[Vector3], + rotation_error: Optional[Vector3], + velocity_cmd: Optional[Vector3], + angular_velocity_cmd: Optional[Vector3], + target_reached: bool, + direct_ee_control: bool = False, +) -> np.ndarray: + """ + Create PBVS controller status overlay on image. + + Args: + image: Input image + current_target: Current target Detection3D (for display) + position_error: Position error vector + rotation_error: Rotation error vector + velocity_cmd: Linear velocity command + angular_velocity_cmd: Angular velocity command + target_reached: Whether target is reached + direct_ee_control: Whether in direct EE control mode + + Returns: + Image with PBVS status overlay + """ + viz_img = image.copy() + height, width = image.shape[:2] + + # Status panel + if current_target is not None: + panel_height = 160 # Adjusted panel height + panel_y = height - panel_height + overlay = viz_img.copy() + cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) + viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) + + # Status text + y = panel_y + 20 + mode_text = "Direct EE" if direct_ee_control else "Velocity" + cv2.putText( + viz_img, + f"PBVS Status ({mode_text})", + (10, y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, + ) + + # Add frame info + cv2.putText( + viz_img, "Frame: Camera", (250, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 + ) + + if position_error: + error_mag = np.linalg.norm( + [ + position_error.x, + position_error.y, + position_error.z, + ] + ) + color = (0, 255, 0) if target_reached else (0, 255, 255) + + cv2.putText( + viz_img, + f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", + (10, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color, + 1, + ) + + cv2.putText( + viz_img, + f"XYZ: ({position_error.x:.3f}, {position_error.y:.3f}, {position_error.z:.3f})", + (10, y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + if velocity_cmd and not direct_ee_control: + cv2.putText( + viz_img, + f"Lin Vel: ({velocity_cmd.x:.2f}, {velocity_cmd.y:.2f}, {velocity_cmd.z:.2f})m/s", + (10, y + 65), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 200, 0), + 1, + ) + + if rotation_error: + cv2.putText( + viz_img, + f"Rot Error: ({rotation_error.x:.2f}, {rotation_error.y:.2f}, {rotation_error.z:.2f})rad", + (10, y + 85), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + if angular_velocity_cmd and not direct_ee_control: + cv2.putText( + viz_img, + f"Ang Vel: ({angular_velocity_cmd.x:.2f}, {angular_velocity_cmd.y:.2f}, {angular_velocity_cmd.z:.2f})rad/s", + (10, y + 105), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 200, 0), + 1, + ) + + if target_reached: + cv2.putText( + viz_img, + "TARGET REACHED", + (width - 150, y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + return viz_img + + +def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> Tuple[float, float, float, float]: + """ + Convert BoundingBox2D from center format to corner format. + + Args: + bbox_2d: BoundingBox2D with center and size + + Returns: + Tuple of (x1, y1, x2, y2) corner coordinates + """ + center_x = bbox_2d.center.position.x + center_y = bbox_2d.center.position.y + half_width = bbox_2d.size_x / 2.0 + half_height = bbox_2d.size_y / 2.0 + + x1 = center_x - half_width + y1 = center_y - half_height + x2 = center_x + half_width + y2 = center_y + half_height + + return x1, y1, x2, y2 + + +def find_clicked_detection( + click_pos: Tuple[int, int], + detections_2d: List[Detection2D], + detections_3d: List[Detection3D] +) -> Optional[Detection3D]: + """ + Find which detection was clicked based on 2D bounding boxes. + + Args: + click_pos: (x, y) click position + detections_2d: List of Detection2D objects + detections_3d: List of Detection3D objects (must be 1:1 correspondence) + + Returns: + Corresponding Detection3D object if found, None otherwise + """ + click_x, click_y = click_pos + + for i, det_2d in enumerate(detections_2d): + if det_2d.bbox and i < len(detections_3d): + x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) + + if x1 <= click_x <= x2 and y1 <= click_y <= y2: + return detections_3d[i] + + return None + + +def get_detection2d_for_detection3d( + detection_3d: Detection3D, + detections_3d: List[Detection3D], + detections_2d: List[Detection2D] +) -> Optional[Detection2D]: + """ + Find the corresponding Detection2D for a given Detection3D. + + Args: + detection_3d: The Detection3D to match + detections_3d: List of all Detection3D objects + detections_2d: List of all Detection2D objects (must be 1:1 correspondence) + + Returns: + Corresponding Detection2D if found, None otherwise + """ + for i, det_3d in enumerate(detections_3d): + if det_3d.id == detection_3d.id and i < len(detections_2d): + return detections_2d[i] + return None \ No newline at end of file diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index b5bfe6afc5..ff3d4aa7a5 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -22,26 +22,26 @@ """ import cv2 -import numpy as np import sys -import os import time -import tests.test_header from dimos.hardware.zed_camera import ZEDCamera from dimos.hardware.piper_arm import PiperArm from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor -from dimos.perception.common.utils import find_clicked_object from dimos.manipulation.visual_servoing.pbvs import PBVS, GraspStage +from dimos.manipulation.visual_servoing.utils import ( + find_clicked_detection, + get_detection2d_for_detection3d, + bbox2d_to_corners, +) from dimos.utils.transform_utils import ( pose_to_matrix, matrix_to_pose, create_transform_from_6dof, compose_transforms, - quaternion_to_euler, ) -from dimos.msgs.geometry_msgs import Vector3 +from dimos_lcm.geometry_msgs import Vector3 try: import pyzed.sl as sl @@ -54,7 +54,7 @@ mouse_click = None -def mouse_callback(event, x, y, flags, param): +def mouse_callback(event, x, y, _flags, _param): global mouse_click if event == cv2.EVENT_LBUTTONDOWN: mouse_click = (x, y) @@ -66,8 +66,8 @@ def execute_grasp(arm, target_object, target_pose, grasp_width_offset: float = 0 Args: arm: Robot arm interface with gripper control - target_object: ObjectData with size information - safety_margin: Multiplier for gripper opening (default 1.5x object size) + target_object: Detection3D with size information + grasp_width_offset: Additional width to add to object size for gripper opening Returns: True if grasp was executed, False if no target or no size data @@ -76,33 +76,24 @@ def execute_grasp(arm, target_object, target_pose, grasp_width_offset: float = 0 print("❌ No target object provided for grasping") return False - if "size" not in target_object: + if not target_object.bbox or not target_object.bbox.size: print("❌ Target has no size information for grasping") return False # Get object size from detection3d data (already in meters) - object_size = target_object["size"] - object_width = object_size["width"] - object_height = object_size["height"] - object_depth = object_size["depth"] - - # Use the larger dimension (width or height) for gripper opening - # Depth is not relevant for gripper opening (that's approach direction) + object_size = target_object.bbox.size + object_width = object_size.x - # Calculate gripper opening with safety margin + # Calculate gripper opening with offset gripper_opening = object_width + grasp_width_offset # Clamp gripper opening to reasonable limits (0.5cm to 10cm) - gripper_opening = max(0.005, min(gripper_opening, 0.1)) # 0.5cm to 10cm + gripper_opening = max(0.005, min(gripper_opening, 0.1)) - print( - f"🤏 Executing grasp: object size w={object_width * 1000:.1f}mm h={object_height * 1000:.1f}mm d={object_depth * 1000:.1f}mm, " - f"offset={grasp_width_offset * 1000:.1f}mm, opening gripper to {gripper_opening * 1000:.1f}mm" - ) + print(f"🤏 Executing grasp: opening gripper to {gripper_opening * 1000:.1f}mm") # Command gripper to open arm.cmd_gripper_ctrl(gripper_opening) - arm.cmd_ee_pose(target_pose, line_mode=True) return True @@ -111,8 +102,8 @@ def execute_grasp(arm, target_object, target_pose, grasp_width_offset: float = 0 def main(): global mouse_click - # Control mode flag - DIRECT_EE_CONTROL = True # Set to True for direct EE pose control, False for velocity control + # Configuration + DIRECT_EE_CONTROL = True # True: direct EE pose control, False: velocity control print("=== PBVS Eye-in-Hand Test ===") print("Using EE pose as odometry for camera pose") @@ -206,21 +197,25 @@ def main(): camera_pose = matrix_to_pose(camera_transform) # Process detections using camera transform - detections = detector.process_frame(rgb, depth, camera_transform) + detection_3d_array, detection_2d_array = detector.process_frame(rgb, depth, camera_transform) # Handle click if mouse_click: - clicked = find_clicked_object(mouse_click, detections) - if clicked: - pbvs.set_target(clicked) + clicked_3d = find_clicked_detection( + mouse_click, + detection_2d_array.detections, + detection_3d_array.detections + ) + if clicked_3d: + pbvs.set_target(clicked_3d) mouse_click = None # Create visualization with position overlays - viz = detector.visualize_detections(rgb, detections) + viz = detector.visualize_detections(rgb, detection_3d_array.detections, detection_2d_array.detections) # PBVS control vel_cmd, ang_vel_cmd, reached, target_tracked, target_pose = pbvs.compute_control( - ee_pose, detections + ee_pose, detection_3d_array ) # Apply commands to robot based on control mode @@ -246,17 +241,25 @@ def main(): vel_cmd.x, vel_cmd.y, vel_cmd.z, ang_vel_cmd.x, ang_vel_cmd.y, ang_vel_cmd.z ) - # Apply PBVS overlay + # Add PBVS status overlay viz = pbvs.create_status_overlay(viz) # Highlight target current_target = pbvs.get_current_target() - if target_tracked and current_target and "bbox" in current_target: - x1, y1, x2, y2 = map(int, current_target["bbox"]) - cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) - cv2.putText( - viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 + if target_tracked and current_target: + det_2d = get_detection2d_for_detection3d( + current_target, + detection_3d_array.detections, + detection_2d_array.detections ) + if det_2d and det_2d.bbox: + x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + + cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) + cv2.putText( + viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 + ) # Convert back to BGR for OpenCV display viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) @@ -281,71 +284,44 @@ def main(): ee_text = f"EE: ({ee_pose.position.x:.2f}, {ee_pose.position.y:.2f}, {ee_pose.position.z:.2f})m" cv2.putText(viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) - # Add direct EE control status + # Add control status if DIRECT_EE_CONTROL: - if target_pose: - status_text = "Target Ready - Press SPACE to execute" - status_color = (0, 255, 255) # Yellow - else: - status_text = "No target selected" - status_color = (100, 100, 100) # Gray - - cv2.putText( - viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 - ) - - cv2.putText( - viz_bgr, - "s=STOP | h=HOME | SPACE=EXECUTE | g=RELEASE", - (10, 110), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) + status_text = "Target Ready - Press SPACE to execute" if target_pose else "No target selected" + status_color = (0, 255, 255) if target_pose else (100, 100, 100) + cv2.putText(viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1) + cv2.putText(viz_bgr, "s=STOP | h=HOME | SPACE=EXECUTE | g=RELEASE", + (10, 110), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) # Display cv2.imshow("PBVS", viz_bgr) - # Keyboard + # Handle keyboard input key = cv2.waitKey(1) & 0xFF if key == ord("q"): break elif key == ord("r"): pbvs.clear_target() elif key == ord("s"): - # SOFT STOP - Emergency stop print("🛑 SOFT STOP - Emergency stopping robot!") arm.softStop() elif key == ord("h"): - # GO HOME - Return to safe position print("🏠 GO HOME - Returning to safe position...") arm.gotoZero() - elif key == ord(" "): - # SPACE - Execute target pose (only in direct EE mode) - if DIRECT_EE_CONTROL and target_pose: - execute_target = True - target_euler = quaternion_to_euler(target_pose.orientation, degrees=True) - if pbvs.grasp_stage == GraspStage.PRE_GRASP: - pbvs.set_grasp_stage(GraspStage.GRASP) - print("⚡ SPACE pressed - Target will execute on next frame!") - print( - f"📍 Target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f}) " - f"rot=({target_euler.x:.1f}°, {target_euler.y:.1f}°, {target_euler.z:.1f}°)" - ) - elif key == 82: # Up arrow key (increase pitch) - current_pitch = pbvs.grasp_pitch_degrees - new_pitch = min(90.0, current_pitch + 15.0) + elif key == ord(" ") and DIRECT_EE_CONTROL and target_pose: + execute_target = True + if pbvs.grasp_stage == GraspStage.PRE_GRASP: + pbvs.set_grasp_stage(GraspStage.GRASP) + print("⚡ Executing target pose") + elif key == 82: # Up arrow - increase pitch + new_pitch = min(90.0, pbvs.grasp_pitch_degrees + 15.0) pbvs.set_grasp_pitch(new_pitch) - print(f"↑ Grasp pitch increased to {new_pitch:.0f}° (0°=level, 90°=top-down)") - elif key == 84: # Down arrow key (decrease pitch) - current_pitch = pbvs.grasp_pitch_degrees - new_pitch = max(0.0, current_pitch - 15.0) + print(f"↑ Grasp pitch: {new_pitch:.0f}°") + elif key == 84: # Down arrow - decrease pitch + new_pitch = max(0.0, pbvs.grasp_pitch_degrees - 15.0) pbvs.set_grasp_pitch(new_pitch) - print(f"↓ Grasp pitch decreased to {new_pitch:.0f}° (0°=level, 90°=top-down)") + print(f"↓ Grasp pitch: {new_pitch:.0f}°") elif key == ord("g"): - # G - Release gripper (open to 100mm) - print("🖐️ RELEASE - Opening gripper to 100mm...") + print("🖐️ Opening gripper") arm.release_gripper() except KeyboardInterrupt: From 3e0cd6c1b6119b82023f8a8fb9e12342dfcb5fc5 Mon Sep 17 00:00:00 2001 From: alexlin2 <44330195+alexlin2@users.noreply.github.com> Date: Fri, 18 Jul 2025 05:16:48 +0000 Subject: [PATCH 66/89] CI code cleanup --- .../visual_servoing/detection3d.py | 81 ++++++++++--------- dimos/manipulation/visual_servoing/pbvs.py | 35 +++++--- dimos/manipulation/visual_servoing/utils.py | 47 ++++++----- tests/test_ibvs.py | 41 ++++++---- 4 files changed, 117 insertions(+), 87 deletions(-) diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index 52d9e524b0..9ecc79bb1d 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -26,7 +26,18 @@ from dimos.perception.detection2d.utils import calculate_object_size_from_bbox from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point -from dimos_lcm.vision_msgs import Detection3D, Detection3DArray, BoundingBox3D, ObjectHypothesisWithPose, ObjectHypothesis, Detection2D, Detection2DArray, BoundingBox2D, Pose2D, Point2D +from dimos_lcm.vision_msgs import ( + Detection3D, + Detection3DArray, + BoundingBox3D, + ObjectHypothesisWithPose, + ObjectHypothesis, + Detection2D, + Detection2DArray, + BoundingBox2D, + Pose2D, + Point2D, +) from dimos_lcm.std_msgs import Header from dimos.manipulation.visual_servoing.utils import estimate_object_depth, visualize_detections_3d from dimos.utils.transform_utils import ( @@ -109,7 +120,9 @@ def process_frame( # Early exit if no detections if not masks or len(masks) == 0: - return Detection3DArray(detections_length=0, header=Header(), detections=[]), Detection2DArray(detections_length=0, header=Header(), detections=[]) + return Detection3DArray( + detections_length=0, header=Header(), detections=[] + ), Detection2DArray(detections_length=0, header=Header(), detections=[]) # Convert CUDA tensors to numpy arrays if needed numpy_masks = [] @@ -133,14 +146,13 @@ def process_frame( pose_dict = {p["mask_idx"]: p for p in poses if p["centroid"][2] < self.max_depth} for i, (bbox, name, prob, track_id) in enumerate(zip(bboxes, names, probs, track_ids)): - # Skip if no 3D pose data if i not in pose_dict: continue - + pose = pose_dict[i] obj_cam_pos = pose["centroid"] - + if obj_cam_pos[2] > self.max_depth: continue @@ -175,68 +187,56 @@ def process_frame( # If no transform, use camera coordinates center_pose = Pose( Point(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]), - Quaternion(0.0, 0.0, 0.0, 1.0) # Default orientation + Quaternion(0.0, 0.0, 0.0, 1.0), # Default orientation ) # Create Detection3D object detection = Detection3D( results_length=1, header=Header(), # Empty header - results=[ObjectHypothesisWithPose( - hypothesis=ObjectHypothesis( - class_id=name, - score=float(prob) + results=[ + ObjectHypothesisWithPose( + hypothesis=ObjectHypothesis(class_id=name, score=float(prob)) ) - )], - bbox=BoundingBox3D( - center=center_pose, - size=Vector3(size_x, size_y, size_z) - ), - id=str(track_id) + ], + bbox=BoundingBox3D(center=center_pose, size=Vector3(size_x, size_y, size_z)), + id=str(track_id), ) - + detections_3d.append(detection) - + # Create corresponding Detection2D x1, y1, x2, y2 = bbox center_x = (x1 + x2) / 2.0 center_y = (y1 + y2) / 2.0 width = x2 - x1 height = y2 - y1 - + detection_2d = Detection2D( results_length=1, header=Header(), - results=[ObjectHypothesisWithPose( - hypothesis=ObjectHypothesis( - class_id=name, - score=float(prob) + results=[ + ObjectHypothesisWithPose( + hypothesis=ObjectHypothesis(class_id=name, score=float(prob)) ) - )], + ], bbox=BoundingBox2D( - center=Pose2D( - position=Point2D(center_x, center_y), - theta=0.0 - ), + center=Pose2D(position=Point2D(center_x, center_y), theta=0.0), size_x=float(width), - size_y=float(height) + size_y=float(height), ), - id=str(track_id) + id=str(track_id), ) detections_2d.append(detection_2d) # Create and return both arrays return ( Detection3DArray( - detections_length=len(detections_3d), - header=Header(), - detections=detections_3d + detections_length=len(detections_3d), header=Header(), detections=detections_3d ), Detection2DArray( - detections_length=len(detections_2d), - header=Header(), - detections=detections_2d - ) + detections_length=len(detections_2d), header=Header(), detections=detections_2d + ), ) def _transform_object_pose( @@ -295,12 +295,13 @@ def visualize_detections( """ # Extract 2D bboxes from Detection2D objects from dimos.manipulation.visual_servoing.utils import bbox2d_to_corners + bboxes_2d = [] for det_2d in detections_2d: if det_2d.bbox: x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) bboxes_2d.append([x1, y1, x2, y2]) - + return visualize_detections_3d(rgb_image, detections_3d, show_coordinates, bboxes_2d) def get_closest_detection( @@ -321,7 +322,9 @@ def get_closest_detection( # Check if has valid bbox center position if d.bbox and d.bbox.center and d.bbox.center.position: # Check class filter if specified - if class_filter is None or (d.results_length > 0 and d.results[0].hypothesis.class_id == class_filter): + if class_filter is None or ( + d.results_length > 0 and d.results[0].hypothesis.class_id == class_filter + ): valid_detections.append(d) if not valid_detections: diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index 682a165042..e90f8e6996 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -109,7 +109,9 @@ def __init__( self.pregrasp_distance = pregrasp_distance self.grasp_distance = grasp_distance self.direct_ee_control = direct_ee_control - self.grasp_pitch_degrees = 45.0 # Default grasp pitch in degrees (45° between level and top-down) + self.grasp_pitch_degrees = ( + 45.0 # Default grasp pitch in degrees (45° between level and top-down) + ) # Target state self.current_target = None @@ -176,7 +178,7 @@ def set_grasp_stage(self, stage: GraspStage): def set_grasp_pitch(self, pitch_degrees: float): """ Set the grasp pitch angle in degrees. - + Args: pitch_degrees: Grasp pitch angle in degrees (0-90) 0° = level grasp (horizontal) @@ -231,7 +233,11 @@ def update_target_tracking(self, new_detections: Detection3DArray) -> bool: Returns: True if target was successfully tracked, False if lost (but target is kept) """ - if not self.current_target or not self.current_target.bbox or not self.current_target.bbox.center: + if ( + not self.current_target + or not self.current_target.bbox + or not self.current_target.bbox.center + ): return False if not new_detections or new_detections.detections_length == 0: @@ -273,7 +279,11 @@ def _update_target_grasp_pose(self, ee_pose: Pose): Args: ee_pose: Current end-effector pose """ - if not self.current_target or not self.current_target.bbox or not self.current_target.bbox.center: + if ( + not self.current_target + or not self.current_target.bbox + or not self.current_target.bbox.center + ): return # Get target position @@ -288,7 +298,7 @@ def _update_target_grasp_pose(self, ee_pose: Pose): # Convert grasp pitch from degrees to radians with mapping: # 0° (level) -> π/2 (1.57 rad), 90° (top-down) -> π (3.14 rad) pitch_radians = 1.57 + (self.grasp_pitch_degrees * np.pi / 180.0 / 2.0) - + # Convert euler angles to quaternion using utility function euler = Vector3(0.0, pitch_radians, yaw_to_ee) # roll=0, pitch=mapped, yaw=calculated target_orientation = euler_to_quaternion(euler) @@ -317,22 +327,22 @@ def _apply_grasp_distance(self, target_pose: Pose, distance: float) -> Pose: # Convert pose to transformation matrix to extract rotation T_target = pose_to_matrix(target_pose) rotation_matrix = T_target[:3, :3] - + # Define the approach vector based on the target pose orientation # Assuming the gripper approaches along its local -z axis (common for downward grasps) # You can change this to [1, 0, 0] for x-axis or [0, 1, 0] for y-axis based on your gripper approach_vector_local = np.array([0, 0, -1]) - + # Transform approach vector to world coordinates approach_vector_world = rotation_matrix @ approach_vector_local - + # Apply offset along the approach direction offset_position = Point( target_pose.position.x + distance * approach_vector_world[0], target_pose.position.y + distance * approach_vector_world[1], target_pose.position.z + distance * approach_vector_world[2], ) - + return Pose(offset_position, target_pose.orientation) def compute_control( @@ -354,7 +364,11 @@ def compute_control( - target_pose: Target EE pose (only in direct_ee_control mode, otherwise None) """ # Check if we have a target - if not self.current_target or not self.current_target.bbox or not self.current_target.bbox.center: + if ( + not self.current_target + or not self.current_target.bbox + or not self.current_target.bbox.center + ): return None, None, False, False, None # Try to update target tracking if new detections provided @@ -465,7 +479,6 @@ def create_status_overlay( ) - class PBVSController: """ Low-level Position-Based Visual Servoing controller. diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index 098c49e7ab..2dff33a410 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -229,8 +229,10 @@ def estimate_object_depth( else: return 0.05 + # ============= Visualization Functions ============= + def visualize_detections_3d( rgb_image: np.ndarray, detections: List[Detection3D], @@ -252,16 +254,21 @@ def visualize_detections_3d( if not detections: return rgb_image.copy() - # If no 2D bboxes provided, skip visualization + # If no 2D bboxes provided, skip visualization if bboxes_2d is None: return rgb_image.copy() - + # Extract data for plot_results function bboxes = bboxes_2d track_ids = [int(det.id) if det.id.isdigit() else i for i, det in enumerate(detections)] class_ids = [i for i in range(len(detections))] - confidences = [det.results[0].hypothesis.score if det.results_length > 0 else 0.0 for det in detections] - names = [det.results[0].hypothesis.class_id if det.results_length > 0 else "unknown" for det in detections] + confidences = [ + det.results[0].hypothesis.score if det.results_length > 0 else 0.0 for det in detections + ] + names = [ + det.results[0].hypothesis.class_id if det.results_length > 0 else "unknown" + for det in detections + ] # Use plot_results for basic visualization viz = plot_results(rgb_image, bboxes, track_ids, class_ids, confidences, names) @@ -595,10 +602,10 @@ def create_pbvs_controller_overlay( def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> Tuple[float, float, float, float]: """ Convert BoundingBox2D from center format to corner format. - + Args: bbox_2d: BoundingBox2D with center and size - + Returns: Tuple of (x1, y1, x2, y2) corner coordinates """ @@ -606,60 +613,56 @@ def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> Tuple[float, float, float, floa center_y = bbox_2d.center.position.y half_width = bbox_2d.size_x / 2.0 half_height = bbox_2d.size_y / 2.0 - + x1 = center_x - half_width y1 = center_y - half_height x2 = center_x + half_width y2 = center_y + half_height - + return x1, y1, x2, y2 def find_clicked_detection( - click_pos: Tuple[int, int], - detections_2d: List[Detection2D], - detections_3d: List[Detection3D] + click_pos: Tuple[int, int], detections_2d: List[Detection2D], detections_3d: List[Detection3D] ) -> Optional[Detection3D]: """ Find which detection was clicked based on 2D bounding boxes. - + Args: click_pos: (x, y) click position detections_2d: List of Detection2D objects detections_3d: List of Detection3D objects (must be 1:1 correspondence) - + Returns: Corresponding Detection3D object if found, None otherwise """ click_x, click_y = click_pos - + for i, det_2d in enumerate(detections_2d): if det_2d.bbox and i < len(detections_3d): x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) - + if x1 <= click_x <= x2 and y1 <= click_y <= y2: return detections_3d[i] - + return None def get_detection2d_for_detection3d( - detection_3d: Detection3D, - detections_3d: List[Detection3D], - detections_2d: List[Detection2D] + detection_3d: Detection3D, detections_3d: List[Detection3D], detections_2d: List[Detection2D] ) -> Optional[Detection2D]: """ Find the corresponding Detection2D for a given Detection3D. - + Args: detection_3d: The Detection3D to match detections_3d: List of all Detection3D objects detections_2d: List of all Detection2D objects (must be 1:1 correspondence) - + Returns: Corresponding Detection2D if found, None otherwise """ for i, det_3d in enumerate(detections_3d): if det_3d.id == detection_3d.id and i < len(detections_2d): return detections_2d[i] - return None \ No newline at end of file + return None diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index ff3d4aa7a5..6738058dca 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -161,7 +161,7 @@ def main(): grasp_distance=0.01, direct_ee_control=DIRECT_EE_CONTROL, ) - + # Set custom grasp pitch (60 degrees - between level and top-down) GRASP_PITCH_DEGREES = 0 # 0° = level grasp, 90° = top-down grasp pbvs.set_grasp_pitch(GRASP_PITCH_DEGREES) @@ -173,7 +173,7 @@ def main(): # Control state for direct EE mode execute_target = False # Only move when space is pressed last_valid_target = None - + # Rate limiting for pose execution MIN_EXECUTION_PERIOD = 1.0 # Minimum seconds between pose executions last_execution_time = 0 @@ -197,21 +197,23 @@ def main(): camera_pose = matrix_to_pose(camera_transform) # Process detections using camera transform - detection_3d_array, detection_2d_array = detector.process_frame(rgb, depth, camera_transform) + detection_3d_array, detection_2d_array = detector.process_frame( + rgb, depth, camera_transform + ) # Handle click if mouse_click: clicked_3d = find_clicked_detection( - mouse_click, - detection_2d_array.detections, - detection_3d_array.detections + mouse_click, detection_2d_array.detections, detection_3d_array.detections ) if clicked_3d: pbvs.set_target(clicked_3d) mouse_click = None # Create visualization with position overlays - viz = detector.visualize_detections(rgb, detection_3d_array.detections, detection_2d_array.detections) + viz = detector.visualize_detections( + rgb, detection_3d_array.detections, detection_2d_array.detections + ) # PBVS control vel_cmd, ang_vel_cmd, reached, target_tracked, target_pose = pbvs.compute_control( @@ -248,14 +250,12 @@ def main(): current_target = pbvs.get_current_target() if target_tracked and current_target: det_2d = get_detection2d_for_detection3d( - current_target, - detection_3d_array.detections, - detection_2d_array.detections + current_target, detection_3d_array.detections, detection_2d_array.detections ) if det_2d and det_2d.bbox: x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) - + cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) cv2.putText( viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 @@ -286,11 +286,22 @@ def main(): # Add control status if DIRECT_EE_CONTROL: - status_text = "Target Ready - Press SPACE to execute" if target_pose else "No target selected" + status_text = ( + "Target Ready - Press SPACE to execute" if target_pose else "No target selected" + ) status_color = (0, 255, 255) if target_pose else (100, 100, 100) - cv2.putText(viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1) - cv2.putText(viz_bgr, "s=STOP | h=HOME | SPACE=EXECUTE | g=RELEASE", - (10, 110), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) + cv2.putText( + viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 + ) + cv2.putText( + viz_bgr, + "s=STOP | h=HOME | SPACE=EXECUTE | g=RELEASE", + (10, 110), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) # Display cv2.imshow("PBVS", viz_bgr) From 5441bf17c364dc9ea472e197f434898d7ca0f18b Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Fri, 18 Jul 2025 00:52:21 -0700 Subject: [PATCH 67/89] cleanup --- .../visual_servoing/detection3d.py | 2 +- dimos/manipulation/visual_servoing/utils.py | 53 +------------------ dimos/perception/common/utils.py | 49 ++++++++++++++--- tests/test_ibvs.py | 8 +-- 4 files changed, 49 insertions(+), 63 deletions(-) diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index 9ecc79bb1d..b54bb81fd3 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -24,6 +24,7 @@ from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter from dimos.perception.pointcloud.utils import extract_centroids_from_masks from dimos.perception.detection2d.utils import calculate_object_size_from_bbox +from dimos.perception.common.utils import bbox2d_to_corners from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point from dimos_lcm.vision_msgs import ( @@ -294,7 +295,6 @@ def visualize_detections( Visualization image """ # Extract 2D bboxes from Detection2D objects - from dimos.manipulation.visual_servoing.utils import bbox2d_to_corners bboxes_2d = [] for det_2d in detections_2d: diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index 2dff33a410..4e8a0a81b7 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point -from dimos_lcm.vision_msgs import Detection3D, Detection2D, BoundingBox2D +from dimos_lcm.vision_msgs import Detection3D, Detection2D import cv2 from dimos.perception.detection2d.utils import plot_results @@ -599,56 +599,7 @@ def create_pbvs_controller_overlay( return viz_img -def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> Tuple[float, float, float, float]: - """ - Convert BoundingBox2D from center format to corner format. - - Args: - bbox_2d: BoundingBox2D with center and size - - Returns: - Tuple of (x1, y1, x2, y2) corner coordinates - """ - center_x = bbox_2d.center.position.x - center_y = bbox_2d.center.position.y - half_width = bbox_2d.size_x / 2.0 - half_height = bbox_2d.size_y / 2.0 - - x1 = center_x - half_width - y1 = center_y - half_height - x2 = center_x + half_width - y2 = center_y + half_height - - return x1, y1, x2, y2 - - -def find_clicked_detection( - click_pos: Tuple[int, int], detections_2d: List[Detection2D], detections_3d: List[Detection3D] -) -> Optional[Detection3D]: - """ - Find which detection was clicked based on 2D bounding boxes. - - Args: - click_pos: (x, y) click position - detections_2d: List of Detection2D objects - detections_3d: List of Detection3D objects (must be 1:1 correspondence) - - Returns: - Corresponding Detection3D object if found, None otherwise - """ - click_x, click_y = click_pos - - for i, det_2d in enumerate(detections_2d): - if det_2d.bbox and i < len(detections_3d): - x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) - - if x1 <= click_x <= x2 and y1 <= click_y <= y2: - return detections_3d[i] - - return None - - -def get_detection2d_for_detection3d( +def match_detection_by_id( detection_3d: Detection3D, detections_3d: List[Detection3D], detections_2d: List[Detection2D] ) -> Optional[Detection2D]: """ diff --git a/dimos/perception/common/utils.py b/dimos/perception/common/utils.py index fc50e042ad..ce2a358661 100644 --- a/dimos/perception/common/utils.py +++ b/dimos/perception/common/utils.py @@ -18,6 +18,7 @@ from dimos.types.manipulation import ObjectData from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger +from dimos_lcm.vision_msgs import Detection3D, Detection2D, BoundingBox2D import torch logger = setup_logger("dimos.perception.common.utils") @@ -347,18 +348,50 @@ def point_in_bbox(point: Tuple[int, int], bbox: List[float]) -> bool: return x1 <= x <= x2 and y1 <= y <= y2 -def find_clicked_object(click_point: Tuple[int, int], objects: List[Any]) -> Optional[Any]: +def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> Tuple[float, float, float, float]: """ - Find which object was clicked based on bounding boxes. + Convert BoundingBox2D from center format to corner format. Args: - click_point: (x, y) coordinates of mouse click - objects: List of objects with 'bbox' field + bbox_2d: BoundingBox2D with center and size Returns: - Clicked object or None + Tuple of (x1, y1, x2, y2) corner coordinates """ - for obj in objects: - if "bbox" in obj and point_in_bbox(click_point, obj["bbox"]): - return obj + center_x = bbox_2d.center.position.x + center_y = bbox_2d.center.position.y + half_width = bbox_2d.size_x / 2.0 + half_height = bbox_2d.size_y / 2.0 + + x1 = center_x - half_width + y1 = center_y - half_height + x2 = center_x + half_width + y2 = center_y + half_height + + return x1, y1, x2, y2 + + +def find_clicked_detection( + click_pos: Tuple[int, int], detections_2d: List[Detection2D], detections_3d: List[Detection3D] +) -> Optional[Detection3D]: + """ + Find which detection was clicked based on 2D bounding boxes. + + Args: + click_pos: (x, y) click position + detections_2d: List of Detection2D objects + detections_3d: List of Detection3D objects (must be 1:1 correspondence) + + Returns: + Corresponding Detection3D object if found, None otherwise + """ + click_x, click_y = click_pos + + for i, det_2d in enumerate(detections_2d): + if det_2d.bbox and i < len(detections_3d): + x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) + + if x1 <= click_x <= x2 and y1 <= click_y <= y2: + return detections_3d[i] + return None diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index 6738058dca..33774ad030 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -30,11 +30,13 @@ from dimos.hardware.piper_arm import PiperArm from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor from dimos.manipulation.visual_servoing.pbvs import PBVS, GraspStage -from dimos.manipulation.visual_servoing.utils import ( +from dimos.perception.common.utils import ( find_clicked_detection, - get_detection2d_for_detection3d, bbox2d_to_corners, ) +from dimos.manipulation.visual_servoing.utils import ( + match_detection_by_id, +) from dimos.utils.transform_utils import ( pose_to_matrix, matrix_to_pose, @@ -249,7 +251,7 @@ def main(): # Highlight target current_target = pbvs.get_current_target() if target_tracked and current_target: - det_2d = get_detection2d_for_detection3d( + det_2d = match_detection_by_id( current_target, detection_3d_array.detections, detection_2d_array.detections ) if det_2d and det_2d.bbox: From bebc9fa8e77bf37fdabdd3f7696d48f7a6d15775 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Fri, 18 Jul 2025 00:53:45 -0700 Subject: [PATCH 68/89] remove build directory --- It | 0 build/lib/dimos/__init__.py | 1 - build/lib/dimos/agents/__init__.py | 0 build/lib/dimos/agents/agent.py | 904 ---------- build/lib/dimos/agents/agent_config.py | 55 - .../dimos/agents/agent_ctransformers_gguf.py | 210 --- .../dimos/agents/agent_huggingface_local.py | 235 --- .../dimos/agents/agent_huggingface_remote.py | 143 -- build/lib/dimos/agents/cerebras_agent.py | 608 ------- build/lib/dimos/agents/claude_agent.py | 735 --------- build/lib/dimos/agents/memory/__init__.py | 0 build/lib/dimos/agents/memory/base.py | 133 -- build/lib/dimos/agents/memory/chroma_impl.py | 167 -- .../dimos/agents/memory/image_embedding.py | 263 --- .../dimos/agents/memory/spatial_vector_db.py | 268 --- .../agents/memory/test_image_embedding.py | 212 --- .../lib/dimos/agents/memory/visual_memory.py | 182 --- build/lib/dimos/agents/planning_agent.py | 317 ---- .../dimos/agents/prompt_builder/__init__.py | 0 build/lib/dimos/agents/prompt_builder/impl.py | 221 --- build/lib/dimos/agents/tokenizer/__init__.py | 0 build/lib/dimos/agents/tokenizer/base.py | 37 - .../agents/tokenizer/huggingface_tokenizer.py | 88 - .../agents/tokenizer/openai_tokenizer.py | 88 - build/lib/dimos/core/__init__.py | 103 -- build/lib/dimos/core/colors.py | 43 - build/lib/dimos/core/core.py | 260 --- build/lib/dimos/core/module.py | 172 -- build/lib/dimos/core/o3dpickle.py | 38 - build/lib/dimos/core/test_core.py | 199 --- build/lib/dimos/core/transport.py | 102 -- build/lib/dimos/environment/__init__.py | 0 .../dimos/environment/agent_environment.py | 139 -- .../dimos/environment/colmap_environment.py | 89 - build/lib/dimos/environment/environment.py | 172 -- build/lib/dimos/exceptions/__init__.py | 0 .../exceptions/agent_memory_exceptions.py | 89 - build/lib/dimos/hardware/__init__.py | 0 build/lib/dimos/hardware/camera.py | 52 - build/lib/dimos/hardware/end_effector.py | 21 - build/lib/dimos/hardware/interface.py | 51 - build/lib/dimos/hardware/piper_arm.py | 372 ----- build/lib/dimos/hardware/sensor.py | 35 - build/lib/dimos/hardware/stereo_camera.py | 26 - .../dimos/hardware/test_simple_module(1).py | 90 - build/lib/dimos/hardware/ufactory.py | 32 - build/lib/dimos/hardware/zed_camera.py | 514 ------ build/lib/dimos/manipulation/__init__.py | 0 .../dimos/manipulation/manip_aio_pipeline.py | 590 ------- .../dimos/manipulation/manip_aio_processer.py | 411 ----- .../manipulation/manipulation_history.py | 418 ----- .../manipulation/manipulation_interface.py | 292 ---- .../manipulation/test_manipulation_history.py | 461 ------ build/lib/dimos/models/__init__.py | 0 build/lib/dimos/models/depth/__init__.py | 0 build/lib/dimos/models/depth/metric3d.py | 173 -- build/lib/dimos/models/labels/__init__.py | 0 build/lib/dimos/models/labels/llava-34b.py | 92 -- .../lib/dimos/models/manipulation/__init__.py | 0 build/lib/dimos/models/pointcloud/__init__.py | 0 .../models/pointcloud/pointcloud_utils.py | 214 --- .../lib/dimos/models/segmentation/__init__.py | 0 .../lib/dimos/models/segmentation/clipseg.py | 32 - build/lib/dimos/models/segmentation/sam.py | 35 - .../models/segmentation/segment_utils.py | 73 - build/lib/dimos/msgs/__init__.py | 0 build/lib/dimos/msgs/geometry_msgs/Pose.py | 181 -- .../dimos/msgs/geometry_msgs/PoseStamped.py | 76 - .../dimos/msgs/geometry_msgs/Quaternion.py | 167 -- build/lib/dimos/msgs/geometry_msgs/Twist.py | 87 - build/lib/dimos/msgs/geometry_msgs/Vector3.py | 467 ------ .../lib/dimos/msgs/geometry_msgs/__init__.py | 4 - .../lib/dimos/msgs/geometry_msgs/test_Pose.py | 555 ------- .../msgs/geometry_msgs/test_Quaternion.py | 210 --- .../dimos/msgs/geometry_msgs/test_Vector3.py | 462 ------ .../dimos/msgs/geometry_msgs/test_publish.py | 54 - build/lib/dimos/msgs/sensor_msgs/Image.py | 372 ----- .../lib/dimos/msgs/sensor_msgs/PointCloud2.py | 213 --- build/lib/dimos/msgs/sensor_msgs/__init__.py | 2 - .../msgs/sensor_msgs/test_PointCloud2.py | 81 - .../lib/dimos/msgs/sensor_msgs/test_image.py | 63 - build/lib/dimos/perception/__init__.py | 0 build/lib/dimos/perception/common/__init__.py | 3 - .../lib/dimos/perception/common/cuboid_fit.py | 331 ---- .../perception/common/detection2d_tracker.py | 385 ----- .../perception/common/export_tensorrt.py | 57 - build/lib/dimos/perception/common/ibvs.py | 280 ---- build/lib/dimos/perception/common/utils.py | 364 ----- .../dimos/perception/detection2d/__init__.py | 2 - .../perception/detection2d/detic_2d_det.py | 414 ----- .../detection2d/test_yolo_2d_det.py | 177 -- .../lib/dimos/perception/detection2d/utils.py | 338 ---- .../perception/detection2d/yolo_2d_det.py | 157 -- .../perception/grasp_generation/__init__.py | 1 - .../grasp_generation/grasp_generation.py | 228 --- .../perception/grasp_generation/utils.py | 621 ------- .../perception/object_detection_stream.py | 373 ----- build/lib/dimos/perception/object_tracker.py | 357 ---- build/lib/dimos/perception/person_tracker.py | 154 -- .../dimos/perception/pointcloud/__init__.py | 3 - .../dimos/perception/pointcloud/cuboid_fit.py | 414 ----- .../pointcloud/pointcloud_filtering.py | 674 -------- .../lib/dimos/perception/pointcloud/utils.py | 1451 ----------------- .../dimos/perception/segmentation/__init__.py | 2 - .../perception/segmentation/image_analyzer.py | 161 -- .../perception/segmentation/sam_2d_seg.py | 335 ---- .../segmentation/test_sam_2d_seg.py | 214 --- .../dimos/perception/segmentation/utils.py | 315 ---- build/lib/dimos/perception/semantic_seg.py | 245 --- .../dimos/perception/spatial_perception.py | 438 ----- .../dimos/perception/test_spatial_memory.py | 214 --- build/lib/dimos/perception/visual_servoing.py | 500 ------ build/lib/dimos/robot/__init__.py | 0 build/lib/dimos/robot/connection_interface.py | 70 - build/lib/dimos/robot/foxglove_bridge.py | 49 - .../robot/frontier_exploration/__init__.py | 1 - .../qwen_frontier_predictor.py | 368 ----- .../test_wavefront_frontier_goal_selector.py | 297 ---- .../dimos/robot/frontier_exploration/utils.py | 188 --- .../wavefront_frontier_goal_selector.py | 665 -------- .../dimos/robot/global_planner/__init__.py | 1 - build/lib/dimos/robot/global_planner/algo.py | 273 ---- .../lib/dimos/robot/global_planner/planner.py | 96 -- .../lib/dimos/robot/local_planner/__init__.py | 7 - .../robot/local_planner/local_planner.py | 1442 ---------------- build/lib/dimos/robot/local_planner/simple.py | 265 --- .../robot/local_planner/vfh_local_planner.py | 435 ----- build/lib/dimos/robot/position_stream.py | 162 -- build/lib/dimos/robot/recorder.py | 159 -- build/lib/dimos/robot/robot.py | 435 ----- build/lib/dimos/robot/ros_command_queue.py | 471 ------ build/lib/dimos/robot/ros_control.py | 867 ---------- build/lib/dimos/robot/ros_observable_topic.py | 240 --- build/lib/dimos/robot/ros_transform.py | 243 --- .../dimos/robot/test_ros_observable_topic.py | 255 --- build/lib/dimos/robot/unitree/__init__.py | 0 build/lib/dimos/robot/unitree/unitree_go2.py | 208 --- .../robot/unitree/unitree_ros_control.py | 157 -- .../lib/dimos/robot/unitree/unitree_skills.py | 314 ---- .../dimos/robot/unitree_webrtc/__init__.py | 0 .../dimos/robot/unitree_webrtc/connection.py | 309 ---- .../robot/unitree_webrtc/testing/__init__.py | 0 .../robot/unitree_webrtc/testing/helpers.py | 168 -- .../robot/unitree_webrtc/testing/mock.py | 91 -- .../robot/unitree_webrtc/testing/multimock.py | 142 -- .../robot/unitree_webrtc/testing/test_mock.py | 62 - .../unitree_webrtc/testing/test_multimock.py | 111 -- .../robot/unitree_webrtc/type/__init__.py | 0 .../dimos/robot/unitree_webrtc/type/lidar.py | 138 -- .../robot/unitree_webrtc/type/lowstate.py | 93 -- .../dimos/robot/unitree_webrtc/type/map.py | 150 -- .../robot/unitree_webrtc/type/odometry.py | 108 -- .../robot/unitree_webrtc/type/test_lidar.py | 51 - .../robot/unitree_webrtc/type/test_map.py | 80 - .../unitree_webrtc/type/test_odometry.py | 109 -- .../unitree_webrtc/type/test_timeseries.py | 44 - .../robot/unitree_webrtc/type/timeseries.py | 146 -- .../dimos/robot/unitree_webrtc/type/vector.py | 448 ----- .../dimos/robot/unitree_webrtc/unitree_go2.py | 224 --- .../robot/unitree_webrtc/unitree_skills.py | 279 ---- build/lib/dimos/simulation/__init__.py | 15 - build/lib/dimos/simulation/base/__init__.py | 0 .../dimos/simulation/base/simulator_base.py | 48 - .../lib/dimos/simulation/base/stream_base.py | 116 -- .../lib/dimos/simulation/genesis/__init__.py | 4 - .../lib/dimos/simulation/genesis/simulator.py | 158 -- build/lib/dimos/simulation/genesis/stream.py | 143 -- build/lib/dimos/simulation/isaac/__init__.py | 4 - build/lib/dimos/simulation/isaac/simulator.py | 43 - build/lib/dimos/simulation/isaac/stream.py | 136 -- build/lib/dimos/skills/__init__.py | 0 build/lib/dimos/skills/kill_skill.py | 62 - build/lib/dimos/skills/navigation.py | 699 -------- build/lib/dimos/skills/observe.py | 192 --- build/lib/dimos/skills/observe_stream.py | 245 --- build/lib/dimos/skills/rest/__init__.py | 0 build/lib/dimos/skills/rest/rest.py | 99 -- build/lib/dimos/skills/skills.py | 340 ---- build/lib/dimos/skills/speak.py | 166 -- build/lib/dimos/skills/unitree/__init__.py | 1 - .../lib/dimos/skills/unitree/unitree_speak.py | 280 ---- .../dimos/skills/visual_navigation_skills.py | 148 -- build/lib/dimos/stream/__init__.py | 0 build/lib/dimos/stream/audio/__init__.py | 0 build/lib/dimos/stream/audio/base.py | 114 -- .../dimos/stream/audio/node_key_recorder.py | 336 ---- .../lib/dimos/stream/audio/node_microphone.py | 131 -- .../lib/dimos/stream/audio/node_normalizer.py | 220 --- build/lib/dimos/stream/audio/node_output.py | 187 --- .../lib/dimos/stream/audio/node_simulated.py | 221 --- .../dimos/stream/audio/node_volume_monitor.py | 176 -- build/lib/dimos/stream/audio/pipelines.py | 52 - build/lib/dimos/stream/audio/utils.py | 26 - build/lib/dimos/stream/audio/volume.py | 108 -- build/lib/dimos/stream/data_provider.py | 183 --- build/lib/dimos/stream/frame_processor.py | 300 ---- build/lib/dimos/stream/ros_video_provider.py | 112 -- build/lib/dimos/stream/rtsp_video_provider.py | 380 ----- build/lib/dimos/stream/stream_merger.py | 45 - build/lib/dimos/stream/video_operators.py | 622 ------- build/lib/dimos/stream/video_provider.py | 235 --- .../dimos/stream/video_providers/__init__.py | 0 .../dimos/stream/video_providers/unitree.py | 167 -- build/lib/dimos/stream/videostream.py | 41 - build/lib/dimos/types/__init__.py | 0 build/lib/dimos/types/constants.py | 24 - build/lib/dimos/types/costmap.py | 534 ------ build/lib/dimos/types/label.py | 39 - build/lib/dimos/types/manipulation.py | 155 -- build/lib/dimos/types/path.py | 414 ----- build/lib/dimos/types/pose.py | 149 -- build/lib/dimos/types/robot_capabilities.py | 27 - build/lib/dimos/types/robot_location.py | 130 -- build/lib/dimos/types/ros_polyfill.py | 103 -- build/lib/dimos/types/sample.py | 572 ------- build/lib/dimos/types/segmentation.py | 44 - build/lib/dimos/types/test_pose.py | 323 ---- build/lib/dimos/types/test_timestamped.py | 26 - build/lib/dimos/types/test_vector.py | 384 ----- build/lib/dimos/types/timestamped.py | 55 - build/lib/dimos/types/vector.py | 460 ------ build/lib/dimos/web/__init__.py | 0 .../lib/dimos/web/dimos_interface/__init__.py | 7 - .../dimos/web/dimos_interface/api/__init__.py | 0 .../dimos/web/dimos_interface/api/server.py | 362 ---- build/lib/dimos/web/edge_io.py | 26 - build/lib/dimos/web/fastapi_server.py | 224 --- build/lib/dimos/web/flask_server.py | 95 -- build/lib/dimos/web/robot_web_interface.py | 35 - build/lib/tests/__init__.py | 1 - .../tests/agent_manip_flow_fastapi_test.py | 153 -- .../lib/tests/agent_manip_flow_flask_test.py | 195 --- build/lib/tests/agent_memory_test.py | 61 - build/lib/tests/colmap_test.py | 25 - build/lib/tests/run.py | 361 ---- build/lib/tests/run_go2_ros.py | 178 -- build/lib/tests/run_navigation_only.py | 191 --- build/lib/tests/simple_agent_test.py | 39 - build/lib/tests/test_agent.py | 60 - build/lib/tests/test_agent_alibaba.py | 59 - .../tests/test_agent_ctransformers_gguf.py | 44 - .../lib/tests/test_agent_huggingface_local.py | 72 - .../test_agent_huggingface_local_jetson.py | 73 - .../tests/test_agent_huggingface_remote.py | 64 - build/lib/tests/test_audio_agent.py | 39 - build/lib/tests/test_audio_robot_agent.py | 51 - build/lib/tests/test_cerebras_unitree_ros.py | 118 -- build/lib/tests/test_claude_agent_query.py | 29 - .../tests/test_claude_agent_skills_query.py | 135 -- build/lib/tests/test_command_pose_unitree.py | 82 - build/lib/tests/test_header.py | 58 - build/lib/tests/test_huggingface_llm_agent.py | 116 -- build/lib/tests/test_ibvs.py | 229 --- build/lib/tests/test_manipulation_agent.py | 346 ---- ...est_manipulation_perception_pipeline.py.py | 167 -- ...test_manipulation_pipeline_single_frame.py | 248 --- ..._manipulation_pipeline_single_frame_lcm.py | 431 ----- build/lib/tests/test_move_vel_unitree.py | 32 - .../tests/test_navigate_to_object_robot.py | 137 -- build/lib/tests/test_navigation_skills.py | 269 --- ...bject_detection_agent_data_query_stream.py | 191 --- .../lib/tests/test_object_detection_stream.py | 240 --- .../lib/tests/test_object_tracking_webcam.py | 222 --- .../tests/test_object_tracking_with_qwen.py | 216 --- build/lib/tests/test_observe_stream_skill.py | 131 -- .../lib/tests/test_person_following_robot.py | 113 -- .../lib/tests/test_person_following_webcam.py | 230 --- .../test_planning_agent_web_interface.py | 180 -- build/lib/tests/test_planning_robot_agent.py | 177 -- build/lib/tests/test_pointcloud_filtering.py | 105 -- build/lib/tests/test_qwen_image_query.py | 49 - build/lib/tests/test_robot.py | 86 - build/lib/tests/test_rtsp_video_provider.py | 146 -- build/lib/tests/test_semantic_seg_robot.py | 151 -- .../tests/test_semantic_seg_robot_agent.py | 141 -- build/lib/tests/test_semantic_seg_webcam.py | 140 -- build/lib/tests/test_skills.py | 185 --- build/lib/tests/test_skills_rest.py | 73 - build/lib/tests/test_spatial_memory.py | 297 ---- build/lib/tests/test_spatial_memory_query.py | 297 ---- build/lib/tests/test_standalone_chromadb.py | 87 - build/lib/tests/test_standalone_fastapi.py | 81 - .../lib/tests/test_standalone_hugging_face.py | 147 -- .../lib/tests/test_standalone_openai_json.py | 108 -- .../test_standalone_openai_json_struct.py | 92 -- ...test_standalone_openai_json_struct_func.py | 177 -- ...lone_openai_json_struct_func_playground.py | 222 --- .../lib/tests/test_standalone_project_out.py | 141 -- build/lib/tests/test_standalone_rxpy_01.py | 133 -- build/lib/tests/test_unitree_agent.py | 318 ---- .../test_unitree_agent_queries_fastapi.py | 105 -- build/lib/tests/test_unitree_ros_v0.0.4.py | 198 --- build/lib/tests/test_webrtc_queue.py | 156 -- build/lib/tests/test_websocketvis.py | 152 -- build/lib/tests/test_zed_setup.py | 182 --- build/lib/tests/visualization_script.py | 1041 ------------ build/lib/tests/zed_neural_depth_demo.py | 450 ----- 297 files changed, 54846 deletions(-) delete mode 100644 It delete mode 100644 build/lib/dimos/__init__.py delete mode 100644 build/lib/dimos/agents/__init__.py delete mode 100644 build/lib/dimos/agents/agent.py delete mode 100644 build/lib/dimos/agents/agent_config.py delete mode 100644 build/lib/dimos/agents/agent_ctransformers_gguf.py delete mode 100644 build/lib/dimos/agents/agent_huggingface_local.py delete mode 100644 build/lib/dimos/agents/agent_huggingface_remote.py delete mode 100644 build/lib/dimos/agents/cerebras_agent.py delete mode 100644 build/lib/dimos/agents/claude_agent.py delete mode 100644 build/lib/dimos/agents/memory/__init__.py delete mode 100644 build/lib/dimos/agents/memory/base.py delete mode 100644 build/lib/dimos/agents/memory/chroma_impl.py delete mode 100644 build/lib/dimos/agents/memory/image_embedding.py delete mode 100644 build/lib/dimos/agents/memory/spatial_vector_db.py delete mode 100644 build/lib/dimos/agents/memory/test_image_embedding.py delete mode 100644 build/lib/dimos/agents/memory/visual_memory.py delete mode 100644 build/lib/dimos/agents/planning_agent.py delete mode 100644 build/lib/dimos/agents/prompt_builder/__init__.py delete mode 100644 build/lib/dimos/agents/prompt_builder/impl.py delete mode 100644 build/lib/dimos/agents/tokenizer/__init__.py delete mode 100644 build/lib/dimos/agents/tokenizer/base.py delete mode 100644 build/lib/dimos/agents/tokenizer/huggingface_tokenizer.py delete mode 100644 build/lib/dimos/agents/tokenizer/openai_tokenizer.py delete mode 100644 build/lib/dimos/core/__init__.py delete mode 100644 build/lib/dimos/core/colors.py delete mode 100644 build/lib/dimos/core/core.py delete mode 100644 build/lib/dimos/core/module.py delete mode 100644 build/lib/dimos/core/o3dpickle.py delete mode 100644 build/lib/dimos/core/test_core.py delete mode 100644 build/lib/dimos/core/transport.py delete mode 100644 build/lib/dimos/environment/__init__.py delete mode 100644 build/lib/dimos/environment/agent_environment.py delete mode 100644 build/lib/dimos/environment/colmap_environment.py delete mode 100644 build/lib/dimos/environment/environment.py delete mode 100644 build/lib/dimos/exceptions/__init__.py delete mode 100644 build/lib/dimos/exceptions/agent_memory_exceptions.py delete mode 100644 build/lib/dimos/hardware/__init__.py delete mode 100644 build/lib/dimos/hardware/camera.py delete mode 100644 build/lib/dimos/hardware/end_effector.py delete mode 100644 build/lib/dimos/hardware/interface.py delete mode 100644 build/lib/dimos/hardware/piper_arm.py delete mode 100644 build/lib/dimos/hardware/sensor.py delete mode 100644 build/lib/dimos/hardware/stereo_camera.py delete mode 100644 build/lib/dimos/hardware/test_simple_module(1).py delete mode 100644 build/lib/dimos/hardware/ufactory.py delete mode 100644 build/lib/dimos/hardware/zed_camera.py delete mode 100644 build/lib/dimos/manipulation/__init__.py delete mode 100644 build/lib/dimos/manipulation/manip_aio_pipeline.py delete mode 100644 build/lib/dimos/manipulation/manip_aio_processer.py delete mode 100644 build/lib/dimos/manipulation/manipulation_history.py delete mode 100644 build/lib/dimos/manipulation/manipulation_interface.py delete mode 100644 build/lib/dimos/manipulation/test_manipulation_history.py delete mode 100644 build/lib/dimos/models/__init__.py delete mode 100644 build/lib/dimos/models/depth/__init__.py delete mode 100644 build/lib/dimos/models/depth/metric3d.py delete mode 100644 build/lib/dimos/models/labels/__init__.py delete mode 100644 build/lib/dimos/models/labels/llava-34b.py delete mode 100644 build/lib/dimos/models/manipulation/__init__.py delete mode 100644 build/lib/dimos/models/pointcloud/__init__.py delete mode 100644 build/lib/dimos/models/pointcloud/pointcloud_utils.py delete mode 100644 build/lib/dimos/models/segmentation/__init__.py delete mode 100644 build/lib/dimos/models/segmentation/clipseg.py delete mode 100644 build/lib/dimos/models/segmentation/sam.py delete mode 100644 build/lib/dimos/models/segmentation/segment_utils.py delete mode 100644 build/lib/dimos/msgs/__init__.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/Pose.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/PoseStamped.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/Quaternion.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/Twist.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/Vector3.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/__init__.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/test_Pose.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/test_Quaternion.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/test_Vector3.py delete mode 100644 build/lib/dimos/msgs/geometry_msgs/test_publish.py delete mode 100644 build/lib/dimos/msgs/sensor_msgs/Image.py delete mode 100644 build/lib/dimos/msgs/sensor_msgs/PointCloud2.py delete mode 100644 build/lib/dimos/msgs/sensor_msgs/__init__.py delete mode 100644 build/lib/dimos/msgs/sensor_msgs/test_PointCloud2.py delete mode 100644 build/lib/dimos/msgs/sensor_msgs/test_image.py delete mode 100644 build/lib/dimos/perception/__init__.py delete mode 100644 build/lib/dimos/perception/common/__init__.py delete mode 100644 build/lib/dimos/perception/common/cuboid_fit.py delete mode 100644 build/lib/dimos/perception/common/detection2d_tracker.py delete mode 100644 build/lib/dimos/perception/common/export_tensorrt.py delete mode 100644 build/lib/dimos/perception/common/ibvs.py delete mode 100644 build/lib/dimos/perception/common/utils.py delete mode 100644 build/lib/dimos/perception/detection2d/__init__.py delete mode 100644 build/lib/dimos/perception/detection2d/detic_2d_det.py delete mode 100644 build/lib/dimos/perception/detection2d/test_yolo_2d_det.py delete mode 100644 build/lib/dimos/perception/detection2d/utils.py delete mode 100644 build/lib/dimos/perception/detection2d/yolo_2d_det.py delete mode 100644 build/lib/dimos/perception/grasp_generation/__init__.py delete mode 100644 build/lib/dimos/perception/grasp_generation/grasp_generation.py delete mode 100644 build/lib/dimos/perception/grasp_generation/utils.py delete mode 100644 build/lib/dimos/perception/object_detection_stream.py delete mode 100644 build/lib/dimos/perception/object_tracker.py delete mode 100644 build/lib/dimos/perception/person_tracker.py delete mode 100644 build/lib/dimos/perception/pointcloud/__init__.py delete mode 100644 build/lib/dimos/perception/pointcloud/cuboid_fit.py delete mode 100644 build/lib/dimos/perception/pointcloud/pointcloud_filtering.py delete mode 100644 build/lib/dimos/perception/pointcloud/utils.py delete mode 100644 build/lib/dimos/perception/segmentation/__init__.py delete mode 100644 build/lib/dimos/perception/segmentation/image_analyzer.py delete mode 100644 build/lib/dimos/perception/segmentation/sam_2d_seg.py delete mode 100644 build/lib/dimos/perception/segmentation/test_sam_2d_seg.py delete mode 100644 build/lib/dimos/perception/segmentation/utils.py delete mode 100644 build/lib/dimos/perception/semantic_seg.py delete mode 100644 build/lib/dimos/perception/spatial_perception.py delete mode 100644 build/lib/dimos/perception/test_spatial_memory.py delete mode 100644 build/lib/dimos/perception/visual_servoing.py delete mode 100644 build/lib/dimos/robot/__init__.py delete mode 100644 build/lib/dimos/robot/connection_interface.py delete mode 100644 build/lib/dimos/robot/foxglove_bridge.py delete mode 100644 build/lib/dimos/robot/frontier_exploration/__init__.py delete mode 100644 build/lib/dimos/robot/frontier_exploration/qwen_frontier_predictor.py delete mode 100644 build/lib/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py delete mode 100644 build/lib/dimos/robot/frontier_exploration/utils.py delete mode 100644 build/lib/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py delete mode 100644 build/lib/dimos/robot/global_planner/__init__.py delete mode 100644 build/lib/dimos/robot/global_planner/algo.py delete mode 100644 build/lib/dimos/robot/global_planner/planner.py delete mode 100644 build/lib/dimos/robot/local_planner/__init__.py delete mode 100644 build/lib/dimos/robot/local_planner/local_planner.py delete mode 100644 build/lib/dimos/robot/local_planner/simple.py delete mode 100644 build/lib/dimos/robot/local_planner/vfh_local_planner.py delete mode 100644 build/lib/dimos/robot/position_stream.py delete mode 100644 build/lib/dimos/robot/recorder.py delete mode 100644 build/lib/dimos/robot/robot.py delete mode 100644 build/lib/dimos/robot/ros_command_queue.py delete mode 100644 build/lib/dimos/robot/ros_control.py delete mode 100644 build/lib/dimos/robot/ros_observable_topic.py delete mode 100644 build/lib/dimos/robot/ros_transform.py delete mode 100644 build/lib/dimos/robot/test_ros_observable_topic.py delete mode 100644 build/lib/dimos/robot/unitree/__init__.py delete mode 100644 build/lib/dimos/robot/unitree/unitree_go2.py delete mode 100644 build/lib/dimos/robot/unitree/unitree_ros_control.py delete mode 100644 build/lib/dimos/robot/unitree/unitree_skills.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/__init__.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/connection.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/__init__.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/helpers.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/mock.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/multimock.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/test_mock.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/testing/test_multimock.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/__init__.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/lidar.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/lowstate.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/map.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/odometry.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/test_lidar.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/test_map.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/test_odometry.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/test_timeseries.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/timeseries.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/type/vector.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/unitree_go2.py delete mode 100644 build/lib/dimos/robot/unitree_webrtc/unitree_skills.py delete mode 100644 build/lib/dimos/simulation/__init__.py delete mode 100644 build/lib/dimos/simulation/base/__init__.py delete mode 100644 build/lib/dimos/simulation/base/simulator_base.py delete mode 100644 build/lib/dimos/simulation/base/stream_base.py delete mode 100644 build/lib/dimos/simulation/genesis/__init__.py delete mode 100644 build/lib/dimos/simulation/genesis/simulator.py delete mode 100644 build/lib/dimos/simulation/genesis/stream.py delete mode 100644 build/lib/dimos/simulation/isaac/__init__.py delete mode 100644 build/lib/dimos/simulation/isaac/simulator.py delete mode 100644 build/lib/dimos/simulation/isaac/stream.py delete mode 100644 build/lib/dimos/skills/__init__.py delete mode 100644 build/lib/dimos/skills/kill_skill.py delete mode 100644 build/lib/dimos/skills/navigation.py delete mode 100644 build/lib/dimos/skills/observe.py delete mode 100644 build/lib/dimos/skills/observe_stream.py delete mode 100644 build/lib/dimos/skills/rest/__init__.py delete mode 100644 build/lib/dimos/skills/rest/rest.py delete mode 100644 build/lib/dimos/skills/skills.py delete mode 100644 build/lib/dimos/skills/speak.py delete mode 100644 build/lib/dimos/skills/unitree/__init__.py delete mode 100644 build/lib/dimos/skills/unitree/unitree_speak.py delete mode 100644 build/lib/dimos/skills/visual_navigation_skills.py delete mode 100644 build/lib/dimos/stream/__init__.py delete mode 100644 build/lib/dimos/stream/audio/__init__.py delete mode 100644 build/lib/dimos/stream/audio/base.py delete mode 100644 build/lib/dimos/stream/audio/node_key_recorder.py delete mode 100644 build/lib/dimos/stream/audio/node_microphone.py delete mode 100644 build/lib/dimos/stream/audio/node_normalizer.py delete mode 100644 build/lib/dimos/stream/audio/node_output.py delete mode 100644 build/lib/dimos/stream/audio/node_simulated.py delete mode 100644 build/lib/dimos/stream/audio/node_volume_monitor.py delete mode 100644 build/lib/dimos/stream/audio/pipelines.py delete mode 100644 build/lib/dimos/stream/audio/utils.py delete mode 100644 build/lib/dimos/stream/audio/volume.py delete mode 100644 build/lib/dimos/stream/data_provider.py delete mode 100644 build/lib/dimos/stream/frame_processor.py delete mode 100644 build/lib/dimos/stream/ros_video_provider.py delete mode 100644 build/lib/dimos/stream/rtsp_video_provider.py delete mode 100644 build/lib/dimos/stream/stream_merger.py delete mode 100644 build/lib/dimos/stream/video_operators.py delete mode 100644 build/lib/dimos/stream/video_provider.py delete mode 100644 build/lib/dimos/stream/video_providers/__init__.py delete mode 100644 build/lib/dimos/stream/video_providers/unitree.py delete mode 100644 build/lib/dimos/stream/videostream.py delete mode 100644 build/lib/dimos/types/__init__.py delete mode 100644 build/lib/dimos/types/constants.py delete mode 100644 build/lib/dimos/types/costmap.py delete mode 100644 build/lib/dimos/types/label.py delete mode 100644 build/lib/dimos/types/manipulation.py delete mode 100644 build/lib/dimos/types/path.py delete mode 100644 build/lib/dimos/types/pose.py delete mode 100644 build/lib/dimos/types/robot_capabilities.py delete mode 100644 build/lib/dimos/types/robot_location.py delete mode 100644 build/lib/dimos/types/ros_polyfill.py delete mode 100644 build/lib/dimos/types/sample.py delete mode 100644 build/lib/dimos/types/segmentation.py delete mode 100644 build/lib/dimos/types/test_pose.py delete mode 100644 build/lib/dimos/types/test_timestamped.py delete mode 100644 build/lib/dimos/types/test_vector.py delete mode 100644 build/lib/dimos/types/timestamped.py delete mode 100644 build/lib/dimos/types/vector.py delete mode 100644 build/lib/dimos/web/__init__.py delete mode 100644 build/lib/dimos/web/dimos_interface/__init__.py delete mode 100644 build/lib/dimos/web/dimos_interface/api/__init__.py delete mode 100644 build/lib/dimos/web/dimos_interface/api/server.py delete mode 100644 build/lib/dimos/web/edge_io.py delete mode 100644 build/lib/dimos/web/fastapi_server.py delete mode 100644 build/lib/dimos/web/flask_server.py delete mode 100644 build/lib/dimos/web/robot_web_interface.py delete mode 100644 build/lib/tests/__init__.py delete mode 100644 build/lib/tests/agent_manip_flow_fastapi_test.py delete mode 100644 build/lib/tests/agent_manip_flow_flask_test.py delete mode 100644 build/lib/tests/agent_memory_test.py delete mode 100644 build/lib/tests/colmap_test.py delete mode 100644 build/lib/tests/run.py delete mode 100644 build/lib/tests/run_go2_ros.py delete mode 100644 build/lib/tests/run_navigation_only.py delete mode 100644 build/lib/tests/simple_agent_test.py delete mode 100644 build/lib/tests/test_agent.py delete mode 100644 build/lib/tests/test_agent_alibaba.py delete mode 100644 build/lib/tests/test_agent_ctransformers_gguf.py delete mode 100644 build/lib/tests/test_agent_huggingface_local.py delete mode 100644 build/lib/tests/test_agent_huggingface_local_jetson.py delete mode 100644 build/lib/tests/test_agent_huggingface_remote.py delete mode 100644 build/lib/tests/test_audio_agent.py delete mode 100644 build/lib/tests/test_audio_robot_agent.py delete mode 100644 build/lib/tests/test_cerebras_unitree_ros.py delete mode 100644 build/lib/tests/test_claude_agent_query.py delete mode 100644 build/lib/tests/test_claude_agent_skills_query.py delete mode 100644 build/lib/tests/test_command_pose_unitree.py delete mode 100644 build/lib/tests/test_header.py delete mode 100644 build/lib/tests/test_huggingface_llm_agent.py delete mode 100644 build/lib/tests/test_ibvs.py delete mode 100644 build/lib/tests/test_manipulation_agent.py delete mode 100644 build/lib/tests/test_manipulation_perception_pipeline.py.py delete mode 100644 build/lib/tests/test_manipulation_pipeline_single_frame.py delete mode 100644 build/lib/tests/test_manipulation_pipeline_single_frame_lcm.py delete mode 100644 build/lib/tests/test_move_vel_unitree.py delete mode 100644 build/lib/tests/test_navigate_to_object_robot.py delete mode 100644 build/lib/tests/test_navigation_skills.py delete mode 100644 build/lib/tests/test_object_detection_agent_data_query_stream.py delete mode 100644 build/lib/tests/test_object_detection_stream.py delete mode 100644 build/lib/tests/test_object_tracking_webcam.py delete mode 100644 build/lib/tests/test_object_tracking_with_qwen.py delete mode 100644 build/lib/tests/test_observe_stream_skill.py delete mode 100644 build/lib/tests/test_person_following_robot.py delete mode 100644 build/lib/tests/test_person_following_webcam.py delete mode 100644 build/lib/tests/test_planning_agent_web_interface.py delete mode 100644 build/lib/tests/test_planning_robot_agent.py delete mode 100644 build/lib/tests/test_pointcloud_filtering.py delete mode 100644 build/lib/tests/test_qwen_image_query.py delete mode 100644 build/lib/tests/test_robot.py delete mode 100644 build/lib/tests/test_rtsp_video_provider.py delete mode 100644 build/lib/tests/test_semantic_seg_robot.py delete mode 100644 build/lib/tests/test_semantic_seg_robot_agent.py delete mode 100644 build/lib/tests/test_semantic_seg_webcam.py delete mode 100644 build/lib/tests/test_skills.py delete mode 100644 build/lib/tests/test_skills_rest.py delete mode 100644 build/lib/tests/test_spatial_memory.py delete mode 100644 build/lib/tests/test_spatial_memory_query.py delete mode 100644 build/lib/tests/test_standalone_chromadb.py delete mode 100644 build/lib/tests/test_standalone_fastapi.py delete mode 100644 build/lib/tests/test_standalone_hugging_face.py delete mode 100644 build/lib/tests/test_standalone_openai_json.py delete mode 100644 build/lib/tests/test_standalone_openai_json_struct.py delete mode 100644 build/lib/tests/test_standalone_openai_json_struct_func.py delete mode 100644 build/lib/tests/test_standalone_openai_json_struct_func_playground.py delete mode 100644 build/lib/tests/test_standalone_project_out.py delete mode 100644 build/lib/tests/test_standalone_rxpy_01.py delete mode 100644 build/lib/tests/test_unitree_agent.py delete mode 100644 build/lib/tests/test_unitree_agent_queries_fastapi.py delete mode 100644 build/lib/tests/test_unitree_ros_v0.0.4.py delete mode 100644 build/lib/tests/test_webrtc_queue.py delete mode 100644 build/lib/tests/test_websocketvis.py delete mode 100644 build/lib/tests/test_zed_setup.py delete mode 100644 build/lib/tests/visualization_script.py delete mode 100644 build/lib/tests/zed_neural_depth_demo.py diff --git a/It b/It deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/__init__.py b/build/lib/dimos/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/build/lib/dimos/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/build/lib/dimos/agents/__init__.py b/build/lib/dimos/agents/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/agents/agent.py b/build/lib/dimos/agents/agent.py deleted file mode 100644 index 1ce2216fe7..0000000000 --- a/build/lib/dimos/agents/agent.py +++ /dev/null @@ -1,904 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Agent framework for LLM-based autonomous systems. - -This module provides a flexible foundation for creating agents that can: -- Process image and text inputs through LLM APIs -- Store and retrieve contextual information using semantic memory -- Handle tool/function calling -- Process streaming inputs asynchronously - -The module offers base classes (Agent, LLMAgent) and concrete implementations -like OpenAIAgent that connect to specific LLM providers. -""" - -from __future__ import annotations - -# Standard library imports -import json -import os -import threading -from typing import Any, Tuple, Optional, Union - -# Third-party imports -from dotenv import load_dotenv -from openai import NOT_GIVEN, OpenAI -from pydantic import BaseModel -from reactivex import Observer, create, Observable, empty, operators as RxOps, just -from reactivex.disposable import CompositeDisposable, Disposable -from reactivex.scheduler import ThreadPoolScheduler -from reactivex.subject import Subject - -# Local imports -from dimos.agents.memory.base import AbstractAgentSemanticMemory -from dimos.agents.memory.chroma_impl import OpenAISemanticMemory -from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.agents.tokenizer.base import AbstractTokenizer -from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer -from dimos.skills.skills import AbstractSkill, SkillLibrary -from dimos.stream.frame_processor import FrameProcessor -from dimos.stream.stream_merger import create_stream_merger -from dimos.stream.video_operators import Operators as MyOps, VideoOperators as MyVidOps -from dimos.utils.threadpool import get_scheduler -from dimos.utils.logging_config import setup_logger - -# Initialize environment variables -load_dotenv() - -# Initialize logger for the agent module -logger = setup_logger("dimos.agents") - -# Constants -_TOKEN_BUDGET_PARTS = 4 # Number of parts to divide token budget -_MAX_SAVED_FRAMES = 100 # Maximum number of frames to save - - -# ----------------------------------------------------------------------------- -# region Agent Base Class -# ----------------------------------------------------------------------------- -class Agent: - """Base agent that manages memory and subscriptions.""" - - def __init__( - self, - dev_name: str = "NA", - agent_type: str = "Base", - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - pool_scheduler: Optional[ThreadPoolScheduler] = None, - ): - """ - Initializes a new instance of the Agent. - - Args: - dev_name (str): The device name of the agent. - agent_type (str): The type of the agent (e.g., 'Base', 'Vision'). - agent_memory (AbstractAgentSemanticMemory): The memory system for the agent. - pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. - If None, the global scheduler from get_scheduler() will be used. - """ - self.dev_name = dev_name - self.agent_type = agent_type - self.agent_memory = agent_memory or OpenAISemanticMemory() - self.disposables = CompositeDisposable() - self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() - - def dispose_all(self): - """Disposes of all active subscriptions managed by this agent.""" - if self.disposables: - self.disposables.dispose() - else: - logger.info("No disposables to dispose.") - - -# endregion Agent Base Class - - -# ----------------------------------------------------------------------------- -# region LLMAgent Base Class (Generic LLM Agent) -# ----------------------------------------------------------------------------- -class LLMAgent(Agent): - """Generic LLM agent containing common logic for LLM-based agents. - - This class implements functionality for: - - Updating the query - - Querying the agent's memory (for RAG) - - Building prompts via a prompt builder - - Handling tooling callbacks in responses - - Subscribing to image and query streams - - Emitting responses as an observable stream - - Subclasses must implement the `_send_query` method, which is responsible - for sending the prompt to a specific LLM API. - - Attributes: - query (str): The current query text to process. - prompt_builder (PromptBuilder): Handles construction of prompts. - system_query (str): System prompt for RAG context situations. - image_detail (str): Detail level for image processing ('low','high','auto'). - max_input_tokens_per_request (int): Maximum input token count. - max_output_tokens_per_request (int): Maximum output token count. - max_tokens_per_request (int): Total maximum token count. - rag_query_n (int): Number of results to fetch from memory. - rag_similarity_threshold (float): Minimum similarity for RAG results. - frame_processor (FrameProcessor): Processes video frames. - output_dir (str): Directory for output files. - response_subject (Subject): Subject that emits agent responses. - process_all_inputs (bool): Whether to process every input emission (True) or - skip emissions when the agent is busy processing a previous input (False). - """ - - logging_file_memory_lock = threading.Lock() - - def __init__( - self, - dev_name: str = "NA", - agent_type: str = "LLM", - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: bool = False, - system_query: Optional[str] = None, - max_output_tokens_per_request: int = 16384, - max_input_tokens_per_request: int = 128000, - input_query_stream: Optional[Observable] = None, - input_data_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - ): - """ - Initializes a new instance of the LLMAgent. - - Args: - dev_name (str): The device name of the agent. - agent_type (str): The type of the agent. - agent_memory (AbstractAgentSemanticMemory): The memory system for the agent. - pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. - If None, the global scheduler from get_scheduler() will be used. - process_all_inputs (bool): Whether to process every input emission (True) or - skip emissions when the agent is busy processing a previous input (False). - """ - super().__init__(dev_name, agent_type, agent_memory, pool_scheduler) - # These attributes can be configured by a subclass if needed. - self.query: Optional[str] = None - self.prompt_builder: Optional[PromptBuilder] = None - self.system_query: Optional[str] = system_query - self.image_detail: str = "low" - self.max_input_tokens_per_request: int = max_input_tokens_per_request - self.max_output_tokens_per_request: int = max_output_tokens_per_request - self.max_tokens_per_request: int = ( - self.max_input_tokens_per_request + self.max_output_tokens_per_request - ) - self.rag_query_n: int = 4 - self.rag_similarity_threshold: float = 0.45 - self.frame_processor: Optional[FrameProcessor] = None - self.output_dir: str = os.path.join(os.getcwd(), "assets", "agent") - self.process_all_inputs: bool = process_all_inputs - os.makedirs(self.output_dir, exist_ok=True) - - # Subject for emitting responses - self.response_subject = Subject() - - # Conversation history for maintaining context between calls - self.conversation_history = [] - - # Initialize input streams - self.input_video_stream = input_video_stream - self.input_query_stream = ( - input_query_stream - if (input_data_stream is None) - else ( - input_query_stream.pipe( - RxOps.with_latest_from(input_data_stream), - RxOps.map( - lambda combined: { - "query": combined[0], - "objects": combined[1] - if len(combined) > 1 - else "No object data available", - } - ), - RxOps.map( - lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}" - ), - RxOps.do_action( - lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m") - or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]] - ), - ) - ) - ) - - # Setup stream subscriptions based on inputs provided - if (self.input_video_stream is not None) and (self.input_query_stream is not None): - self.merged_stream = create_stream_merger( - data_input_stream=self.input_video_stream, text_query_stream=self.input_query_stream - ) - - logger.info("Subscribing to merged input stream...") - # Define a query extractor for the merged stream - query_extractor = lambda emission: (emission[0], emission[1][0]) - self.disposables.add( - self.subscribe_to_image_processing( - self.merged_stream, query_extractor=query_extractor - ) - ) - else: - # If no merged stream, fall back to individual streams - if self.input_video_stream is not None: - logger.info("Subscribing to input video stream...") - self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) - if self.input_query_stream is not None: - logger.info("Subscribing to input query stream...") - self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) - - def _update_query(self, incoming_query: Optional[str]) -> None: - """Updates the query if an incoming query is provided. - - Args: - incoming_query (str): The new query text. - """ - if incoming_query is not None: - self.query = incoming_query - - def _get_rag_context(self) -> Tuple[str, str]: - """Queries the agent memory to retrieve RAG context. - - Returns: - Tuple[str, str]: A tuple containing the formatted results (for logging) - and condensed results (for use in the prompt). - """ - results = self.agent_memory.query( - query_texts=self.query, - n_results=self.rag_query_n, - similarity_threshold=self.rag_similarity_threshold, - ) - formatted_results = "\n".join( - f"Document ID: {doc.id}\nMetadata: {doc.metadata}\nContent: {doc.page_content}\nScore: {score}\n" - for (doc, score) in results - ) - condensed_results = " | ".join(f"{doc.page_content}" for (doc, _) in results) - logger.info(f"Agent Memory Query Results:\n{formatted_results}") - logger.info("=== Results End ===") - return formatted_results, condensed_results - - def _build_prompt( - self, - base64_image: Optional[str], - dimensions: Optional[Tuple[int, int]], - override_token_limit: bool, - condensed_results: str, - ) -> list: - """Builds a prompt message using the prompt builder. - - Args: - base64_image (str): Optional Base64-encoded image. - dimensions (Tuple[int, int]): Optional image dimensions. - override_token_limit (bool): Whether to override token limits. - condensed_results (str): The condensed RAG context. - - Returns: - list: A list of message dictionaries to be sent to the LLM. - """ - # Budget for each component of the prompt - budgets = { - "system_prompt": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, - "user_query": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, - "image": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, - "rag": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, - } - - # Define truncation policies for each component - policies = { - "system_prompt": "truncate_end", - "user_query": "truncate_middle", - "image": "do_not_truncate", - "rag": "truncate_end", - } - - return self.prompt_builder.build( - user_query=self.query, - override_token_limit=override_token_limit, - base64_image=base64_image, - image_width=dimensions[0] if dimensions is not None else None, - image_height=dimensions[1] if dimensions is not None else None, - image_detail=self.image_detail, - rag_context=condensed_results, - system_prompt=self.system_query, - budgets=budgets, - policies=policies, - ) - - def _handle_tooling(self, response_message, messages): - """Handles tooling callbacks in the response message. - - If tool calls are present, the corresponding functions are executed and - a follow-up query is sent. - - Args: - response_message: The response message containing tool calls. - messages (list): The original list of messages sent. - - Returns: - The final response message after processing tool calls, if any. - """ - - # TODO: Make this more generic or move implementation to OpenAIAgent. - # This is presently OpenAI-specific. - def _tooling_callback(message, messages, response_message, skill_library: SkillLibrary): - has_called_tools = False - new_messages = [] - for tool_call in message.tool_calls: - has_called_tools = True - name = tool_call.function.name - args = json.loads(tool_call.function.arguments) - result = skill_library.call(name, **args) - logger.info(f"Function Call Results: {result}") - new_messages.append( - { - "role": "tool", - "tool_call_id": tool_call.id, - "content": str(result), - "name": name, - } - ) - if has_called_tools: - logger.info("Sending Another Query.") - messages.append(response_message) - messages.extend(new_messages) - # Delegate to sending the query again. - return self._send_query(messages) - else: - logger.info("No Need for Another Query.") - return None - - if response_message.tool_calls is not None: - return _tooling_callback( - response_message, messages, response_message, self.skill_library - ) - return None - - def _observable_query( - self, - observer: Observer, - base64_image: Optional[str] = None, - dimensions: Optional[Tuple[int, int]] = None, - override_token_limit: bool = False, - incoming_query: Optional[str] = None, - ): - """Prepares and sends a query to the LLM, emitting the response to the observer. - - Args: - observer (Observer): The observer to emit responses to. - base64_image (str): Optional Base64-encoded image. - dimensions (Tuple[int, int]): Optional image dimensions. - override_token_limit (bool): Whether to override token limits. - incoming_query (str): Optional query to update the agent's query. - - Raises: - Exception: Propagates any exceptions encountered during processing. - """ - try: - self._update_query(incoming_query) - _, condensed_results = self._get_rag_context() - messages = self._build_prompt( - base64_image, dimensions, override_token_limit, condensed_results - ) - # logger.debug(f"Sending Query: {messages}") - logger.info("Sending Query.") - response_message = self._send_query(messages) - logger.info(f"Received Response: {response_message}") - if response_message is None: - raise Exception("Response message does not exist.") - - # TODO: Make this more generic. The parsed tag and tooling handling may be OpenAI-specific. - # If no skill library is provided or there are no tool calls, emit the response directly. - if ( - self.skill_library is None - or self.skill_library.get_tools() in (None, NOT_GIVEN) - or response_message.tool_calls is None - ): - final_msg = ( - response_message.parsed - if hasattr(response_message, "parsed") and response_message.parsed - else ( - response_message.content - if hasattr(response_message, "content") - else response_message - ) - ) - observer.on_next(final_msg) - self.response_subject.on_next(final_msg) - else: - response_message_2 = self._handle_tooling(response_message, messages) - final_msg = ( - response_message_2 if response_message_2 is not None else response_message - ) - if isinstance(final_msg, BaseModel): # TODO: Test - final_msg = str(final_msg.content) - observer.on_next(final_msg) - self.response_subject.on_next(final_msg) - observer.on_completed() - except Exception as e: - logger.error(f"Query failed in {self.dev_name}: {e}") - observer.on_error(e) - self.response_subject.on_error(e) - - def _send_query(self, messages: list) -> Any: - """Sends the query to the LLM API. - - This method must be implemented by subclasses with specifics of the LLM API. - - Args: - messages (list): The prompt messages to be sent. - - Returns: - Any: The response message from the LLM. - - Raises: - NotImplementedError: Always, unless overridden. - """ - raise NotImplementedError("Subclasses must implement _send_query method.") - - def _log_response_to_file(self, response, output_dir: str = None): - """Logs the LLM response to a file. - - Args: - response: The response message to log. - output_dir (str): The directory where the log file is stored. - """ - if output_dir is None: - output_dir = self.output_dir - if response is not None: - with self.logging_file_memory_lock: - log_path = os.path.join(output_dir, "memory.txt") - with open(log_path, "a") as file: - file.write(f"{self.dev_name}: {response}\n") - logger.info(f"LLM Response [{self.dev_name}]: {response}") - - def subscribe_to_image_processing( - self, frame_observable: Observable, query_extractor=None - ) -> Disposable: - """Subscribes to a stream of video frames for processing. - - This method sets up a subscription to process incoming video frames. - Each frame is encoded and then sent to the LLM by directly calling the - _observable_query method. The response is then logged to a file. - - Args: - frame_observable (Observable): An observable emitting video frames or - (query, frame) tuples if query_extractor is provided. - query_extractor (callable, optional): Function to extract query and frame from - each emission. If None, assumes emissions are - raw frames and uses self.system_query. - - Returns: - Disposable: A disposable representing the subscription. - """ - # Initialize frame processor if not already set - if self.frame_processor is None: - self.frame_processor = FrameProcessor(delete_on_init=True) - - print_emission_args = {"enabled": True, "dev_name": self.dev_name, "counts": {}} - - def _process_frame(emission) -> Observable: - """ - Processes a frame or (query, frame) tuple. - """ - # Extract query and frame - if query_extractor: - query, frame = query_extractor(emission) - else: - query = self.system_query - frame = emission - return just(frame).pipe( - MyOps.print_emission(id="B", **print_emission_args), - RxOps.observe_on(self.pool_scheduler), - MyOps.print_emission(id="C", **print_emission_args), - RxOps.subscribe_on(self.pool_scheduler), - MyOps.print_emission(id="D", **print_emission_args), - MyVidOps.with_jpeg_export( - self.frame_processor, - suffix=f"{self.dev_name}_frame_", - save_limit=_MAX_SAVED_FRAMES, - ), - MyOps.print_emission(id="E", **print_emission_args), - MyVidOps.encode_image(), - MyOps.print_emission(id="F", **print_emission_args), - RxOps.filter( - lambda base64_and_dims: base64_and_dims is not None - and base64_and_dims[0] is not None - and base64_and_dims[1] is not None - ), - MyOps.print_emission(id="G", **print_emission_args), - RxOps.flat_map( - lambda base64_and_dims: create( - lambda observer, _: self._observable_query( - observer, - base64_image=base64_and_dims[0], - dimensions=base64_and_dims[1], - incoming_query=query, - ) - ) - ), # Use the extracted query - MyOps.print_emission(id="H", **print_emission_args), - ) - - # Use a mutable flag to ensure only one frame is processed at a time. - is_processing = [False] - - def process_if_free(emission): - if not self.process_all_inputs and is_processing[0]: - # Drop frame if a request is in progress and process_all_inputs is False - return empty() - else: - is_processing[0] = True - return _process_frame(emission).pipe( - MyOps.print_emission(id="I", **print_emission_args), - RxOps.observe_on(self.pool_scheduler), - MyOps.print_emission(id="J", **print_emission_args), - RxOps.subscribe_on(self.pool_scheduler), - MyOps.print_emission(id="K", **print_emission_args), - RxOps.do_action( - on_completed=lambda: is_processing.__setitem__(0, False), - on_error=lambda e: is_processing.__setitem__(0, False), - ), - MyOps.print_emission(id="L", **print_emission_args), - ) - - observable = frame_observable.pipe( - MyOps.print_emission(id="A", **print_emission_args), - RxOps.flat_map(process_if_free), - MyOps.print_emission(id="M", **print_emission_args), - ) - - disposable = observable.subscribe( - on_next=lambda response: self._log_response_to_file(response, self.output_dir), - on_error=lambda e: logger.error(f"Error encountered: {e}"), - on_completed=lambda: logger.info(f"Stream processing completed for {self.dev_name}"), - ) - self.disposables.add(disposable) - return disposable - - def subscribe_to_query_processing(self, query_observable: Observable) -> Disposable: - """Subscribes to a stream of queries for processing. - - This method sets up a subscription to process incoming queries by directly - calling the _observable_query method. The responses are logged to a file. - - Args: - query_observable (Observable): An observable emitting queries. - - Returns: - Disposable: A disposable representing the subscription. - """ - print_emission_args = {"enabled": False, "dev_name": self.dev_name, "counts": {}} - - def _process_query(query) -> Observable: - """ - Processes a single query by logging it and passing it to _observable_query. - Returns an observable that emits the LLM response. - """ - return just(query).pipe( - MyOps.print_emission(id="Pr A", **print_emission_args), - RxOps.flat_map( - lambda query: create( - lambda observer, _: self._observable_query(observer, incoming_query=query) - ) - ), - MyOps.print_emission(id="Pr B", **print_emission_args), - ) - - # A mutable flag indicating whether a query is currently being processed. - is_processing = [False] - - def process_if_free(query): - logger.info(f"Processing Query: {query}") - if not self.process_all_inputs and is_processing[0]: - # Drop query if a request is already in progress and process_all_inputs is False - return empty() - else: - is_processing[0] = True - logger.info("Processing Query.") - return _process_query(query).pipe( - MyOps.print_emission(id="B", **print_emission_args), - RxOps.observe_on(self.pool_scheduler), - MyOps.print_emission(id="C", **print_emission_args), - RxOps.subscribe_on(self.pool_scheduler), - MyOps.print_emission(id="D", **print_emission_args), - RxOps.do_action( - on_completed=lambda: is_processing.__setitem__(0, False), - on_error=lambda e: is_processing.__setitem__(0, False), - ), - MyOps.print_emission(id="E", **print_emission_args), - ) - - observable = query_observable.pipe( - MyOps.print_emission(id="A", **print_emission_args), - RxOps.flat_map(lambda query: process_if_free(query)), - MyOps.print_emission(id="F", **print_emission_args), - ) - - disposable = observable.subscribe( - on_next=lambda response: self._log_response_to_file(response, self.output_dir), - on_error=lambda e: logger.error(f"Error processing query for {self.dev_name}: {e}"), - on_completed=lambda: logger.info(f"Stream processing completed for {self.dev_name}"), - ) - self.disposables.add(disposable) - return disposable - - def get_response_observable(self) -> Observable: - """Gets an observable that emits responses from this agent. - - Returns: - Observable: An observable that emits string responses from the agent. - """ - return self.response_subject.pipe( - RxOps.observe_on(self.pool_scheduler), - RxOps.subscribe_on(self.pool_scheduler), - RxOps.share(), - ) - - def run_observable_query(self, query_text: str, **kwargs) -> Observable: - """Creates an observable that processes a one-off text query to Agent and emits the response. - - This method provides a simple way to send a text query and get an observable - stream of the response. It's designed for one-off queries rather than - continuous processing of input streams. Useful for testing and development. - - Args: - query_text (str): The query text to process. - **kwargs: Additional arguments to pass to _observable_query. Supported args vary by agent type. - For example, ClaudeAgent supports: base64_image, dimensions, override_token_limit, - reset_conversation, thinking_budget_tokens - - Returns: - Observable: An observable that emits the response as a string. - """ - return create( - lambda observer, _: self._observable_query( - observer, incoming_query=query_text, **kwargs - ) - ) - - def dispose_all(self): - """Disposes of all active subscriptions managed by this agent.""" - super().dispose_all() - self.response_subject.on_completed() - - -# endregion LLMAgent Base Class (Generic LLM Agent) - - -# ----------------------------------------------------------------------------- -# region OpenAIAgent Subclass (OpenAI-Specific Implementation) -# ----------------------------------------------------------------------------- -class OpenAIAgent(LLMAgent): - """OpenAI agent implementation that uses OpenAI's API for processing. - - This class implements the _send_query method to interact with OpenAI's API. - It also sets up OpenAI-specific parameters, such as the client, model name, - tokenizer, and response model. - """ - - def __init__( - self, - dev_name: str, - agent_type: str = "Vision", - query: str = "What do you see?", - input_query_stream: Optional[Observable] = None, - input_data_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, - max_input_tokens_per_request: int = 128000, - max_output_tokens_per_request: int = 16384, - model_name: str = "gpt-4o", - prompt_builder: Optional[PromptBuilder] = None, - tokenizer: Optional[AbstractTokenizer] = None, - rag_query_n: int = 4, - rag_similarity_threshold: float = 0.45, - skills: Optional[Union[AbstractSkill, list[AbstractSkill], SkillLibrary]] = None, - response_model: Optional[BaseModel] = None, - frame_processor: Optional[FrameProcessor] = None, - image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - openai_client: Optional[OpenAI] = None, - ): - """ - Initializes a new instance of the OpenAIAgent. - - Args: - dev_name (str): The device name of the agent. - agent_type (str): The type of the agent. - query (str): The default query text. - input_query_stream (Observable): An observable for query input. - input_data_stream (Observable): An observable for data input. - input_video_stream (Observable): An observable for video frames. - output_dir (str): Directory for output files. - agent_memory (AbstractAgentSemanticMemory): The memory system. - system_query (str): The system prompt to use with RAG context. - max_input_tokens_per_request (int): Maximum tokens for input. - max_output_tokens_per_request (int): Maximum tokens for output. - model_name (str): The OpenAI model name to use. - prompt_builder (PromptBuilder): Custom prompt builder. - tokenizer (AbstractTokenizer): Custom tokenizer for token counting. - rag_query_n (int): Number of results to fetch in RAG queries. - rag_similarity_threshold (float): Minimum similarity for RAG results. - skills (Union[AbstractSkill, List[AbstractSkill], SkillLibrary]): Skills available to the agent. - response_model (BaseModel): Optional Pydantic model for responses. - frame_processor (FrameProcessor): Custom frame processor. - image_detail (str): Detail level for images ("low", "high", "auto"). - pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. - If None, the global scheduler from get_scheduler() will be used. - process_all_inputs (bool): Whether to process all inputs or skip when busy. - If None, defaults to True for text queries and merged streams, False for video streams. - openai_client (OpenAI): The OpenAI client to use. This can be used to specify - a custom OpenAI client if targetting another provider. - """ - # Determine appropriate default for process_all_inputs if not provided - if process_all_inputs is None: - if input_query_stream is not None: - process_all_inputs = True - else: - process_all_inputs = False - - super().__init__( - dev_name=dev_name, - agent_type=agent_type, - agent_memory=agent_memory, - pool_scheduler=pool_scheduler, - process_all_inputs=process_all_inputs, - system_query=system_query, - input_query_stream=input_query_stream, - input_data_stream=input_data_stream, - input_video_stream=input_video_stream, - ) - self.client = openai_client or OpenAI() - self.query = query - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - - # Configure skill library. - self.skills = skills - self.skill_library = None - if isinstance(self.skills, SkillLibrary): - self.skill_library = self.skills - elif isinstance(self.skills, list): - self.skill_library = SkillLibrary() - for skill in self.skills: - self.skill_library.add(skill) - elif isinstance(self.skills, AbstractSkill): - self.skill_library = SkillLibrary() - self.skill_library.add(self.skills) - - self.response_model = response_model if response_model is not None else NOT_GIVEN - self.model_name = model_name - self.tokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name) - self.prompt_builder = prompt_builder or PromptBuilder( - self.model_name, tokenizer=self.tokenizer - ) - self.rag_query_n = rag_query_n - self.rag_similarity_threshold = rag_similarity_threshold - self.image_detail = image_detail - self.max_output_tokens_per_request = max_output_tokens_per_request - self.max_input_tokens_per_request = max_input_tokens_per_request - self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request - - # Add static context to memory. - self._add_context_to_memory() - - self.frame_processor = frame_processor or FrameProcessor(delete_on_init=True) - - logger.info("OpenAI Agent Initialized.") - - def _add_context_to_memory(self): - """Adds initial context to the agent's memory.""" - context_data = [ - ( - "id0", - "Optical Flow is a technique used to track the movement of objects in a video sequence.", - ), - ( - "id1", - "Edge Detection is a technique used to identify the boundaries of objects in an image.", - ), - ("id2", "Video is a sequence of frames captured at regular intervals."), - ( - "id3", - "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", - ), - ( - "id4", - "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", - ), - ] - for doc_id, text in context_data: - self.agent_memory.add_vector(doc_id, text) - - def _send_query(self, messages: list) -> Any: - """Sends the query to OpenAI's API. - - Depending on whether a response model is provided, the appropriate API - call is made. - - Args: - messages (list): The prompt messages to send. - - Returns: - The response message from OpenAI. - - Raises: - Exception: If no response message is returned. - ConnectionError: If there's an issue connecting to the API. - ValueError: If the messages or other parameters are invalid. - """ - try: - if self.response_model is not NOT_GIVEN: - response = self.client.beta.chat.completions.parse( - model=self.model_name, - messages=messages, - response_format=self.response_model, - tools=( - self.skill_library.get_tools() - if self.skill_library is not None - else NOT_GIVEN - ), - max_tokens=self.max_output_tokens_per_request, - ) - else: - response = self.client.chat.completions.create( - model=self.model_name, - messages=messages, - max_tokens=self.max_output_tokens_per_request, - tools=( - self.skill_library.get_tools() - if self.skill_library is not None - else NOT_GIVEN - ), - ) - response_message = response.choices[0].message - if response_message is None: - logger.error("Response message does not exist.") - raise Exception("Response message does not exist.") - return response_message - except ConnectionError as ce: - logger.error(f"Connection error with API: {ce}") - raise - except ValueError as ve: - logger.error(f"Invalid parameters: {ve}") - raise - except Exception as e: - logger.error(f"Unexpected error in API call: {e}") - raise - - def stream_query(self, query_text: str) -> Observable: - """Creates an observable that processes a text query and emits the response. - - This method provides a simple way to send a text query and get an observable - stream of the response. It's designed for one-off queries rather than - continuous processing of input streams. - - Args: - query_text (str): The query text to process. - - Returns: - Observable: An observable that emits the response as a string. - """ - return create( - lambda observer, _: self._observable_query(observer, incoming_query=query_text) - ) - - -# endregion OpenAIAgent Subclass (OpenAI-Specific Implementation) diff --git a/build/lib/dimos/agents/agent_config.py b/build/lib/dimos/agents/agent_config.py deleted file mode 100644 index 0ffbcd2983..0000000000 --- a/build/lib/dimos/agents/agent_config.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List -from dimos.agents.agent import Agent - - -class AgentConfig: - def __init__(self, agents: List[Agent] = None): - """ - Initialize an AgentConfig with a list of agents. - - Args: - agents (List[Agent], optional): List of Agent instances. Defaults to empty list. - """ - self.agents = agents if agents is not None else [] - - def add_agent(self, agent: Agent): - """ - Add an agent to the configuration. - - Args: - agent (Agent): Agent instance to add - """ - self.agents.append(agent) - - def remove_agent(self, agent: Agent): - """ - Remove an agent from the configuration. - - Args: - agent (Agent): Agent instance to remove - """ - if agent in self.agents: - self.agents.remove(agent) - - def get_agents(self) -> List[Agent]: - """ - Get the list of configured agents. - - Returns: - List[Agent]: List of configured agents - """ - return self.agents diff --git a/build/lib/dimos/agents/agent_ctransformers_gguf.py b/build/lib/dimos/agents/agent_ctransformers_gguf.py deleted file mode 100644 index 32d6fc59ca..0000000000 --- a/build/lib/dimos/agents/agent_ctransformers_gguf.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -# Standard library imports -import logging -import os -from typing import Any, Optional - -# Third-party imports -from dotenv import load_dotenv -from reactivex import Observable, create -from reactivex.scheduler import ThreadPoolScheduler -from reactivex.subject import Subject -import torch - -# Local imports -from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory -from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.utils.logging_config import setup_logger - -# Initialize environment variables -load_dotenv() - -# Initialize logger for the agent module -logger = setup_logger("dimos.agents", level=logging.DEBUG) - -from ctransformers import AutoModelForCausalLM as CTransformersModel - - -class CTransformersTokenizerAdapter: - def __init__(self, model): - self.model = model - - def encode(self, text, **kwargs): - return self.model.tokenize(text) - - def decode(self, token_ids, **kwargs): - return self.model.detokenize(token_ids) - - def token_count(self, text): - return len(self.tokenize_text(text)) if text else 0 - - def tokenize_text(self, text): - return self.model.tokenize(text) - - def detokenize_text(self, tokenized_text): - try: - return self.model.detokenize(tokenized_text) - except Exception as e: - raise ValueError(f"Failed to detokenize text. Error: {str(e)}") - - def apply_chat_template(self, conversation, tokenize=False, add_generation_prompt=True): - prompt = "" - for message in conversation: - role = message["role"] - content = message["content"] - if role == "system": - prompt += f"<|system|>\n{content}\n" - elif role == "user": - prompt += f"<|user|>\n{content}\n" - elif role == "assistant": - prompt += f"<|assistant|>\n{content}\n" - if add_generation_prompt: - prompt += "<|assistant|>\n" - return prompt - - -# CTransformers Agent Class -class CTransformersGGUFAgent(LLMAgent): - def __init__( - self, - dev_name: str, - agent_type: str = "HF-LLM", - model_name: str = "TheBloke/Llama-2-7B-GGUF", - model_file: str = "llama-2-7b.Q4_K_M.gguf", - model_type: str = "llama", - gpu_layers: int = 50, - device: str = "auto", - query: str = "How many r's are in the word 'strawberry'?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = "You are a helpful assistant.", - max_output_tokens_per_request: int = 10, - max_input_tokens_per_request: int = 250, - prompt_builder: Optional[PromptBuilder] = None, - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - ): - # Determine appropriate default for process_all_inputs if not provided - if process_all_inputs is None: - # Default to True for text queries, False for video streams - if input_query_stream is not None and input_video_stream is None: - process_all_inputs = True - else: - process_all_inputs = False - - super().__init__( - dev_name=dev_name, - agent_type=agent_type, - agent_memory=agent_memory, - pool_scheduler=pool_scheduler, - process_all_inputs=process_all_inputs, - system_query=system_query, - max_output_tokens_per_request=max_output_tokens_per_request, - max_input_tokens_per_request=max_input_tokens_per_request, - ) - - self.query = query - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - - self.model_name = model_name - self.device = device - if self.device == "auto": - self.device = "cuda" if torch.cuda.is_available() else "cpu" - if self.device == "cuda": - print(f"Using GPU: {torch.cuda.get_device_name(0)}") - else: - print("GPU not available, using CPU") - print(f"Device: {self.device}") - - self.model = CTransformersModel.from_pretrained( - model_name, model_file=model_file, model_type=model_type, gpu_layers=gpu_layers - ) - - self.tokenizer = CTransformersTokenizerAdapter(self.model) - - self.prompt_builder = prompt_builder or PromptBuilder( - self.model_name, tokenizer=self.tokenizer - ) - - self.max_output_tokens_per_request = max_output_tokens_per_request - - # self.stream_query(self.query).subscribe(lambda x: print(x)) - - self.input_video_stream = input_video_stream - self.input_query_stream = input_query_stream - - # Ensure only one input stream is provided. - if self.input_video_stream is not None and self.input_query_stream is not None: - raise ValueError( - "More than one input stream provided. Please provide only one input stream." - ) - - if self.input_video_stream is not None: - logger.info("Subscribing to input video stream...") - self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) - if self.input_query_stream is not None: - logger.info("Subscribing to input query stream...") - self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) - - def _send_query(self, messages: list) -> Any: - try: - _BLUE_PRINT_COLOR: str = "\033[34m" - _RESET_COLOR: str = "\033[0m" - - # === FIX: Flatten message content === - flat_messages = [] - for msg in messages: - role = msg["role"] - content = msg["content"] - if isinstance(content, list): - # Assume it's a list of {'type': 'text', 'text': ...} - text_parts = [c["text"] for c in content if isinstance(c, dict) and "text" in c] - content = " ".join(text_parts) - flat_messages.append({"role": role, "content": content}) - - print(f"{_BLUE_PRINT_COLOR}Messages: {flat_messages}{_RESET_COLOR}") - - print("Applying chat template...") - prompt_text = self.tokenizer.apply_chat_template( - conversation=flat_messages, tokenize=False, add_generation_prompt=True - ) - print("Chat template applied.") - print(f"Prompt text:\n{prompt_text}") - - response = self.model(prompt_text, max_new_tokens=self.max_output_tokens_per_request) - print("Model response received.") - return response - - except Exception as e: - logger.error(f"Error during HuggingFace query: {e}") - return "Error processing request." - - def stream_query(self, query_text: str) -> Subject: - """ - Creates an observable that processes a text query and emits the response. - """ - return create( - lambda observer, _: self._observable_query(observer, incoming_query=query_text) - ) - - -# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation) diff --git a/build/lib/dimos/agents/agent_huggingface_local.py b/build/lib/dimos/agents/agent_huggingface_local.py deleted file mode 100644 index 14f970c3bc..0000000000 --- a/build/lib/dimos/agents/agent_huggingface_local.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -# Standard library imports -import logging -import os -from typing import Any, Optional - -# Third-party imports -from dotenv import load_dotenv -from reactivex import Observable, create -from reactivex.scheduler import ThreadPoolScheduler -from reactivex.subject import Subject -import torch -from transformers import AutoModelForCausalLM - -# Local imports -from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory -from dimos.agents.memory.chroma_impl import LocalSemanticMemory -from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.agents.tokenizer.base import AbstractTokenizer -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer -from dimos.utils.logging_config import setup_logger - -# Initialize environment variables -load_dotenv() - -# Initialize logger for the agent module -logger = setup_logger("dimos.agents", level=logging.DEBUG) - - -# HuggingFaceLLMAgent Class -class HuggingFaceLocalAgent(LLMAgent): - def __init__( - self, - dev_name: str, - agent_type: str = "HF-LLM", - model_name: str = "Qwen/Qwen2.5-3B", - device: str = "auto", - query: str = "How many r's are in the word 'strawberry'?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, - max_output_tokens_per_request: int = None, - max_input_tokens_per_request: int = None, - prompt_builder: Optional[PromptBuilder] = None, - tokenizer: Optional[AbstractTokenizer] = None, - image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - ): - # Determine appropriate default for process_all_inputs if not provided - if process_all_inputs is None: - # Default to True for text queries, False for video streams - if input_query_stream is not None and input_video_stream is None: - process_all_inputs = True - else: - process_all_inputs = False - - super().__init__( - dev_name=dev_name, - agent_type=agent_type, - agent_memory=agent_memory or LocalSemanticMemory(), - pool_scheduler=pool_scheduler, - process_all_inputs=process_all_inputs, - system_query=system_query, - ) - - self.query = query - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - - self.model_name = model_name - self.device = device - if self.device == "auto": - self.device = "cuda" if torch.cuda.is_available() else "cpu" - if self.device == "cuda": - print(f"Using GPU: {torch.cuda.get_device_name(0)}") - else: - print("GPU not available, using CPU") - print(f"Device: {self.device}") - - self.tokenizer = tokenizer or HuggingFaceTokenizer(self.model_name) - - self.prompt_builder = prompt_builder or PromptBuilder( - self.model_name, tokenizer=self.tokenizer - ) - - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, - device_map=self.device, - ) - - self.max_output_tokens_per_request = max_output_tokens_per_request - - # self.stream_query(self.query).subscribe(lambda x: print(x)) - - self.input_video_stream = input_video_stream - self.input_query_stream = input_query_stream - - # Ensure only one input stream is provided. - if self.input_video_stream is not None and self.input_query_stream is not None: - raise ValueError( - "More than one input stream provided. Please provide only one input stream." - ) - - if self.input_video_stream is not None: - logger.info("Subscribing to input video stream...") - self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) - if self.input_query_stream is not None: - logger.info("Subscribing to input query stream...") - self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) - - def _send_query(self, messages: list) -> Any: - _BLUE_PRINT_COLOR: str = "\033[34m" - _RESET_COLOR: str = "\033[0m" - - try: - # Log the incoming messages - print(f"{_BLUE_PRINT_COLOR}Messages: {str(messages)}{_RESET_COLOR}") - - # Process with chat template - try: - print("Applying chat template...") - prompt_text = self.tokenizer.tokenizer.apply_chat_template( - conversation=[{"role": "user", "content": str(messages)}], - tokenize=False, - add_generation_prompt=True, - ) - print("Chat template applied.") - - # Tokenize the prompt - print("Preparing model inputs...") - model_inputs = self.tokenizer.tokenizer([prompt_text], return_tensors="pt").to( - self.model.device - ) - print("Model inputs prepared.") - - # Generate the response - print("Generating response...") - generated_ids = self.model.generate( - **model_inputs, max_new_tokens=self.max_output_tokens_per_request - ) - - # Extract the generated tokens (excluding the input prompt tokens) - print("Processing generated output...") - generated_ids = [ - output_ids[len(input_ids) :] - for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) - ] - - # Convert tokens back to text - response = self.tokenizer.tokenizer.batch_decode( - generated_ids, skip_special_tokens=True - )[0] - print("Response successfully generated.") - - return response - - except AttributeError as e: - # Handle case where tokenizer doesn't have the expected methods - logger.warning(f"Chat template not available: {e}. Using simple format.") - # Continue with execution and use simple format - - except Exception as e: - # Log any other errors but continue execution - logger.warning( - f"Error in chat template processing: {e}. Falling back to simple format." - ) - - # Fallback approach for models without chat template support - # This code runs if the try block above raises an exception - print("Using simple prompt format...") - - # Convert messages to a simple text format - if ( - isinstance(messages, list) - and messages - and isinstance(messages[0], dict) - and "content" in messages[0] - ): - prompt_text = messages[0]["content"] - else: - prompt_text = str(messages) - - # Tokenize the prompt - model_inputs = self.tokenizer.tokenize_text(prompt_text) - model_inputs = torch.tensor([model_inputs], device=self.model.device) - - # Generate the response - generated_ids = self.model.generate( - input_ids=model_inputs, max_new_tokens=self.max_output_tokens_per_request - ) - - # Extract the generated tokens - generated_ids = generated_ids[0][len(model_inputs[0]) :] - - # Convert tokens back to text - response = self.tokenizer.detokenize_text(generated_ids.tolist()) - print("Response generated using simple format.") - - return response - - except Exception as e: - # Catch all other errors - logger.error(f"Error during query processing: {e}", exc_info=True) - return "Error processing request. Please try again." - - def stream_query(self, query_text: str) -> Subject: - """ - Creates an observable that processes a text query and emits the response. - """ - return create( - lambda observer, _: self._observable_query(observer, incoming_query=query_text) - ) - - -# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation) diff --git a/build/lib/dimos/agents/agent_huggingface_remote.py b/build/lib/dimos/agents/agent_huggingface_remote.py deleted file mode 100644 index d98b277706..0000000000 --- a/build/lib/dimos/agents/agent_huggingface_remote.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -# Standard library imports -import logging -import os -from typing import Any, Optional - -# Third-party imports -from dotenv import load_dotenv -from huggingface_hub import InferenceClient -from reactivex import create, Observable -from reactivex.scheduler import ThreadPoolScheduler -from reactivex.subject import Subject - -# Local imports -from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory -from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.agents.tokenizer.base import AbstractTokenizer -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer -from dimos.utils.logging_config import setup_logger - -# Initialize environment variables -load_dotenv() - -# Initialize logger for the agent module -logger = setup_logger("dimos.agents", level=logging.DEBUG) - - -# HuggingFaceLLMAgent Class -class HuggingFaceRemoteAgent(LLMAgent): - def __init__( - self, - dev_name: str, - agent_type: str = "HF-LLM", - model_name: str = "Qwen/QwQ-32B", - query: str = "How many r's are in the word 'strawberry'?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, - max_output_tokens_per_request: int = 16384, - prompt_builder: Optional[PromptBuilder] = None, - tokenizer: Optional[AbstractTokenizer] = None, - image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - api_key: Optional[str] = None, - hf_provider: Optional[str] = None, - hf_base_url: Optional[str] = None, - ): - # Determine appropriate default for process_all_inputs if not provided - if process_all_inputs is None: - # Default to True for text queries, False for video streams - if input_query_stream is not None and input_video_stream is None: - process_all_inputs = True - else: - process_all_inputs = False - - super().__init__( - dev_name=dev_name, - agent_type=agent_type, - agent_memory=agent_memory, - pool_scheduler=pool_scheduler, - process_all_inputs=process_all_inputs, - system_query=system_query, - ) - - self.query = query - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - - self.model_name = model_name - self.prompt_builder = prompt_builder or PromptBuilder( - self.model_name, tokenizer=tokenizer or HuggingFaceTokenizer(self.model_name) - ) - - self.model_name = model_name - - self.max_output_tokens_per_request = max_output_tokens_per_request - - self.api_key = api_key or os.getenv("HF_TOKEN") - self.provider = hf_provider or "hf-inference" - self.base_url = hf_base_url or os.getenv("HUGGINGFACE_PRV_ENDPOINT") - self.client = InferenceClient( - provider=self.provider, - base_url=self.base_url, - api_key=self.api_key, - ) - - # self.stream_query(self.query).subscribe(lambda x: print(x)) - - self.input_video_stream = input_video_stream - self.input_query_stream = input_query_stream - - # Ensure only one input stream is provided. - if self.input_video_stream is not None and self.input_query_stream is not None: - raise ValueError( - "More than one input stream provided. Please provide only one input stream." - ) - - if self.input_video_stream is not None: - logger.info("Subscribing to input video stream...") - self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) - if self.input_query_stream is not None: - logger.info("Subscribing to input query stream...") - self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) - - def _send_query(self, messages: list) -> Any: - try: - completion = self.client.chat.completions.create( - model=self.model_name, - messages=messages, - max_tokens=self.max_output_tokens_per_request, - ) - - return completion.choices[0].message - except Exception as e: - logger.error(f"Error during HuggingFace query: {e}") - return "Error processing request." - - def stream_query(self, query_text: str) -> Subject: - """ - Creates an observable that processes a text query and emits the response. - """ - return create( - lambda observer, _: self._observable_query(observer, incoming_query=query_text) - ) diff --git a/build/lib/dimos/agents/cerebras_agent.py b/build/lib/dimos/agents/cerebras_agent.py deleted file mode 100644 index 854beb848d..0000000000 --- a/build/lib/dimos/agents/cerebras_agent.py +++ /dev/null @@ -1,608 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Cerebras agent implementation for the DIMOS agent framework. - -This module provides a CerebrasAgent class that implements the LLMAgent interface -for Cerebras inference API using the official Cerebras Python SDK. -""" - -from __future__ import annotations - -import os -import threading -import copy -from typing import Any, Dict, List, Optional, Union, Tuple -import logging -import json -import re -import time - -from cerebras.cloud.sdk import Cerebras -from dotenv import load_dotenv -from pydantic import BaseModel -from reactivex import Observable -from reactivex.observer import Observer -from reactivex.scheduler import ThreadPoolScheduler - -# Local imports -from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory -from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.agents.tokenizer.base import AbstractTokenizer -from dimos.skills.skills import AbstractSkill, SkillLibrary -from dimos.stream.frame_processor import FrameProcessor -from dimos.utils.logging_config import setup_logger -from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer - -# Initialize environment variables -load_dotenv() - -# Initialize logger for the Cerebras agent -logger = setup_logger("dimos.agents.cerebras") - - -# Response object compatible with LLMAgent -class CerebrasResponseMessage(dict): - def __init__( - self, - content="", - tool_calls=None, - ): - self.content = content - self.tool_calls = tool_calls or [] - self.parsed = None - - # Initialize as dict with the proper structure - super().__init__(self.to_dict()) - - def __str__(self): - # Return a string representation for logging - if self.content: - return self.content - elif self.tool_calls: - # Return JSON representation of the first tool call - if self.tool_calls: - tool_call = self.tool_calls[0] - tool_json = { - "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), - } - return json.dumps(tool_json) - return "[No content]" - - def to_dict(self): - """Convert to dictionary format for JSON serialization.""" - result = {"role": "assistant", "content": self.content or ""} - - if self.tool_calls: - result["tool_calls"] = [] - for tool_call in self.tool_calls: - result["tool_calls"].append( - { - "id": tool_call.id, - "type": "function", - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, - } - ) - - return result - - -class CerebrasAgent(LLMAgent): - """Cerebras agent implementation using the official Cerebras Python SDK. - - This class implements the _send_query method to interact with Cerebras API - using their official SDK, allowing most of the LLMAgent logic to be reused. - """ - - def __init__( - self, - dev_name: str, - agent_type: str = "Vision", - query: str = "What do you see?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - input_data_stream: Optional[Observable] = None, - output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, - max_input_tokens_per_request: int = 128000, - max_output_tokens_per_request: int = 16384, - model_name: str = "llama-4-scout-17b-16e-instruct", - skills: Optional[Union[AbstractSkill, list[AbstractSkill], SkillLibrary]] = None, - response_model: Optional[BaseModel] = None, - frame_processor: Optional[FrameProcessor] = None, - image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - tokenizer: Optional[AbstractTokenizer] = None, - prompt_builder: Optional[PromptBuilder] = None, - ): - """ - Initializes a new instance of the CerebrasAgent. - - Args: - dev_name (str): The device name of the agent. - agent_type (str): The type of the agent. - query (str): The default query text. - input_query_stream (Observable): An observable for query input. - input_video_stream (Observable): An observable for video frames. - input_data_stream (Observable): An observable for data input. - output_dir (str): Directory for output files. - agent_memory (AbstractAgentSemanticMemory): The memory system. - system_query (str): The system prompt to use with RAG context. - max_input_tokens_per_request (int): Maximum tokens for input. - max_output_tokens_per_request (int): Maximum tokens for output. - model_name (str): The Cerebras model name to use. Available options: - - llama-4-scout-17b-16e-instruct (default, fastest) - - llama3.1-8b - - llama-3.3-70b - - qwen-3-32b - - deepseek-r1-distill-llama-70b (private preview) - skills (Union[AbstractSkill, List[AbstractSkill], SkillLibrary]): Skills available to the agent. - response_model (BaseModel): Optional Pydantic model for structured responses. - frame_processor (FrameProcessor): Custom frame processor. - image_detail (str): Detail level for images ("low", "high", "auto"). - pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. - process_all_inputs (bool): Whether to process all inputs or skip when busy. - tokenizer (AbstractTokenizer): The tokenizer for the agent. - prompt_builder (PromptBuilder): The prompt builder for the agent. - """ - # Determine appropriate default for process_all_inputs if not provided - if process_all_inputs is None: - # Default to True for text queries, False for video streams - if input_query_stream is not None and input_video_stream is None: - process_all_inputs = True - else: - process_all_inputs = False - - super().__init__( - dev_name=dev_name, - agent_type=agent_type, - agent_memory=agent_memory, - pool_scheduler=pool_scheduler, - process_all_inputs=process_all_inputs, - system_query=system_query, - input_query_stream=input_query_stream, - input_video_stream=input_video_stream, - input_data_stream=input_data_stream, - ) - - # Initialize Cerebras client - self.client = Cerebras() - - self.query = query - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - - # Initialize conversation history for multi-turn conversations - self.conversation_history = [] - self._history_lock = threading.Lock() - - # Configure skills - self.skills = skills - self.skill_library = None - if isinstance(self.skills, SkillLibrary): - self.skill_library = self.skills - elif isinstance(self.skills, list): - self.skill_library = SkillLibrary() - for skill in self.skills: - self.skill_library.add(skill) - elif isinstance(self.skills, AbstractSkill): - self.skill_library = SkillLibrary() - self.skill_library.add(self.skills) - - self.response_model = response_model - self.model_name = model_name - self.image_detail = image_detail - self.max_output_tokens_per_request = max_output_tokens_per_request - self.max_input_tokens_per_request = max_input_tokens_per_request - self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request - - # Add static context to memory. - self._add_context_to_memory() - - # Initialize tokenizer and prompt builder - self.tokenizer = tokenizer or OpenAITokenizer( - model_name="gpt-4o" - ) # Use GPT-4 tokenizer for better accuracy - self.prompt_builder = prompt_builder or PromptBuilder( - model_name=self.model_name, - max_tokens=self.max_input_tokens_per_request, - tokenizer=self.tokenizer, - ) - - logger.info("Cerebras Agent Initialized.") - - def _add_context_to_memory(self): - """Adds initial context to the agent's memory.""" - context_data = [ - ( - "id0", - "Optical Flow is a technique used to track the movement of objects in a video sequence.", - ), - ( - "id1", - "Edge Detection is a technique used to identify the boundaries of objects in an image.", - ), - ("id2", "Video is a sequence of frames captured at regular intervals."), - ( - "id3", - "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", - ), - ( - "id4", - "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", - ), - ] - for doc_id, text in context_data: - self.agent_memory.add_vector(doc_id, text) - - def _build_prompt( - self, - messages: list, - base64_image: Optional[Union[str, List[str]]] = None, - dimensions: Optional[Tuple[int, int]] = None, - override_token_limit: bool = False, - condensed_results: str = "", - ) -> list: - """Builds a prompt message specifically for Cerebras API. - - Args: - messages (list): Existing messages list to build upon. - base64_image (Union[str, List[str]]): Optional Base64-encoded image(s). - dimensions (Tuple[int, int]): Optional image dimensions. - override_token_limit (bool): Whether to override token limits. - condensed_results (str): The condensed RAG context. - - Returns: - list: Messages formatted for Cerebras API. - """ - # Add system message if provided and not already in history - if self.system_query and (not messages or messages[0].get("role") != "system"): - messages.insert(0, {"role": "system", "content": self.system_query}) - logger.info("Added system message to conversation") - - # Append user query while handling RAG - if condensed_results: - user_message = {"role": "user", "content": f"{condensed_results}\n\n{self.query}"} - logger.info("Created user message with RAG context") - else: - user_message = {"role": "user", "content": self.query} - - messages.append(user_message) - - if base64_image is not None: - # Handle both single image (str) and multiple images (List[str]) - images = [base64_image] if isinstance(base64_image, str) else base64_image - - # For Cerebras, we'll add images inline with text (OpenAI-style format) - for img in images: - img_content = [ - {"type": "text", "text": "Here is an image to analyze:"}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{img}", - "detail": self.image_detail, - }, - }, - ] - messages.append({"role": "user", "content": img_content}) - - logger.info(f"Added {len(images)} image(s) to conversation") - - # Use new truncation function - messages = self._truncate_messages(messages, override_token_limit) - - return messages - - def _truncate_messages(self, messages: list, override_token_limit: bool = False) -> list: - """Truncate messages if total tokens exceed 16k using existing truncate_tokens method. - - Args: - messages (list): List of message dictionaries - override_token_limit (bool): Whether to skip truncation - - Returns: - list: Messages with content truncated if needed - """ - if override_token_limit: - return messages - - total_tokens = 0 - for message in messages: - if isinstance(message.get("content"), str): - total_tokens += self.prompt_builder.tokenizer.token_count(message["content"]) - elif isinstance(message.get("content"), list): - for item in message["content"]: - if item.get("type") == "text": - total_tokens += self.prompt_builder.tokenizer.token_count(item["text"]) - elif item.get("type") == "image_url": - total_tokens += 85 - - if total_tokens > 16000: - excess_tokens = total_tokens - 16000 - current_tokens = total_tokens - - # Start from oldest messages and truncate until under 16k - for i in range(len(messages)): - if current_tokens <= 16000: - break - - msg = messages[i] - if msg.get("role") == "system": - continue - - if isinstance(msg.get("content"), str): - original_tokens = self.prompt_builder.tokenizer.token_count(msg["content"]) - # Calculate how much to truncate from this message - tokens_to_remove = min(excess_tokens, original_tokens // 3) - new_max_tokens = max(50, original_tokens - tokens_to_remove) - - msg["content"] = self.prompt_builder.truncate_tokens( - msg["content"], new_max_tokens, "truncate_end" - ) - - new_tokens = self.prompt_builder.tokenizer.token_count(msg["content"]) - tokens_saved = original_tokens - new_tokens - current_tokens -= tokens_saved - excess_tokens -= tokens_saved - - logger.info( - f"Truncated older messages using truncate_tokens, final tokens: {current_tokens}" - ) - else: - logger.info(f"No truncation needed, total tokens: {total_tokens}") - - return messages - - def clean_cerebras_schema(self, schema: dict) -> dict: - """Simple schema cleaner that removes unsupported fields for Cerebras API.""" - if not isinstance(schema, dict): - return schema - - # Removing the problematic fields that pydantic generates - cleaned = {} - unsupported_fields = { - "minItems", - "maxItems", - "uniqueItems", - "exclusiveMinimum", - "exclusiveMaximum", - "minimum", - "maximum", - } - - for key, value in schema.items(): - if key in unsupported_fields: - continue # Skip unsupported fields - elif isinstance(value, dict): - cleaned[key] = self.clean_cerebras_schema(value) - elif isinstance(value, list): - cleaned[key] = [ - self.clean_cerebras_schema(item) if isinstance(item, dict) else item - for item in value - ] - else: - cleaned[key] = value - - return cleaned - - def create_tool_call( - self, name: str = None, arguments: dict = None, call_id: str = None, content: str = None - ): - """Create a tool call object from either direct parameters or JSON content.""" - # If content is provided, parse it as JSON - if content: - logger.info(f"Creating tool call from content: {content}") - try: - content_json = json.loads(content) - if ( - isinstance(content_json, dict) - and "name" in content_json - and "arguments" in content_json - ): - name = content_json["name"] - arguments = content_json["arguments"] - else: - return None - except json.JSONDecodeError: - logger.warning("Content appears to be JSON but failed to parse") - return None - - # Create the tool call object - if name and arguments is not None: - timestamp = int(time.time() * 1000000) # microsecond precision - tool_id = f"call_{timestamp}" - - logger.info(f"Creating tool call with timestamp ID: {tool_id}") - return type( - "ToolCall", - (), - { - "id": tool_id, - "function": type( - "Function", (), {"name": name, "arguments": json.dumps(arguments)} - ), - }, - ) - - return None - - def _send_query(self, messages: list) -> CerebrasResponseMessage: - """Sends the query to Cerebras API using the official Cerebras SDK. - - Args: - messages (list): The prompt messages to send. - - Returns: - The response message from Cerebras wrapped in our CerebrasResponseMessage class. - - Raises: - Exception: If no response message is returned from the API. - ConnectionError: If there's an issue connecting to the API. - ValueError: If the messages or other parameters are invalid. - """ - try: - # Prepare API call parameters - api_params = { - "model": self.model_name, - "messages": messages, - # "max_tokens": self.max_output_tokens_per_request, - } - - # Add tools if available - if self.skill_library and self.skill_library.get_tools(): - tools = self.skill_library.get_tools() - for tool in tools: - if "function" in tool and "parameters" in tool["function"]: - tool["function"]["parameters"] = self.clean_cerebras_schema( - tool["function"]["parameters"] - ) - api_params["tools"] = tools - api_params["tool_choice"] = "auto" - - if self.response_model is not None: - api_params["response_format"] = { - "type": "json_object", - "schema": self.response_model, - } - - # Make the API call - response = self.client.chat.completions.create(**api_params) - - raw_message = response.choices[0].message - if raw_message is None: - logger.error("Response message does not exist.") - raise Exception("Response message does not exist.") - - # Process response into final format - content = raw_message.content - tool_calls = getattr(raw_message, "tool_calls", None) - - # If no structured tool calls from API, try parsing content as JSON tool call - if not tool_calls and content and content.strip().startswith("{"): - parsed_tool_call = self.create_tool_call(content=content) - if parsed_tool_call: - tool_calls = [parsed_tool_call] - content = None - - return CerebrasResponseMessage(content=content, tool_calls=tool_calls) - - except ConnectionError as ce: - logger.error(f"Connection error with Cerebras API: {ce}") - raise - except ValueError as ve: - logger.error(f"Invalid parameters for Cerebras API: {ve}") - raise - except Exception as e: - # Print the raw API parameters when an error occurs - logger.error(f"Raw API parameters: {json.dumps(api_params, indent=2)}") - logger.error(f"Unexpected error in Cerebras API call: {e}") - raise - - def _observable_query( - self, - observer: Observer, - base64_image: Optional[str] = None, - dimensions: Optional[Tuple[int, int]] = None, - override_token_limit: bool = False, - incoming_query: Optional[str] = None, - reset_conversation: bool = False, - ): - """Main query handler that manages conversation history and Cerebras interactions. - - This method follows ClaudeAgent's pattern for efficient conversation history management. - - Args: - observer (Observer): The observer to emit responses to. - base64_image (str): Optional Base64-encoded image. - dimensions (Tuple[int, int]): Optional image dimensions. - override_token_limit (bool): Whether to override token limits. - incoming_query (str): Optional query to update the agent's query. - reset_conversation (bool): Whether to reset the conversation history. - """ - try: - # Reset conversation history if requested - if reset_conversation: - self.conversation_history = [] - logger.info("Conversation history reset") - - # Create a local copy of conversation history and record its length - messages = copy.deepcopy(self.conversation_history) - - # Update query and get context - self._update_query(incoming_query) - _, condensed_results = self._get_rag_context() - - # Build prompt - messages = self._build_prompt( - messages, base64_image, dimensions, override_token_limit, condensed_results - ) - - while True: - logger.info("Sending Query.") - response_message = self._send_query(messages) - logger.info(f"Received Response: {response_message}") - - if response_message is None: - raise Exception("Response message does not exist.") - - # If no skill library or no tool calls, we're done - if ( - self.skill_library is None - or self.skill_library.get_tools() is None - or response_message.tool_calls is None - ): - final_msg = ( - response_message.parsed - if hasattr(response_message, "parsed") and response_message.parsed - else ( - response_message.content - if hasattr(response_message, "content") - else response_message - ) - ) - messages.append(response_message) - break - - logger.info(f"Assistant requested {len(response_message.tool_calls)} tool call(s)") - next_response = self._handle_tooling(response_message, messages) - - if next_response is None: - final_msg = response_message.content or "" - break - - response_message = next_response - - with self._history_lock: - self.conversation_history = messages - logger.info( - f"Updated conversation history (total: {len(self.conversation_history)} messages)" - ) - - # Emit the final message content to the observer - observer.on_next(final_msg) - self.response_subject.on_next(final_msg) - observer.on_completed() - - except Exception as e: - logger.error(f"Query failed in {self.dev_name}: {e}") - observer.on_error(e) - self.response_subject.on_error(e) diff --git a/build/lib/dimos/agents/claude_agent.py b/build/lib/dimos/agents/claude_agent.py deleted file mode 100644 index e87b1f47b4..0000000000 --- a/build/lib/dimos/agents/claude_agent.py +++ /dev/null @@ -1,735 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Claude agent implementation for the DIMOS agent framework. - -This module provides a ClaudeAgent class that implements the LLMAgent interface -for Anthropic's Claude models. It handles conversion between the DIMOS skill format -and Claude's tools format. -""" - -from __future__ import annotations - -import json -import os -from typing import Any, Dict, List, Optional, Tuple, Union - -import anthropic -from dotenv import load_dotenv -from pydantic import BaseModel -from reactivex import Observable -from reactivex.scheduler import ThreadPoolScheduler - -# Local imports -from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory -from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.skills.skills import AbstractSkill, SkillLibrary -from dimos.stream.frame_processor import FrameProcessor -from dimos.utils.logging_config import setup_logger - -# Initialize environment variables -load_dotenv() - -# Initialize logger for the Claude agent -logger = setup_logger("dimos.agents.claude") - - -# Response object compatible with LLMAgent -class ResponseMessage: - def __init__(self, content="", tool_calls=None, thinking_blocks=None): - self.content = content - self.tool_calls = tool_calls or [] - self.thinking_blocks = thinking_blocks or [] - self.parsed = None - - def __str__(self): - # Return a string representation for logging - parts = [] - - # Include content if available - if self.content: - parts.append(self.content) - - # Include tool calls if available - if self.tool_calls: - tool_names = [tc.function.name for tc in self.tool_calls] - parts.append(f"[Tools called: {', '.join(tool_names)}]") - - return "\n".join(parts) if parts else "[No content]" - - -class ClaudeAgent(LLMAgent): - """Claude agent implementation that uses Anthropic's API for processing. - - This class implements the _send_query method to interact with Anthropic's API - and overrides _build_prompt to create Claude-formatted messages directly. - """ - - def __init__( - self, - dev_name: str, - agent_type: str = "Vision", - query: str = "What do you see?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - input_data_stream: Optional[Observable] = None, - output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, - max_input_tokens_per_request: int = 128000, - max_output_tokens_per_request: int = 16384, - model_name: str = "claude-3-7-sonnet-20250219", - prompt_builder: Optional[PromptBuilder] = None, - rag_query_n: int = 4, - rag_similarity_threshold: float = 0.45, - skills: Optional[AbstractSkill] = None, - response_model: Optional[BaseModel] = None, - frame_processor: Optional[FrameProcessor] = None, - image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - thinking_budget_tokens: Optional[int] = 2000, - ): - """ - Initializes a new instance of the ClaudeAgent. - - Args: - dev_name (str): The device name of the agent. - agent_type (str): The type of the agent. - query (str): The default query text. - input_query_stream (Observable): An observable for query input. - input_video_stream (Observable): An observable for video frames. - output_dir (str): Directory for output files. - agent_memory (AbstractAgentSemanticMemory): The memory system. - system_query (str): The system prompt to use with RAG context. - max_input_tokens_per_request (int): Maximum tokens for input. - max_output_tokens_per_request (int): Maximum tokens for output. - model_name (str): The Claude model name to use. - prompt_builder (PromptBuilder): Custom prompt builder (not used in Claude implementation). - rag_query_n (int): Number of results to fetch in RAG queries. - rag_similarity_threshold (float): Minimum similarity for RAG results. - skills (AbstractSkill): Skills available to the agent. - response_model (BaseModel): Optional Pydantic model for responses. - frame_processor (FrameProcessor): Custom frame processor. - image_detail (str): Detail level for images ("low", "high", "auto"). - pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. - process_all_inputs (bool): Whether to process all inputs or skip when busy. - thinking_budget_tokens (int): Number of tokens to allocate for Claude's thinking. 0 disables thinking. - """ - # Determine appropriate default for process_all_inputs if not provided - if process_all_inputs is None: - # Default to True for text queries, False for video streams - if input_query_stream is not None and input_video_stream is None: - process_all_inputs = True - else: - process_all_inputs = False - - super().__init__( - dev_name=dev_name, - agent_type=agent_type, - agent_memory=agent_memory, - pool_scheduler=pool_scheduler, - process_all_inputs=process_all_inputs, - system_query=system_query, - input_query_stream=input_query_stream, - input_video_stream=input_video_stream, - input_data_stream=input_data_stream, - ) - - self.client = anthropic.Anthropic() - self.query = query - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - - # Claude-specific parameters - self.thinking_budget_tokens = thinking_budget_tokens - self.claude_api_params = {} # Will store params for Claude API calls - - # Configure skills - self.skills = skills - self.skill_library = None # Required for error 'ClaudeAgent' object has no attribute 'skill_library' due to skills refactor - if isinstance(self.skills, SkillLibrary): - self.skill_library = self.skills - elif isinstance(self.skills, list): - self.skill_library = SkillLibrary() - for skill in self.skills: - self.skill_library.add(skill) - elif isinstance(self.skills, AbstractSkill): - self.skill_library = SkillLibrary() - self.skill_library.add(self.skills) - - self.response_model = response_model - self.model_name = model_name - self.rag_query_n = rag_query_n - self.rag_similarity_threshold = rag_similarity_threshold - self.image_detail = image_detail - self.max_output_tokens_per_request = max_output_tokens_per_request - self.max_input_tokens_per_request = max_input_tokens_per_request - self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request - - # Add static context to memory. - self._add_context_to_memory() - - self.frame_processor = frame_processor or FrameProcessor(delete_on_init=True) - - # Ensure only one input stream is provided. - if self.input_video_stream is not None and self.input_query_stream is not None: - raise ValueError( - "More than one input stream provided. Please provide only one input stream." - ) - - logger.info("Claude Agent Initialized.") - - def _add_context_to_memory(self): - """Adds initial context to the agent's memory.""" - context_data = [ - ( - "id0", - "Optical Flow is a technique used to track the movement of objects in a video sequence.", - ), - ( - "id1", - "Edge Detection is a technique used to identify the boundaries of objects in an image.", - ), - ("id2", "Video is a sequence of frames captured at regular intervals."), - ( - "id3", - "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", - ), - ( - "id4", - "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", - ), - ] - for doc_id, text in context_data: - self.agent_memory.add_vector(doc_id, text) - - def _convert_tools_to_claude_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Converts DIMOS tools to Claude format. - - Args: - tools: List of tools in DIMOS format. - - Returns: - List of tools in Claude format. - """ - if not tools: - return [] - - claude_tools = [] - - for tool in tools: - # Skip if not a function - if tool.get("type") != "function": - continue - - function = tool.get("function", {}) - name = function.get("name") - description = function.get("description", "") - parameters = function.get("parameters", {}) - - claude_tool = { - "name": name, - "description": description, - "input_schema": { - "type": "object", - "properties": parameters.get("properties", {}), - "required": parameters.get("required", []), - }, - } - - claude_tools.append(claude_tool) - - return claude_tools - - def _build_prompt( - self, - messages: list, - base64_image: Optional[Union[str, List[str]]] = None, - dimensions: Optional[Tuple[int, int]] = None, - override_token_limit: bool = False, - rag_results: str = "", - thinking_budget_tokens: int = None, - ) -> list: - """Builds a prompt message specifically for Claude API, using local messages copy.""" - """Builds a prompt message specifically for Claude API. - - This method creates messages in Claude's format directly, without using - any OpenAI-specific formatting or token counting. - - Args: - base64_image (Union[str, List[str]]): Optional Base64-encoded image(s). - dimensions (Tuple[int, int]): Optional image dimensions. - override_token_limit (bool): Whether to override token limits. - rag_results (str): The condensed RAG context. - thinking_budget_tokens (int): Number of tokens to allocate for Claude's thinking. - - Returns: - dict: A dict containing Claude API parameters. - """ - - # Append user query to conversation history while handling RAG - if rag_results: - messages.append({"role": "user", "content": f"{rag_results}\n\n{self.query}"}) - logger.info( - f"Added new user message to conversation history with RAG context (now has {len(messages)} messages)" - ) - else: - messages.append({"role": "user", "content": self.query}) - logger.info( - f"Added new user message to conversation history (now has {len(messages)} messages)" - ) - - if base64_image is not None: - # Handle both single image (str) and multiple images (List[str]) - images = [base64_image] if isinstance(base64_image, str) else base64_image - - # Add each image as a separate entry in conversation history - for img in images: - img_content = [ - { - "type": "image", - "source": {"type": "base64", "media_type": "image/jpeg", "data": img}, - } - ] - messages.append({"role": "user", "content": img_content}) - - if images: - logger.info( - f"Added {len(images)} image(s) as separate entries to conversation history" - ) - - # Create Claude parameters with basic settings - claude_params = { - "model": self.model_name, - "max_tokens": self.max_output_tokens_per_request, - "temperature": 0, # Add temperature to make responses more deterministic - "messages": messages, - } - - # Add system prompt as a top-level parameter (not as a message) - if self.system_query: - claude_params["system"] = self.system_query - - # Store the parameters for use in _send_query - self.claude_api_params = claude_params.copy() - - # Add tools if skills are available - if self.skills and self.skills.get_tools(): - tools = self._convert_tools_to_claude_format(self.skills.get_tools()) - if tools: # Only add if we have valid tools - claude_params["tools"] = tools - # Enable tool calling with proper format - claude_params["tool_choice"] = {"type": "auto"} - - # Add thinking if enabled and hard code required temperature = 1 - if thinking_budget_tokens is not None and thinking_budget_tokens != 0: - claude_params["thinking"] = {"type": "enabled", "budget_tokens": thinking_budget_tokens} - claude_params["temperature"] = ( - 1 # Required to be 1 when thinking is enabled # Default to 0 for deterministic responses - ) - - # Store the parameters for use in _send_query and return them - self.claude_api_params = claude_params.copy() - return messages, claude_params - - def _send_query(self, messages: list, claude_params: dict) -> Any: - """Sends the query to Anthropic's API using streaming for better thinking visualization. - - Args: - messages: Dict with 'claude_prompt' key containing Claude API parameters. - - Returns: - The response message in a format compatible with LLMAgent's expectations. - """ - try: - # Get Claude parameters - claude_params = claude_params.get("claude_prompt", None) or self.claude_api_params - - # Log request parameters with truncated base64 data - logger.debug(self._debug_api_call(claude_params)) - - # Initialize response containers - text_content = "" - tool_calls = [] - thinking_blocks = [] - - # Log the start of streaming and the query - logger.info("Sending streaming request to Claude API") - - # Log the query to memory.txt - with open(os.path.join(self.output_dir, "memory.txt"), "a") as f: - f.write(f"\n\nQUERY: {self.query}\n\n") - f.flush() - - # Stream the response - with self.client.messages.stream(**claude_params) as stream: - print("\n==== CLAUDE API RESPONSE STREAM STARTED ====") - - # Open the memory file once for the entire stream processing - with open(os.path.join(self.output_dir, "memory.txt"), "a") as memory_file: - # Track the current block being processed - current_block = {"type": None, "id": None, "content": "", "signature": None} - - for event in stream: - # Log each event to console - # print(f"EVENT: {event.type}") - # print(json.dumps(event.model_dump(), indent=2, default=str)) - - if event.type == "content_block_start": - # Initialize a new content block - block_type = event.content_block.type - current_block = { - "type": block_type, - "id": event.index, - "content": "", - "signature": None, - } - logger.debug(f"Starting {block_type} block...") - - elif event.type == "content_block_delta": - if event.delta.type == "thinking_delta": - # Accumulate thinking content - current_block["content"] = event.delta.thinking - memory_file.write(f"{event.delta.thinking}") - memory_file.flush() # Ensure content is written immediately - - elif event.delta.type == "text_delta": - # Accumulate text content - text_content += event.delta.text - current_block["content"] += event.delta.text - memory_file.write(f"{event.delta.text}") - memory_file.flush() - - elif event.delta.type == "signature_delta": - # Store signature for thinking blocks - current_block["signature"] = event.delta.signature - memory_file.write( - f"\n[Signature received for block {current_block['id']}]\n" - ) - memory_file.flush() - - elif event.type == "content_block_stop": - # Store completed blocks - if current_block["type"] == "thinking": - # IMPORTANT: Store the complete event.content_block to ensure we preserve - # the exact format that Claude expects in subsequent requests - if hasattr(event, "content_block"): - # Use the exact thinking block as provided by Claude - thinking_blocks.append(event.content_block.model_dump()) - memory_file.write( - f"\nTHINKING COMPLETE: block {current_block['id']}\n" - ) - else: - # Fallback to constructed thinking block if content_block missing - thinking_block = { - "type": "thinking", - "thinking": current_block["content"], - "signature": current_block["signature"], - } - thinking_blocks.append(thinking_block) - memory_file.write( - f"\nTHINKING COMPLETE: block {current_block['id']}\n" - ) - - elif current_block["type"] == "redacted_thinking": - # Handle redacted thinking blocks - if hasattr(event, "content_block") and hasattr( - event.content_block, "data" - ): - redacted_block = { - "type": "redacted_thinking", - "data": event.content_block.data, - } - thinking_blocks.append(redacted_block) - - elif current_block["type"] == "tool_use": - # Process tool use blocks when they're complete - if hasattr(event, "content_block"): - tool_block = event.content_block - tool_id = tool_block.id - tool_name = tool_block.name - tool_input = tool_block.input - - # Create a tool call object for LLMAgent compatibility - tool_call_obj = type( - "ToolCall", - (), - { - "id": tool_id, - "function": type( - "Function", - (), - { - "name": tool_name, - "arguments": json.dumps(tool_input), - }, - ), - }, - ) - tool_calls.append(tool_call_obj) - - # Write tool call information to memory.txt - memory_file.write(f"\n\nTOOL CALL: {tool_name}\n") - memory_file.write( - f"ARGUMENTS: {json.dumps(tool_input, indent=2)}\n" - ) - - # Reset current block - current_block = { - "type": None, - "id": None, - "content": "", - "signature": None, - } - memory_file.flush() - - elif ( - event.type == "message_delta" and event.delta.stop_reason == "tool_use" - ): - # When a tool use is detected - logger.info("Tool use stop reason detected in stream") - - # Mark the end of the response in memory.txt - memory_file.write("\n\nRESPONSE COMPLETE\n\n") - memory_file.flush() - - print("\n==== CLAUDE API RESPONSE STREAM COMPLETED ====") - - # Final response - logger.info( - f"Claude streaming complete. Text: {len(text_content)} chars, Tool calls: {len(tool_calls)}, Thinking blocks: {len(thinking_blocks)}" - ) - - # Return the complete response with all components - return ResponseMessage( - content=text_content, - tool_calls=tool_calls if tool_calls else None, - thinking_blocks=thinking_blocks if thinking_blocks else None, - ) - - except ConnectionError as ce: - logger.error(f"Connection error with Anthropic API: {ce}") - raise - except ValueError as ve: - logger.error(f"Invalid parameters for Anthropic API: {ve}") - raise - except Exception as e: - logger.error(f"Unexpected error in Anthropic API call: {e}") - logger.exception(e) # This will print the full traceback - raise - - def _observable_query( - self, - observer: Observer, - base64_image: Optional[str] = None, - dimensions: Optional[Tuple[int, int]] = None, - override_token_limit: bool = False, - incoming_query: Optional[str] = None, - reset_conversation: bool = False, - thinking_budget_tokens: int = None, - ): - """Main query handler that manages conversation history and Claude interactions. - - This is the primary method for handling all queries, whether they come through - direct_query or through the observable pattern. It manages the conversation - history, builds prompts, and handles tool calls. - - Args: - observer (Observer): The observer to emit responses to - base64_image (Optional[str]): Optional Base64-encoded image - dimensions (Optional[Tuple[int, int]]): Optional image dimensions - override_token_limit (bool): Whether to override token limits - incoming_query (Optional[str]): Optional query to update the agent's query - reset_conversation (bool): Whether to reset the conversation history - """ - - try: - logger.info("_observable_query called in claude") - import copy - - # Reset conversation history if requested - if reset_conversation: - self.conversation_history = [] - - # Create a local copy of conversation history and record its length - messages = copy.deepcopy(self.conversation_history) - base_len = len(messages) - - # Update query and get context - self._update_query(incoming_query) - _, rag_results = self._get_rag_context() - - # Build prompt and get Claude parameters - budget = ( - thinking_budget_tokens - if thinking_budget_tokens is not None - else self.thinking_budget_tokens - ) - messages, claude_params = self._build_prompt( - messages, base64_image, dimensions, override_token_limit, rag_results, budget - ) - - # Send query and get response - response_message = self._send_query(messages, claude_params) - - if response_message is None: - logger.error("Received None response from Claude API") - observer.on_next("") - observer.on_completed() - return - # Add thinking blocks and text content to conversation history - content_blocks = [] - if response_message.thinking_blocks: - content_blocks.extend(response_message.thinking_blocks) - if response_message.content: - content_blocks.append({"type": "text", "text": response_message.content}) - if content_blocks: - messages.append({"role": "assistant", "content": content_blocks}) - - # Handle tool calls if present - if response_message.tool_calls: - self._handle_tooling(response_message, messages) - - # At the end, append only new messages (including tool-use/results) to the global conversation history under a lock - import threading - - if not hasattr(self, "_history_lock"): - self._history_lock = threading.Lock() - with self._history_lock: - for msg in messages[base_len:]: - self.conversation_history.append(msg) - - # After merging, run tooling callback (outside lock) - if response_message.tool_calls: - self._tooling_callback(response_message) - - # Send response to observers - result = response_message.content or "" - observer.on_next(result) - self.response_subject.on_next(result) - observer.on_completed() - except Exception as e: - logger.error(f"Query failed in {self.dev_name}: {e}") - # Send a user-friendly error message instead of propagating the error - error_message = "I apologize, but I'm having trouble processing your request right now. Please try again." - observer.on_next(error_message) - self.response_subject.on_next(error_message) - observer.on_completed() - - def _handle_tooling(self, response_message, messages): - """Executes tools and appends tool-use/result blocks to messages.""" - if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: - logger.info("No tool calls found in response message") - return None - - if len(response_message.tool_calls) > 1: - logger.warning( - "Multiple tool calls detected in response message. Not a tested feature." - ) - - # Execute all tools first and collect their results - for tool_call in response_message.tool_calls: - logger.info(f"Processing tool call: {tool_call.function.name}") - tool_use_block = { - "type": "tool_use", - "id": tool_call.id, - "name": tool_call.function.name, - "input": json.loads(tool_call.function.arguments), - } - messages.append({"role": "assistant", "content": [tool_use_block]}) - - try: - # Execute the tool - args = json.loads(tool_call.function.arguments) - tool_result = self.skills.call(tool_call.function.name, **args) - - # Check if the result is an error message - if isinstance(tool_result, str) and ( - "Error executing skill" in tool_result or "is not available" in tool_result - ): - # Log the error but provide a user-friendly message - logger.error(f"Tool execution failed: {tool_result}") - tool_result = "I apologize, but I'm having trouble executing that action right now. Please try again or ask for something else." - - # Add tool result to conversation history - if tool_result: - messages.append( - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tool_call.id, - "content": f"{tool_result}", - } - ], - } - ) - except Exception as e: - logger.error(f"Unexpected error executing tool {tool_call.function.name}: {e}") - # Add error result to conversation history - messages.append( - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tool_call.id, - "content": "I apologize, but I encountered an error while trying to execute that action. Please try again.", - } - ], - } - ) - - def _tooling_callback(self, response_message): - """Runs the observable query for each tool call in the current response_message""" - if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: - return - - try: - for tool_call in response_message.tool_calls: - tool_name = tool_call.function.name - tool_id = tool_call.id - self.run_observable_query( - query_text=f"Tool {tool_name}, ID: {tool_id} execution complete. Please summarize the results and continue.", - thinking_budget_tokens=0, - ).run() - except Exception as e: - logger.error(f"Error in tooling callback: {e}") - # Continue processing even if the callback fails - pass - - def _debug_api_call(self, claude_params: dict): - """Debugging function to log API calls with truncated base64 data.""" - # Remove tools to reduce verbosity - import copy - - log_params = copy.deepcopy(claude_params) - if "tools" in log_params: - del log_params["tools"] - - # Truncate base64 data in images - much cleaner approach - if "messages" in log_params: - for msg in log_params["messages"]: - if "content" in msg: - for content in msg["content"]: - if isinstance(content, dict) and content.get("type") == "image": - source = content.get("source", {}) - if source.get("type") == "base64" and "data" in source: - data = source["data"] - source["data"] = f"{data[:50]}..." - return json.dumps(log_params, indent=2, default=str) diff --git a/build/lib/dimos/agents/memory/__init__.py b/build/lib/dimos/agents/memory/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/agents/memory/base.py b/build/lib/dimos/agents/memory/base.py deleted file mode 100644 index af8cbf689f..0000000000 --- a/build/lib/dimos/agents/memory/base.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import abstractmethod -from dimos.exceptions.agent_memory_exceptions import ( - UnknownConnectionTypeError, - AgentMemoryConnectionError, -) -from dimos.utils.logging_config import setup_logger - -# TODO -# class AbstractAgentMemory(ABC): - -# TODO -# class AbstractAgentSymbolicMemory(AbstractAgentMemory): - - -class AbstractAgentSemanticMemory: # AbstractAgentMemory): - def __init__(self, connection_type="local", **kwargs): - """ - Initialize with dynamic connection parameters. - Args: - connection_type (str): 'local' for a local database, 'remote' for a remote connection. - Raises: - UnknownConnectionTypeError: If an unrecognized connection type is specified. - AgentMemoryConnectionError: If initializing the database connection fails. - """ - self.logger = setup_logger(self.__class__.__name__) - self.logger.info("Initializing AgentMemory with connection type: %s", connection_type) - self.connection_params = kwargs - self.db_connection = ( - None # Holds the conection, whether local or remote, to the database used. - ) - - if connection_type not in ["local", "remote"]: - error = UnknownConnectionTypeError( - f"Invalid connection_type {connection_type}. Expected 'local' or 'remote'." - ) - self.logger.error(str(error)) - raise error - - try: - if connection_type == "remote": - self.connect() - elif connection_type == "local": - self.create() - except Exception as e: - self.logger.error("Failed to initialize database connection: %s", str(e), exc_info=True) - raise AgentMemoryConnectionError( - "Initialization failed due to an unexpected error.", cause=e - ) from e - - @abstractmethod - def connect(self): - """Establish a connection to the data store using dynamic parameters specified during initialization.""" - - @abstractmethod - def create(self): - """Create a local instance of the data store tailored to specific requirements.""" - - ## Create ## - @abstractmethod - def add_vector(self, vector_id, vector_data): - """Add a vector to the database. - Args: - vector_id (any): Unique identifier for the vector. - vector_data (any): The actual data of the vector to be stored. - """ - - ## Read ## - @abstractmethod - def get_vector(self, vector_id): - """Retrieve a vector from the database by its identifier. - Args: - vector_id (any): The identifier of the vector to retrieve. - """ - - @abstractmethod - def query(self, query_texts, n_results=4, similarity_threshold=None): - """Performs a semantic search in the vector database. - - Args: - query_texts (Union[str, List[str]]): The query text or list of query texts to search for. - n_results (int, optional): Number of results to return. Defaults to 4. - similarity_threshold (float, optional): Minimum similarity score for results to be included [0.0, 1.0]. Defaults to None. - - Returns: - List[Tuple[Document, Optional[float]]]: A list of tuples containing the search results. Each tuple - contains: - Document: The retrieved document object. - Optional[float]: The similarity score of the match, or None if not applicable. - - Raises: - ValueError: If query_texts is empty or invalid. - ConnectionError: If database connection fails during query. - """ - - ## Update ## - @abstractmethod - def update_vector(self, vector_id, new_vector_data): - """Update an existing vector in the database. - Args: - vector_id (any): The identifier of the vector to update. - new_vector_data (any): The new data to replace the existing vector data. - """ - - ## Delete ## - @abstractmethod - def delete_vector(self, vector_id): - """Delete a vector from the database using its identifier. - Args: - vector_id (any): The identifier of the vector to delete. - """ - - -# query(string, metadata/tag, n_rets, kwargs) - -# query by string, timestamp, id, n_rets - -# (some sort of tag/metadata) - -# temporal diff --git a/build/lib/dimos/agents/memory/chroma_impl.py b/build/lib/dimos/agents/memory/chroma_impl.py deleted file mode 100644 index 06f6989355..0000000000 --- a/build/lib/dimos/agents/memory/chroma_impl.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.agents.memory.base import AbstractAgentSemanticMemory - -from langchain_openai import OpenAIEmbeddings -from langchain_chroma import Chroma -import os -import torch - - -class ChromaAgentSemanticMemory(AbstractAgentSemanticMemory): - """Base class for Chroma-based semantic memory implementations.""" - - def __init__(self, collection_name="my_collection"): - """Initialize the connection to the local Chroma DB.""" - self.collection_name = collection_name - self.db_connection = None - self.embeddings = None - super().__init__(connection_type="local") - - def connect(self): - # Stub - return super().connect() - - def create(self): - """Create the embedding function and initialize the Chroma database. - This method must be implemented by child classes.""" - raise NotImplementedError("Child classes must implement this method") - - def add_vector(self, vector_id, vector_data): - """Add a vector to the ChromaDB collection.""" - if not self.db_connection: - raise Exception("Collection not initialized. Call connect() first.") - self.db_connection.add_texts( - ids=[vector_id], - texts=[vector_data], - metadatas=[{"name": vector_id}], - ) - - def get_vector(self, vector_id): - """Retrieve a vector from the ChromaDB by its identifier.""" - result = self.db_connection.get(include=["embeddings"], ids=[vector_id]) - return result - - def query(self, query_texts, n_results=4, similarity_threshold=None): - """Query the collection with a specific text and return up to n results.""" - if not self.db_connection: - raise Exception("Collection not initialized. Call connect() first.") - - if similarity_threshold is not None: - if not (0 <= similarity_threshold <= 1): - raise ValueError("similarity_threshold must be between 0 and 1.") - return self.db_connection.similarity_search_with_relevance_scores( - query=query_texts, k=n_results, score_threshold=similarity_threshold - ) - else: - documents = self.db_connection.similarity_search(query=query_texts, k=n_results) - return [(doc, None) for doc in documents] - - def update_vector(self, vector_id, new_vector_data): - # TODO - return super().connect() - - def delete_vector(self, vector_id): - """Delete a vector from the ChromaDB using its identifier.""" - if not self.db_connection: - raise Exception("Collection not initialized. Call connect() first.") - self.db_connection.delete(ids=[vector_id]) - - -class OpenAISemanticMemory(ChromaAgentSemanticMemory): - """Semantic memory implementation using OpenAI's embedding API.""" - - def __init__( - self, collection_name="my_collection", model="text-embedding-3-large", dimensions=1024 - ): - """Initialize OpenAI-based semantic memory. - - Args: - collection_name (str): Name of the Chroma collection - model (str): OpenAI embedding model to use - dimensions (int): Dimension of the embedding vectors - """ - self.model = model - self.dimensions = dimensions - super().__init__(collection_name=collection_name) - - def create(self): - """Connect to OpenAI API and create the ChromaDB client.""" - # Get OpenAI key - self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") - if not self.OPENAI_API_KEY: - raise Exception("OpenAI key was not specified.") - - # Set embeddings - self.embeddings = OpenAIEmbeddings( - model=self.model, - dimensions=self.dimensions, - api_key=self.OPENAI_API_KEY, - ) - - # Create the database - self.db_connection = Chroma( - collection_name=self.collection_name, - embedding_function=self.embeddings, - collection_metadata={"hnsw:space": "cosine"}, - ) - - -class LocalSemanticMemory(ChromaAgentSemanticMemory): - """Semantic memory implementation using local models.""" - - def __init__( - self, collection_name="my_collection", model_name="sentence-transformers/all-MiniLM-L6-v2" - ): - """Initialize the local semantic memory using SentenceTransformer. - - Args: - collection_name (str): Name of the Chroma collection - model_name (str): Embeddings model - """ - - self.model_name = model_name - super().__init__(collection_name=collection_name) - - def create(self): - """Create local embedding model and initialize the ChromaDB client.""" - # Load the sentence transformer model - # Use CUDA if available, otherwise fall back to CPU - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Using device: {device}") - self.model = SentenceTransformer(self.model_name, device=device) - - # Create a custom embedding class that implements the embed_query method - class SentenceTransformerEmbeddings: - def __init__(self, model): - self.model = model - - def embed_query(self, text): - """Embed a single query text.""" - return self.model.encode(text, normalize_embeddings=True).tolist() - - def embed_documents(self, texts): - """Embed multiple documents/texts.""" - return self.model.encode(texts, normalize_embeddings=True).tolist() - - # Create an instance of our custom embeddings class - self.embeddings = SentenceTransformerEmbeddings(self.model) - - # Create the database - self.db_connection = Chroma( - collection_name=self.collection_name, - embedding_function=self.embeddings, - collection_metadata={"hnsw:space": "cosine"}, - ) diff --git a/build/lib/dimos/agents/memory/image_embedding.py b/build/lib/dimos/agents/memory/image_embedding.py deleted file mode 100644 index 1ad0e9132d..0000000000 --- a/build/lib/dimos/agents/memory/image_embedding.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Image embedding module for converting images to vector embeddings. - -This module provides a class for generating vector embeddings from images -using pre-trained models like CLIP, ResNet, etc. -""" - -import base64 -import io -import os -from typing import Union - -import cv2 -import numpy as np -from PIL import Image - -from dimos.utils.data import get_data -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.agents.memory.image_embedding") - - -class ImageEmbeddingProvider: - """ - A provider for generating vector embeddings from images. - - This class uses pre-trained models to convert images into vector embeddings - that can be stored in a vector database and used for similarity search. - """ - - def __init__(self, model_name: str = "clip", dimensions: int = 512): - """ - Initialize the image embedding provider. - - Args: - model_name: Name of the embedding model to use ("clip", "resnet", etc.) - dimensions: Dimensions of the embedding vectors - """ - self.model_name = model_name - self.dimensions = dimensions - self.model = None - self.processor = None - - self._initialize_model() - - logger.info(f"ImageEmbeddingProvider initialized with model {model_name}") - - def _initialize_model(self): - """Initialize the specified embedding model.""" - try: - import onnxruntime as ort - import torch - from transformers import AutoFeatureExtractor, AutoModel, CLIPProcessor - - if self.model_name == "clip": - model_id = get_data("models_clip") / "model.onnx" - processor_id = "openai/clip-vit-base-patch32" - self.model = ort.InferenceSession(model_id) - self.processor = CLIPProcessor.from_pretrained(processor_id) - logger.info(f"Loaded CLIP model: {model_id}") - elif self.model_name == "resnet": - model_id = "microsoft/resnet-50" - self.model = AutoModel.from_pretrained(model_id) - self.processor = AutoFeatureExtractor.from_pretrained(model_id) - logger.info(f"Loaded ResNet model: {model_id}") - else: - raise ValueError(f"Unsupported model: {self.model_name}") - except ImportError as e: - logger.error(f"Failed to import required modules: {e}") - logger.error("Please install with: pip install transformers torch") - # Initialize with dummy model for type checking - self.model = None - self.processor = None - raise - - def get_embedding(self, image: Union[np.ndarray, str, bytes]) -> np.ndarray: - """ - Generate an embedding vector for the provided image. - - Args: - image: The image to embed, can be a numpy array (OpenCV format), - a file path, or a base64-encoded string - - Returns: - A numpy array containing the embedding vector - """ - if self.model is None or self.processor is None: - logger.error("Model not initialized. Using fallback random embedding.") - return np.random.randn(self.dimensions).astype(np.float32) - - pil_image = self._prepare_image(image) - - try: - import torch - - if self.model_name == "clip": - inputs = self.processor(images=pil_image, return_tensors="np") - - with torch.no_grad(): - ort_inputs = { - inp.name: inputs[inp.name] - for inp in self.model.get_inputs() - if inp.name in inputs - } - - # If required, add dummy text inputs - input_names = [i.name for i in self.model.get_inputs()] - batch_size = inputs["pixel_values"].shape[0] - if "input_ids" in input_names: - ort_inputs["input_ids"] = np.zeros((batch_size, 1), dtype=np.int64) - if "attention_mask" in input_names: - ort_inputs["attention_mask"] = np.ones((batch_size, 1), dtype=np.int64) - - # Run inference - ort_outputs = self.model.run(None, ort_inputs) - - # Look up correct output name - output_names = [o.name for o in self.model.get_outputs()] - if "image_embeds" in output_names: - image_embedding = ort_outputs[output_names.index("image_embeds")] - else: - raise RuntimeError(f"No 'image_embeds' found in outputs: {output_names}") - - embedding = image_embedding / np.linalg.norm(image_embedding, axis=1, keepdims=True) - embedding = embedding[0] - - elif self.model_name == "resnet": - inputs = self.processor(images=pil_image, return_tensors="pt") - - with torch.no_grad(): - outputs = self.model(**inputs) - - # Get the [CLS] token embedding - embedding = outputs.last_hidden_state[:, 0, :].numpy()[0] - else: - logger.warning(f"Unsupported model: {self.model_name}. Using random embedding.") - embedding = np.random.randn(self.dimensions).astype(np.float32) - - # Normalize and ensure correct dimensions - embedding = embedding / np.linalg.norm(embedding) - - logger.debug(f"Generated embedding with shape {embedding.shape}") - return embedding - - except Exception as e: - logger.error(f"Error generating embedding: {e}") - return np.random.randn(self.dimensions).astype(np.float32) - - def get_text_embedding(self, text: str) -> np.ndarray: - """ - Generate an embedding vector for the provided text. - - Args: - text: The text to embed - - Returns: - A numpy array containing the embedding vector - """ - if self.model is None or self.processor is None: - logger.error("Model not initialized. Using fallback random embedding.") - return np.random.randn(self.dimensions).astype(np.float32) - - if self.model_name != "clip": - logger.warning( - f"Text embeddings are only supported with CLIP model, not {self.model_name}. Using random embedding." - ) - return np.random.randn(self.dimensions).astype(np.float32) - - try: - import torch - - inputs = self.processor(text=[text], return_tensors="np", padding=True) - - with torch.no_grad(): - # Prepare ONNX input dict (handle only what's needed) - ort_inputs = { - inp.name: inputs[inp.name] - for inp in self.model.get_inputs() - if inp.name in inputs - } - # Determine which inputs are expected by the ONNX model - input_names = [i.name for i in self.model.get_inputs()] - batch_size = inputs["input_ids"].shape[0] # pulled from text input - - # If the model expects pixel_values (i.e., fused model), add dummy vision input - if "pixel_values" in input_names: - ort_inputs["pixel_values"] = np.zeros( - (batch_size, 3, 224, 224), dtype=np.float32 - ) - - # Run inference - ort_outputs = self.model.run(None, ort_inputs) - - # Determine correct output (usually 'last_hidden_state' or 'text_embeds') - output_names = [o.name for o in self.model.get_outputs()] - if "text_embeds" in output_names: - text_embedding = ort_outputs[output_names.index("text_embeds")] - else: - text_embedding = ort_outputs[0] # fallback to first output - - # Normalize - text_embedding = text_embedding / np.linalg.norm( - text_embedding, axis=1, keepdims=True - ) - text_embedding = text_embedding[0] # shape: (512,) - - logger.debug( - f"Generated text embedding with shape {text_embedding.shape} for text: '{text}'" - ) - return text_embedding - - except Exception as e: - logger.error(f"Error generating text embedding: {e}") - return np.random.randn(self.dimensions).astype(np.float32) - - def _prepare_image(self, image: Union[np.ndarray, str, bytes]) -> Image.Image: - """ - Convert the input image to PIL format required by the models. - - Args: - image: Input image in various formats - - Returns: - PIL Image object - """ - if isinstance(image, np.ndarray): - if len(image.shape) == 3 and image.shape[2] == 3: - image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - else: - image_rgb = image - - return Image.fromarray(image_rgb) - - elif isinstance(image, str): - if os.path.isfile(image): - return Image.open(image) - else: - try: - image_data = base64.b64decode(image) - return Image.open(io.BytesIO(image_data)) - except Exception as e: - logger.error(f"Failed to decode image string: {e}") - raise ValueError("Invalid image string format") - - elif isinstance(image, bytes): - return Image.open(io.BytesIO(image)) - - else: - raise ValueError(f"Unsupported image format: {type(image)}") diff --git a/build/lib/dimos/agents/memory/spatial_vector_db.py b/build/lib/dimos/agents/memory/spatial_vector_db.py deleted file mode 100644 index cf44d0c589..0000000000 --- a/build/lib/dimos/agents/memory/spatial_vector_db.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Spatial vector database for storing and querying images with XY locations. - -This module extends the ChromaDB implementation to support storing images with -their XY locations and querying by location or image similarity. -""" - -import numpy as np -from typing import List, Dict, Tuple, Any -import chromadb - -from dimos.agents.memory.visual_memory import VisualMemory -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.agents.memory.spatial_vector_db") - - -class SpatialVectorDB: - """ - A vector database for storing and querying images mapped to X,Y,theta absolute locations for SpatialMemory. - - This class extends the ChromaDB implementation to support storing images with - their absolute locations and querying by location, text, or image cosine semantic similarity. - """ - - def __init__( - self, collection_name: str = "spatial_memory", chroma_client=None, visual_memory=None - ): - """ - Initialize the spatial vector database. - - Args: - collection_name: Name of the vector database collection - chroma_client: Optional ChromaDB client for persistence. If None, an in-memory client is used. - visual_memory: Optional VisualMemory instance for storing images. If None, a new one is created. - """ - self.collection_name = collection_name - - # Use provided client or create in-memory client - self.client = chroma_client if chroma_client is not None else chromadb.Client() - - # Check if collection already exists - in newer ChromaDB versions list_collections returns names directly - existing_collections = self.client.list_collections() - - # Handle different versions of ChromaDB API - try: - collection_exists = collection_name in existing_collections - except: - try: - collection_exists = collection_name in [c.name for c in existing_collections] - except: - try: - self.client.get_collection(name=collection_name) - collection_exists = True - except Exception: - collection_exists = False - - # Get or create the collection - self.image_collection = self.client.get_or_create_collection( - name=collection_name, metadata={"hnsw:space": "cosine"} - ) - - # Use provided visual memory or create a new one - self.visual_memory = visual_memory if visual_memory is not None else VisualMemory() - - # Log initialization info with details about whether using existing collection - client_type = "persistent" if chroma_client is not None else "in-memory" - try: - count = len(self.image_collection.get(include=[])["ids"]) - if collection_exists: - logger.info( - f"Using EXISTING {client_type} collection '{collection_name}' with {count} entries" - ) - else: - logger.info(f"Created NEW {client_type} collection '{collection_name}'") - except Exception as e: - logger.info( - f"Initialized {client_type} collection '{collection_name}' (count error: {str(e)})" - ) - - def add_image_vector( - self, vector_id: str, image: np.ndarray, embedding: np.ndarray, metadata: Dict[str, Any] - ) -> None: - """ - Add an image with its embedding and metadata to the vector database. - - Args: - vector_id: Unique identifier for the vector - image: The image to store - embedding: The pre-computed embedding vector for the image - metadata: Metadata for the image, including x, y coordinates - """ - # Store the image in visual memory - self.visual_memory.add(vector_id, image) - - # Add the vector to ChromaDB - self.image_collection.add( - ids=[vector_id], embeddings=[embedding.tolist()], metadatas=[metadata] - ) - - logger.debug(f"Added image vector {vector_id} with metadata: {metadata}") - - def query_by_embedding(self, embedding: np.ndarray, limit: int = 5) -> List[Dict]: - """ - Query the vector database for images similar to the provided embedding. - - Args: - embedding: Query embedding vector - limit: Maximum number of results to return - - Returns: - List of results, each containing the image and its metadata - """ - results = self.image_collection.query( - query_embeddings=[embedding.tolist()], n_results=limit - ) - - return self._process_query_results(results) - - # TODO: implement efficient nearest neighbor search - def query_by_location( - self, x: float, y: float, radius: float = 2.0, limit: int = 5 - ) -> List[Dict]: - """ - Query the vector database for images near the specified location. - - Args: - x: X coordinate - y: Y coordinate - radius: Search radius in meters - limit: Maximum number of results to return - - Returns: - List of results, each containing the image and its metadata - """ - results = self.image_collection.get() - - if not results or not results["ids"]: - return [] - - filtered_results = {"ids": [], "metadatas": [], "distances": []} - - for i, metadata in enumerate(results["metadatas"]): - item_x = metadata.get("x") - item_y = metadata.get("y") - - if item_x is not None and item_y is not None: - distance = np.sqrt((x - item_x) ** 2 + (y - item_y) ** 2) - - if distance <= radius: - filtered_results["ids"].append(results["ids"][i]) - filtered_results["metadatas"].append(metadata) - filtered_results["distances"].append(distance) - - sorted_indices = np.argsort(filtered_results["distances"]) - filtered_results["ids"] = [filtered_results["ids"][i] for i in sorted_indices[:limit]] - filtered_results["metadatas"] = [ - filtered_results["metadatas"][i] for i in sorted_indices[:limit] - ] - filtered_results["distances"] = [ - filtered_results["distances"][i] for i in sorted_indices[:limit] - ] - - return self._process_query_results(filtered_results) - - def _process_query_results(self, results) -> List[Dict]: - """Process query results to include decoded images.""" - if not results or not results["ids"]: - return [] - - processed_results = [] - - for i, vector_id in enumerate(results["ids"]): - lookup_id = vector_id[0] if isinstance(vector_id, list) else vector_id - - # Create the result dictionary with metadata regardless of image availability - result = { - "metadata": results["metadatas"][i] if "metadatas" in results else {}, - "id": lookup_id, - } - - # Add distance if available - if "distances" in results: - result["distance"] = ( - results["distances"][i][0] - if isinstance(results["distances"][i], list) - else results["distances"][i] - ) - - # Get the image from visual memory - image = self.visual_memory.get(lookup_id) - result["image"] = image - - processed_results.append(result) - - return processed_results - - def query_by_text(self, text: str, limit: int = 5) -> List[Dict]: - """ - Query the vector database for images matching the provided text description. - - This method uses CLIP's text-to-image matching capability to find images - that semantically match the text query (e.g., "where is the kitchen"). - - Args: - text: Text query to search for - limit: Maximum number of results to return - - Returns: - List of results, each containing the image, its metadata, and similarity score - """ - from dimos.agents.memory.image_embedding import ImageEmbeddingProvider - - embedding_provider = ImageEmbeddingProvider(model_name="clip") - - text_embedding = embedding_provider.get_text_embedding(text) - - results = self.image_collection.query( - query_embeddings=[text_embedding.tolist()], - n_results=limit, - include=["documents", "metadatas", "distances"], - ) - - logger.info( - f"Text query: '{text}' returned {len(results['ids'] if 'ids' in results else [])} results" - ) - return self._process_query_results(results) - - def get_all_locations(self) -> List[Tuple[float, float, float]]: - """Get all locations stored in the database.""" - # Get all items from the collection without embeddings - results = self.image_collection.get(include=["metadatas"]) - - if not results or "metadatas" not in results or not results["metadatas"]: - return [] - - # Extract x, y coordinates from metadata - locations = [] - for metadata in results["metadatas"]: - if isinstance(metadata, list) and metadata and isinstance(metadata[0], dict): - metadata = metadata[0] # Handle nested metadata - - if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: - x = metadata.get("x", 0) - y = metadata.get("y", 0) - z = metadata.get("z", 0) if "z" in metadata else 0 - locations.append((x, y, z)) - - return locations - - @property - def image_storage(self): - """Legacy accessor for compatibility with existing code.""" - return self.visual_memory.images diff --git a/build/lib/dimos/agents/memory/test_image_embedding.py b/build/lib/dimos/agents/memory/test_image_embedding.py deleted file mode 100644 index c424b950bb..0000000000 --- a/build/lib/dimos/agents/memory/test_image_embedding.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test module for the CLIP image embedding functionality in dimos. -""" - -import os -import time - -import numpy as np -import pytest -import reactivex as rx -from reactivex import operators as ops - -from dimos.agents.memory.image_embedding import ImageEmbeddingProvider -from dimos.stream.video_provider import VideoProvider - - -@pytest.mark.heavy -class TestImageEmbedding: - """Test class for CLIP image embedding functionality.""" - - def test_clip_embedding_initialization(self): - """Test CLIP embedding provider initializes correctly.""" - try: - # Initialize the embedding provider with CLIP model - embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) - assert embedding_provider.model is not None, "CLIP model failed to initialize" - assert embedding_provider.processor is not None, "CLIP processor failed to initialize" - assert embedding_provider.model_name == "clip", "Model name should be 'clip'" - assert embedding_provider.dimensions == 512, "Embedding dimensions should be 512" - except Exception as e: - pytest.skip(f"Skipping test due to model initialization error: {e}") - - def test_clip_embedding_process_video(self): - """Test CLIP embedding provider can process video frames and return embeddings.""" - try: - from dimos.utils.data import get_data - - video_path = get_data("assets") / "trimmed_video_office.mov" - - embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) - - assert os.path.exists(video_path), f"Test video not found: {video_path}" - video_provider = VideoProvider(dev_name="test_video", video_source=video_path) - - video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) - - # Use ReactiveX operators to process the stream - def process_frame(frame): - try: - # Process frame with CLIP - embedding = embedding_provider.get_embedding(frame) - print( - f"Generated CLIP embedding with shape: {embedding.shape}, norm: {np.linalg.norm(embedding):.4f}" - ) - - return {"frame": frame, "embedding": embedding} - except Exception as e: - print(f"Error in process_frame: {e}") - return None - - embedding_stream = video_stream.pipe(ops.map(process_frame)) - - results = [] - frames_processed = 0 - target_frames = 10 - - def on_next(result): - nonlocal frames_processed, results - if not result: # Skip None results - return - - results.append(result) - frames_processed += 1 - - # Stop processing after target frames - if frames_processed >= target_frames: - subscription.dispose() - - def on_error(error): - pytest.fail(f"Error in embedding stream: {error}") - - def on_completed(): - pass - - # Subscribe and wait for results - subscription = embedding_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - timeout = 60.0 - start_time = time.time() - while frames_processed < target_frames and time.time() - start_time < timeout: - time.sleep(0.5) - print(f"Processed {frames_processed}/{target_frames} frames") - - # Clean up subscription - subscription.dispose() - video_provider.dispose_all() - - # Check if we have results - if len(results) == 0: - pytest.skip("No embeddings generated, but test connection established correctly") - return - - print(f"Processed {len(results)} frames with CLIP embeddings") - - # Analyze the results - assert len(results) > 0, "No embeddings generated" - - # Check properties of first embedding - first_result = results[0] - assert "embedding" in first_result, "Result doesn't contain embedding" - assert "frame" in first_result, "Result doesn't contain frame" - - # Check embedding shape and normalization - embedding = first_result["embedding"] - assert isinstance(embedding, np.ndarray), "Embedding is not a numpy array" - assert embedding.shape == (512,), ( - f"Embedding has wrong shape: {embedding.shape}, expected (512,)" - ) - assert abs(np.linalg.norm(embedding) - 1.0) < 1e-5, "Embedding is not normalized" - - # Save the first embedding for similarity tests - if len(results) > 1 and "embedding" in results[0]: - # Create a class variable to store embeddings for the similarity test - TestImageEmbedding.test_embeddings = { - "embedding1": results[0]["embedding"], - "embedding2": results[1]["embedding"] if len(results) > 1 else None, - } - print(f"Saved embeddings for similarity testing") - - print("CLIP embedding test passed successfully!") - - except Exception as e: - pytest.fail(f"Test failed with error: {e}") - - def test_clip_embedding_similarity(self): - """Test CLIP embedding similarity search and text-to-image queries.""" - try: - # Skip if previous test didn't generate embeddings - if not hasattr(TestImageEmbedding, "test_embeddings"): - pytest.skip("No embeddings available from previous test") - return - - # Get embeddings from previous test - embedding1 = TestImageEmbedding.test_embeddings["embedding1"] - embedding2 = TestImageEmbedding.test_embeddings["embedding2"] - - # Initialize embedding provider for text embeddings - embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) - - # Test frame-to-frame similarity - if embedding1 is not None and embedding2 is not None: - # Compute cosine similarity - similarity = np.dot(embedding1, embedding2) - print(f"Similarity between first two frames: {similarity:.4f}") - - # Should be in range [-1, 1] - assert -1.0 <= similarity <= 1.0, f"Similarity out of valid range: {similarity}" - - # Test text-to-image similarity - if embedding1 is not None: - # Generate a list of text queries to test - text_queries = ["a video frame", "a person", "an outdoor scene", "a kitchen"] - - # Test each text query - for text_query in text_queries: - # Get text embedding - text_embedding = embedding_provider.get_text_embedding(text_query) - - # Check text embedding properties - assert isinstance(text_embedding, np.ndarray), ( - "Text embedding is not a numpy array" - ) - assert text_embedding.shape == (512,), ( - f"Text embedding has wrong shape: {text_embedding.shape}" - ) - assert abs(np.linalg.norm(text_embedding) - 1.0) < 1e-5, ( - "Text embedding is not normalized" - ) - - # Compute similarity between frame and text - text_similarity = np.dot(embedding1, text_embedding) - print(f"Similarity between frame and '{text_query}': {text_similarity:.4f}") - - # Should be in range [-1, 1] - assert -1.0 <= text_similarity <= 1.0, ( - f"Text-image similarity out of range: {text_similarity}" - ) - - print("CLIP embedding similarity tests passed successfully!") - - except Exception as e: - pytest.fail(f"Similarity test failed with error: {e}") - - -if __name__ == "__main__": - pytest.main(["-v", "--disable-warnings", __file__]) diff --git a/build/lib/dimos/agents/memory/visual_memory.py b/build/lib/dimos/agents/memory/visual_memory.py deleted file mode 100644 index 0087a4fe9b..0000000000 --- a/build/lib/dimos/agents/memory/visual_memory.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Visual memory storage for managing image data persistence and retrieval -""" - -import os -import pickle -import base64 -import numpy as np -import cv2 - -from typing import Optional -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.agents.memory.visual_memory") - - -class VisualMemory: - """ - A class for storing and retrieving visual memories (images) with persistence. - - This class handles the storage, encoding, and retrieval of images associated - with vector database entries. It provides persistence mechanisms to save and - load the image data from disk. - """ - - def __init__(self, output_dir: str = None): - """ - Initialize the visual memory system. - - Args: - output_dir: Directory to store the serialized image data - """ - self.images = {} # Maps IDs to encoded images - self.output_dir = output_dir - - if output_dir: - os.makedirs(output_dir, exist_ok=True) - logger.info(f"VisualMemory initialized with output directory: {output_dir}") - else: - logger.info("VisualMemory initialized with no persistence directory") - - def add(self, image_id: str, image: np.ndarray) -> None: - """ - Add an image to visual memory. - - Args: - image_id: Unique identifier for the image - image: The image data as a numpy array - """ - # Encode the image to base64 for storage - success, encoded_image = cv2.imencode(".jpg", image) - if not success: - logger.error(f"Failed to encode image {image_id}") - return - - image_bytes = encoded_image.tobytes() - b64_encoded = base64.b64encode(image_bytes).decode("utf-8") - - # Store the encoded image - self.images[image_id] = b64_encoded - logger.debug(f"Added image {image_id} to visual memory") - - def get(self, image_id: str) -> Optional[np.ndarray]: - """ - Retrieve an image from visual memory. - - Args: - image_id: Unique identifier for the image - - Returns: - The decoded image as a numpy array, or None if not found - """ - if image_id not in self.images: - logger.warning( - f"Image not found in storage for ID {image_id}. Incomplete or corrupted image storage." - ) - return None - - try: - encoded_image = self.images[image_id] - image_bytes = base64.b64decode(encoded_image) - image_array = np.frombuffer(image_bytes, dtype=np.uint8) - image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) - return image - except Exception as e: - logger.warning(f"Failed to decode image for ID {image_id}: {str(e)}") - return None - - def contains(self, image_id: str) -> bool: - """ - Check if an image ID exists in visual memory. - - Args: - image_id: Unique identifier for the image - - Returns: - True if the image exists, False otherwise - """ - return image_id in self.images - - def count(self) -> int: - """ - Get the number of images in visual memory. - - Returns: - The number of images stored - """ - return len(self.images) - - def save(self, filename: Optional[str] = None) -> str: - """ - Save the visual memory to disk. - - Args: - filename: Optional filename to save to. If None, uses a default name in the output directory. - - Returns: - The path where the data was saved - """ - if not self.output_dir: - logger.warning("No output directory specified for VisualMemory. Cannot save.") - return "" - - if not filename: - filename = "visual_memory.pkl" - - output_path = os.path.join(self.output_dir, filename) - - try: - with open(output_path, "wb") as f: - pickle.dump(self.images, f) - logger.info(f"Saved {len(self.images)} images to {output_path}") - return output_path - except Exception as e: - logger.error(f"Failed to save visual memory: {str(e)}") - return "" - - @classmethod - def load(cls, path: str, output_dir: Optional[str] = None) -> "VisualMemory": - """ - Load visual memory from disk. - - Args: - path: Path to the saved visual memory file - output_dir: Optional output directory for the new instance - - Returns: - A new VisualMemory instance with the loaded data - """ - instance = cls(output_dir=output_dir) - - if not os.path.exists(path): - logger.warning(f"Visual memory file {path} not found") - return instance - - try: - with open(path, "rb") as f: - instance.images = pickle.load(f) - logger.info(f"Loaded {len(instance.images)} images from {path}") - return instance - except Exception as e: - logger.error(f"Failed to load visual memory: {str(e)}") - return instance - - def clear(self) -> None: - """Clear all images from memory.""" - self.images = {} - logger.info("Visual memory cleared") diff --git a/build/lib/dimos/agents/planning_agent.py b/build/lib/dimos/agents/planning_agent.py deleted file mode 100644 index 52971e770a..0000000000 --- a/build/lib/dimos/agents/planning_agent.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -from typing import List, Optional, Literal -from reactivex import Observable -from reactivex import operators as ops -import time -from dimos.skills.skills import AbstractSkill -from dimos.agents.agent import OpenAIAgent -from dimos.utils.logging_config import setup_logger -from textwrap import dedent -from pydantic import BaseModel - -logger = setup_logger("dimos.agents.planning_agent") - - -# For response validation -class PlanningAgentResponse(BaseModel): - type: Literal["dialogue", "plan"] - content: List[str] - needs_confirmation: bool - - -class PlanningAgent(OpenAIAgent): - """Agent that plans and breaks down tasks through dialogue. - - This agent specializes in: - 1. Understanding complex tasks through dialogue - 2. Breaking tasks into concrete, executable steps - 3. Refining plans based on user feedback - 4. Streaming individual steps to ExecutionAgents - - The agent maintains conversation state and can refine plans until - the user confirms they are ready to execute. - """ - - def __init__( - self, - dev_name: str = "PlanningAgent", - model_name: str = "gpt-4", - input_query_stream: Optional[Observable] = None, - use_terminal: bool = False, - skills: Optional[AbstractSkill] = None, - ): - """Initialize the planning agent. - - Args: - dev_name: Name identifier for the agent - model_name: OpenAI model to use - input_query_stream: Observable stream of user queries - use_terminal: Whether to enable terminal input - skills: Available skills/functions for the agent - """ - # Planning state - self.conversation_history = [] - self.current_plan = [] - self.plan_confirmed = False - self.latest_response = None - - # Build system prompt - skills_list = [] - if skills is not None: - skills_list = skills.get_tools() - - system_query = dedent(f""" - You are a Robot planning assistant that helps break down tasks into concrete, executable steps. - Your goal is to: - 1. Break down the task into clear, sequential steps - 2. Refine the plan based on user feedback as needed - 3. Only finalize the plan when the user explicitly confirms - - You have the following skills at your disposal: - {skills_list} - - IMPORTANT: You MUST ALWAYS respond with ONLY valid JSON in the following format, with no additional text or explanation: - {{ - "type": "dialogue" | "plan", - "content": string | list[string], - "needs_confirmation": boolean - }} - - Your goal is to: - 1. Understand the user's task through dialogue - 2. Break it down into clear, sequential steps - 3. Refine the plan based on user feedback - 4. Only finalize the plan when the user explicitly confirms - - For dialogue responses, use: - {{ - "type": "dialogue", - "content": "Your message to the user", - "needs_confirmation": false - }} - - For plan proposals, use: - {{ - "type": "plan", - "content": ["Execute", "Execute", ...], - "needs_confirmation": true - }} - - Remember: ONLY output valid JSON, no other text.""") - - # Initialize OpenAIAgent with our configuration - super().__init__( - dev_name=dev_name, - agent_type="Planning", - query="", # Will be set by process_user_input - model_name=model_name, - input_query_stream=input_query_stream, - system_query=system_query, - max_output_tokens_per_request=1000, - response_model=PlanningAgentResponse, - ) - logger.info("Planning agent initialized") - - # Set up terminal mode if requested - self.use_terminal = use_terminal - use_terminal = False - if use_terminal: - # Start terminal interface in a separate thread - logger.info("Starting terminal interface in a separate thread") - terminal_thread = threading.Thread(target=self.start_terminal_interface, daemon=True) - terminal_thread.start() - - def _handle_response(self, response) -> None: - """Handle the agent's response and update state. - - Args: - response: ParsedChatCompletionMessage containing PlanningAgentResponse - """ - print("handle response", response) - print("handle response type", type(response)) - - # Extract the PlanningAgentResponse from parsed field if available - planning_response = response.parsed if hasattr(response, "parsed") else response - print("planning response", planning_response) - print("planning response type", type(planning_response)) - # Convert to dict for storage in conversation history - response_dict = planning_response.model_dump() - self.conversation_history.append(response_dict) - - # If it's a plan, update current plan - if planning_response.type == "plan": - logger.info(f"Updating current plan: {planning_response.content}") - self.current_plan = planning_response.content - - # Store latest response - self.latest_response = response_dict - - def _stream_plan(self) -> None: - """Stream each step of the confirmed plan.""" - logger.info("Starting to stream plan steps") - logger.debug(f"Current plan: {self.current_plan}") - - for i, step in enumerate(self.current_plan, 1): - logger.info(f"Streaming step {i}: {step}") - # Add a small delay between steps to ensure they're processed - time.sleep(0.5) - try: - self.response_subject.on_next(str(step)) - logger.debug(f"Successfully emitted step {i} to response_subject") - except Exception as e: - logger.error(f"Error emitting step {i}: {e}") - - logger.info("Plan streaming completed") - self.response_subject.on_completed() - - def _send_query(self, messages: list) -> PlanningAgentResponse: - """Send query to OpenAI and parse the response. - - Extends OpenAIAgent's _send_query to handle planning-specific response formats. - - Args: - messages: List of message dictionaries - - Returns: - PlanningAgentResponse: Validated response with type, content, and needs_confirmation - """ - try: - return super()._send_query(messages) - except Exception as e: - logger.error(f"Caught exception in _send_query: {str(e)}") - return PlanningAgentResponse( - type="dialogue", content=f"Error: {str(e)}", needs_confirmation=False - ) - - def process_user_input(self, user_input: str) -> None: - """Process user input and generate appropriate response. - - Args: - user_input: The user's message - """ - if not user_input: - return - - # Check for plan confirmation - if self.current_plan and user_input.lower() in ["yes", "y", "confirm"]: - logger.info("Plan confirmation received") - self.plan_confirmed = True - # Create a proper PlanningAgentResponse with content as a list - confirmation_msg = PlanningAgentResponse( - type="dialogue", - content="Plan confirmed! Streaming steps to execution...", - needs_confirmation=False, - ) - self._handle_response(confirmation_msg) - self._stream_plan() - return - - # Build messages for OpenAI with conversation history - messages = [ - {"role": "system", "content": self.system_query} # Using system_query from OpenAIAgent - ] - - # Add the new user input to conversation history - self.conversation_history.append({"type": "user_message", "content": user_input}) - - # Add complete conversation history including both user and assistant messages - for msg in self.conversation_history: - if msg["type"] == "user_message": - messages.append({"role": "user", "content": msg["content"]}) - elif msg["type"] == "dialogue": - messages.append({"role": "assistant", "content": msg["content"]}) - elif msg["type"] == "plan": - plan_text = "Here's my proposed plan:\n" + "\n".join( - f"{i + 1}. {step}" for i, step in enumerate(msg["content"]) - ) - messages.append({"role": "assistant", "content": plan_text}) - - # Get and handle response - response = self._send_query(messages) - self._handle_response(response) - - def start_terminal_interface(self): - """Start the terminal interface for input/output.""" - - time.sleep(5) # buffer time for clean terminal interface printing - print("=" * 50) - print("\nDimOS Action PlanningAgent\n") - print("I have access to your Robot() and Robot Skills()") - print( - "Describe your task and I'll break it down into steps using your skills as a reference." - ) - print("Once you're happy with the plan, type 'yes' to execute it.") - print("Type 'quit' to exit.\n") - - while True: - try: - print("=" * 50) - user_input = input("USER > ") - if user_input.lower() in ["quit", "exit"]: - break - - self.process_user_input(user_input) - - # Display response - if self.latest_response["type"] == "dialogue": - print(f"\nPlanner: {self.latest_response['content']}") - elif self.latest_response["type"] == "plan": - print("\nProposed Plan:") - for i, step in enumerate(self.latest_response["content"], 1): - print(f"{i}. {step}") - if self.latest_response["needs_confirmation"]: - print("\nDoes this plan look good? (yes/no)") - - if self.plan_confirmed: - print("\nPlan confirmed! Streaming steps to execution...") - break - - except KeyboardInterrupt: - print("\nStopping...") - break - except Exception as e: - print(f"\nError: {e}") - break - - def get_response_observable(self) -> Observable: - """Gets an observable that emits responses from this agent. - - This method processes the response stream from the parent class, - extracting content from `PlanningAgentResponse` objects and flattening - any lists of plan steps for emission. - - Returns: - Observable: An observable that emits plan steps from the agent. - """ - - def extract_content(response) -> List[str]: - if isinstance(response, PlanningAgentResponse): - if response.type == "plan": - return response.content # List of steps to be emitted individually - else: # dialogue type - return [response.content] # Wrap single dialogue message in a list - else: - return [str(response)] # Wrap non-PlanningAgentResponse in a list - - # Get base observable from parent class - base_observable = super().get_response_observable() - - # Process the stream: extract content and flatten plan lists - return base_observable.pipe( - ops.map(extract_content), - ops.flat_map(lambda items: items), # Flatten the list of items - ) diff --git a/build/lib/dimos/agents/prompt_builder/__init__.py b/build/lib/dimos/agents/prompt_builder/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/agents/prompt_builder/impl.py b/build/lib/dimos/agents/prompt_builder/impl.py deleted file mode 100644 index 0e66191837..0000000000 --- a/build/lib/dimos/agents/prompt_builder/impl.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from textwrap import dedent -from typing import Optional -from dimos.agents.tokenizer.base import AbstractTokenizer -from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer - -# TODO: Make class more generic when implementing other tokenizers. Presently its OpenAI specific. -# TODO: Build out testing and logging - - -class PromptBuilder: - DEFAULT_SYSTEM_PROMPT = dedent(""" - You are an AI assistant capable of understanding and analyzing both visual and textual information. - Your task is to provide accurate and insightful responses based on the data provided to you. - Use the following information to assist the user with their query. Do not rely on any internal - knowledge or make assumptions beyond the provided data. - - Visual Context: You may have been given an image to analyze. Use the visual details to enhance your response. - Textual Context: There may be some text retrieved from a relevant database to assist you - - Instructions: - - Combine insights from both the image and the text to answer the user's question. - - If the information is insufficient to provide a complete answer, acknowledge the limitation. - - Maintain a professional and informative tone in your response. - """) - - def __init__( - self, model_name="gpt-4o", max_tokens=128000, tokenizer: Optional[AbstractTokenizer] = None - ): - """ - Initialize the prompt builder. - Args: - model_name (str): Model used (e.g., 'gpt-4o', 'gpt-4', 'gpt-3.5-turbo'). - max_tokens (int): Maximum tokens allowed in the input prompt. - tokenizer (AbstractTokenizer): The tokenizer to use for token counting and truncation. - """ - self.model_name = model_name - self.max_tokens = max_tokens - self.tokenizer: AbstractTokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name) - - def truncate_tokens(self, text, max_tokens, strategy): - """ - Truncate text to fit within max_tokens using a specified strategy. - Args: - text (str): Input text to truncate. - max_tokens (int): Maximum tokens allowed. - strategy (str): Truncation strategy ('truncate_head', 'truncate_middle', 'truncate_end', 'do_not_truncate'). - Returns: - str: Truncated text. - """ - if strategy == "do_not_truncate" or not text: - return text - - tokens = self.tokenizer.tokenize_text(text) - if len(tokens) <= max_tokens: - return text - - if strategy == "truncate_head": - truncated = tokens[-max_tokens:] - elif strategy == "truncate_end": - truncated = tokens[:max_tokens] - elif strategy == "truncate_middle": - half = max_tokens // 2 - truncated = tokens[:half] + tokens[-half:] - else: - raise ValueError(f"Unknown truncation strategy: {strategy}") - - return self.tokenizer.detokenize_text(truncated) - - def build( - self, - system_prompt=None, - user_query=None, - base64_image=None, - image_width=None, - image_height=None, - image_detail="low", - rag_context=None, - budgets=None, - policies=None, - override_token_limit=False, - ): - """ - Builds a dynamic prompt tailored to token limits, respecting budgets and policies. - - Args: - system_prompt (str): System-level instructions. - user_query (str, optional): User's query. - base64_image (str, optional): Base64-encoded image string. - image_width (int, optional): Width of the image. - image_height (int, optional): Height of the image. - image_detail (str, optional): Detail level for the image ("low" or "high"). - rag_context (str, optional): Retrieved context. - budgets (dict, optional): Token budgets for each input type. Defaults to equal allocation. - policies (dict, optional): Truncation policies for each input type. - override_token_limit (bool, optional): Whether to override the token limit. Defaults to False. - - Returns: - dict: Messages array ready to send to the OpenAI API. - """ - if user_query is None: - raise ValueError("User query is required.") - - # Debug: - # base64_image = None - - budgets = budgets or { - "system_prompt": self.max_tokens // 4, - "user_query": self.max_tokens // 4, - "image": self.max_tokens // 4, - "rag": self.max_tokens // 4, - } - policies = policies or { - "system_prompt": "truncate_end", - "user_query": "truncate_middle", - "image": "do_not_truncate", - "rag": "truncate_end", - } - - # Validate and sanitize image_detail - if image_detail not in {"low", "high"}: - image_detail = "low" # Default to "low" if invalid or None - - # Determine which system prompt to use - if system_prompt is None: - system_prompt = self.DEFAULT_SYSTEM_PROMPT - - rag_context = rag_context or "" - - # Debug: - # print("system_prompt: ", system_prompt) - # print("rag_context: ", rag_context) - - # region Token Counts - if not override_token_limit: - rag_token_cnt = self.tokenizer.token_count(rag_context) - system_prompt_token_cnt = self.tokenizer.token_count(system_prompt) - user_query_token_cnt = self.tokenizer.token_count(user_query) - image_token_cnt = ( - self.tokenizer.image_token_count(image_width, image_height, image_detail) - if base64_image - else 0 - ) - else: - rag_token_cnt = 0 - system_prompt_token_cnt = 0 - user_query_token_cnt = 0 - image_token_cnt = 0 - # endregion Token Counts - - # Create a component dictionary for dynamic allocation - components = { - "system_prompt": {"text": system_prompt, "tokens": system_prompt_token_cnt}, - "user_query": {"text": user_query, "tokens": user_query_token_cnt}, - "image": {"text": None, "tokens": image_token_cnt}, - "rag": {"text": rag_context, "tokens": rag_token_cnt}, - } - - if not override_token_limit: - # Adjust budgets and apply truncation - total_tokens = sum(comp["tokens"] for comp in components.values()) - excess_tokens = total_tokens - self.max_tokens - if excess_tokens > 0: - for key, component in components.items(): - if excess_tokens <= 0: - break - if policies[key] != "do_not_truncate": - max_allowed = max(0, budgets[key] - excess_tokens) - components[key]["text"] = self.truncate_tokens( - component["text"], max_allowed, policies[key] - ) - tokens_after = self.tokenizer.token_count(components[key]["text"]) - excess_tokens -= component["tokens"] - tokens_after - component["tokens"] = tokens_after - - # Build the `messages` structure (OpenAI specific) - messages = [{"role": "system", "content": components["system_prompt"]["text"]}] - - if components["rag"]["text"]: - user_content = [ - { - "type": "text", - "text": f"{components['rag']['text']}\n\n{components['user_query']['text']}", - } - ] - else: - user_content = [{"type": "text", "text": components["user_query"]["text"]}] - - if base64_image: - user_content.append( - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}", - "detail": image_detail, - }, - } - ) - messages.append({"role": "user", "content": user_content}) - - # Debug: - # print("system_prompt: ", system_prompt) - # print("user_query: ", user_query) - # print("user_content: ", user_content) - # print(f"Messages: {messages}") - - return messages diff --git a/build/lib/dimos/agents/tokenizer/__init__.py b/build/lib/dimos/agents/tokenizer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/agents/tokenizer/base.py b/build/lib/dimos/agents/tokenizer/base.py deleted file mode 100644 index b7e96de71f..0000000000 --- a/build/lib/dimos/agents/tokenizer/base.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod - -# TODO: Add a class for specific tokenizer exceptions -# TODO: Build out testing and logging -# TODO: Create proper doc strings after multiple tokenizers are implemented - - -class AbstractTokenizer(ABC): - @abstractmethod - def tokenize_text(self, text): - pass - - @abstractmethod - def detokenize_text(self, tokenized_text): - pass - - @abstractmethod - def token_count(self, text): - pass - - @abstractmethod - def image_token_count(self, image_width, image_height, image_detail="low"): - pass diff --git a/build/lib/dimos/agents/tokenizer/huggingface_tokenizer.py b/build/lib/dimos/agents/tokenizer/huggingface_tokenizer.py deleted file mode 100644 index 2a7b0d2283..0000000000 --- a/build/lib/dimos/agents/tokenizer/huggingface_tokenizer.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import AutoTokenizer -from dimos.agents.tokenizer.base import AbstractTokenizer -from dimos.utils.logging_config import setup_logger - - -class HuggingFaceTokenizer(AbstractTokenizer): - def __init__(self, model_name: str = "Qwen/Qwen2.5-0.5B", **kwargs): - super().__init__(**kwargs) - - # Initilize the tokenizer for the huggingface models - self.model_name = model_name - try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - except Exception as e: - raise ValueError( - f"Failed to initialize tokenizer for model {self.model_name}. Error: {str(e)}" - ) - - def tokenize_text(self, text): - """ - Tokenize a text string using the openai tokenizer. - """ - return self.tokenizer.encode(text) - - def detokenize_text(self, tokenized_text): - """ - Detokenize a text string using the openai tokenizer. - """ - try: - return self.tokenizer.decode(tokenized_text, errors="ignore") - except Exception as e: - raise ValueError(f"Failed to detokenize text. Error: {str(e)}") - - def token_count(self, text): - """ - Gets the token count of a text string using the openai tokenizer. - """ - return len(self.tokenize_text(text)) if text else 0 - - @staticmethod - def image_token_count(image_width, image_height, image_detail="high"): - """ - Calculate the number of tokens in an image. Low detail is 85 tokens, high detail is 170 tokens per 512x512 square. - """ - logger = setup_logger("dimos.agents.tokenizer.HuggingFaceTokenizer.image_token_count") - - if image_detail == "low": - return 85 - elif image_detail == "high": - # Image dimensions - logger.debug(f"Image Width: {image_width}, Image Height: {image_height}") - if image_width is None or image_height is None: - raise ValueError( - "Image width and height must be provided for high detail image token count calculation." - ) - - # Scale image to fit within 2048 x 2048 - max_dimension = max(image_width, image_height) - if max_dimension > 2048: - scale_factor = 2048 / max_dimension - image_width = int(image_width * scale_factor) - image_height = int(image_height * scale_factor) - - # Scale shortest side to 768px - min_dimension = min(image_width, image_height) - scale_factor = 768 / min_dimension - image_width = int(image_width * scale_factor) - image_height = int(image_height * scale_factor) - - # Calculate number of 512px squares - num_squares = (image_width // 512) * (image_height // 512) - return 170 * num_squares + 85 - else: - raise ValueError("Detail specification of image is not 'low' or 'high'") diff --git a/build/lib/dimos/agents/tokenizer/openai_tokenizer.py b/build/lib/dimos/agents/tokenizer/openai_tokenizer.py deleted file mode 100644 index 7517ae5e72..0000000000 --- a/build/lib/dimos/agents/tokenizer/openai_tokenizer.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tiktoken -from dimos.agents.tokenizer.base import AbstractTokenizer -from dimos.utils.logging_config import setup_logger - - -class OpenAITokenizer(AbstractTokenizer): - def __init__(self, model_name: str = "gpt-4o", **kwargs): - super().__init__(**kwargs) - - # Initilize the tokenizer for the openai set of models - self.model_name = model_name - try: - self.tokenizer = tiktoken.encoding_for_model(self.model_name) - except Exception as e: - raise ValueError( - f"Failed to initialize tokenizer for model {self.model_name}. Error: {str(e)}" - ) - - def tokenize_text(self, text): - """ - Tokenize a text string using the openai tokenizer. - """ - return self.tokenizer.encode(text) - - def detokenize_text(self, tokenized_text): - """ - Detokenize a text string using the openai tokenizer. - """ - try: - return self.tokenizer.decode(tokenized_text, errors="ignore") - except Exception as e: - raise ValueError(f"Failed to detokenize text. Error: {str(e)}") - - def token_count(self, text): - """ - Gets the token count of a text string using the openai tokenizer. - """ - return len(self.tokenize_text(text)) if text else 0 - - @staticmethod - def image_token_count(image_width, image_height, image_detail="high"): - """ - Calculate the number of tokens in an image. Low detail is 85 tokens, high detail is 170 tokens per 512x512 square. - """ - logger = setup_logger("dimos.agents.tokenizer.openai.image_token_count") - - if image_detail == "low": - return 85 - elif image_detail == "high": - # Image dimensions - logger.debug(f"Image Width: {image_width}, Image Height: {image_height}") - if image_width is None or image_height is None: - raise ValueError( - "Image width and height must be provided for high detail image token count calculation." - ) - - # Scale image to fit within 2048 x 2048 - max_dimension = max(image_width, image_height) - if max_dimension > 2048: - scale_factor = 2048 / max_dimension - image_width = int(image_width * scale_factor) - image_height = int(image_height * scale_factor) - - # Scale shortest side to 768px - min_dimension = min(image_width, image_height) - scale_factor = 768 / min_dimension - image_width = int(image_width * scale_factor) - image_height = int(image_height * scale_factor) - - # Calculate number of 512px squares - num_squares = (image_width // 512) * (image_height // 512) - return 170 * num_squares + 85 - else: - raise ValueError("Detail specification of image is not 'low' or 'high'") diff --git a/build/lib/dimos/core/__init__.py b/build/lib/dimos/core/__init__.py deleted file mode 100644 index 5df6d4e803..0000000000 --- a/build/lib/dimos/core/__init__.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import multiprocessing as mp -import time -from typing import Optional - -from dask.distributed import Client, LocalCluster -from rich.console import Console - -import dimos.core.colors as colors -from dimos.core.core import In, Out, RemoteOut, rpc -from dimos.core.module import Module, ModuleBase -from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport -from dimos.protocol.rpc.lcmrpc import LCMRPC -from dimos.protocol.rpc.spec import RPC - - -def patch_actor(actor, cls): ... - - -class RPCClient: - def __init__(self, actor_instance, actor_class): - self.rpc = LCMRPC() - self.actor_class = actor_class - self.remote_name = actor_class.__name__ - self.actor_instance = actor_instance - self.rpcs = actor_class.rpcs.keys() - self.rpc.start() - - def __reduce__(self): - # Return the class and the arguments needed to reconstruct the object - return ( - self.__class__, - (self.actor_instance, self.actor_class), - ) - - # passthrough - def __getattr__(self, name: str): - # Check if accessing a known safe attribute to avoid recursion - if name in { - "__class__", - "__init__", - "__dict__", - "__getattr__", - "rpcs", - "remote_name", - "remote_instance", - "actor_instance", - }: - raise AttributeError(f"{name} is not found.") - - if name in self.rpcs: - return lambda *args: self.rpc.call_sync(f"{self.remote_name}/{name}", args) - - # return super().__getattr__(name) - # Try to avoid recursion by directly accessing attributes that are known - return self.actor_instance.__getattr__(name) - - -def patchdask(dask_client: Client): - def deploy( - actor_class, - *args, - **kwargs, - ): - console = Console() - with console.status(f"deploying [green]{actor_class.__name__}", spinner="arc"): - actor = dask_client.submit( - actor_class, - *args, - **kwargs, - actor=True, - ).result() - - worker = actor.set_ref(actor).result() - print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) - - return RPCClient(actor, actor_class) - - dask_client.deploy = deploy - return dask_client - - -def start(n: Optional[int] = None) -> Client: - console = Console() - if not n: - n = mp.cpu_count() - with console.status( - f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc" - ) as status: - cluster = LocalCluster( - n_workers=n, - threads_per_worker=4, - ) - client = Client(cluster) - - console.print(f"[green]Initialized dimos local cluster with [bright_blue]{n} workers") - return patchdask(client) - - -def stop(client: Client): - client.close() - client.cluster.close() diff --git a/build/lib/dimos/core/colors.py b/build/lib/dimos/core/colors.py deleted file mode 100644 index f137523e67..0000000000 --- a/build/lib/dimos/core/colors.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def green(text: str) -> str: - """Return the given text in green color.""" - return f"\033[92m{text}\033[0m" - - -def blue(text: str) -> str: - """Return the given text in blue color.""" - return f"\033[94m{text}\033[0m" - - -def red(text: str) -> str: - """Return the given text in red color.""" - return f"\033[91m{text}\033[0m" - - -def yellow(text: str) -> str: - """Return the given text in yellow color.""" - return f"\033[93m{text}\033[0m" - - -def cyan(text: str) -> str: - """Return the given text in cyan color.""" - return f"\033[96m{text}\033[0m" - - -def orange(text: str) -> str: - """Return the given text in orange color.""" - return f"\033[38;5;208m{text}\033[0m" diff --git a/build/lib/dimos/core/core.py b/build/lib/dimos/core/core.py deleted file mode 100644 index 9c57d93559..0000000000 --- a/build/lib/dimos/core/core.py +++ /dev/null @@ -1,260 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import enum -import inspect -import traceback -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Optional, - Protocol, - TypeVar, - get_args, - get_origin, - get_type_hints, -) - -from dask.distributed import Actor - -import dimos.core.colors as colors -from dimos.core.o3dpickle import register_picklers - -register_picklers() -T = TypeVar("T") - - -class Transport(Protocol[T]): - # used by local Output - def broadcast(self, selfstream: Out[T], value: T): ... - - # used by local Input - def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... - - -class DaskTransport(Transport[T]): - subscribers: List[Callable[[T], None]] - _started: bool = False - - def __init__(self): - self.subscribers = [] - - def __str__(self) -> str: - return colors.yellow("DaskTransport") - - def __reduce__(self): - return (DaskTransport, ()) - - def broadcast(self, selfstream: RemoteIn[T], msg: T) -> None: - for subscriber in self.subscribers: - # there is some sort of a bug here with losing worker loop - # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) - # subscriber.owner._try_bind_worker_client() - # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) - - subscriber.owner.dask_receive_msg(subscriber.name, msg).result() - - def dask_receive_msg(self, msg) -> None: - for subscriber in self.subscribers: - try: - subscriber(msg) - except Exception as e: - print( - colors.red("Error in DaskTransport subscriber callback:"), - e, - traceback.format_exc(), - ) - - # for outputs - def dask_register_subscriber(self, remoteInput: RemoteIn[T]) -> None: - self.subscribers.append(remoteInput) - - # for inputs - def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: - if not self._started: - selfstream.connection.owner.dask_register_subscriber( - selfstream.connection.name, selfstream - ).result() - self._started = True - self.subscribers.append(callback) - - -class State(enum.Enum): - UNBOUND = "unbound" # descriptor defined but not bound - READY = "ready" # bound to owner but not yet connected - CONNECTED = "connected" # input bound to an output - FLOWING = "flowing" # runtime: data observed - - -class Stream(Generic[T]): - _transport: Optional[Transport] - - def __init__( - self, - type: type[T], - name: str, - owner: Optional[Any] = None, - transport: Optional[Transport] = None, - ): - self.name = name - self.owner = owner - self.type = type - if transport: - self._transport = transport - if not hasattr(self, "_transport"): - self._transport = None - - @property - def type_name(self) -> str: - return getattr(self.type, "__name__", repr(self.type)) - - def _color_fn(self) -> Callable[[str], str]: - if self.state == State.UNBOUND: - return colors.orange - if self.state == State.READY: - return colors.blue - if self.state == State.CONNECTED: - return colors.green - return lambda s: s - - def __str__(self) -> str: # noqa: D401 - return ( - self.__class__.__name__ - + " " - + self._color_fn()(f"{self.name}[{self.type_name}]") - + " @ " - + ( - colors.orange(self.owner) - if isinstance(self.owner, Actor) - else colors.green(self.owner) - ) - + ("" if not self._transport else " via " + str(self._transport)) - ) - - -class Out(Stream[T]): - _transport: Transport - - def __init__(self, *argv, **kwargs): - super().__init__(*argv, **kwargs) - if not hasattr(self, "_transport") or self._transport is None: - self._transport = DaskTransport() - - @property - def transport(self) -> Transport[T]: - return self._transport - - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - def __reduce__(self): # noqa: D401 - if self.owner is None or not hasattr(self.owner, "ref"): - raise ValueError("Cannot serialise Out without an owner ref") - return ( - RemoteOut, - ( - self.type, - self.name, - self.owner.ref, - self._transport, - ), - ) - - def publish(self, msg): - self._transport.broadcast(self, msg) - - -class RemoteStream(Stream[T]): - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - # this won't work but nvm - @property - def transport(self) -> Transport[T]: - return self._transport - - @transport.setter - def transport(self, value: Transport[T]) -> None: - self.owner.set_transport(self.name, value).result() - self._transport = value - - -class RemoteOut(RemoteStream[T]): - def connect(self, other: RemoteIn[T]): - return other.connect(self) - - -class In(Stream[T]): - connection: Optional[RemoteOut[T]] = None - _transport: Transport - - def __str__(self): - mystr = super().__str__() - - if not self.connection: - return mystr - - return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}" - - def __reduce__(self): # noqa: D401 - if self.owner is None or not hasattr(self.owner, "ref"): - raise ValueError("Cannot serialise Out without an owner ref") - return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport)) - - @property - def transport(self) -> Transport[T]: - if not self._transport: - self._transport = self.connection.transport - return self._transport - - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - def subscribe(self, cb): - self.transport.subscribe(self, cb) - - -class RemoteIn(RemoteStream[T]): - def connect(self, other: RemoteOut[T]) -> None: - return self.owner.connect_stream(self.name, other).result() - - # this won't work but that's ok - @property - def transport(self) -> Transport[T]: - return self._transport - - def publish(self, msg): - self.transport.broadcast(self, msg) - - @transport.setter - def transport(self, value: Transport[T]) -> None: - self.owner.set_transport(self.name, value).result() - self._transport = value - - -def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: - fn.__rpc__ = True # type: ignore[attr-defined] - return fn - - -daskTransport = DaskTransport() # singleton instance for use in Out/RemoteOut diff --git a/build/lib/dimos/core/module.py b/build/lib/dimos/core/module.py deleted file mode 100644 index c232e613c2..0000000000 --- a/build/lib/dimos/core/module.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -from typing import ( - Any, - Callable, - get_args, - get_origin, - get_type_hints, -) - -from dask.distributed import Actor, get_worker - -from dimos.core import colors -from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport -from dimos.protocol.rpc.lcmrpc import LCMRPC - - -class ModuleBase: - def __init__(self, *args, **kwargs): - try: - get_worker() - self.rpc = LCMRPC() - self.rpc.serve_module_rpc(self) - self.rpc.start() - except ValueError: - return - - @property - def outputs(self) -> dict[str, Out]: - return { - name: s - for name, s in self.__dict__.items() - if isinstance(s, Out) and not name.startswith("_") - } - - @property - def inputs(self) -> dict[str, In]: - return { - name: s - for name, s in self.__dict__.items() - if isinstance(s, In) and not name.startswith("_") - } - - @classmethod - @property - def rpcs(cls) -> dict[str, Callable]: - return { - name: getattr(cls, name) - for name in dir(cls) - if not name.startswith("_") - and name != "rpcs" # Exclude the rpcs property itself to prevent recursion - and callable(getattr(cls, name, None)) - and hasattr(getattr(cls, name), "__rpc__") - } - - def io(self) -> str: - def _box(name: str) -> str: - return [ - f"┌┴" + "─" * (len(name) + 1) + "┐", - f"│ {name} │", - f"└┬" + "─" * (len(name) + 1) + "┘", - ] - - # can't modify __str__ on a function like we are doing for I/O - # so we have a separate repr function here - def repr_rpc(fn: Callable) -> str: - sig = inspect.signature(fn) - # Remove 'self' parameter - params = [p for name, p in sig.parameters.items() if name != "self"] - - # Format parameters with colored types - param_strs = [] - for param in params: - param_str = param.name - if param.annotation != inspect.Parameter.empty: - type_name = getattr(param.annotation, "__name__", str(param.annotation)) - param_str += ": " + colors.green(type_name) - if param.default != inspect.Parameter.empty: - param_str += f" = {param.default}" - param_strs.append(param_str) - - # Format return type - return_annotation = "" - if sig.return_annotation != inspect.Signature.empty: - return_type = getattr(sig.return_annotation, "__name__", str(sig.return_annotation)) - return_annotation = " -> " + colors.green(return_type) - - return ( - "RPC " + colors.blue(fn.__name__) + f"({', '.join(param_strs)})" + return_annotation - ) - - ret = [ - *(f" ├─ {stream}" for stream in self.inputs.values()), - *_box(self.__class__.__name__), - *(f" ├─ {stream}" for stream in self.outputs.values()), - " │", - *(f" ├─ {repr_rpc(rpc)}" for rpc in self.rpcs.values()), - ] - - return "\n".join(ret) - - -class DaskModule(ModuleBase): - ref: Actor - worker: int - - def __init__(self, *args, **kwargs): - self.ref = None - - for name, ann in get_type_hints(self, include_extras=True).items(): - origin = get_origin(ann) - if origin is Out: - inner, *_ = get_args(ann) or (Any,) - stream = Out(inner, name, self) - setattr(self, name, stream) - elif origin is In: - inner, *_ = get_args(ann) or (Any,) - stream = In(inner, name, self) - setattr(self, name, stream) - super().__init__(*args, **kwargs) - - def set_ref(self, ref) -> int: - worker = get_worker() - self.ref = ref - self.worker = worker.name - return worker.name - - def __str__(self): - return f"{self.__class__.__name__}" - - # called from remote - def set_transport(self, stream_name: str, transport: Transport): - stream = getattr(self, stream_name, None) - if not stream: - raise ValueError(f"{stream_name} not found in {self.__class__.__name__}") - - if not isinstance(stream, Out) and not isinstance(stream, In): - raise TypeError(f"Output {stream_name} is not a valid stream") - - stream._transport = transport - return True - - # called from remote - def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): - input_stream = getattr(self, input_name, None) - if not input_stream: - raise ValueError(f"{input_name} not found in {self.__class__.__name__}") - if not isinstance(input_stream, In): - raise TypeError(f"Input {input_name} is not a valid stream") - input_stream.connection = remote_stream - - def dask_receive_msg(self, input_name: str, msg: Any): - getattr(self, input_name).transport.dask_receive_msg(msg) - - def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]): - getattr(self, output_name).transport.dask_register_subscriber(subscriber) - - -# global setting -Module = DaskModule diff --git a/build/lib/dimos/core/o3dpickle.py b/build/lib/dimos/core/o3dpickle.py deleted file mode 100644 index a18916a06c..0000000000 --- a/build/lib/dimos/core/o3dpickle.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copyreg - -import numpy as np -import open3d as o3d - - -def reduce_external(obj): - # Convert Vector3dVector to numpy array for pickling - points_array = np.asarray(obj.points) - return (reconstruct_pointcloud, (points_array,)) - - -def reconstruct_pointcloud(points_array): - # Create new PointCloud and assign the points - pc = o3d.geometry.PointCloud() - pc.points = o3d.utility.Vector3dVector(points_array) - return pc - - -def register_picklers(): - # Register for the actual PointCloud class that gets instantiated - # We need to create a dummy PointCloud to get its actual class - _dummy_pc = o3d.geometry.PointCloud() - copyreg.pickle(_dummy_pc.__class__, reduce_external) diff --git a/build/lib/dimos/core/test_core.py b/build/lib/dimos/core/test_core.py deleted file mode 100644 index ace435b54b..0000000000 --- a/build/lib/dimos/core/test_core.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from threading import Event, Thread - -import pytest - -from dimos.core import ( - In, - LCMTransport, - Module, - Out, - RemoteOut, - ZenohTransport, - pLCMTransport, - rpc, - start, - stop, -) -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.vector import Vector -from dimos.utils.testing import SensorReplay - -# never delete this line - - -@pytest.fixture -def dimos(): - """Fixture to create a Dimos client for testing.""" - client = start(2) - yield client - stop(client) - - -class RobotClient(Module): - odometry: Out[Odometry] = None - lidar: Out[LidarMessage] = None - mov: In[Vector] = None - - mov_msg_count = 0 - - def mov_callback(self, msg): - self.mov_msg_count += 1 - - def __init__(self): - super().__init__() - self._stop_event = Event() - self._thread = None - - def start(self): - self._thread = Thread(target=self.odomloop) - self._thread.start() - self.mov.subscribe(self.mov_callback) - - def odomloop(self): - odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) - lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) - - lidariter = lidardata.iterate() - self._stop_event.clear() - while not self._stop_event.is_set(): - for odom in odomdata.iterate(): - if self._stop_event.is_set(): - return - print(odom) - odom.pubtime = time.perf_counter() - self.odometry.publish(odom) - - lidarmsg = next(lidariter) - lidarmsg.pubtime = time.perf_counter() - self.lidar.publish(lidarmsg) - time.sleep(0.1) - - def stop(self): - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=1.0) # Wait up to 1 second for clean shutdown - - -class Navigation(Module): - mov: Out[Vector] = None - lidar: In[LidarMessage] = None - target_position: In[Vector] = None - odometry: In[Odometry] = None - - odom_msg_count = 0 - lidar_msg_count = 0 - - @rpc - def navigate_to(self, target: Vector) -> bool: ... - - def __init__(self): - super().__init__() - - @rpc - def start(self): - def _odom(msg): - self.odom_msg_count += 1 - print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) - self.mov.publish(msg.position) - - self.odometry.subscribe(_odom) - - def _lidar(msg): - self.lidar_msg_count += 1 - if hasattr(msg, "pubtime"): - print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) - else: - print("RCV: unknown time", msg) - - self.lidar.subscribe(_lidar) - - -def test_classmethods(): - # Test class property access - class_rpcs = Navigation.rpcs - print("Class rpcs:", class_rpcs) - - # Test instance property access - nav = Navigation() - instance_rpcs = nav.rpcs - print("Instance rpcs:", instance_rpcs) - - # Assertions - assert isinstance(class_rpcs, dict), "Class rpcs should be a dictionary" - assert isinstance(instance_rpcs, dict), "Instance rpcs should be a dictionary" - assert class_rpcs == instance_rpcs, "Class and instance rpcs should be identical" - - # Check that we have the expected RPC methods - assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" - assert "start" in class_rpcs, "start should be in rpcs" - assert len(class_rpcs) == 2, "Should have exactly 2 RPC methods" - - # Check that the values are callable - assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" - assert callable(class_rpcs["start"]), "start should be callable" - - # Check that they have the __rpc__ attribute - assert hasattr(class_rpcs["navigate_to"], "__rpc__"), ( - "navigate_to should have __rpc__ attribute" - ) - assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" - - -@pytest.mark.tool -def test_deployment(dimos): - robot = dimos.deploy(RobotClient) - target_stream = RemoteOut[Vector](Vector, "target") - - print("\n") - print("lidar stream", robot.lidar) - print("target stream", target_stream) - print("odom stream", robot.odometry) - - nav = dimos.deploy(Navigation) - - # this one encodes proper LCM messages - robot.lidar.transport = LCMTransport("/lidar", LidarMessage) - # odometry & mov using just a pickle over LCM - robot.odometry.transport = pLCMTransport("/odom") - nav.mov.transport = pLCMTransport("/mov") - - nav.lidar.connect(robot.lidar) - nav.odometry.connect(robot.odometry) - robot.mov.connect(nav.mov) - - print("\n" + robot.io().result() + "\n") - print("\n" + nav.io().result() + "\n") - robot.start().result() - nav.start().result() - - time.sleep(1) - robot.stop().result() - - print("robot.mov_msg_count", robot.mov_msg_count) - print("nav.odom_msg_count", nav.odom_msg_count) - print("nav.lidar_msg_count", nav.lidar_msg_count) - - assert robot.mov_msg_count >= 8 - assert nav.odom_msg_count >= 8 - assert nav.lidar_msg_count >= 8 - - -if __name__ == "__main__": - client = start(1) # single process for CI memory - test_deployment(client) diff --git a/build/lib/dimos/core/transport.py b/build/lib/dimos/core/transport.py deleted file mode 100644 index 5457517b28..0000000000 --- a/build/lib/dimos/core/transport.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import traceback -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Optional, - Protocol, - TypeVar, - get_args, - get_origin, - get_type_hints, -) - -import dimos.core.colors as colors -from dimos.core.core import In, Transport -from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM -from dimos.protocol.pubsub.lcmpubsub import Topic as LCMTopic - -T = TypeVar("T") - - -class PubSubTransport(Transport[T]): - topic: any - - def __init__(self, topic: any): - self.topic = topic - - def __str__(self) -> str: - return ( - colors.green(f"{self.__class__.__name__}(") - + colors.blue(self.topic) - + colors.green(")") - ) - - -class pLCMTransport(PubSubTransport[T]): - _started: bool = False - - def __init__(self, topic: str, **kwargs): - super().__init__(topic) - self.lcm = PickleLCM(**kwargs) - - def __reduce__(self): - return (pLCMTransport, (self.topic,)) - - def broadcast(self, _, msg): - if not self._started: - self.lcm.start() - self._started = True - - self.lcm.publish(self.topic, msg) - - def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: - if not self._started: - self.lcm.start() - self._started = True - self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) - - -class LCMTransport(PubSubTransport[T]): - _started: bool = False - - def __init__(self, topic: str, type: type, **kwargs): - super().__init__(LCMTopic(topic, type)) - self.lcm = LCM(**kwargs) - - def __reduce__(self): - return (LCMTransport, (self.topic.topic, self.topic.lcm_type)) - - def broadcast(self, _, msg): - if not self._started: - self.lcm.start() - self._started = True - - self.lcm.publish(self.topic, msg) - - def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: - if not self._started: - self.lcm.start() - self._started = True - self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) - - -class ZenohTransport(PubSubTransport[T]): ... diff --git a/build/lib/dimos/environment/__init__.py b/build/lib/dimos/environment/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/environment/agent_environment.py b/build/lib/dimos/environment/agent_environment.py deleted file mode 100644 index 861a1f429b..0000000000 --- a/build/lib/dimos/environment/agent_environment.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -from pathlib import Path -from typing import List, Union -from .environment import Environment - - -class AgentEnvironment(Environment): - def __init__(self): - super().__init__() - self.environment_type = "agent" - self.frames = [] - self.current_frame_idx = 0 - self._depth_maps = [] - self._segmentations = [] - self._point_clouds = [] - - def initialize_from_images(self, images: Union[List[str], List[np.ndarray]]) -> bool: - """Initialize environment from a list of image paths or numpy arrays. - - Args: - images: List of image paths or numpy arrays representing frames - - Returns: - bool: True if initialization successful, False otherwise - """ - try: - self.frames = [] - for img in images: - if isinstance(img, str): - frame = cv2.imread(img) - if frame is None: - raise ValueError(f"Failed to load image: {img}") - self.frames.append(frame) - elif isinstance(img, np.ndarray): - self.frames.append(img.copy()) - else: - raise ValueError(f"Unsupported image type: {type(img)}") - return True - except Exception as e: - print(f"Failed to initialize from images: {e}") - return False - - def initialize_from_file(self, file_path: str) -> bool: - """Initialize environment from a video file. - - Args: - file_path: Path to the video file - - Returns: - bool: True if initialization successful, False otherwise - """ - try: - if not Path(file_path).exists(): - raise FileNotFoundError(f"Video file not found: {file_path}") - - cap = cv2.VideoCapture(file_path) - self.frames = [] - - while cap.isOpened(): - ret, frame = cap.read() - if not ret: - break - self.frames.append(frame) - - cap.release() - return len(self.frames) > 0 - except Exception as e: - print(f"Failed to initialize from video: {e}") - return False - - def initialize_from_directory(self, directory_path: str) -> bool: - """Initialize environment from a directory of images.""" - # TODO: Implement directory initialization - raise NotImplementedError("Directory initialization not yet implemented") - - def label_objects(self) -> List[str]: - """Implementation of abstract method to label objects.""" - # TODO: Implement object labeling using a detection model - raise NotImplementedError("Object labeling not yet implemented") - - def generate_segmentations( - self, model: str = None, objects: List[str] = None, *args, **kwargs - ) -> List[np.ndarray]: - """Generate segmentations for the current frame.""" - # TODO: Implement segmentation generation using specified model - raise NotImplementedError("Segmentation generation not yet implemented") - - def get_segmentations(self) -> List[np.ndarray]: - """Return pre-computed segmentations for the current frame.""" - if self._segmentations: - return self._segmentations[self.current_frame_idx] - return [] - - def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarray: - """Generate point cloud from the current frame.""" - # TODO: Implement point cloud generation - raise NotImplementedError("Point cloud generation not yet implemented") - - def get_point_cloud(self, object: str = None) -> np.ndarray: - """Return pre-computed point cloud.""" - if self._point_clouds: - return self._point_clouds[self.current_frame_idx] - return np.array([]) - - def generate_depth_map( - self, stereo: bool = None, monocular: bool = None, model: str = None, *args, **kwargs - ) -> np.ndarray: - """Generate depth map for the current frame.""" - # TODO: Implement depth map generation using specified method - raise NotImplementedError("Depth map generation not yet implemented") - - def get_depth_map(self) -> np.ndarray: - """Return pre-computed depth map for the current frame.""" - if self._depth_maps: - return self._depth_maps[self.current_frame_idx] - return np.array([]) - - def get_frame_count(self) -> int: - """Return the total number of frames.""" - return len(self.frames) - - def get_current_frame_index(self) -> int: - """Return the current frame index.""" - return self.current_frame_idx diff --git a/build/lib/dimos/environment/colmap_environment.py b/build/lib/dimos/environment/colmap_environment.py deleted file mode 100644 index 9981e50098..0000000000 --- a/build/lib/dimos/environment/colmap_environment.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# UNDER DEVELOPMENT 🚧🚧🚧 - -import cv2 -import pycolmap -from pathlib import Path -from dimos.environment.environment import Environment - - -class COLMAPEnvironment(Environment): - def initialize_from_images(self, image_dir): - """Initialize the environment from a set of image frames or video.""" - image_dir = Path(image_dir) - output_path = Path("colmap_output") - output_path.mkdir(exist_ok=True) - mvs_path = output_path / "mvs" - database_path = output_path / "database.db" - - # Step 1: Feature extraction - pycolmap.extract_features(database_path, image_dir) - - # Step 2: Feature matching - pycolmap.match_exhaustive(database_path) - - # Step 3: Sparse reconstruction - maps = pycolmap.incremental_mapping(database_path, image_dir, output_path) - maps[0].write(output_path) - - # Step 4: Dense reconstruction (optional) - pycolmap.undistort_images(mvs_path, output_path, image_dir) - pycolmap.patch_match_stereo(mvs_path) # Requires compilation with CUDA - pycolmap.stereo_fusion(mvs_path / "dense.ply", mvs_path) - - return maps - - def initialize_from_video(self, video_path, frame_output_dir): - """Extract frames from a video and initialize the environment.""" - video_path = Path(video_path) - frame_output_dir = Path(frame_output_dir) - frame_output_dir.mkdir(exist_ok=True) - - # Extract frames from the video - self._extract_frames_from_video(video_path, frame_output_dir) - - # Initialize from the extracted frames - return self.initialize_from_images(frame_output_dir) - - def _extract_frames_from_video(self, video_path, frame_output_dir): - """Extract frames from a video and save them to a directory.""" - cap = cv2.VideoCapture(str(video_path)) - frame_count = 0 - - while cap.isOpened(): - ret, frame = cap.read() - if not ret: - break - frame_filename = frame_output_dir / f"frame_{frame_count:04d}.jpg" - cv2.imwrite(str(frame_filename), frame) - frame_count += 1 - - cap.release() - - def label_objects(self): - pass - - def get_visualization(self, format_type): - pass - - def get_segmentations(self): - pass - - def get_point_cloud(self, object_id=None): - pass - - def get_depth_map(self): - pass diff --git a/build/lib/dimos/environment/environment.py b/build/lib/dimos/environment/environment.py deleted file mode 100644 index 0770b0f2ce..0000000000 --- a/build/lib/dimos/environment/environment.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -import numpy as np - - -class Environment(ABC): - def __init__(self): - self.environment_type = None - self.graph = None - - @abstractmethod - def label_objects(self) -> list[str]: - """ - Label all objects in the environment. - - Returns: - A list of string labels representing the objects in the environment. - """ - pass - - @abstractmethod - def get_visualization(self, format_type): - """Return different visualization formats like images, NERFs, or other 3D file types.""" - pass - - @abstractmethod - def generate_segmentations( - self, model: str = None, objects: list[str] = None, *args, **kwargs - ) -> list[np.ndarray]: - """ - Generate object segmentations of objects[] using neural methods. - - Args: - model (str, optional): The string of the desired segmentation model (SegmentAnything, etc.) - objects (list[str], optional): The list of strings of the specific objects to segment. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - list of numpy.ndarray: A list where each element is a numpy array - representing a binary mask for a segmented area of an object in the environment. - - Note: - The specific arguments and their usage will depend on the subclass implementation. - """ - pass - - @abstractmethod - def get_segmentations(self) -> list[np.ndarray]: - """ - Get segmentations using a method like 'segment anything'. - - Returns: - list of numpy.ndarray: A list where each element is a numpy array - representing a binary mask for a segmented area of an object in the environment. - """ - pass - - @abstractmethod - def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarray: - """ - Generate a point cloud for the entire environment or a specific object. - - Args: - object (str, optional): The string of the specific object to get the point cloud for. - If None, returns the point cloud for the entire environment. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - np.ndarray: A numpy array representing the generated point cloud. - Shape: (n, 3) where n is the number of points and each point is [x, y, z]. - - Note: - The specific arguments and their usage will depend on the subclass implementation. - """ - pass - - @abstractmethod - def get_point_cloud(self, object: str = None) -> np.ndarray: - """ - Return point clouds of the entire environment or a specific object. - - Args: - object (str, optional): The string of the specific object to get the point cloud for. If None, returns the point cloud for the entire environment. - - Returns: - np.ndarray: A numpy array representing the point cloud. - Shape: (n, 3) where n is the number of points and each point is [x, y, z]. - """ - pass - - @abstractmethod - def generate_depth_map( - self, stereo: bool = None, monocular: bool = None, model: str = None, *args, **kwargs - ) -> np.ndarray: - """ - Generate a depth map using monocular or stereo camera methods. - - Args: - stereo (bool, optional): Whether to stereo camera is avaliable for ground truth depth map generation. - monocular (bool, optional): Whether to use monocular camera for neural depth map generation. - model (str, optional): The string of the desired monocular depth model (Metric3D, ZoeDepth, etc.) - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - np.ndarray: A 2D numpy array representing the generated depth map. - Shape: (height, width) where each value represents the depth - at that pixel location. - - Note: - The specific arguments and their usage will depend on the subclass implementation. - """ - pass - - @abstractmethod - def get_depth_map(self) -> np.ndarray: - """ - Return a depth map of the environment. - - Returns: - np.ndarray: A 2D numpy array representing the depth map. - Shape: (height, width) where each value represents the depth - at that pixel location. Typically, closer objects have smaller - values and farther objects have larger values. - - Note: - The exact range and units of the depth values may vary depending on the - specific implementation and the sensor or method used to generate the depth map. - """ - pass - - def initialize_from_images(self, images): - """Initialize the environment from a set of image frames or video.""" - raise NotImplementedError("This method is not implemented for this environment type.") - - def initialize_from_file(self, file_path): - """Initialize the environment from a spatial file type. - - Supported file types include: - - GLTF/GLB (GL Transmission Format) - - FBX (Filmbox) - - OBJ (Wavefront Object) - - USD/USDA/USDC (Universal Scene Description) - - STL (Stereolithography) - - COLLADA (DAE) - - Alembic (ABC) - - PLY (Polygon File Format) - - 3DS (3D Studio) - - VRML/X3D (Virtual Reality Modeling Language) - - Args: - file_path (str): Path to the spatial file. - - Raises: - NotImplementedError: If the method is not implemented for this environment type. - """ - raise NotImplementedError("This method is not implemented for this environment type.") diff --git a/build/lib/dimos/exceptions/__init__.py b/build/lib/dimos/exceptions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/exceptions/agent_memory_exceptions.py b/build/lib/dimos/exceptions/agent_memory_exceptions.py deleted file mode 100644 index cbf3460754..0000000000 --- a/build/lib/dimos/exceptions/agent_memory_exceptions.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import traceback - - -class AgentMemoryError(Exception): - """ - Base class for all exceptions raised by AgentMemory operations. - All custom exceptions related to AgentMemory should inherit from this class. - - Args: - message (str): Human-readable message describing the error. - """ - - def __init__(self, message="Error in AgentMemory operation"): - super().__init__(message) - - -class AgentMemoryConnectionError(AgentMemoryError): - """ - Exception raised for errors attempting to connect to the database. - This includes failures due to network issues, authentication errors, or incorrect connection parameters. - - Args: - message (str): Human-readable message describing the error. - cause (Exception, optional): Original exception, if any, that led to this error. - """ - - def __init__(self, message="Failed to connect to the database", cause=None): - super().__init__(message) - if cause: - self.cause = cause - self.traceback = traceback.format_exc() if cause else None - - def __str__(self): - return f"{self.message}\nCaused by: {repr(self.cause)}" if self.cause else self.message - - -class UnknownConnectionTypeError(AgentMemoryConnectionError): - """ - Exception raised when an unknown or unsupported connection type is specified during AgentMemory setup. - - Args: - message (str): Human-readable message explaining that an unknown connection type was used. - """ - - def __init__(self, message="Unknown connection type used in AgentMemory connection"): - super().__init__(message) - - -class DataRetrievalError(AgentMemoryError): - """ - Exception raised for errors retrieving data from the database. - This could occur due to query failures, timeouts, or corrupt data issues. - - Args: - message (str): Human-readable message describing the data retrieval error. - """ - - def __init__(self, message="Error in retrieving data during AgentMemory operation"): - super().__init__(message) - - -class DataNotFoundError(DataRetrievalError): - """ - Exception raised when the requested data is not found in the database. - This is used when a query completes successfully but returns no result for the specified identifier. - - Args: - vector_id (int or str): The identifier for the vector that was not found. - message (str, optional): Human-readable message providing more detail. If not provided, a default message is generated. - """ - - def __init__(self, vector_id, message=None): - message = message or f"Requested data for vector ID {vector_id} was not found." - super().__init__(message) - self.vector_id = vector_id diff --git a/build/lib/dimos/hardware/__init__.py b/build/lib/dimos/hardware/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/hardware/camera.py b/build/lib/dimos/hardware/camera.py deleted file mode 100644 index 07c74ce508..0000000000 --- a/build/lib/dimos/hardware/camera.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.hardware.sensor import AbstractSensor - - -class Camera(AbstractSensor): - def __init__(self, resolution=None, focal_length=None, sensor_size=None, sensor_type="Camera"): - super().__init__(sensor_type) - self.resolution = resolution # (width, height) in pixels - self.focal_length = focal_length # in millimeters - self.sensor_size = sensor_size # (width, height) in millimeters - - def get_sensor_type(self): - return self.sensor_type - - def calculate_intrinsics(self): - if not self.resolution or not self.focal_length or not self.sensor_size: - raise ValueError("Resolution, focal length, and sensor size must be provided") - - # Calculate pixel size - pixel_size_x = self.sensor_size[0] / self.resolution[0] - pixel_size_y = self.sensor_size[1] / self.resolution[1] - - # Calculate the principal point (assuming it's at the center of the image) - principal_point_x = self.resolution[0] / 2 - principal_point_y = self.resolution[1] / 2 - - # Calculate the focal length in pixels - focal_length_x = self.focal_length / pixel_size_x - focal_length_y = self.focal_length / pixel_size_y - - return { - "focal_length_x": focal_length_x, - "focal_length_y": focal_length_y, - "principal_point_x": principal_point_x, - "principal_point_y": principal_point_y, - } - - def get_intrinsics(self): - return self.calculate_intrinsics() diff --git a/build/lib/dimos/hardware/end_effector.py b/build/lib/dimos/hardware/end_effector.py deleted file mode 100644 index 373408003d..0000000000 --- a/build/lib/dimos/hardware/end_effector.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class EndEffector: - def __init__(self, effector_type=None): - self.effector_type = effector_type - - def get_effector_type(self): - return self.effector_type diff --git a/build/lib/dimos/hardware/interface.py b/build/lib/dimos/hardware/interface.py deleted file mode 100644 index 9d7797a569..0000000000 --- a/build/lib/dimos/hardware/interface.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.hardware.end_effector import EndEffector -from dimos.hardware.camera import Camera -from dimos.hardware.stereo_camera import StereoCamera -from dimos.hardware.ufactory import UFactory7DOFArm - - -class HardwareInterface: - def __init__( - self, - end_effector: EndEffector = None, - sensors: list = None, - arm_architecture: UFactory7DOFArm = None, - ): - self.end_effector = end_effector - self.sensors = sensors if sensors is not None else [] - self.arm_architecture = arm_architecture - - def get_configuration(self): - """Return the current hardware configuration.""" - return { - "end_effector": self.end_effector, - "sensors": [sensor.get_sensor_type() for sensor in self.sensors], - "arm_architecture": self.arm_architecture, - } - - def set_configuration(self, configuration): - """Set the hardware configuration.""" - self.end_effector = configuration.get("end_effector", self.end_effector) - self.sensors = configuration.get("sensors", self.sensors) - self.arm_architecture = configuration.get("arm_architecture", self.arm_architecture) - - def add_sensor(self, sensor): - """Add a sensor to the hardware interface.""" - if isinstance(sensor, (Camera, StereoCamera)): - self.sensors.append(sensor) - else: - raise ValueError("Sensor must be a Camera or StereoCamera instance.") diff --git a/build/lib/dimos/hardware/piper_arm.py b/build/lib/dimos/hardware/piper_arm.py deleted file mode 100644 index 5ff6357237..0000000000 --- a/build/lib/dimos/hardware/piper_arm.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# dimos/hardware/piper_arm.py - -from typing import ( - Optional, -) -from piper_sdk import * # from the official Piper SDK -import numpy as np -import time -import subprocess -import kinpy as kp -import sys -import termios -import tty -import select - -import random -import threading - -import pytest - -import dimos.core as core -import dimos.protocol.service.lcmservice as lcmservice -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Pose, Vector3, Twist - - -class PiperArm: - def __init__(self, arm_name: str = "arm"): - self.init_can() - self.arm = C_PiperInterface_V2() - self.arm.ConnectPort() - time.sleep(0.1) - self.resetArm() - time.sleep(0.1) - self.enable() - self.gotoZero() - time.sleep(1) - self.init_vel_controller() - - def init_can(self): - result = subprocess.run( - [ - "bash", - "dimos/hardware/can_activate.sh", - ], # pass the script path directly if it has a shebang and execute perms - stdout=subprocess.PIPE, # capture stdout - stderr=subprocess.PIPE, # capture stderr - text=True, # return strings instead of bytes - ) - - def enable(self): - while not self.arm.EnablePiper(): - pass - time.sleep(0.01) - print(f"[PiperArm] Enabled") - # self.arm.ModeCtrl( - # ctrl_mode=0x01, # CAN command mode - # move_mode=0x01, # “Move-J”, but ignored in MIT - # move_spd_rate_ctrl=100, # doesn’t matter in MIT - # is_mit_mode=0xAD # <-- the magic flag - # ) - self.arm.MotionCtrl_2(0x01, 0x01, 80, 0xAD) - - def gotoZero(self): - factor = 1000 - position = [57.0, 0.0, 250.0, 0, 85.0, .0, 0] - X = round(position[0] * factor) - Y = round(position[1] * factor) - Z = round(position[2] * factor) - RX = round(position[3] * factor) - RY = round(position[4] * factor) - RZ = round(position[5] * factor) - joint_6 = round(position[6] * factor) - print(X, Y, Z, RX, RY, RZ) - self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) - self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) - self.arm.GripperCtrl(abs(joint_6), 1000, 0x01, 0) - - def softStop(self): - self.gotoZero() - time.sleep(1) - self.arm.MotionCtrl_2(0x01, 0x00, 100, ) - self.arm.MotionCtrl_1(0x01, 0, 0) - time.sleep(5) - - def cmd_EE_pose(self, x, y, z, r, p, y_): - """Command end-effector to target pose in space (position + Euler angles)""" - factor = 1000 - pose = [x * factor, y * factor, z * factor, r * factor, p * factor, y_ * factor] - self.arm.MotionCtrl_2(0x01, 0x00, 100, 0xAD) - self.arm.EndPoseCtrl( - int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5]) - ) - - def get_EE_pose(self): - """Return the current end-effector pose as (x, y, z, r, p, y)""" - pose = self.arm.GetArmEndPoseMsgs() - # Extract individual pose values and convert to base units - # Position values are divided by 1000 to convert from SDK units to mm - # Rotation values are divided by 1000 to convert from SDK units to degrees - x = pose.end_pose.X_axis / 1000.0 - y = pose.end_pose.Y_axis / 1000.0 - z = pose.end_pose.Z_axis / 1000.0 - r = pose.end_pose.RX_axis / 1000.0 - p = pose.end_pose.RY_axis / 1000.0 - y_rot = pose.end_pose.RZ_axis / 1000.0 - - return (x, y, z, r, p, y_rot) - - def cmd_gripper_ctrl(self, position): - """Command end-effector gripper""" - position = position * 1000 - - self.arm.GripperCtrl(abs(round(position)), 1000, 0x01, 0) - print(f"[PiperArm] Commanding gripper position: {position}") - - def resetArm(self): - self.arm.MotionCtrl_1(0x02, 0, 0) - self.arm.MotionCtrl_2(0, 0, 0, 0xAD) - print(f"[PiperArm] Resetting arm") - - def init_vel_controller(self): - self.chain = kp.build_serial_chain_from_urdf( - open("dimos/hardware/piper_description.urdf"), "gripper_base" - ) - self.J = self.chain.jacobian(np.zeros(6)) - self.J_pinv = np.linalg.pinv(self.J) - self.dt = 0.01 - - def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): - - - joint_state = self.arm.GetArmJointMsgs().joint_state - # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) - joint_angles = np.array( - [ - joint_state.joint_1, - joint_state.joint_2, - joint_state.joint_3, - joint_state.joint_4, - joint_state.joint_5, - joint_state.joint_6, - ] - ) - # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) - factor = 57295.7795 # 1000*180/3.1415926 - joint_angles = joint_angles / factor # convert to radians - - q = np.array( - [ - joint_angles[0], - joint_angles[1], - joint_angles[2], - joint_angles[3], - joint_angles[4], - joint_angles[5], - ] - ) - J = self.chain.jacobian(q) - self.J_pinv = np.linalg.pinv(J) - dq = self.J_pinv @ np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt - newq = q + dq - - - - newq = newq * factor - - self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) - self.arm.JointCtrl( - int(round(newq[0])), - int(round(newq[1])), - int(round(newq[2])), - int(round(newq[3])), - int(round(newq[4])), - int(round(newq[5])), - ) - time.sleep(self.dt) - # print(f"[PiperArm] Moving to Joints to : {newq}") - - def cmd_vel_ee(self, x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot): - factor = 1000 - x_dot = x_dot * factor - y_dot = y_dot * factor - z_dot = z_dot * factor - RX_dot = RX_dot * factor - PY_dot = PY_dot * factor - YZ_dot = YZ_dot * factor - - current_pose = self.get_EE_pose() - current_pose = np.array(current_pose) - current_pose = current_pose - current_pose = current_pose + np.array([x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot]) * self.dt - current_pose = current_pose - self.cmd_EE_pose( - current_pose[0], - current_pose[1], - current_pose[2], - current_pose[3], - current_pose[4], - current_pose[5], - ) - time.sleep(self.dt) - - def disable(self): - self.softStop() - - while self.arm.DisablePiper(): - pass - time.sleep(0.01) - self.arm.DisconnectPort() - -class VelocityController(Module): - - cmd_vel: In[Twist] = None - - def __init__(self, arm, period=0.01, *args, **kwargs): - super().__init__(*args, **kwargs) - self.arm = arm - self.period = period - self.latest_cmd = None - - - @rpc - def start(self): - self.cmd_vel.subscribe(self.handle_cmd_vel) - - def control_loop(): - - while True: - - cmd_vel = self.latest_cmd - - joint_state = self.arm.GetArmJointMsgs().joint_state - # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) - joint_angles = np.array( - [ - joint_state.joint_1, - joint_state.joint_2, - joint_state.joint_3, - joint_state.joint_4, - joint_state.joint_5, - joint_state.joint_6, - ] - ) - factor = 57295.7795 # 1000*180/3.1415926 - joint_angles = joint_angles / factor # convert to radians - q = np.array( - [ - joint_angles[0], - joint_angles[1], - joint_angles[2], - joint_angles[3], - joint_angles[4], - joint_angles[5], - ] - ) - - J = self.chain.jacobian(q) - self.J_pinv = np.linalg.pinv(J) - dq = self.J_pinv @ np.array([cmd_vel.linear.X, cmd_vel.linear.y, cmd_vel.linear.z, cmd_vel.angular.x, cmd_vel.angular.y, cmd_vel.angular.z]) * self.dt - newq = q + dq - - newq = newq * factor #convert radians to scaled degree units for joint control - - self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) - self.arm.JointCtrl( - int(round(newq[0])), - int(round(newq[1])), - int(round(newq[2])), - int(round(newq[3])), - int(round(newq[4])), - int(round(newq[5])), - ) - time.sleep(self.period) - - thread = threading.Thread(target=control_loop, daemon=True) - thread.start() - - def handle_cmd_vel(self, cmd_vel: Twist): - self.latest_cmd = cmd_vel - -@pytest.mark.tool -def run_velocity_controller(): - lcmservice.autoconf() - dimos = core.start(2) - - velocity_controller = dimos.deploy(VelocityController, arm=arm, period=0.01) - velocity_controller.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) - - velocity_controller.start() - - print("Velocity controller started") - while True: - time.sleep(1) - - - -if __name__ == "__main__": - arm = PiperArm() - - print("get_EE_pose") - arm.get_EE_pose() - - def get_key(timeout=0.1): - """Non-blocking key reader for arrow keys.""" - fd = sys.stdin.fileno() - old_settings = termios.tcgetattr(fd) - try: - tty.setraw(fd) - rlist, _, _ = select.select([fd], [], [], timeout) - if rlist: - ch1 = sys.stdin.read(1) - if ch1 == "\x1b": # Arrow keys start with ESC - ch2 = sys.stdin.read(1) - if ch2 == "[": - ch3 = sys.stdin.read(1) - return ch1 + ch2 + ch3 - else: - return ch1 - return None - finally: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - - def teleop_linear_vel(arm): - print("Use arrow keys to control linear velocity (x/y/z). Press 'q' to quit.") - print("Up/Down: +x/-x, Left/Right: +y/-y, 'w'/'s': +z/-z") - x_dot, y_dot, z_dot = 0.0, 0.0, 0.0 - while True: - key = get_key(timeout=0.1) - if key == "\x1b[A": # Up arrow - x_dot += 0.01 - elif key == "\x1b[B": # Down arrow - x_dot -= 0.01 - elif key == "\x1b[C": # Right arrow - y_dot += 0.01 - elif key == "\x1b[D": # Left arrow - y_dot -= 0.01 - elif key == "w": - z_dot += 0.01 - elif key == "s": - z_dot -= 0.01 - elif key == "q": - print("Exiting teleop.") - arm.disable() - break - - # Optionally, clamp velocities to reasonable limits - x_dot = max(min(x_dot, 0.5), -0.5) - y_dot = max(min(y_dot, 0.5), -0.5) - z_dot = max(min(z_dot, 0.5), -0.5) - - # Only linear velocities, angular set to zero - arm.cmd_vel_ee(x_dot, y_dot, z_dot, 0, 0, 0) - print( - f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s" - ) - - run_velocity_controller() diff --git a/build/lib/dimos/hardware/sensor.py b/build/lib/dimos/hardware/sensor.py deleted file mode 100644 index 3dc7b3850e..0000000000 --- a/build/lib/dimos/hardware/sensor.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod - - -class AbstractSensor(ABC): - def __init__(self, sensor_type=None): - self.sensor_type = sensor_type - - @abstractmethod - def get_sensor_type(self): - """Return the type of sensor.""" - pass - - @abstractmethod - def calculate_intrinsics(self): - """Calculate the sensor's intrinsics.""" - pass - - @abstractmethod - def get_intrinsics(self): - """Return the sensor's intrinsics.""" - pass diff --git a/build/lib/dimos/hardware/stereo_camera.py b/build/lib/dimos/hardware/stereo_camera.py deleted file mode 100644 index 4ffdc51811..0000000000 --- a/build/lib/dimos/hardware/stereo_camera.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.hardware.camera import Camera - - -class StereoCamera(Camera): - def __init__(self, baseline=None, **kwargs): - super().__init__(**kwargs) - self.baseline = baseline - - def get_intrinsics(self): - intrinsics = super().get_intrinsics() - intrinsics["baseline"] = self.baseline - return intrinsics diff --git a/build/lib/dimos/hardware/test_simple_module(1).py b/build/lib/dimos/hardware/test_simple_module(1).py deleted file mode 100644 index 759b627ac6..0000000000 --- a/build/lib/dimos/hardware/test_simple_module(1).py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import threading -import time - -import pytest - -import dimos.core as core -import dimos.protocol.service.lcmservice as lcmservice -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Pose, Vector3 - - -class MyComponent(Module): - ctrl: In[Vector3] = None - current_pose: Out[Vector3] = None - - @rpc - def start(self): - # at start you have self.ctrl and self.current_pose available - self.ctrl.subscribe(self.handle_ctrl) - - def handle_ctrl(self, target: Vector3): - print("handling control command:", target) - self.current_pose.publish(target) - - @rpc - def some_service_call(self, x: int) -> int: - return 3 + x - - -class Controller(Module): - cmd: Out[Vector3] = None - - # we can accept some parameters in the constructor - # but make sure to call super().__init__(*args, **kwargs) - def __init__(self, period=1, *args, **kwargs): - super().__init__(*args, **kwargs) - self.period = period - - @rpc - def start(self): - def send_loop(): - while True: - time.sleep(self.period) - vector = Vector3(0, 0, random.uniform(-1, 1)) - print("sending", vector) - self.cmd.publish(vector) - - thread = threading.Thread(target=send_loop, daemon=True) - thread.start() - - -@pytest.mark.tool -def test_my_component(): - # configures underlying system - lcmservice.autoconf() - dimos = core.start(2) - - controller = dimos.deploy(Controller, period=2) - component = dimos.deploy(MyComponent) - - controller.cmd.transport = core.LCMTransport("/cmd", Vector3) - component.current_pose.transport = core.LCMTransport("/pos", Vector3) - - controller.cmd.connect(component.ctrl) - controller.start() - component.start() - - print("service call result is", component.some_service_call(3)) - - while True: - time.sleep(1) - - -if __name__ == "__main__": - test_my_component() diff --git a/build/lib/dimos/hardware/ufactory.py b/build/lib/dimos/hardware/ufactory.py deleted file mode 100644 index cf4e139ccb..0000000000 --- a/build/lib/dimos/hardware/ufactory.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.hardware.end_effector import EndEffector - - -class UFactoryEndEffector(EndEffector): - def __init__(self, model=None, **kwargs): - super().__init__(**kwargs) - self.model = model - - def get_model(self): - return self.model - - -class UFactory7DOFArm: - def __init__(self, arm_length=None): - self.arm_length = arm_length - - def get_arm_length(self): - return self.arm_length diff --git a/build/lib/dimos/hardware/zed_camera.py b/build/lib/dimos/hardware/zed_camera.py deleted file mode 100644 index a2ceeba54e..0000000000 --- a/build/lib/dimos/hardware/zed_camera.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import cv2 -import open3d as o3d -from typing import Optional, Tuple, Dict, Any -import logging - -try: - import pyzed.sl as sl -except ImportError: - sl = None - logging.warning("ZED SDK not found. Please install pyzed to use ZED camera functionality.") - -from dimos.hardware.stereo_camera import StereoCamera - -logger = logging.getLogger(__name__) - - -class ZEDCamera(StereoCamera): - """ZED Camera capture node with neural depth processing.""" - - def __init__( - self, - camera_id: int = 0, - resolution: sl.RESOLUTION = sl.RESOLUTION.HD720, - depth_mode: sl.DEPTH_MODE = sl.DEPTH_MODE.NEURAL, - fps: int = 30, - **kwargs, - ): - """ - Initialize ZED Camera. - - Args: - camera_id: Camera ID (0 for first ZED) - resolution: ZED camera resolution - depth_mode: Depth computation mode - fps: Camera frame rate (default: 30) - """ - if sl is None: - raise ImportError("ZED SDK not installed. Please install pyzed package.") - - super().__init__(**kwargs) - - self.camera_id = camera_id - self.resolution = resolution - self.depth_mode = depth_mode - self.fps = fps - - # Initialize ZED camera - self.zed = sl.Camera() - self.init_params = sl.InitParameters() - self.init_params.camera_resolution = resolution - self.init_params.depth_mode = depth_mode - self.init_params.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Z_UP_X_FWD - self.init_params.coordinate_units = sl.UNIT.METER - self.init_params.camera_fps = fps - - # Set camera ID using the correct parameter name - if hasattr(self.init_params, "set_from_camera_id"): - self.init_params.set_from_camera_id(camera_id) - elif hasattr(self.init_params, "input"): - self.init_params.input.set_from_camera_id(camera_id) - - # Use enable_fill_mode instead of SENSING_MODE.STANDARD - self.runtime_params = sl.RuntimeParameters() - self.runtime_params.enable_fill_mode = True # False = STANDARD mode, True = FILL mode - - # Image containers - self.image_left = sl.Mat() - self.image_right = sl.Mat() - self.depth_map = sl.Mat() - self.point_cloud = sl.Mat() - self.confidence_map = sl.Mat() - - # Positional tracking - self.tracking_enabled = False - self.tracking_params = sl.PositionalTrackingParameters() - self.camera_pose = sl.Pose() - self.sensors_data = sl.SensorsData() - - self.is_opened = False - - def open(self) -> bool: - """Open the ZED camera.""" - try: - err = self.zed.open(self.init_params) - if err != sl.ERROR_CODE.SUCCESS: - logger.error(f"Failed to open ZED camera: {err}") - return False - - self.is_opened = True - logger.info("ZED camera opened successfully") - - # Get camera information - info = self.zed.get_camera_information() - logger.info(f"ZED Camera Model: {info.camera_model}") - logger.info(f"Serial Number: {info.serial_number}") - logger.info(f"Firmware: {info.camera_configuration.firmware_version}") - - return True - - except Exception as e: - logger.error(f"Error opening ZED camera: {e}") - return False - - def enable_positional_tracking( - self, - enable_area_memory: bool = False, - enable_pose_smoothing: bool = True, - enable_imu_fusion: bool = True, - set_floor_as_origin: bool = False, - initial_world_transform: Optional[sl.Transform] = None, - ) -> bool: - """ - Enable positional tracking on the ZED camera. - - Args: - enable_area_memory: Enable area learning to correct tracking drift - enable_pose_smoothing: Enable pose smoothing - enable_imu_fusion: Enable IMU fusion if available - set_floor_as_origin: Set the floor as origin (useful for robotics) - initial_world_transform: Initial world transform - - Returns: - True if tracking enabled successfully - """ - if not self.is_opened: - logger.error("ZED camera not opened") - return False - - try: - # Configure tracking parameters - self.tracking_params.enable_area_memory = enable_area_memory - self.tracking_params.enable_pose_smoothing = enable_pose_smoothing - self.tracking_params.enable_imu_fusion = enable_imu_fusion - self.tracking_params.set_floor_as_origin = set_floor_as_origin - - if initial_world_transform is not None: - self.tracking_params.initial_world_transform = initial_world_transform - - # Enable tracking - err = self.zed.enable_positional_tracking(self.tracking_params) - if err != sl.ERROR_CODE.SUCCESS: - logger.error(f"Failed to enable positional tracking: {err}") - return False - - self.tracking_enabled = True - logger.info("Positional tracking enabled successfully") - return True - - except Exception as e: - logger.error(f"Error enabling positional tracking: {e}") - return False - - def disable_positional_tracking(self): - """Disable positional tracking.""" - if self.tracking_enabled: - self.zed.disable_positional_tracking() - self.tracking_enabled = False - logger.info("Positional tracking disabled") - - def get_pose( - self, reference_frame: sl.REFERENCE_FRAME = sl.REFERENCE_FRAME.WORLD - ) -> Optional[Dict[str, Any]]: - """ - Get the current camera pose. - - Args: - reference_frame: Reference frame (WORLD or CAMERA) - - Returns: - Dictionary containing: - - position: [x, y, z] in meters - - rotation: [x, y, z, w] quaternion - - euler_angles: [roll, pitch, yaw] in radians - - timestamp: Pose timestamp in nanoseconds - - confidence: Tracking confidence (0-100) - - valid: Whether pose is valid - """ - if not self.tracking_enabled: - logger.error("Positional tracking not enabled") - return None - - try: - # Get current pose - tracking_state = self.zed.get_position(self.camera_pose, reference_frame) - - if tracking_state == sl.POSITIONAL_TRACKING_STATE.OK: - # Extract translation - translation = self.camera_pose.get_translation().get() - - # Extract rotation (quaternion) - rotation = self.camera_pose.get_orientation().get() - - # Get Euler angles - euler = self.camera_pose.get_euler_angles() - - return { - "position": translation.tolist(), - "rotation": rotation.tolist(), # [x, y, z, w] - "euler_angles": euler.tolist(), # [roll, pitch, yaw] - "timestamp": self.camera_pose.timestamp.get_nanoseconds(), - "confidence": self.camera_pose.pose_confidence, - "valid": True, - "tracking_state": str(tracking_state), - } - else: - logger.warning(f"Tracking state: {tracking_state}") - return {"valid": False, "tracking_state": str(tracking_state)} - - except Exception as e: - logger.error(f"Error getting pose: {e}") - return None - - def get_imu_data(self) -> Optional[Dict[str, Any]]: - """ - Get IMU sensor data if available. - - Returns: - Dictionary containing: - - orientation: IMU orientation quaternion [x, y, z, w] - - angular_velocity: [x, y, z] in rad/s - - linear_acceleration: [x, y, z] in m/s² - - timestamp: IMU data timestamp - """ - if not self.is_opened: - logger.error("ZED camera not opened") - return None - - try: - # Get sensors data synchronized with images - if ( - self.zed.get_sensors_data(self.sensors_data, sl.TIME_REFERENCE.IMAGE) - == sl.ERROR_CODE.SUCCESS - ): - imu = self.sensors_data.get_imu_data() - - # Get IMU orientation - imu_orientation = imu.get_pose().get_orientation().get() - - # Get angular velocity - angular_vel = imu.get_angular_velocity() - - # Get linear acceleration - linear_accel = imu.get_linear_acceleration() - - return { - "orientation": imu_orientation.tolist(), - "angular_velocity": angular_vel.tolist(), - "linear_acceleration": linear_accel.tolist(), - "timestamp": self.sensors_data.timestamp.get_nanoseconds(), - "temperature": self.sensors_data.temperature.get(sl.SENSOR_LOCATION.IMU), - } - else: - return None - - except Exception as e: - logger.error(f"Error getting IMU data: {e}") - return None - - def capture_frame( - self, - ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: - """ - Capture a frame from ZED camera. - - Returns: - Tuple of (left_image, right_image, depth_map) as numpy arrays - """ - if not self.is_opened: - logger.error("ZED camera not opened") - return None, None, None - - try: - # Grab frame - if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: - # Retrieve left image - self.zed.retrieve_image(self.image_left, sl.VIEW.LEFT) - left_img = self.image_left.get_data()[:, :, :3] # Remove alpha channel - - # Retrieve right image - self.zed.retrieve_image(self.image_right, sl.VIEW.RIGHT) - right_img = self.image_right.get_data()[:, :, :3] # Remove alpha channel - - # Retrieve depth map - self.zed.retrieve_measure(self.depth_map, sl.MEASURE.DEPTH) - depth = self.depth_map.get_data() - - return left_img, right_img, depth - else: - logger.warning("Failed to grab frame from ZED camera") - return None, None, None - - except Exception as e: - logger.error(f"Error capturing frame: {e}") - return None, None, None - - def capture_pointcloud(self) -> Optional[o3d.geometry.PointCloud]: - """ - Capture point cloud from ZED camera. - - Returns: - Open3D point cloud with XYZ coordinates and RGB colors - """ - if not self.is_opened: - logger.error("ZED camera not opened") - return None - - try: - if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: - # Retrieve point cloud with RGBA data - self.zed.retrieve_measure(self.point_cloud, sl.MEASURE.XYZRGBA) - point_cloud_data = self.point_cloud.get_data() - - # Convert to numpy array format - height, width = point_cloud_data.shape[:2] - points = point_cloud_data.reshape(-1, 4) - - # Extract XYZ coordinates - xyz = points[:, :3] - - # Extract and unpack RGBA color data from 4th channel - rgba_packed = points[:, 3].view(np.uint32) - - # Unpack RGBA: each 32-bit value contains 4 bytes (R, G, B, A) - colors_rgba = np.zeros((len(rgba_packed), 4), dtype=np.uint8) - colors_rgba[:, 0] = rgba_packed & 0xFF # R - colors_rgba[:, 1] = (rgba_packed >> 8) & 0xFF # G - colors_rgba[:, 2] = (rgba_packed >> 16) & 0xFF # B - colors_rgba[:, 3] = (rgba_packed >> 24) & 0xFF # A - - # Extract RGB (ignore alpha) and normalize to [0, 1] - colors_rgb = colors_rgba[:, :3].astype(np.float64) / 255.0 - - # Filter out invalid points (NaN or inf) - valid = np.isfinite(xyz).all(axis=1) - valid_xyz = xyz[valid] - valid_colors = colors_rgb[valid] - - # Create Open3D point cloud - pcd = o3d.geometry.PointCloud() - - if len(valid_xyz) > 0: - pcd.points = o3d.utility.Vector3dVector(valid_xyz) - pcd.colors = o3d.utility.Vector3dVector(valid_colors) - - return pcd - else: - logger.warning("Failed to grab frame for point cloud") - return None - - except Exception as e: - logger.error(f"Error capturing point cloud: {e}") - return None - - def capture_frame_with_pose( - self, - ) -> Tuple[ - Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[Dict[str, Any]] - ]: - """ - Capture a frame with synchronized pose data. - - Returns: - Tuple of (left_image, right_image, depth_map, pose_data) - """ - if not self.is_opened: - logger.error("ZED camera not opened") - return None, None, None, None - - try: - # Grab frame - if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: - # Get images and depth - left_img, right_img, depth = self.capture_frame() - - # Get synchronized pose if tracking is enabled - pose_data = None - if self.tracking_enabled: - pose_data = self.get_pose() - - return left_img, right_img, depth, pose_data - else: - logger.warning("Failed to grab frame from ZED camera") - return None, None, None, None - - except Exception as e: - logger.error(f"Error capturing frame with pose: {e}") - return None, None, None, None - - def close(self): - """Close the ZED camera.""" - if self.is_opened: - # Disable tracking if enabled - if self.tracking_enabled: - self.disable_positional_tracking() - - self.zed.close() - self.is_opened = False - logger.info("ZED camera closed") - - def get_camera_info(self) -> Dict[str, Any]: - """Get ZED camera information and calibration parameters.""" - if not self.is_opened: - return {} - - try: - info = self.zed.get_camera_information() - calibration = info.camera_configuration.calibration_parameters - - # In ZED SDK 4.0+, the baseline calculation has changed - # Try to get baseline from the stereo parameters - try: - # Method 1: Try to get from stereo parameters if available - if hasattr(calibration, "getCameraBaseline"): - baseline = calibration.getCameraBaseline() - else: - # Method 2: Calculate from left and right camera positions - # The baseline is the distance between left and right cameras - left_cam = calibration.left_cam - right_cam = calibration.right_cam - - # Try different ways to get baseline in SDK 4.0+ - if hasattr(info.camera_configuration, "calibration_parameters_raw"): - # Use raw calibration if available - raw_calib = info.camera_configuration.calibration_parameters_raw - if hasattr(raw_calib, "T"): - baseline = abs(raw_calib.T[0]) - else: - baseline = 0.12 # Default ZED-M baseline approximation - else: - # Use default baseline for ZED-M - baseline = 0.12 # ZED-M baseline is approximately 120mm - except: - baseline = 0.12 # Fallback to approximate ZED-M baseline - - return { - "model": str(info.camera_model), - "serial_number": info.serial_number, - "firmware": info.camera_configuration.firmware_version, - "resolution": { - "width": info.camera_configuration.resolution.width, - "height": info.camera_configuration.resolution.height, - }, - "fps": info.camera_configuration.fps, - "left_cam": { - "fx": calibration.left_cam.fx, - "fy": calibration.left_cam.fy, - "cx": calibration.left_cam.cx, - "cy": calibration.left_cam.cy, - "k1": calibration.left_cam.disto[0], - "k2": calibration.left_cam.disto[1], - "p1": calibration.left_cam.disto[2], - "p2": calibration.left_cam.disto[3], - "k3": calibration.left_cam.disto[4], - }, - "right_cam": { - "fx": calibration.right_cam.fx, - "fy": calibration.right_cam.fy, - "cx": calibration.right_cam.cx, - "cy": calibration.right_cam.cy, - "k1": calibration.right_cam.disto[0], - "k2": calibration.right_cam.disto[1], - "p1": calibration.right_cam.disto[2], - "p2": calibration.right_cam.disto[3], - "k3": calibration.right_cam.disto[4], - }, - "baseline": baseline, - } - except Exception as e: - logger.error(f"Error getting camera info: {e}") - return {} - - def calculate_intrinsics(self): - """Calculate camera intrinsics from ZED calibration.""" - info = self.get_camera_info() - if not info: - return super().calculate_intrinsics() - - left_cam = info.get("left_cam", {}) - resolution = info.get("resolution", {}) - - return { - "focal_length_x": left_cam.get("fx", 0), - "focal_length_y": left_cam.get("fy", 0), - "principal_point_x": left_cam.get("cx", 0), - "principal_point_y": left_cam.get("cy", 0), - "baseline": info.get("baseline", 0), - "resolution_width": resolution.get("width", 0), - "resolution_height": resolution.get("height", 0), - } - - def __enter__(self): - """Context manager entry.""" - if not self.open(): - raise RuntimeError("Failed to open ZED camera") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit.""" - self.close() diff --git a/build/lib/dimos/manipulation/__init__.py b/build/lib/dimos/manipulation/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/manipulation/manip_aio_pipeline.py b/build/lib/dimos/manipulation/manip_aio_pipeline.py deleted file mode 100644 index 22e3f5d49e..0000000000 --- a/build/lib/dimos/manipulation/manip_aio_pipeline.py +++ /dev/null @@ -1,590 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Asynchronous, reactive manipulation pipeline for realtime detection, filtering, and grasp generation. -""" - -import asyncio -import json -import logging -import threading -import time -import traceback -import websockets -from typing import Dict, List, Optional, Any -import numpy as np -import reactivex as rx -import reactivex.operators as ops -from dimos.utils.logging_config import setup_logger -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.perception.grasp_generation.utils import draw_grasps_on_image -from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization -from dimos.perception.common.utils import colorize_depth -from dimos.utils.logging_config import setup_logger -import cv2 - -logger = setup_logger("dimos.perception.manip_aio_pipeline") - - -class ManipulationPipeline: - """ - Clean separated stream pipeline with frame buffering. - - - Object detection runs independently on RGB stream - - Point cloud processing subscribes to both detection and ZED streams separately - - Simple frame buffering to match RGB+depth+objects - """ - - def __init__( - self, - camera_intrinsics: List[float], # [fx, fy, cx, cy] - min_confidence: float = 0.6, - max_objects: int = 10, - vocabulary: Optional[str] = None, - grasp_server_url: Optional[str] = None, - enable_grasp_generation: bool = False, - ): - """ - Initialize the manipulation pipeline. - - Args: - camera_intrinsics: [fx, fy, cx, cy] camera parameters - min_confidence: Minimum detection confidence threshold - max_objects: Maximum number of objects to process - vocabulary: Optional vocabulary for Detic detector - grasp_server_url: Optional WebSocket URL for AnyGrasp server - enable_grasp_generation: Whether to enable async grasp generation - """ - self.camera_intrinsics = camera_intrinsics - self.min_confidence = min_confidence - - # Grasp generation settings - self.grasp_server_url = grasp_server_url - self.enable_grasp_generation = enable_grasp_generation - - # Asyncio event loop for WebSocket communication - self.grasp_loop = None - self.grasp_loop_thread = None - - # Storage for grasp results and filtered objects - self.latest_grasps: List[dict] = [] # Simplified: just a list of grasps - self.grasps_consumed = False - self.latest_filtered_objects = [] - self.latest_rgb_for_grasps = None # Store RGB image for grasp overlay - self.grasp_lock = threading.Lock() - - # Track pending requests - simplified to single task - self.grasp_task: Optional[asyncio.Task] = None - - # Reactive subjects for streaming filtered objects and grasps - self.filtered_objects_subject = rx.subject.Subject() - self.grasps_subject = rx.subject.Subject() - self.grasp_overlay_subject = rx.subject.Subject() # Add grasp overlay subject - - # Initialize grasp client if enabled - if self.enable_grasp_generation and self.grasp_server_url: - self._start_grasp_loop() - - # Initialize object detector - self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) - - # Initialize point cloud processor - self.pointcloud_filter = PointcloudFiltering( - color_intrinsics=camera_intrinsics, - depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics - max_num_objects=max_objects, - ) - - logger.info(f"Initialized ManipulationPipeline with confidence={min_confidence}") - - def create_streams(self, zed_stream: rx.Observable) -> Dict[str, rx.Observable]: - """ - Create streams using exact old main logic. - """ - # Create ZED streams (from old main) - zed_frame_stream = zed_stream.pipe(ops.share()) - - # RGB stream for object detection (from old main) - video_stream = zed_frame_stream.pipe( - ops.map(lambda x: x.get("rgb") if x is not None else None), - ops.filter(lambda x: x is not None), - ops.share(), - ) - object_detector = ObjectDetectionStream( - camera_intrinsics=self.camera_intrinsics, - min_confidence=self.min_confidence, - class_filter=None, - detector=self.detector, - video_stream=video_stream, - disable_depth=True, - ) - - # Store latest frames for point cloud processing (from old main) - latest_rgb = None - latest_depth = None - latest_point_cloud_overlay = None - frame_lock = threading.Lock() - - # Subscribe to combined ZED frames (from old main) - def on_zed_frame(zed_data): - nonlocal latest_rgb, latest_depth - if zed_data is not None: - with frame_lock: - latest_rgb = zed_data.get("rgb") - latest_depth = zed_data.get("depth") - - # Depth stream for point cloud filtering (from old main) - def get_depth_or_overlay(zed_data): - if zed_data is None: - return None - - # Check if we have a point cloud overlay available - with frame_lock: - overlay = latest_point_cloud_overlay - - if overlay is not None: - return overlay - else: - # Return regular colorized depth - return colorize_depth(zed_data.get("depth"), max_depth=10.0) - - depth_stream = zed_frame_stream.pipe( - ops.map(get_depth_or_overlay), ops.filter(lambda x: x is not None), ops.share() - ) - - # Process object detection results with point cloud filtering (from old main) - def on_detection_next(result): - nonlocal latest_point_cloud_overlay - if "objects" in result and result["objects"]: - # Get latest RGB and depth frames - with frame_lock: - rgb = latest_rgb - depth = latest_depth - - if rgb is not None and depth is not None: - try: - filtered_objects = self.pointcloud_filter.process_images( - rgb, depth, result["objects"] - ) - - if filtered_objects: - # Store filtered objects - with self.grasp_lock: - self.latest_filtered_objects = filtered_objects - self.filtered_objects_subject.on_next(filtered_objects) - - # Create base image (colorized depth) - base_image = colorize_depth(depth, max_depth=10.0) - - # Create point cloud overlay visualization - overlay_viz = create_point_cloud_overlay_visualization( - base_image=base_image, - objects=filtered_objects, - intrinsics=self.camera_intrinsics, - ) - - # Store the overlay for the stream - with frame_lock: - latest_point_cloud_overlay = overlay_viz - - # Request grasps if enabled - if self.enable_grasp_generation and len(filtered_objects) > 0: - # Save RGB image for later grasp overlay - with frame_lock: - self.latest_rgb_for_grasps = rgb.copy() - - task = self.request_scene_grasps(filtered_objects) - if task: - # Check for results after a delay - def check_grasps_later(): - time.sleep(2.0) # Wait for grasp processing - # Wait for task to complete - if hasattr(self, "grasp_task") and self.grasp_task: - try: - result = self.grasp_task.result( - timeout=3.0 - ) # Get result with timeout - except Exception as e: - logger.warning(f"Grasp task failed or timeout: {e}") - - # Try to get latest grasps and create overlay - with self.grasp_lock: - grasps = self.latest_grasps - - if grasps and hasattr(self, "latest_rgb_for_grasps"): - # Create grasp overlay on the saved RGB image - try: - bgr_image = cv2.cvtColor( - self.latest_rgb_for_grasps, cv2.COLOR_RGB2BGR - ) - result_bgr = draw_grasps_on_image( - bgr_image, - grasps, - self.camera_intrinsics, - max_grasps=-1, # Show all grasps - ) - result_rgb = cv2.cvtColor( - result_bgr, cv2.COLOR_BGR2RGB - ) - - # Emit grasp overlay immediately - self.grasp_overlay_subject.on_next(result_rgb) - - except Exception as e: - logger.error(f"Error creating grasp overlay: {e}") - - # Emit grasps to stream - self.grasps_subject.on_next(grasps) - - threading.Thread(target=check_grasps_later, daemon=True).start() - else: - logger.warning("Failed to create grasp task") - except Exception as e: - logger.error(f"Error in point cloud filtering: {e}") - with frame_lock: - latest_point_cloud_overlay = None - - def on_error(error): - logger.error(f"Error in stream: {error}") - - def on_completed(): - logger.info("Stream completed") - - def start_subscriptions(): - """Start subscriptions in background thread (from old main)""" - # Subscribe to combined ZED frames - zed_frame_stream.subscribe(on_next=on_zed_frame) - - # Start subscriptions in background thread (from old main) - subscription_thread = threading.Thread(target=start_subscriptions, daemon=True) - subscription_thread.start() - time.sleep(2) # Give subscriptions time to start - - # Subscribe to object detection stream (from old main) - object_detector.get_stream().subscribe( - on_next=on_detection_next, on_error=on_error, on_completed=on_completed - ) - - # Create visualization stream for web interface (from old main) - viz_stream = object_detector.get_stream().pipe( - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), - ) - - # Create filtered objects stream - filtered_objects_stream = self.filtered_objects_subject - - # Create grasps stream - grasps_stream = self.grasps_subject - - # Create grasp overlay subject for immediate emission - grasp_overlay_stream = self.grasp_overlay_subject - - return { - "detection_viz": viz_stream, - "pointcloud_viz": depth_stream, - "objects": object_detector.get_stream().pipe(ops.map(lambda x: x.get("objects", []))), - "filtered_objects": filtered_objects_stream, - "grasps": grasps_stream, - "grasp_overlay": grasp_overlay_stream, - } - - def _start_grasp_loop(self): - """Start asyncio event loop in a background thread for WebSocket communication.""" - - def run_loop(): - self.grasp_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.grasp_loop) - self.grasp_loop.run_forever() - - self.grasp_loop_thread = threading.Thread(target=run_loop, daemon=True) - self.grasp_loop_thread.start() - - # Wait for loop to start - while self.grasp_loop is None: - time.sleep(0.01) - - async def _send_grasp_request( - self, points: np.ndarray, colors: Optional[np.ndarray] - ) -> Optional[List[dict]]: - """Send grasp request to AnyGrasp server.""" - try: - # Comprehensive client-side validation to prevent server errors - - # Validate points array - if points is None: - logger.error("Points array is None") - return None - if not isinstance(points, np.ndarray): - logger.error(f"Points is not numpy array: {type(points)}") - return None - if points.size == 0: - logger.error("Points array is empty") - return None - if len(points.shape) != 2 or points.shape[1] != 3: - logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)") - return None - if points.shape[0] < 100: # Minimum points for stable grasp detection - logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100") - return None - - # Validate and prepare colors - if colors is not None: - if not isinstance(colors, np.ndarray): - colors = None - elif colors.size == 0: - colors = None - elif len(colors.shape) != 2 or colors.shape[1] != 3: - colors = None - elif colors.shape[0] != points.shape[0]: - colors = None - - # If no valid colors, create default colors (required by server) - if colors is None: - # Create default white colors for all points - colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 - - # Ensure data types are correct (server expects float32) - points = points.astype(np.float32) - colors = colors.astype(np.float32) - - # Validate ranges (basic sanity checks) - if np.any(np.isnan(points)) or np.any(np.isinf(points)): - logger.error("Points contain NaN or Inf values") - return None - if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): - logger.error("Colors contain NaN or Inf values") - return None - - # Clamp color values to valid range [0, 1] - colors = np.clip(colors, 0.0, 1.0) - - async with websockets.connect(self.grasp_server_url) as websocket: - request = { - "points": points.tolist(), - "colors": colors.tolist(), # Always send colors array - "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], # Default workspace limits - } - - await websocket.send(json.dumps(request)) - - response = await websocket.recv() - grasps = json.loads(response) - - # Handle server response validation - if isinstance(grasps, dict) and "error" in grasps: - logger.error(f"Server returned error: {grasps['error']}") - return None - elif isinstance(grasps, (int, float)) and grasps == 0: - return None - elif not isinstance(grasps, list): - logger.error( - f"Server returned unexpected response type: {type(grasps)}, value: {grasps}" - ) - return None - elif len(grasps) == 0: - return None - - converted_grasps = self._convert_grasp_format(grasps) - with self.grasp_lock: - self.latest_grasps = converted_grasps - self.grasps_consumed = False # Reset consumed flag - - # Emit to reactive stream - self.grasps_subject.on_next(self.latest_grasps) - - return converted_grasps - except websockets.exceptions.ConnectionClosed as e: - logger.error(f"WebSocket connection closed: {e}") - except websockets.exceptions.WebSocketException as e: - logger.error(f"WebSocket error: {e}") - except json.JSONDecodeError as e: - logger.error(f"Failed to parse server response as JSON: {e}") - except Exception as e: - logger.error(f"Error requesting grasps: {e}") - - return None - - def request_scene_grasps(self, objects: List[dict]) -> Optional[asyncio.Task]: - """Request grasps for entire scene by combining all object point clouds.""" - if not self.grasp_loop or not objects: - return None - - all_points = [] - all_colors = [] - valid_objects = 0 - - for i, obj in enumerate(objects): - # Validate point cloud data - if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: - continue - - points = obj["point_cloud_numpy"] - if not isinstance(points, np.ndarray) or points.size == 0: - continue - - # Ensure points have correct shape (N, 3) - if len(points.shape) != 2 or points.shape[1] != 3: - continue - - # Validate colors if present - colors = None - if "colors_numpy" in obj and obj["colors_numpy"] is not None: - colors = obj["colors_numpy"] - if isinstance(colors, np.ndarray) and colors.size > 0: - # Ensure colors match points count and have correct shape - if colors.shape[0] != points.shape[0]: - colors = None # Ignore colors for this object - elif len(colors.shape) != 2 or colors.shape[1] != 3: - colors = None # Ignore colors for this object - - all_points.append(points) - if colors is not None: - all_colors.append(colors) - valid_objects += 1 - - if not all_points: - return None - - try: - combined_points = np.vstack(all_points) - - # Only combine colors if ALL objects have valid colors - combined_colors = None - if len(all_colors) == valid_objects and len(all_colors) > 0: - combined_colors = np.vstack(all_colors) - - # Validate final combined data - if combined_points.size == 0: - logger.warning("Combined point cloud is empty") - return None - - if combined_colors is not None and combined_colors.shape[0] != combined_points.shape[0]: - logger.warning( - f"Color/point count mismatch: {combined_colors.shape[0]} colors vs {combined_points.shape[0]} points, dropping colors" - ) - combined_colors = None - - except Exception as e: - logger.error(f"Failed to combine point clouds: {e}") - return None - - try: - # Check if there's already a grasp task running - if hasattr(self, "grasp_task") and self.grasp_task and not self.grasp_task.done(): - return self.grasp_task - - task = asyncio.run_coroutine_threadsafe( - self._send_grasp_request(combined_points, combined_colors), self.grasp_loop - ) - - self.grasp_task = task - return task - except Exception as e: - logger.warning("Failed to create grasp task") - return None - - def get_latest_grasps(self, timeout: float = 5.0) -> Optional[List[dict]]: - """Get latest grasp results, waiting for new ones if current ones have been consumed.""" - # Mark current grasps as consumed and get a reference - with self.grasp_lock: - current_grasps = self.latest_grasps - self.grasps_consumed = True - - # If we already have grasps and they haven't been consumed, return them - if current_grasps is not None and not getattr(self, "grasps_consumed", False): - return current_grasps - - # Wait for new grasps - start_time = time.time() - while time.time() - start_time < timeout: - with self.grasp_lock: - # Check if we have new grasps (different from what we marked as consumed) - if self.latest_grasps is not None and not getattr(self, "grasps_consumed", False): - return self.latest_grasps - time.sleep(0.1) # Check every 100ms - - return None # Timeout reached - - def clear_grasps(self) -> None: - """Clear all stored grasp results.""" - with self.grasp_lock: - self.latest_grasps = [] - - def _prepare_colors(self, colors: Optional[np.ndarray]) -> Optional[np.ndarray]: - """Prepare colors array, converting from various formats if needed.""" - if colors is None: - return None - - if colors.max() > 1.0: - colors = colors / 255.0 - - return colors - - def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: - """Convert AnyGrasp format to our visualization format.""" - converted = [] - - for i, grasp in enumerate(anygrasp_grasps): - rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) - euler_angles = self._rotation_matrix_to_euler(rotation_matrix) - - converted_grasp = { - "id": f"grasp_{i}", - "score": grasp.get("score", 0.0), - "width": grasp.get("width", 0.0), - "height": grasp.get("height", 0.0), - "depth": grasp.get("depth", 0.0), - "translation": grasp.get("translation", [0, 0, 0]), - "rotation_matrix": rotation_matrix.tolist(), - "euler_angles": euler_angles, - } - converted.append(converted_grasp) - - converted.sort(key=lambda x: x["score"], reverse=True) - - return converted - - def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: - """Convert rotation matrix to Euler angles (in radians).""" - sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) - - singular = sy < 1e-6 - - if not singular: - x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) - y = np.arctan2(-rotation_matrix[2, 0], sy) - z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) - else: - x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) - y = np.arctan2(-rotation_matrix[2, 0], sy) - z = 0 - - return {"roll": x, "pitch": y, "yaw": z} - - def cleanup(self): - """Clean up resources.""" - if hasattr(self.detector, "cleanup"): - self.detector.cleanup() - - if self.grasp_loop and self.grasp_loop_thread: - self.grasp_loop.call_soon_threadsafe(self.grasp_loop.stop) - self.grasp_loop_thread.join(timeout=1.0) - - if hasattr(self.pointcloud_filter, "cleanup"): - self.pointcloud_filter.cleanup() - logger.info("ManipulationPipeline cleaned up") diff --git a/build/lib/dimos/manipulation/manip_aio_processer.py b/build/lib/dimos/manipulation/manip_aio_processer.py deleted file mode 100644 index a8afc96a7c..0000000000 --- a/build/lib/dimos/manipulation/manip_aio_processer.py +++ /dev/null @@ -1,411 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Sequential manipulation processor for single-frame processing without reactive streams. -""" - -import logging -import time -from typing import Dict, List, Optional, Any, Tuple -import numpy as np -import cv2 - -from dimos.utils.logging_config import setup_logger -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering -from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter -from dimos.perception.grasp_generation.grasp_generation import AnyGraspGenerator -from dimos.perception.grasp_generation.utils import create_grasp_overlay -from dimos.perception.pointcloud.utils import ( - create_point_cloud_overlay_visualization, - extract_and_cluster_misc_points, - overlay_point_clouds_on_image, -) -from dimos.perception.common.utils import ( - colorize_depth, - detection_results_to_object_data, - combine_object_data, -) - -logger = setup_logger("dimos.perception.manip_aio_processor") - - -class ManipulationProcessor: - """ - Sequential manipulation processor for single-frame processing. - - Processes RGB-D frames through object detection, point cloud filtering, - and AnyGrasp grasp generation in a single thread without reactive streams. - """ - - def __init__( - self, - camera_intrinsics: List[float], # [fx, fy, cx, cy] - min_confidence: float = 0.6, - max_objects: int = 20, - vocabulary: Optional[str] = None, - enable_grasp_generation: bool = False, - grasp_server_url: Optional[str] = None, # Required when enable_grasp_generation=True - enable_segmentation: bool = True, - ): - """ - Initialize the manipulation processor. - - Args: - camera_intrinsics: [fx, fy, cx, cy] camera parameters - min_confidence: Minimum detection confidence threshold - max_objects: Maximum number of objects to process - vocabulary: Optional vocabulary for Detic detector - enable_grasp_generation: Whether to enable grasp generation - grasp_server_url: WebSocket URL for AnyGrasp server (required when enable_grasp_generation=True) - enable_segmentation: Whether to enable semantic segmentation - segmentation_model: Segmentation model to use (SAM 2 or FastSAM) - """ - self.camera_intrinsics = camera_intrinsics - self.min_confidence = min_confidence - self.max_objects = max_objects - self.enable_grasp_generation = enable_grasp_generation - self.grasp_server_url = grasp_server_url - self.enable_segmentation = enable_segmentation - - # Validate grasp generation requirements - if enable_grasp_generation and not grasp_server_url: - raise ValueError("grasp_server_url is required when enable_grasp_generation=True") - - # Initialize object detector - self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) - - # Initialize point cloud processor - self.pointcloud_filter = PointcloudFiltering( - color_intrinsics=camera_intrinsics, - depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics - max_num_objects=max_objects, - ) - - # Initialize semantic segmentation - self.segmenter = None - if self.enable_segmentation: - self.segmenter = Sam2DSegmenter( - device="cuda", - use_tracker=False, # Disable tracker for simple segmentation - use_analyzer=False, # Disable analyzer for simple segmentation - ) - - # Initialize grasp generator if enabled - self.grasp_generator = None - if self.enable_grasp_generation: - try: - self.grasp_generator = AnyGraspGenerator(server_url=grasp_server_url) - logger.info("AnyGrasp generator initialized successfully") - except Exception as e: - logger.error(f"Failed to initialize AnyGrasp generator: {e}") - self.grasp_generator = None - self.enable_grasp_generation = False - - logger.info( - f"Initialized ManipulationProcessor with confidence={min_confidence}, " - f"grasp_generation={enable_grasp_generation}" - ) - - def process_frame( - self, rgb_image: np.ndarray, depth_image: np.ndarray, generate_grasps: bool = None - ) -> Dict[str, Any]: - """ - Process a single RGB-D frame through the complete pipeline. - - Args: - rgb_image: RGB image (H, W, 3) - depth_image: Depth image (H, W) in meters - generate_grasps: Override grasp generation setting for this frame - - Returns: - Dictionary containing: - - detection_viz: Visualization of object detection - - pointcloud_viz: Visualization of point cloud overlay - - segmentation_viz: Visualization of semantic segmentation (if enabled) - - detection2d_objects: Raw detection results as ObjectData - - segmentation2d_objects: Raw segmentation results as ObjectData (if enabled) - - detected_objects: Detection (Object Detection) objects with point clouds filtered - - all_objects: Combined objects with intelligent duplicate removal - - full_pointcloud: Complete scene point cloud (if point cloud processing enabled) - - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) - - misc_voxel_grid: Open3D voxel grid approximating all misc/background points - - misc_pointcloud_viz: Visualization of misc/background cluster overlay - - grasps: Grasp results (AnyGrasp list of dictionaries, if enabled) - - grasp_overlay: Grasp visualization overlay (if enabled) - - processing_time: Total processing time - """ - start_time = time.time() - results = {} - - try: - # Step 1: Object Detection - step_start = time.time() - detection_results = self.run_object_detection(rgb_image) - results["detection2d_objects"] = detection_results.get("objects", []) - results["detection_viz"] = detection_results.get("viz_frame") - detection_time = time.time() - step_start - - # Step 2: Semantic Segmentation (if enabled) - segmentation_time = 0 - if self.enable_segmentation: - step_start = time.time() - segmentation_results = self.run_segmentation(rgb_image) - results["segmentation2d_objects"] = segmentation_results.get("objects", []) - results["segmentation_viz"] = segmentation_results.get("viz_frame") - segmentation_time = time.time() - step_start - - # Step 3: Point Cloud Processing - pointcloud_time = 0 - detection2d_objects = results.get("detection2d_objects", []) - segmentation2d_objects = results.get("segmentation2d_objects", []) - - # Process detection objects if available - detected_objects = [] - if detection2d_objects: - step_start = time.time() - detected_objects = self.run_pointcloud_filtering( - rgb_image, depth_image, detection2d_objects - ) - pointcloud_time += time.time() - step_start - - # Process segmentation objects if available - segmentation_filtered_objects = [] - if segmentation2d_objects: - step_start = time.time() - segmentation_filtered_objects = self.run_pointcloud_filtering( - rgb_image, depth_image, segmentation2d_objects - ) - pointcloud_time += time.time() - step_start - - # Combine all objects using intelligent duplicate removal - all_objects = combine_object_data( - detected_objects, segmentation_filtered_objects, overlap_threshold=0.8 - ) - - # Get full point cloud - full_pcd = self.pointcloud_filter.get_full_point_cloud() - - # Extract misc/background points and create voxel grid - misc_start = time.time() - misc_clusters, misc_voxel_grid = extract_and_cluster_misc_points( - full_pcd, - all_objects, - eps=0.03, - min_points=100, - enable_filtering=True, - voxel_size=0.02, - ) - misc_time = time.time() - misc_start - - # Store results - results.update( - { - "detected_objects": detected_objects, - "all_objects": all_objects, - "full_pointcloud": full_pcd, - "misc_clusters": misc_clusters, - "misc_voxel_grid": misc_voxel_grid, - } - ) - - # Create point cloud visualizations - base_image = colorize_depth(depth_image, max_depth=10.0) - - # Create visualizations - results["pointcloud_viz"] = ( - create_point_cloud_overlay_visualization( - base_image=base_image, - objects=all_objects, - intrinsics=self.camera_intrinsics, - ) - if all_objects - else base_image - ) - - results["detected_pointcloud_viz"] = ( - create_point_cloud_overlay_visualization( - base_image=base_image, - objects=detected_objects, - intrinsics=self.camera_intrinsics, - ) - if detected_objects - else base_image - ) - - if misc_clusters: - # Generate consistent colors for clusters - cluster_colors = [ - tuple((np.random.RandomState(i + 100).rand(3) * 255).astype(int)) - for i in range(len(misc_clusters)) - ] - results["misc_pointcloud_viz"] = overlay_point_clouds_on_image( - base_image=base_image, - point_clouds=misc_clusters, - camera_intrinsics=self.camera_intrinsics, - colors=cluster_colors, - point_size=2, - alpha=0.6, - ) - else: - results["misc_pointcloud_viz"] = base_image - - # Step 4: Grasp Generation (if enabled) - should_generate_grasps = ( - generate_grasps if generate_grasps is not None else self.enable_grasp_generation - ) - - if should_generate_grasps and all_objects and full_pcd: - grasps = self.run_grasp_generation(all_objects, full_pcd) - results["grasps"] = grasps - if grasps: - results["grasp_overlay"] = create_grasp_overlay( - rgb_image, grasps, self.camera_intrinsics - ) - - except Exception as e: - logger.error(f"Error processing frame: {e}") - results["error"] = str(e) - - # Add timing information - total_time = time.time() - start_time - results.update( - { - "processing_time": total_time, - "timing_breakdown": { - "detection": detection_time if "detection_time" in locals() else 0, - "segmentation": segmentation_time if "segmentation_time" in locals() else 0, - "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, - "misc_extraction": misc_time if "misc_time" in locals() else 0, - "total": total_time, - }, - } - ) - - return results - - def run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: - """Run object detection on RGB image.""" - try: - # Convert RGB to BGR for Detic detector - bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) - - # Use process_image method from Detic detector - bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image( - bgr_image - ) - - # Convert to ObjectData format using utility function - objects = detection_results_to_object_data( - bboxes=bboxes, - track_ids=track_ids, - class_ids=class_ids, - confidences=confidences, - names=names, - masks=masks, - source="detection", - ) - - # Create visualization using detector's built-in method - viz_frame = self.detector.visualize_results( - rgb_image, bboxes, track_ids, class_ids, confidences, names - ) - - return {"objects": objects, "viz_frame": viz_frame} - - except Exception as e: - logger.error(f"Object detection failed: {e}") - return {"objects": [], "viz_frame": rgb_image.copy()} - - def run_pointcloud_filtering( - self, rgb_image: np.ndarray, depth_image: np.ndarray, objects: List[Dict] - ) -> List[Dict]: - """Run point cloud filtering on detected objects.""" - try: - filtered_objects = self.pointcloud_filter.process_images( - rgb_image, depth_image, objects - ) - return filtered_objects if filtered_objects else [] - except Exception as e: - logger.error(f"Point cloud filtering failed: {e}") - return [] - - def run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: - """Run semantic segmentation on RGB image.""" - if not self.segmenter: - return {"objects": [], "viz_frame": rgb_image.copy()} - - try: - # Convert RGB to BGR for segmenter - bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) - - # Get segmentation results - masks, bboxes, track_ids, probs, names = self.segmenter.process_image(bgr_image) - - # Convert to ObjectData format using utility function - objects = detection_results_to_object_data( - bboxes=bboxes, - track_ids=track_ids, - class_ids=list(range(len(bboxes))), # Use indices as class IDs for segmentation - confidences=probs, - names=names, - masks=masks, - source="segmentation", - ) - - # Create visualization - if masks: - viz_bgr = self.segmenter.visualize_results( - bgr_image, masks, bboxes, track_ids, probs, names - ) - # Convert back to RGB - viz_frame = cv2.cvtColor(viz_bgr, cv2.COLOR_BGR2RGB) - else: - viz_frame = rgb_image.copy() - - return {"objects": objects, "viz_frame": viz_frame} - - except Exception as e: - logger.error(f"Segmentation failed: {e}") - return {"objects": [], "viz_frame": rgb_image.copy()} - - def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[List[Dict]]: - """Run grasp generation using the configured generator (AnyGrasp).""" - if not self.grasp_generator: - logger.warning("Grasp generation requested but no generator available") - return None - - try: - # Generate grasps using the configured generator - grasps = self.grasp_generator.generate_grasps_from_objects(filtered_objects, full_pcd) - - # Return parsed results directly (list of grasp dictionaries) - return grasps - - except Exception as e: - logger.error(f"AnyGrasp grasp generation failed: {e}") - return None - - def cleanup(self): - """Clean up resources.""" - if hasattr(self.detector, "cleanup"): - self.detector.cleanup() - if hasattr(self.pointcloud_filter, "cleanup"): - self.pointcloud_filter.cleanup() - if self.segmenter and hasattr(self.segmenter, "cleanup"): - self.segmenter.cleanup() - if self.grasp_generator and hasattr(self.grasp_generator, "cleanup"): - self.grasp_generator.cleanup() - logger.info("ManipulationProcessor cleaned up") diff --git a/build/lib/dimos/manipulation/manipulation_history.py b/build/lib/dimos/manipulation/manipulation_history.py deleted file mode 100644 index 8404b225c1..0000000000 --- a/build/lib/dimos/manipulation/manipulation_history.py +++ /dev/null @@ -1,418 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module for manipulation history tracking and search.""" - -from typing import Dict, List, Optional, Any, Tuple, Union, Set, Callable -from dataclasses import dataclass, field -import time -from datetime import datetime -import os -import json -import pickle -import uuid - -from dimos.types.manipulation import ( - ManipulationTask, - AbstractConstraint, - ManipulationTaskConstraint, - ManipulationMetadata, -) -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.types.manipulation_history") - - -@dataclass -class ManipulationHistoryEntry: - """An entry in the manipulation history. - - Attributes: - task: The manipulation task executed - timestamp: When the manipulation was performed - result: Result of the manipulation (success/failure) - manipulation_response: Response from the motion planner/manipulation executor - """ - - task: ManipulationTask - timestamp: float = field(default_factory=time.time) - result: Dict[str, Any] = field(default_factory=dict) - manipulation_response: Optional[str] = ( - None # Any elaborative response from the motion planner / manipulation executor - ) - - def __str__(self) -> str: - status = self.result.get("status", "unknown") - return f"ManipulationHistoryEntry(task='{self.task.description}', status={status}, time={datetime.fromtimestamp(self.timestamp).strftime('%H:%M:%S')})" - - -class ManipulationHistory: - """A simplified, dictionary-based storage for manipulation history. - - This class provides an efficient way to store and query manipulation tasks, - focusing on quick lookups and flexible search capabilities. - """ - - def __init__(self, output_dir: str = None, new_memory: bool = False): - """Initialize a new manipulation history. - - Args: - output_dir: Directory to save history to - new_memory: If True, creates a new memory instead of loading existing one - """ - self._history: List[ManipulationHistoryEntry] = [] - self._output_dir = output_dir - - if output_dir and not new_memory: - self.load_from_dir(output_dir) - elif output_dir: - os.makedirs(output_dir, exist_ok=True) - logger.info(f"Created new manipulation history at {output_dir}") - - def __len__(self) -> int: - """Return the number of entries in the history.""" - return len(self._history) - - def __str__(self) -> str: - """Return a string representation of the history.""" - if not self._history: - return "ManipulationHistory(empty)" - - return ( - f"ManipulationHistory(entries={len(self._history)}, " - f"time_range={datetime.fromtimestamp(self._history[0].timestamp).strftime('%Y-%m-%d %H:%M:%S')} to " - f"{datetime.fromtimestamp(self._history[-1].timestamp).strftime('%Y-%m-%d %H:%M:%S')})" - ) - - def clear(self) -> None: - """Clear all entries from the history.""" - self._history.clear() - logger.info("Cleared manipulation history") - - if self._output_dir: - self.save_history() - - def add_entry(self, entry: ManipulationHistoryEntry) -> None: - """Add an entry to the history. - - Args: - entry: The entry to add - """ - self._history.append(entry) - self._history.sort(key=lambda e: e.timestamp) - - if self._output_dir: - self.save_history() - - def save_history(self) -> None: - """Save the history to the output directory.""" - if not self._output_dir: - logger.warning("Cannot save history: no output directory specified") - return - - os.makedirs(self._output_dir, exist_ok=True) - history_path = os.path.join(self._output_dir, "manipulation_history.pickle") - - with open(history_path, "wb") as f: - pickle.dump(self._history, f) - - logger.info(f"Saved manipulation history to {history_path}") - - # Also save a JSON representation for easier inspection - json_path = os.path.join(self._output_dir, "manipulation_history.json") - try: - history_data = [ - { - "task": { - "description": entry.task.description, - "target_object": entry.task.target_object, - "target_point": entry.task.target_point, - "timestamp": entry.task.timestamp, - "task_id": entry.task.task_id, - "metadata": entry.task.metadata, - }, - "result": entry.result, - "timestamp": entry.timestamp, - "manipulation_response": entry.manipulation_response, - } - for entry in self._history - ] - - with open(json_path, "w") as f: - json.dump(history_data, f, indent=2) - - logger.info(f"Saved JSON representation to {json_path}") - except Exception as e: - logger.error(f"Failed to save JSON representation: {e}") - - def load_from_dir(self, directory: str) -> None: - """Load history from the specified directory. - - Args: - directory: Directory to load history from - """ - history_path = os.path.join(directory, "manipulation_history.pickle") - - if not os.path.exists(history_path): - logger.warning(f"No history found at {history_path}") - return - - try: - with open(history_path, "rb") as f: - self._history = pickle.load(f) - - logger.info( - f"Loaded manipulation history from {history_path} with {len(self._history)} entries" - ) - except Exception as e: - logger.error(f"Failed to load history: {e}") - - def get_all_entries(self) -> List[ManipulationHistoryEntry]: - """Get all entries in chronological order. - - Returns: - List of all manipulation history entries - """ - return self._history.copy() - - def get_entry_by_index(self, index: int) -> Optional[ManipulationHistoryEntry]: - """Get an entry by its index. - - Args: - index: Index of the entry to retrieve - - Returns: - The entry at the specified index or None if index is out of bounds - """ - if 0 <= index < len(self._history): - return self._history[index] - return None - - def get_entries_by_timerange( - self, start_time: float, end_time: float - ) -> List[ManipulationHistoryEntry]: - """Get entries within a specific time range. - - Args: - start_time: Start time (UNIX timestamp) - end_time: End time (UNIX timestamp) - - Returns: - List of entries within the specified time range - """ - return [entry for entry in self._history if start_time <= entry.timestamp <= end_time] - - def get_entries_by_object(self, object_name: str) -> List[ManipulationHistoryEntry]: - """Get entries related to a specific object. - - Args: - object_name: Name of the object to search for - - Returns: - List of entries related to the specified object - """ - return [entry for entry in self._history if entry.task.target_object == object_name] - - def create_task_entry( - self, task: ManipulationTask, result: Dict[str, Any] = None, agent_response: str = None - ) -> ManipulationHistoryEntry: - """Create a new manipulation history entry. - - Args: - task: The manipulation task - result: Result of the manipulation - agent_response: Response from the agent about this manipulation - - Returns: - The created history entry - """ - entry = ManipulationHistoryEntry( - task=task, result=result or {}, manipulation_response=agent_response - ) - self.add_entry(entry) - return entry - - def search(self, **kwargs) -> List[ManipulationHistoryEntry]: - """Flexible search method that can search by any field in ManipulationHistoryEntry using dot notation. - - This method supports dot notation to access nested fields. String values automatically use - substring matching (contains), while all other types use exact matching. - - Examples: - # Time-based searches: - - search(**{"task.metadata.timestamp": ('>', start_time)}) - entries after start_time - - search(**{"task.metadata.timestamp": ('>=', time - 1800)}) - entries in last 30 mins - - # Constraint searches: - - search(**{"task.constraints.*.reference_point.x": 2.5}) - tasks with x=2.5 reference point - - search(**{"task.constraints.*.end_angle.x": 90}) - tasks with 90-degree x rotation - - search(**{"task.constraints.*.lock_x": True}) - tasks with x-axis translation locked - - # Object and result searches: - - search(**{"task.metadata.objects.*.label": "cup"}) - tasks involving cups - - search(**{"result.status": "success"}) - successful tasks - - search(**{"result.error": "Collision"}) - tasks that had collisions - - Args: - **kwargs: Key-value pairs for searching using dot notation for field paths. - - Returns: - List of matching entries - """ - if not kwargs: - return self._history.copy() - - results = self._history.copy() - - for key, value in kwargs.items(): - # For all searches, automatically determine if we should use contains for strings - results = [e for e in results if self._check_field_match(e, key, value)] - - return results - - def _check_field_match(self, entry, field_path, value) -> bool: - """Check if a field matches the value, with special handling for strings, collections and comparisons. - - For string values, we automatically use substring matching (contains). - For collections (returned by * path), we check if any element matches. - For numeric values (like timestamps), supports >, <, >= and <= comparisons. - For all other types, we use exact matching. - - Args: - entry: The entry to check - field_path: Dot-separated path to the field - value: Value to match against. For comparisons, use tuples like: - ('>', timestamp) - greater than - ('<', timestamp) - less than - ('>=', timestamp) - greater or equal - ('<=', timestamp) - less or equal - - Returns: - True if the field matches the value, False otherwise - """ - try: - field_value = self._get_value_by_path(entry, field_path) - - # Handle comparison operators for timestamps and numbers - if isinstance(value, tuple) and len(value) == 2: - op, compare_value = value - if op == ">": - return field_value > compare_value - elif op == "<": - return field_value < compare_value - elif op == ">=": - return field_value >= compare_value - elif op == "<=": - return field_value <= compare_value - - # Handle lists (from collection searches) - if isinstance(field_value, list): - for item in field_value: - # String values use contains matching - if isinstance(item, str) and isinstance(value, str): - if value in item: - return True - # All other types use exact matching - elif item == value: - return True - return False - - # String values use contains matching - elif isinstance(field_value, str) and isinstance(value, str): - return value in field_value - # All other types use exact matching - else: - return field_value == value - - except (AttributeError, KeyError): - return False - - def _get_value_by_path(self, obj, path): - """Get a value from an object using a dot-separated path. - - This method handles three special cases: - 1. Regular attribute access (obj.attr) - 2. Dictionary key access (dict[key]) - 3. Collection search (dict.*.attr) - when * is used, it searches all values in the collection - - Args: - obj: Object to get value from - path: Dot-separated path to the field (e.g., "task.metadata.robot") - - Returns: - Value at the specified path or list of values for collection searches - - Raises: - AttributeError: If an attribute in the path doesn't exist - KeyError: If a dictionary key in the path doesn't exist - """ - current = obj - parts = path.split(".") - - for i, part in enumerate(parts): - # Collection search (*.attr) - search across all items in a collection - if part == "*": - # Get remaining path parts - remaining_path = ".".join(parts[i + 1 :]) - - # Handle different collection types - if isinstance(current, dict): - items = current.values() - if not remaining_path: # If * is the last part, return all values - return list(items) - elif isinstance(current, list): - items = current - if not remaining_path: # If * is the last part, return all items - return items - else: # Not a collection - raise AttributeError( - f"Cannot use wildcard on non-collection type: {type(current)}" - ) - - # Apply remaining path to each item in the collection - results = [] - for item in items: - try: - # Recursively get values from each item - value = self._get_value_by_path(item, remaining_path) - if isinstance(value, list): # Flatten nested lists - results.extend(value) - else: - results.append(value) - except (AttributeError, KeyError): - # Skip items that don't have the attribute - pass - return results - - # Regular attribute/key access - elif isinstance(current, dict): - current = current[part] - else: - current = getattr(current, part) - - return current diff --git a/build/lib/dimos/manipulation/manipulation_interface.py b/build/lib/dimos/manipulation/manipulation_interface.py deleted file mode 100644 index 68d3924a99..0000000000 --- a/build/lib/dimos/manipulation/manipulation_interface.py +++ /dev/null @@ -1,292 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -ManipulationInterface provides a unified interface for accessing manipulation history. - -This module defines the ManipulationInterface class, which serves as an access point -for the robot's manipulation history, agent-generated constraints, and manipulation -metadata streams. -""" - -from typing import Dict, List, Optional, Any, Tuple, Union -from dataclasses import dataclass -import os -import time -from datetime import datetime -from reactivex.disposable import Disposable -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.types.manipulation import ( - AbstractConstraint, - TranslationConstraint, - RotationConstraint, - ForceConstraint, - ManipulationTaskConstraint, - ManipulationTask, - ManipulationMetadata, - ObjectData, -) -from dimos.manipulation.manipulation_history import ( - ManipulationHistory, - ManipulationHistoryEntry, -) -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.manipulation_interface") - - -class ManipulationInterface: - """ - Interface for accessing and managing robot manipulation data. - - This class provides a unified interface for managing manipulation tasks and constraints. - It maintains a list of constraints generated by the Agent and provides methods to - add and manage manipulation tasks. - """ - - def __init__( - self, - output_dir: str, - new_memory: bool = False, - perception_stream: ObjectDetectionStream = None, - ): - """ - Initialize a new ManipulationInterface instance. - - Args: - output_dir: Directory for storing manipulation data - new_memory: If True, creates a new manipulation history from scratch - perception_stream: ObjectDetectionStream instance for real-time object data - """ - self.output_dir = output_dir - - # Create manipulation history directory - manipulation_dir = os.path.join(output_dir, "manipulation_history") - os.makedirs(manipulation_dir, exist_ok=True) - - # Initialize manipulation history - self.manipulation_history: ManipulationHistory = ManipulationHistory( - output_dir=manipulation_dir, new_memory=new_memory - ) - - # List of constraints generated by the Agent via constraint generation skills - self.agent_constraints: List[AbstractConstraint] = [] - - # Initialize object detection stream and related properties - self.perception_stream = perception_stream - self.latest_objects: List[ObjectData] = [] - self.stream_subscription: Optional[Disposable] = None - - # Set up subscription to perception stream if available - self._setup_perception_subscription() - - logger.info("ManipulationInterface initialized") - - def add_constraint(self, constraint: AbstractConstraint) -> None: - """ - Add a constraint generated by the Agent via a constraint generation skill. - - Args: - constraint: The constraint to add to agent_constraints - """ - self.agent_constraints.append(constraint) - logger.info(f"Added agent constraint: {constraint}") - - def get_constraints(self) -> List[AbstractConstraint]: - """ - Get all constraints generated by the Agent via constraint generation skills. - - Returns: - List of all constraints created by the Agent - """ - return self.agent_constraints - - def get_constraint(self, constraint_id: str) -> Optional[AbstractConstraint]: - """ - Get a specific constraint by its ID. - - Args: - constraint_id: ID of the constraint to retrieve - - Returns: - The matching constraint or None if not found - """ - # Find constraint with matching ID - for constraint in self.agent_constraints: - if constraint.id == constraint_id: - return constraint - - logger.warning(f"Constraint with ID {constraint_id} not found") - return None - - def add_manipulation_task( - self, task: ManipulationTask, manipulation_response: Optional[str] = None - ) -> None: - """ - Add a manipulation task to ManipulationHistory. - - Args: - task: The ManipulationTask to add - manipulation_response: Optional response from the motion planner/executor - - """ - # Add task to history - self.manipulation_history.add_entry( - task=task, result=None, notes=None, manipulation_response=manipulation_response - ) - - def get_manipulation_task(self, task_id: str) -> Optional[ManipulationTask]: - """ - Get a manipulation task by its ID. - - Args: - task_id: ID of the task to retrieve - - Returns: - The task object or None if not found - """ - return self.history.get_manipulation_task(task_id) - - def get_all_manipulation_tasks(self) -> List[ManipulationTask]: - """ - Get all manipulation tasks. - - Returns: - List of all manipulation tasks - """ - return self.history.get_all_manipulation_tasks() - - def update_task_status( - self, task_id: str, status: str, result: Optional[Dict[str, Any]] = None - ) -> Optional[ManipulationTask]: - """ - Update the status and result of a manipulation task. - - Args: - task_id: ID of the task to update - status: New status for the task (e.g., 'completed', 'failed') - result: Optional dictionary with result data - - Returns: - The updated task or None if task not found - """ - return self.history.update_task_status(task_id, status, result) - - # === Perception stream methods === - - def _setup_perception_subscription(self): - """ - Set up subscription to perception stream if available. - """ - if self.perception_stream: - # Subscribe to the stream and update latest_objects - self.stream_subscription = self.perception_stream.get_stream().subscribe( - on_next=self._update_latest_objects, - on_error=lambda e: logger.error(f"Error in perception stream: {e}"), - ) - logger.info("Subscribed to perception stream") - - def _update_latest_objects(self, data): - """ - Update the latest detected objects. - - Args: - data: Data from the object detection stream - """ - if "objects" in data: - self.latest_objects = data["objects"] - - def get_latest_objects(self) -> List[ObjectData]: - """ - Get the latest detected objects from the stream. - - Returns: - List of the most recently detected objects - """ - return self.latest_objects - - def get_object_by_id(self, object_id: int) -> Optional[ObjectData]: - """ - Get a specific object by its tracking ID. - - Args: - object_id: Tracking ID of the object - - Returns: - The object data or None if not found - """ - for obj in self.latest_objects: - if obj["object_id"] == object_id: - return obj - return None - - def get_objects_by_label(self, label: str) -> List[ObjectData]: - """ - Get all objects with a specific label. - - Args: - label: Class label to filter objects by - - Returns: - List of objects matching the label - """ - return [obj for obj in self.latest_objects if obj["label"] == label] - - def set_perception_stream(self, perception_stream): - """ - Set or update the perception stream. - - Args: - perception_stream: The PerceptionStream instance - """ - # Clean up existing subscription if any - self.cleanup_perception_subscription() - - # Set new stream and subscribe - self.perception_stream = perception_stream - self._setup_perception_subscription() - - def cleanup_perception_subscription(self): - """ - Clean up the stream subscription. - """ - if self.stream_subscription: - self.stream_subscription.dispose() - self.stream_subscription = None - - # === Utility methods === - - def clear_history(self) -> None: - """ - Clear all manipulation history data and agent constraints. - """ - self.manipulation_history.clear() - self.agent_constraints.clear() - logger.info("Cleared manipulation history and agent constraints") - - def __str__(self) -> str: - """ - String representation of the manipulation interface. - - Returns: - String representation with key stats - """ - has_stream = self.perception_stream is not None - return f"ManipulationInterface(history={self.manipulation_history}, agent_constraints={len(self.agent_constraints)}, perception_stream={has_stream}, detected_objects={len(self.latest_objects)})" - - def __del__(self): - """ - Clean up resources on deletion. - """ - self.cleanup_perception_subscription() diff --git a/build/lib/dimos/manipulation/test_manipulation_history.py b/build/lib/dimos/manipulation/test_manipulation_history.py deleted file mode 100644 index 239a04a86f..0000000000 --- a/build/lib/dimos/manipulation/test_manipulation_history.py +++ /dev/null @@ -1,461 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -import tempfile -import pytest -from typing import Dict, List, Optional, Any, Tuple - -from dimos.manipulation.manipulation_history import ManipulationHistory, ManipulationHistoryEntry -from dimos.types.manipulation import ( - ManipulationTask, - AbstractConstraint, - TranslationConstraint, - RotationConstraint, - ForceConstraint, - ManipulationTaskConstraint, - ManipulationMetadata, -) -from dimos.types.vector import Vector - - -@pytest.fixture -def sample_task(): - """Create a sample manipulation task for testing.""" - return ManipulationTask( - description="Pick up the cup", - target_object="cup", - target_point=(100, 200), - task_id="task1", - metadata={ - "timestamp": time.time(), - "objects": { - "cup1": { - "object_id": 1, - "label": "cup", - "confidence": 0.95, - "position": {"x": 1.5, "y": 2.0, "z": 0.5}, - }, - "table1": { - "object_id": 2, - "label": "table", - "confidence": 0.98, - "position": {"x": 0.0, "y": 0.0, "z": 0.0}, - }, - }, - }, - ) - - -@pytest.fixture -def sample_task_with_constraints(): - """Create a sample manipulation task with constraints for testing.""" - task = ManipulationTask( - description="Rotate the bottle", - target_object="bottle", - target_point=(150, 250), - task_id="task2", - metadata={ - "timestamp": time.time(), - "objects": { - "bottle1": { - "object_id": 3, - "label": "bottle", - "confidence": 0.92, - "position": {"x": 2.5, "y": 1.0, "z": 0.3}, - } - }, - }, - ) - - # Add rich translation constraint - translation_constraint = TranslationConstraint( - translation_axis="y", - reference_point=Vector(2.5, 1.0, 0.3), - bounds_min=Vector(2.0, 0.5, 0.3), - bounds_max=Vector(3.0, 1.5, 0.3), - target_point=Vector(2.7, 1.2, 0.3), - description="Constrained translation along Y-axis only", - ) - task.add_constraint(translation_constraint) - - # Add rich rotation constraint - rotation_constraint = RotationConstraint( - rotation_axis="roll", - start_angle=Vector(0, 0, 0), - end_angle=Vector(90, 0, 0), - pivot_point=Vector(2.5, 1.0, 0.3), - secondary_pivot_point=Vector(2.5, 1.0, 0.5), - description="Constrained rotation around X-axis (roll only)", - ) - task.add_constraint(rotation_constraint) - - # Add force constraint - force_constraint = ForceConstraint( - min_force=2.0, - max_force=5.0, - force_direction=Vector(0, 0, -1), - description="Apply moderate downward force during manipulation", - ) - task.add_constraint(force_constraint) - - return task - - -@pytest.fixture -def temp_output_dir(): - """Create a temporary directory for testing history saving/loading.""" - with tempfile.TemporaryDirectory() as temp_dir: - yield temp_dir - - -@pytest.fixture -def populated_history(sample_task, sample_task_with_constraints): - """Create a populated history with multiple entries for testing.""" - history = ManipulationHistory() - - # Add first entry - entry1 = ManipulationHistoryEntry( - task=sample_task, - result={"status": "success", "execution_time": 2.5}, - manipulation_response="Successfully picked up the cup", - ) - history.add_entry(entry1) - - # Add second entry - entry2 = ManipulationHistoryEntry( - task=sample_task_with_constraints, - result={"status": "failure", "error": "Collision detected"}, - manipulation_response="Failed to rotate the bottle due to collision", - ) - history.add_entry(entry2) - - return history - - -def test_manipulation_history_init(): - """Test initialization of ManipulationHistory.""" - # Default initialization - history = ManipulationHistory() - assert len(history) == 0 - assert str(history) == "ManipulationHistory(empty)" - - # With output directory - with tempfile.TemporaryDirectory() as temp_dir: - history = ManipulationHistory(output_dir=temp_dir, new_memory=True) - assert len(history) == 0 - assert os.path.exists(temp_dir) - - -def test_manipulation_history_add_entry(sample_task): - """Test adding entries to ManipulationHistory.""" - history = ManipulationHistory() - - # Create and add entry - entry = ManipulationHistoryEntry( - task=sample_task, result={"status": "success"}, manipulation_response="Task completed" - ) - history.add_entry(entry) - - assert len(history) == 1 - assert history.get_entry_by_index(0) == entry - - -def test_manipulation_history_create_task_entry(sample_task): - """Test creating a task entry directly.""" - history = ManipulationHistory() - - entry = history.create_task_entry( - task=sample_task, result={"status": "success"}, agent_response="Task completed" - ) - - assert len(history) == 1 - assert entry.task == sample_task - assert entry.result["status"] == "success" - assert entry.manipulation_response == "Task completed" - - -def test_manipulation_history_save_load(temp_output_dir, sample_task): - """Test saving and loading history from disk.""" - # Create history and add entry - history = ManipulationHistory(output_dir=temp_output_dir) - entry = history.create_task_entry( - task=sample_task, result={"status": "success"}, agent_response="Task completed" - ) - - # Check that files were created - pickle_path = os.path.join(temp_output_dir, "manipulation_history.pickle") - json_path = os.path.join(temp_output_dir, "manipulation_history.json") - assert os.path.exists(pickle_path) - assert os.path.exists(json_path) - - # Create new history that loads from the saved files - loaded_history = ManipulationHistory(output_dir=temp_output_dir) - assert len(loaded_history) == 1 - assert loaded_history.get_entry_by_index(0).task.description == sample_task.description - - -def test_manipulation_history_clear(populated_history): - """Test clearing the history.""" - assert len(populated_history) > 0 - - populated_history.clear() - assert len(populated_history) == 0 - assert str(populated_history) == "ManipulationHistory(empty)" - - -def test_manipulation_history_get_methods(populated_history): - """Test various getter methods of ManipulationHistory.""" - # get_all_entries - entries = populated_history.get_all_entries() - assert len(entries) == 2 - - # get_entry_by_index - entry = populated_history.get_entry_by_index(0) - assert entry.task.task_id == "task1" - - # Out of bounds index - assert populated_history.get_entry_by_index(100) is None - - # get_entries_by_timerange - start_time = time.time() - 3600 # 1 hour ago - end_time = time.time() + 3600 # 1 hour from now - entries = populated_history.get_entries_by_timerange(start_time, end_time) - assert len(entries) == 2 - - # get_entries_by_object - cup_entries = populated_history.get_entries_by_object("cup") - assert len(cup_entries) == 1 - assert cup_entries[0].task.task_id == "task1" - - bottle_entries = populated_history.get_entries_by_object("bottle") - assert len(bottle_entries) == 1 - assert bottle_entries[0].task.task_id == "task2" - - -def test_manipulation_history_search_basic(populated_history): - """Test basic search functionality.""" - # Search by exact match on top-level fields - results = populated_history.search(timestamp=populated_history.get_entry_by_index(0).timestamp) - assert len(results) == 1 - - # Search by task fields - results = populated_history.search(**{"task.task_id": "task1"}) - assert len(results) == 1 - assert results[0].task.target_object == "cup" - - # Search by result fields - results = populated_history.search(**{"result.status": "success"}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Search by manipulation_response (substring match for strings) - results = populated_history.search(manipulation_response="picked up") - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - -def test_manipulation_history_search_nested(populated_history): - """Test search with nested field paths.""" - # Search by nested metadata fields - results = populated_history.search( - **{ - "task.metadata.timestamp": populated_history.get_entry_by_index(0).task.metadata[ - "timestamp" - ] - } - ) - assert len(results) == 1 - - # Search by nested object fields - results = populated_history.search(**{"task.metadata.objects.cup1.label": "cup"}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Search by position values - results = populated_history.search(**{"task.metadata.objects.cup1.position.x": 1.5}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - -def test_manipulation_history_search_wildcards(populated_history): - """Test search with wildcard patterns.""" - # Search for any object with label "cup" - results = populated_history.search(**{"task.metadata.objects.*.label": "cup"}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Search for any object with confidence > 0.95 - results = populated_history.search(**{"task.metadata.objects.*.confidence": 0.98}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Search for any object position with x=2.5 - results = populated_history.search(**{"task.metadata.objects.*.position.x": 2.5}) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - -def test_manipulation_history_search_constraints(populated_history): - """Test search by constraint properties.""" - # Find entries with any TranslationConstraint with y-axis - results = populated_history.search(**{"task.constraints.*.translation_axis": "y"}) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - # Find entries with any RotationConstraint with roll axis - results = populated_history.search(**{"task.constraints.*.rotation_axis": "roll"}) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - -def test_manipulation_history_search_string_contains(populated_history): - """Test string contains searching.""" - # Basic string contains - results = populated_history.search(**{"task.description": "Pick"}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Nested string contains - results = populated_history.search(manipulation_response="collision") - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - -def test_manipulation_history_search_multiple_criteria(populated_history): - """Test search with multiple criteria.""" - # Multiple criteria - all must match - results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Multiple criteria with no matches - results = populated_history.search(**{"task.target_object": "cup", "result.status": "failure"}) - assert len(results) == 0 - - # Combination of direct and wildcard paths - results = populated_history.search( - **{"task.target_object": "bottle", "task.metadata.objects.*.position.z": 0.3} - ) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - -def test_manipulation_history_search_nonexistent_fields(populated_history): - """Test search with fields that don't exist.""" - # Search by nonexistent field - results = populated_history.search(nonexistent_field="value") - assert len(results) == 0 - - # Search by nonexistent nested field - results = populated_history.search(**{"task.nonexistent_field": "value"}) - assert len(results) == 0 - - # Search by nonexistent object - results = populated_history.search(**{"task.metadata.objects.nonexistent_object": "value"}) - assert len(results) == 0 - - -def test_manipulation_history_search_timestamp_ranges(populated_history): - """Test searching by timestamp ranges.""" - # Get reference timestamps - entry1_time = populated_history.get_entry_by_index(0).task.metadata["timestamp"] - entry2_time = populated_history.get_entry_by_index(1).task.metadata["timestamp"] - mid_time = (entry1_time + entry2_time) / 2 - - # Search for timestamps before second entry - results = populated_history.search(**{"task.metadata.timestamp": ("<", entry2_time)}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Search for timestamps after first entry - results = populated_history.search(**{"task.metadata.timestamp": (">", entry1_time)}) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - # Search within a time window using >= and <= - results = populated_history.search(**{"task.metadata.timestamp": (">=", mid_time - 1800)}) - assert len(results) == 2 - assert results[0].task.task_id == "task1" - assert results[1].task.task_id == "task2" - - -def test_manipulation_history_search_vector_fields(populated_history): - """Test searching by vector components in constraints.""" - # Search by reference point components - results = populated_history.search(**{"task.constraints.*.reference_point.x": 2.5}) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - # Search by target point components - results = populated_history.search(**{"task.constraints.*.target_point.z": 0.3}) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - # Search by rotation angles - results = populated_history.search(**{"task.constraints.*.end_angle.x": 90}) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - -def test_manipulation_history_search_execution_details(populated_history): - """Test searching by execution time and error patterns.""" - # Search by execution time - results = populated_history.search(**{"result.execution_time": 2.5}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Search by error message pattern - results = populated_history.search(**{"result.error": "Collision"}) - assert len(results) == 1 - assert results[0].task.task_id == "task2" - - # Search by status - results = populated_history.search(**{"result.status": "success"}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - -def test_manipulation_history_search_multiple_criteria(populated_history): - """Test search with multiple criteria.""" - # Multiple criteria - all must match - results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) - assert len(results) == 1 - assert results[0].task.task_id == "task1" - - # Multiple criteria with no matches - results = populated_history.search(**{"task.target_object": "cup", "result.status": "failure"}) - assert len(results) == 0 - - # Combination of direct and wildcard paths - results = populated_history.search( - **{"task.target_object": "bottle", "task.metadata.objects.*.position.z": 0.3} - ) - assert len(results) == 1 - assert results[0].task.task_id == "task2" diff --git a/build/lib/dimos/models/__init__.py b/build/lib/dimos/models/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/models/depth/__init__.py b/build/lib/dimos/models/depth/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/models/depth/metric3d.py b/build/lib/dimos/models/depth/metric3d.py deleted file mode 100644 index 58cb63f640..0000000000 --- a/build/lib/dimos/models/depth/metric3d.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from PIL import Image -import cv2 -import numpy as np - -# May need to add this back for import to work -# external_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'external', 'Metric3D')) -# if external_path not in sys.path: -# sys.path.append(external_path) - - -class Metric3D: - def __init__(self, gt_depth_scale=256.0): - # self.conf = get_config("zoedepth", "infer") - # self.depth_model = build_model(self.conf) - self.depth_model = torch.hub.load( - "yvanyin/metric3d", "metric3d_vit_small", pretrain=True - ).cuda() - if torch.cuda.device_count() > 1: - print(f"Using {torch.cuda.device_count()} GPUs!") - # self.depth_model = torch.nn.DataParallel(self.depth_model) - self.depth_model.eval() - - self.intrinsic = [707.0493, 707.0493, 604.0814, 180.5066] - self.intrinsic_scaled = None - self.gt_depth_scale = gt_depth_scale # And this - self.pad_info = None - self.rgb_origin = None - - """ - Input: Single image in RGB format - Output: Depth map - """ - - def update_intrinsic(self, intrinsic): - """ - Update the intrinsic parameters dynamically. - Ensure that the input intrinsic is valid. - """ - if len(intrinsic) != 4: - raise ValueError("Intrinsic must be a list or tuple with 4 values: [fx, fy, cx, cy]") - self.intrinsic = intrinsic - print(f"Intrinsics updated to: {self.intrinsic}") - - def infer_depth(self, img, debug=False): - if debug: - print(f"Input image: {img}") - try: - if isinstance(img, str): - print(f"Image type string: {type(img)}") - self.rgb_origin = cv2.imread(img)[:, :, ::-1] - else: - # print(f"Image type not string: {type(img)}, cv2 conversion assumed to be handled. If not, this will throw an error") - self.rgb_origin = img - except Exception as e: - print(f"Error parsing into infer_depth: {e}") - - img = self.rescale_input(img, self.rgb_origin) - - with torch.no_grad(): - pred_depth, confidence, output_dict = self.depth_model.inference({"input": img}) - - # Convert to PIL format - depth_image = self.unpad_transform_depth(pred_depth) - out_16bit_numpy = (depth_image.squeeze().cpu().numpy() * self.gt_depth_scale).astype( - np.uint16 - ) - depth_map_pil = Image.fromarray(out_16bit_numpy) - - return depth_map_pil - - def save_depth(self, pred_depth): - # Save the depth map to a file - pred_depth_np = pred_depth.cpu().numpy() - output_depth_file = "output_depth_map.png" - cv2.imwrite(output_depth_file, pred_depth_np) - print(f"Depth map saved to {output_depth_file}") - - # Adjusts input size to fit pretrained ViT model - def rescale_input(self, rgb, rgb_origin): - #### ajust input size to fit pretrained model - # keep ratio resize - input_size = (616, 1064) # for vit model - # input_size = (544, 1216) # for convnext model - h, w = rgb_origin.shape[:2] - scale = min(input_size[0] / h, input_size[1] / w) - rgb = cv2.resize( - rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR - ) - # remember to scale intrinsic, hold depth - self.intrinsic_scaled = [ - self.intrinsic[0] * scale, - self.intrinsic[1] * scale, - self.intrinsic[2] * scale, - self.intrinsic[3] * scale, - ] - # padding to input_size - padding = [123.675, 116.28, 103.53] - h, w = rgb.shape[:2] - pad_h = input_size[0] - h - pad_w = input_size[1] - w - pad_h_half = pad_h // 2 - pad_w_half = pad_w // 2 - rgb = cv2.copyMakeBorder( - rgb, - pad_h_half, - pad_h - pad_h_half, - pad_w_half, - pad_w - pad_w_half, - cv2.BORDER_CONSTANT, - value=padding, - ) - self.pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] - - #### normalize - mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] - std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None] - rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float() - rgb = torch.div((rgb - mean), std) - rgb = rgb[None, :, :, :].cuda() - return rgb - - def unpad_transform_depth(self, pred_depth): - # un pad - pred_depth = pred_depth.squeeze() - pred_depth = pred_depth[ - self.pad_info[0] : pred_depth.shape[0] - self.pad_info[1], - self.pad_info[2] : pred_depth.shape[1] - self.pad_info[3], - ] - - # upsample to original size - pred_depth = torch.nn.functional.interpolate( - pred_depth[None, None, :, :], self.rgb_origin.shape[:2], mode="bilinear" - ).squeeze() - ###################### canonical camera space ###################### - - #### de-canonical transform - canonical_to_real_scale = ( - self.intrinsic_scaled[0] / 1000.0 - ) # 1000.0 is the focal length of canonical camera - pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric - pred_depth = torch.clamp(pred_depth, 0, 1000) - return pred_depth - - """Set new intrinsic value.""" - - def update_intrinsic(self, intrinsic): - self.intrinsic = intrinsic - - def eval_predicted_depth(self, depth_file, pred_depth): - if depth_file is not None: - gt_depth = cv2.imread(depth_file, -1) - gt_depth = gt_depth / self.gt_depth_scale - gt_depth = torch.from_numpy(gt_depth).float().cuda() - assert gt_depth.shape == pred_depth.shape - - mask = gt_depth > 1e-8 - abs_rel_err = (torch.abs(pred_depth[mask] - gt_depth[mask]) / gt_depth[mask]).mean() - print("abs_rel_err:", abs_rel_err.item()) diff --git a/build/lib/dimos/models/labels/__init__.py b/build/lib/dimos/models/labels/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/models/labels/llava-34b.py b/build/lib/dimos/models/labels/llava-34b.py deleted file mode 100644 index c59a5c8aa9..0000000000 --- a/build/lib/dimos/models/labels/llava-34b.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os - -# llava v1.6 -from llama_cpp import Llama -from llama_cpp.llama_chat_format import Llava15ChatHandler - -from vqasynth.datasets.utils import image_to_base64_data_uri - - -class Llava: - def __init__( - self, - mmproj=f"{os.getcwd()}/models/mmproj-model-f16.gguf", - model_path=f"{os.getcwd()}/models/llava-v1.6-34b.Q4_K_M.gguf", - gpu=True, - ): - chat_handler = Llava15ChatHandler(clip_model_path=mmproj, verbose=True) - n_gpu_layers = 0 - if gpu: - n_gpu_layers = -1 - self.llm = Llama( - model_path=model_path, - chat_handler=chat_handler, - n_ctx=2048, - logits_all=True, - n_gpu_layers=n_gpu_layers, - ) - - def run_inference(self, image, prompt, return_json=True): - data_uri = image_to_base64_data_uri(image) - res = self.llm.create_chat_completion( - messages=[ - { - "role": "system", - "content": "You are an assistant who perfectly describes images.", - }, - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": data_uri}}, - {"type": "text", "text": prompt}, - ], - }, - ] - ) - if return_json: - return list( - set( - self.extract_descriptions_from_incomplete_json( - res["choices"][0]["message"]["content"] - ) - ) - ) - - return res["choices"][0]["message"]["content"] - - def extract_descriptions_from_incomplete_json(self, json_like_str): - last_object_idx = json_like_str.rfind(',"object') - - if last_object_idx != -1: - json_str = json_like_str[:last_object_idx] + "}" - else: - json_str = json_like_str.strip() - if not json_str.endswith("}"): - json_str += "}" - - try: - json_obj = json.loads(json_str) - descriptions = [ - details["description"].replace(".", "") - for key, details in json_obj.items() - if "description" in details - ] - - return descriptions - except json.JSONDecodeError as e: - raise ValueError(f"Error parsing JSON: {e}") diff --git a/build/lib/dimos/models/manipulation/__init__.py b/build/lib/dimos/models/manipulation/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/models/pointcloud/__init__.py b/build/lib/dimos/models/pointcloud/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/models/pointcloud/pointcloud_utils.py b/build/lib/dimos/models/pointcloud/pointcloud_utils.py deleted file mode 100644 index c0951f44f2..0000000000 --- a/build/lib/dimos/models/pointcloud/pointcloud_utils.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import open3d as o3d -import random - - -def save_pointcloud(pcd, file_path): - """ - Save a point cloud to a file using Open3D. - """ - o3d.io.write_point_cloud(file_path, pcd) - - -def restore_pointclouds(pointcloud_paths): - restored_pointclouds = [] - for path in pointcloud_paths: - restored_pointclouds.append(o3d.io.read_point_cloud(path)) - return restored_pointclouds - - -def create_point_cloud_from_rgbd(rgb_image, depth_image, intrinsic_parameters): - rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth( - o3d.geometry.Image(rgb_image), - o3d.geometry.Image(depth_image), - depth_scale=0.125, # 1000.0, - depth_trunc=10.0, # 10.0, - convert_rgb_to_intensity=False, - ) - intrinsic = o3d.camera.PinholeCameraIntrinsic() - intrinsic.set_intrinsics( - intrinsic_parameters["width"], - intrinsic_parameters["height"], - intrinsic_parameters["fx"], - intrinsic_parameters["fy"], - intrinsic_parameters["cx"], - intrinsic_parameters["cy"], - ) - pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic) - return pcd - - -def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3): - # Segment the largest plane, assumed to be the floor - plane_model, inliers = pcd.segment_plane( - distance_threshold=0.01, ransac_n=3, num_iterations=1000 - ) - - canonicalized = False - if len(inliers) / len(pcd.points) > canonicalize_threshold: - canonicalized = True - - # Ensure the plane normal points upwards - if np.dot(plane_model[:3], [0, 1, 0]) < 0: - plane_model = -plane_model - - # Normalize the plane normal vector - normal = plane_model[:3] / np.linalg.norm(plane_model[:3]) - - # Compute the new basis vectors - new_y = normal - new_x = np.cross(new_y, [0, 0, -1]) - new_x /= np.linalg.norm(new_x) - new_z = np.cross(new_x, new_y) - - # Create the transformation matrix - transformation = np.identity(4) - transformation[:3, :3] = np.vstack((new_x, new_y, new_z)).T - transformation[:3, 3] = -np.dot(transformation[:3, :3], pcd.points[inliers[0]]) - - # Apply the transformation - pcd.transform(transformation) - - # Additional 180-degree rotation around the Z-axis - rotation_z_180 = np.array( - [[np.cos(np.pi), -np.sin(np.pi), 0], [np.sin(np.pi), np.cos(np.pi), 0], [0, 0, 1]] - ) - pcd.rotate(rotation_z_180, center=(0, 0, 0)) - - return pcd, canonicalized, transformation - else: - return pcd, canonicalized, None - - -# Distance calculations -def human_like_distance(distance_meters): - # Define the choices with units included, focusing on the 0.1 to 10 meters range - if distance_meters < 1: # For distances less than 1 meter - choices = [ - ( - round(distance_meters * 100, 2), - "centimeters", - 0.2, - ), # Centimeters for very small distances - ( - round(distance_meters * 39.3701, 2), - "inches", - 0.8, - ), # Inches for the majority of cases under 1 meter - ] - elif distance_meters < 3: # For distances less than 3 meters - choices = [ - (round(distance_meters, 2), "meters", 0.5), - ( - round(distance_meters * 3.28084, 2), - "feet", - 0.5, - ), # Feet as a common unit within indoor spaces - ] - else: # For distances from 3 up to 10 meters - choices = [ - ( - round(distance_meters, 2), - "meters", - 0.7, - ), # Meters for clarity and international understanding - ( - round(distance_meters * 3.28084, 2), - "feet", - 0.3, - ), # Feet for additional context - ] - - # Normalize probabilities and make a selection - total_probability = sum(prob for _, _, prob in choices) - cumulative_distribution = [] - cumulative_sum = 0 - for value, unit, probability in choices: - cumulative_sum += probability / total_probability # Normalize probabilities - cumulative_distribution.append((cumulative_sum, value, unit)) - - # Randomly choose based on the cumulative distribution - r = random.random() - for cumulative_prob, value, unit in cumulative_distribution: - if r < cumulative_prob: - return f"{value} {unit}" - - # Fallback to the last choice if something goes wrong - return f"{choices[-1][0]} {choices[-1][1]}" - - -def calculate_distances_between_point_clouds(A, B): - dist_pcd1_to_pcd2 = np.asarray(A.compute_point_cloud_distance(B)) - dist_pcd2_to_pcd1 = np.asarray(B.compute_point_cloud_distance(A)) - combined_distances = np.concatenate((dist_pcd1_to_pcd2, dist_pcd2_to_pcd1)) - avg_dist = np.mean(combined_distances) - return human_like_distance(avg_dist) - - -def calculate_centroid(pcd): - """Calculate the centroid of a point cloud.""" - points = np.asarray(pcd.points) - centroid = np.mean(points, axis=0) - return centroid - - -def calculate_relative_positions(centroids): - """Calculate the relative positions between centroids of point clouds.""" - num_centroids = len(centroids) - relative_positions_info = [] - - for i in range(num_centroids): - for j in range(i + 1, num_centroids): - relative_vector = centroids[j] - centroids[i] - - distance = np.linalg.norm(relative_vector) - relative_positions_info.append( - {"pcd_pair": (i, j), "relative_vector": relative_vector, "distance": distance} - ) - - return relative_positions_info - - -def get_bounding_box_height(pcd): - """ - Compute the height of the bounding box for a given point cloud. - - Parameters: - pcd (open3d.geometry.PointCloud): The input point cloud. - - Returns: - float: The height of the bounding box. - """ - aabb = pcd.get_axis_aligned_bounding_box() - return aabb.get_extent()[1] # Assuming the Y-axis is the up-direction - - -def compare_bounding_box_height(pcd_i, pcd_j): - """ - Compare the bounding box heights of two point clouds. - - Parameters: - pcd_i (open3d.geometry.PointCloud): The first point cloud. - pcd_j (open3d.geometry.PointCloud): The second point cloud. - - Returns: - bool: True if the bounding box of pcd_i is taller than that of pcd_j, False otherwise. - """ - height_i = get_bounding_box_height(pcd_i) - height_j = get_bounding_box_height(pcd_j) - - return height_i > height_j diff --git a/build/lib/dimos/models/segmentation/__init__.py b/build/lib/dimos/models/segmentation/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/models/segmentation/clipseg.py b/build/lib/dimos/models/segmentation/clipseg.py deleted file mode 100644 index 043cd194b0..0000000000 --- a/build/lib/dimos/models/segmentation/clipseg.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import AutoProcessor, CLIPSegForImageSegmentation - - -class CLIPSeg: - def __init__(self, model_name="CIDAS/clipseg-rd64-refined"): - self.clipseg_processor = AutoProcessor.from_pretrained(model_name) - self.clipseg_model = CLIPSegForImageSegmentation.from_pretrained(model_name) - - def run_inference(self, image, text_descriptions): - inputs = self.clipseg_processor( - text=text_descriptions, - images=[image] * len(text_descriptions), - padding=True, - return_tensors="pt", - ) - outputs = self.clipseg_model(**inputs) - logits = outputs.logits - return logits.detach().unsqueeze(1) diff --git a/build/lib/dimos/models/segmentation/sam.py b/build/lib/dimos/models/segmentation/sam.py deleted file mode 100644 index 1efb07c484..0000000000 --- a/build/lib/dimos/models/segmentation/sam.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import SamModel, SamProcessor -import torch - - -class SAM: - def __init__(self, model_name="facebook/sam-vit-huge", device="cuda"): - self.device = device - self.sam_model = SamModel.from_pretrained(model_name).to(self.device) - self.sam_processor = SamProcessor.from_pretrained(model_name) - - def run_inference_from_points(self, image, points): - sam_inputs = self.sam_processor(image, input_points=points, return_tensors="pt").to( - self.device - ) - with torch.no_grad(): - sam_outputs = self.sam_model(**sam_inputs) - return self.sam_processor.image_processor.post_process_masks( - sam_outputs.pred_masks.cpu(), - sam_inputs["original_sizes"].cpu(), - sam_inputs["reshaped_input_sizes"].cpu(), - ) diff --git a/build/lib/dimos/models/segmentation/segment_utils.py b/build/lib/dimos/models/segmentation/segment_utils.py deleted file mode 100644 index 9808f5d4e4..0000000000 --- a/build/lib/dimos/models/segmentation/segment_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import numpy as np - - -def find_medoid_and_closest_points(points, num_closest=5): - """ - Find the medoid from a collection of points and the closest points to the medoid. - - Parameters: - points (np.array): A numpy array of shape (N, D) where N is the number of points and D is the dimensionality. - num_closest (int): Number of closest points to return. - - Returns: - np.array: The medoid point. - np.array: The closest points to the medoid. - """ - distances = np.sqrt(((points[:, np.newaxis, :] - points[np.newaxis, :, :]) ** 2).sum(axis=-1)) - distance_sums = distances.sum(axis=1) - medoid_idx = np.argmin(distance_sums) - medoid = points[medoid_idx] - sorted_indices = np.argsort(distances[medoid_idx]) - closest_indices = sorted_indices[1 : num_closest + 1] - return medoid, points[closest_indices] - - -def sample_points_from_heatmap(heatmap, original_size, num_points=5, percentile=0.95): - """ - Sample points from the given heatmap, focusing on areas with higher values. - """ - width, height = original_size - threshold = np.percentile(heatmap.numpy(), percentile) - masked_heatmap = torch.where(heatmap > threshold, heatmap, torch.tensor(0.0)) - probabilities = torch.softmax(masked_heatmap.flatten(), dim=0) - - attn = torch.sigmoid(heatmap) - w = attn.shape[0] - sampled_indices = torch.multinomial( - torch.tensor(probabilities.ravel()), num_points, replacement=True - ) - - sampled_coords = np.array(np.unravel_index(sampled_indices, attn.shape)).T - medoid, sampled_coords = find_medoid_and_closest_points(sampled_coords) - pts = [] - for pt in sampled_coords.tolist(): - x, y = pt - x = height * x / w - y = width * y / w - pts.append([y, x]) - return pts - - -def apply_mask_to_image(image, mask): - """ - Apply a binary mask to an image. The mask should be a binary array where the regions to keep are True. - """ - masked_image = image.copy() - for c in range(masked_image.shape[2]): - masked_image[:, :, c] = masked_image[:, :, c] * mask - return masked_image diff --git a/build/lib/dimos/msgs/__init__.py b/build/lib/dimos/msgs/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/msgs/geometry_msgs/Pose.py b/build/lib/dimos/msgs/geometry_msgs/Pose.py deleted file mode 100644 index 74b534fefa..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/Pose.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import struct -import traceback -from io import BytesIO -from typing import BinaryIO, TypeAlias - -from dimos_lcm.geometry_msgs import Pose as LCMPose -from plum import dispatch - -from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable -from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable - -# Types that can be converted to/from Pose -PoseConvertable: TypeAlias = ( - tuple[VectorConvertable, QuaternionConvertable] - | LCMPose - | dict[str, VectorConvertable | QuaternionConvertable] -) - - -class Pose(LCMPose): - position: Vector3 - orientation: Quaternion - msg_name = "geometry_msgs.Pose" - - @classmethod - def lcm_decode(cls, data: bytes | BinaryIO): - if not hasattr(data, "read"): - data = BytesIO(data) - if data.read(8) != cls._get_packed_fingerprint(): - traceback.print_exc() - raise ValueError("Decode error") - return cls._lcm_decode_one(data) - - @classmethod - def _lcm_decode_one(cls, buf): - return cls(Vector3._decode_one(buf), Quaternion._decode_one(buf)) - - def lcm_encode(self) -> bytes: - return super().encode() - - @dispatch - def __init__(self) -> None: - """Initialize a pose at origin with identity orientation.""" - self.position = Vector3(0.0, 0.0, 0.0) - self.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) - - @dispatch - def __init__(self, x: int | float, y: int | float, z: int | float) -> None: - """Initialize a pose with position and identity orientation.""" - self.position = Vector3(x, y, z) - self.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) - - @dispatch - def __init__( - self, - x: int | float, - y: int | float, - z: int | float, - qx: int | float, - qy: int | float, - qz: int | float, - qw: int | float, - ) -> None: - """Initialize a pose with position and orientation.""" - self.position = Vector3(x, y, z) - self.orientation = Quaternion(qx, qy, qz, qw) - - @dispatch - def __init__( - self, - position: VectorConvertable | Vector3 = [0, 0, 0], - orientation: QuaternionConvertable | Quaternion = [0, 0, 0, 1], - ) -> None: - """Initialize a pose with position and orientation.""" - self.position = Vector3(position) - self.orientation = Quaternion(orientation) - - @dispatch - def __init__(self, pose_tuple: tuple[VectorConvertable, QuaternionConvertable]) -> None: - """Initialize from a tuple of (position, orientation).""" - self.position = Vector3(pose_tuple[0]) - self.orientation = Quaternion(pose_tuple[1]) - - @dispatch - def __init__(self, pose_dict: dict[str, VectorConvertable | QuaternionConvertable]) -> None: - """Initialize from a dictionary with 'position' and 'orientation' keys.""" - self.position = Vector3(pose_dict["position"]) - self.orientation = Quaternion(pose_dict["orientation"]) - - @dispatch - def __init__(self, pose: Pose) -> None: - """Initialize from another Pose (copy constructor).""" - self.position = Vector3(pose.position) - self.orientation = Quaternion(pose.orientation) - - @dispatch - def __init__(self, lcm_pose: LCMPose) -> None: - """Initialize from an LCM Pose.""" - self.position = Vector3(lcm_pose.position.x, lcm_pose.position.y, lcm_pose.position.z) - self.orientation = Quaternion( - lcm_pose.orientation.x, - lcm_pose.orientation.y, - lcm_pose.orientation.z, - lcm_pose.orientation.w, - ) - - @property - def x(self) -> float: - """X coordinate of position.""" - return self.position.x - - @property - def y(self) -> float: - """Y coordinate of position.""" - return self.position.y - - @property - def z(self) -> float: - """Z coordinate of position.""" - return self.position.z - - @property - def roll(self) -> float: - """Roll angle in radians.""" - return self.orientation.to_euler().roll - - @property - def pitch(self) -> float: - """Pitch angle in radians.""" - return self.orientation.to_euler().pitch - - @property - def yaw(self) -> float: - """Yaw angle in radians.""" - return self.orientation.to_euler().yaw - - def __repr__(self) -> str: - return f"Pose(position={self.position!r}, orientation={self.orientation!r})" - - def __str__(self) -> str: - return ( - f"Pose(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " - f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}])" - ) - - def __eq__(self, other) -> bool: - """Check if two poses are equal.""" - if not isinstance(other, Pose): - return False - return self.position == other.position and self.orientation == other.orientation - - -@dispatch -def to_pose(value: "Pose") -> Pose: - """Pass through Pose objects.""" - return value - - -@dispatch -def to_pose(value: PoseConvertable | Pose) -> Pose: - """Convert a pose-compatible value to a Pose object.""" - return Pose(value) - - -PoseLike: TypeAlias = PoseConvertable | Pose diff --git a/build/lib/dimos/msgs/geometry_msgs/PoseStamped.py b/build/lib/dimos/msgs/geometry_msgs/PoseStamped.py deleted file mode 100644 index 3871072d32..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/PoseStamped.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import struct -import time -from io import BytesIO -from typing import BinaryIO, TypeAlias - -from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped -from dimos_lcm.std_msgs import Header as LCMHeader -from dimos_lcm.std_msgs import Time as LCMTime -from plum import dispatch - -from dimos.msgs.geometry_msgs.Pose import Pose -from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable -from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable -from dimos.types.timestamped import Timestamped - -# Types that can be converted to/from Pose -PoseConvertable: TypeAlias = ( - tuple[VectorConvertable, QuaternionConvertable] - | LCMPoseStamped - | dict[str, VectorConvertable | QuaternionConvertable] -) - - -def sec_nsec(ts): - s = int(ts) - return [s, int((ts - s) * 1_000_000_000)] - - -class PoseStamped(Pose, Timestamped): - msg_name = "geometry_msgs.PoseStamped" - ts: float - frame_id: str - - @dispatch - def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: - self.frame_id = frame_id - self.ts = ts if ts != 0 else time.time() - super().__init__(**kwargs) - - def lcm_encode(self) -> bytes: - lcm_mgs = LCMPoseStamped() - lcm_mgs.pose = self - [lcm_mgs.header.stamp.sec, lcm_mgs.header.stamp.sec] = sec_nsec(self.ts) - lcm_mgs.header.frame_id = self.frame_id - return lcm_mgs.encode() - - @classmethod - def lcm_decode(cls, data: bytes | BinaryIO) -> PoseStamped: - lcm_msg = LCMPoseStamped.decode(data) - return cls( - ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), - frame_id=lcm_msg.header.frame_id, - position=[lcm_msg.pose.position.x, lcm_msg.pose.position.y, lcm_msg.pose.position.z], - orientation=[ - lcm_msg.pose.orientation.x, - lcm_msg.pose.orientation.y, - lcm_msg.pose.orientation.z, - lcm_msg.pose.orientation.w, - ], # noqa: E501, - ) diff --git a/build/lib/dimos/msgs/geometry_msgs/Quaternion.py b/build/lib/dimos/msgs/geometry_msgs/Quaternion.py deleted file mode 100644 index ccb3328510..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/Quaternion.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import struct -from collections.abc import Sequence -from io import BytesIO -from typing import BinaryIO, TypeAlias - -import numpy as np -from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion -from plum import dispatch - -from dimos.msgs.geometry_msgs.Vector3 import Vector3 - -# Types that can be converted to/from Quaternion -QuaternionConvertable: TypeAlias = Sequence[int | float] | LCMQuaternion | np.ndarray - - -class Quaternion(LCMQuaternion): - x: float = 0.0 - y: float = 0.0 - z: float = 0.0 - w: float = 1.0 - msg_name = "geometry_msgs.Quaternion" - - @classmethod - def lcm_decode(cls, data: bytes | BinaryIO): - if not hasattr(data, "read"): - data = BytesIO(data) - if data.read(8) != cls._get_packed_fingerprint(): - raise ValueError("Decode error") - return cls._lcm_decode_one(data) - - @classmethod - def _lcm_decode_one(cls, buf): - return cls(struct.unpack(">dddd", buf.read(32))) - - def lcm_encode(self): - return super().encode() - - @dispatch - def __init__(self) -> None: ... - - @dispatch - def __init__(self, x: int | float, y: int | float, z: int | float, w: int | float) -> None: - self.x = float(x) - self.y = float(y) - self.z = float(z) - self.w = float(w) - - @dispatch - def __init__(self, sequence: Sequence[int | float] | np.ndarray) -> None: - if isinstance(sequence, np.ndarray): - if sequence.size != 4: - raise ValueError("Quaternion requires exactly 4 components [x, y, z, w]") - else: - if len(sequence) != 4: - raise ValueError("Quaternion requires exactly 4 components [x, y, z, w]") - - self.x = sequence[0] - self.y = sequence[1] - self.z = sequence[2] - self.w = sequence[3] - - @dispatch - def __init__(self, quaternion: "Quaternion") -> None: - """Initialize from another Quaternion (copy constructor).""" - self.x, self.y, self.z, self.w = quaternion.x, quaternion.y, quaternion.z, quaternion.w - - @dispatch - def __init__(self, lcm_quaternion: LCMQuaternion) -> None: - """Initialize from an LCM Quaternion.""" - self.x, self.y, self.z, self.w = ( - lcm_quaternion.x, - lcm_quaternion.y, - lcm_quaternion.z, - lcm_quaternion.w, - ) - - def to_tuple(self) -> tuple[float, float, float, float]: - """Tuple representation of the quaternion (x, y, z, w).""" - return (self.x, self.y, self.z, self.w) - - def to_list(self) -> list[float]: - """List representation of the quaternion (x, y, z, w).""" - return [self.x, self.y, self.z, self.w] - - def to_numpy(self) -> np.ndarray: - """Numpy array representation of the quaternion (x, y, z, w).""" - return np.array([self.x, self.y, self.z, self.w]) - - @property - def euler(self) -> Vector3: - return self.to_euler() - - @property - def radians(self) -> Vector3: - return self.to_euler() - - def to_radians(self) -> Vector3: - """Radians representation of the quaternion (x, y, z, w).""" - return self.to_euler() - - def to_euler(self) -> Vector3: - """Convert quaternion to Euler angles (roll, pitch, yaw) in radians. - - Returns: - Vector3: Euler angles as (roll, pitch, yaw) in radians - """ - # Convert quaternion to Euler angles using ZYX convention (yaw, pitch, roll) - # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles - - # Roll (x-axis rotation) - sinr_cosp = 2 * (self.w * self.x + self.y * self.z) - cosr_cosp = 1 - 2 * (self.x * self.x + self.y * self.y) - roll = np.arctan2(sinr_cosp, cosr_cosp) - - # Pitch (y-axis rotation) - sinp = 2 * (self.w * self.y - self.z * self.x) - if abs(sinp) >= 1: - pitch = np.copysign(np.pi / 2, sinp) # Use 90 degrees if out of range - else: - pitch = np.arcsin(sinp) - - # Yaw (z-axis rotation) - siny_cosp = 2 * (self.w * self.z + self.x * self.y) - cosy_cosp = 1 - 2 * (self.y * self.y + self.z * self.z) - yaw = np.arctan2(siny_cosp, cosy_cosp) - - return Vector3(roll, pitch, yaw) - - def __getitem__(self, idx: int) -> float: - """Allow indexing into quaternion components: 0=x, 1=y, 2=z, 3=w.""" - if idx == 0: - return self.x - elif idx == 1: - return self.y - elif idx == 2: - return self.z - elif idx == 3: - return self.w - else: - raise IndexError(f"Quaternion index {idx} out of range [0-3]") - - def __repr__(self) -> str: - return f"Quaternion({self.x:.6f}, {self.y:.6f}, {self.z:.6f}, {self.w:.6f})" - - def __str__(self) -> str: - return self.__repr__() - - def __eq__(self, other) -> bool: - if not isinstance(other, Quaternion): - return False - return self.x == other.x and self.y == other.y and self.z == other.z and self.w == other.w diff --git a/build/lib/dimos/msgs/geometry_msgs/Twist.py b/build/lib/dimos/msgs/geometry_msgs/Twist.py deleted file mode 100644 index 581c1d2e5f..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/Twist.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""LCM type definitions -This file automatically generated by lcm. -DO NOT MODIFY BY HAND!!!! -""" - - -from io import BytesIO -import struct - -from . import * -from .Vector3 import Vector3 -class Twist(object): - - __slots__ = ["linear", "angular"] - - __typenames__ = ["Vector3", "Vector3"] - - __dimensions__ = [None, None] - - def __init__(self): - self.linear = Vector3() - """ LCM Type: Vector3 """ - self.angular = Vector3() - """ LCM Type: Vector3 """ - - def encode(self): - buf = BytesIO() - buf.write(Twist._get_packed_fingerprint()) - self._encode_one(buf) - return buf.getvalue() - - def _encode_one(self, buf): - assert self.linear._get_packed_fingerprint() == Vector3._get_packed_fingerprint() - self.linear._encode_one(buf) - assert self.angular._get_packed_fingerprint() == Vector3._get_packed_fingerprint() - self.angular._encode_one(buf) - - @classmethod - def decode(cls, data: bytes): - if hasattr(data, 'read'): - buf = data - else: - buf = BytesIO(data) - if buf.read(8) != cls._get_packed_fingerprint(): - raise ValueError("Decode error") - return cls._decode_one(buf) - - @classmethod - def _decode_one(cls, buf): - self = Twist() - self.linear = Vector3._decode_one(buf) - self.angular = Vector3._decode_one(buf) - return self - - @classmethod - def _get_hash_recursive(cls, parents): - if cls in parents: return 0 - newparents = parents + [cls] - tmphash = (0x3a4144772922add7+ Vector3._get_hash_recursive(newparents)+ Vector3._get_hash_recursive(newparents)) & 0xffffffffffffffff - tmphash = (((tmphash<<1)&0xffffffffffffffff) + (tmphash>>63)) & 0xffffffffffffffff - return tmphash - _packed_fingerprint = None - - @classmethod - def _get_packed_fingerprint(cls): - if cls._packed_fingerprint is None: - cls._packed_fingerprint = struct.pack(">Q", cls._get_hash_recursive([])) - return cls._packed_fingerprint - - def get_hash(self): - """Get the LCM hash of the struct""" - return struct.unpack(">Q", cls._get_packed_fingerprint())[0] - diff --git a/build/lib/dimos/msgs/geometry_msgs/Vector3.py b/build/lib/dimos/msgs/geometry_msgs/Vector3.py deleted file mode 100644 index 7f839f2773..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/Vector3.py +++ /dev/null @@ -1,467 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import struct -from collections.abc import Sequence -from io import BytesIO -from typing import BinaryIO, TypeAlias - -import numpy as np -from dimos_lcm.geometry_msgs import Vector3 as LCMVector3 -from plum import dispatch - -# Types that can be converted to/from Vector -VectorConvertable: TypeAlias = Sequence[int | float] | LCMVector3 | np.ndarray - - -def _ensure_3d(data: np.ndarray) -> np.ndarray: - """Ensure the data array is exactly 3D by padding with zeros or raising an exception if too long.""" - if len(data) == 3: - return data - elif len(data) < 3: - padded = np.zeros(3, dtype=float) - padded[: len(data)] = data - return padded - else: - raise ValueError( - f"Vector3 cannot be initialized with more than 3 components. Got {len(data)} components." - ) - - -class Vector3(LCMVector3): - x: float = 0.0 - y: float = 0.0 - z: float = 0.0 - msg_name = "geometry_msgs.Vector3" - - @classmethod - def lcm_decode(cls, data: bytes | BinaryIO): - if not hasattr(data, "read"): - data = BytesIO(data) - if data.read(8) != cls._get_packed_fingerprint(): - raise ValueError("Decode error") - return cls._lcm_decode_one(data) - - @classmethod - def _lcm_decode_one(cls, buf): - return cls(struct.unpack(">ddd", buf.read(24))) - - def lcm_encode(self) -> bytes: - return super().encode() - - @dispatch - def __init__(self) -> None: - """Initialize a zero 3D vector.""" - self.x = 0.0 - self.y = 0.0 - self.z = 0.0 - - @dispatch - def __init__(self, x: int | float) -> None: - """Initialize a 3D vector from a single numeric value (x, 0, 0).""" - self.x = float(x) - self.y = 0.0 - self.z = 0.0 - - @dispatch - def __init__(self, x: int | float, y: int | float) -> None: - """Initialize a 3D vector from x, y components (z=0).""" - self.x = float(x) - self.y = float(y) - self.z = 0.0 - - @dispatch - def __init__(self, x: int | float, y: int | float, z: int | float) -> None: - """Initialize a 3D vector from x, y, z components.""" - self.x = float(x) - self.y = float(y) - self.z = float(z) - - @dispatch - def __init__(self, sequence: Sequence[int | float]) -> None: - """Initialize from a sequence (list, tuple) of numbers, ensuring 3D.""" - data = _ensure_3d(np.array(sequence, dtype=float)) - self.x = float(data[0]) - self.y = float(data[1]) - self.z = float(data[2]) - - @dispatch - def __init__(self, array: np.ndarray) -> None: - """Initialize from a numpy array, ensuring 3D.""" - data = _ensure_3d(np.array(array, dtype=float)) - self.x = float(data[0]) - self.y = float(data[1]) - self.z = float(data[2]) - - @dispatch - def __init__(self, vector: "Vector3") -> None: - """Initialize from another Vector3 (copy constructor).""" - self.x = vector.x - self.y = vector.y - self.z = vector.z - - @dispatch - def __init__(self, lcm_vector: LCMVector3) -> None: - """Initialize from an LCM Vector3.""" - self.x = float(lcm_vector.x) - self.y = float(lcm_vector.y) - self.z = float(lcm_vector.z) - - @property - def as_tuple(self) -> tuple[float, float, float]: - return (self.x, self.y, self.z) - - @property - def yaw(self) -> float: - return self.z - - @property - def pitch(self) -> float: - return self.y - - @property - def roll(self) -> float: - return self.x - - @property - def data(self) -> np.ndarray: - """Get the underlying numpy array.""" - return np.array([self.x, self.y, self.z], dtype=float) - - def __getitem__(self, idx): - if idx == 0: - return self.x - elif idx == 1: - return self.y - elif idx == 2: - return self.z - else: - raise IndexError(f"Vector3 index {idx} out of range [0-2]") - - def __repr__(self) -> str: - return f"Vector({self.data})" - - def __str__(self) -> str: - def getArrow(): - repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] - - if self.x == 0 and self.y == 0: - return "·" - - # Calculate angle in radians and convert to directional index - angle = np.arctan2(self.y, self.x) - # Map angle to 0-7 index (8 directions) with proper orientation - dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) - # Get directional arrow symbol - return repr[dir_index] - - return f"{getArrow()} Vector {self.__repr__()}" - - def serialize(self) -> dict: - """Serialize the vector to a tuple.""" - return {"type": "vector", "c": (self.x, self.y, self.z)} - - def __eq__(self, other) -> bool: - """Check if two vectors are equal using numpy's allclose for floating point comparison.""" - if not isinstance(other, Vector3): - return False - return np.allclose([self.x, self.y, self.z], [other.x, other.y, other.z]) - - def __add__(self, other: VectorConvertable | Vector3) -> Vector3: - other_vector: Vector3 = to_vector(other) - return self.__class__( - self.x + other_vector.x, self.y + other_vector.y, self.z + other_vector.z - ) - - def __sub__(self, other: VectorConvertable | Vector3) -> Vector3: - other_vector = to_vector(other) - return self.__class__( - self.x - other_vector.x, self.y - other_vector.y, self.z - other_vector.z - ) - - def __mul__(self, scalar: float) -> Vector3: - return self.__class__(self.x * scalar, self.y * scalar, self.z * scalar) - - def __rmul__(self, scalar: float) -> Vector3: - return self.__mul__(scalar) - - def __truediv__(self, scalar: float) -> Vector3: - return self.__class__(self.x / scalar, self.y / scalar, self.z / scalar) - - def __neg__(self) -> Vector3: - return self.__class__(-self.x, -self.y, -self.z) - - def dot(self, other: VectorConvertable | Vector3) -> float: - """Compute dot product.""" - other_vector = to_vector(other) - return self.x * other_vector.x + self.y * other_vector.y + self.z * other_vector.z - - def cross(self, other: VectorConvertable | Vector3) -> Vector3: - """Compute cross product (3D vectors only).""" - other_vector = to_vector(other) - return self.__class__( - self.y * other_vector.z - self.z * other_vector.y, - self.z * other_vector.x - self.x * other_vector.z, - self.x * other_vector.y - self.y * other_vector.x, - ) - - def length(self) -> float: - """Compute the Euclidean length (magnitude) of the vector.""" - return float(np.sqrt(self.x * self.x + self.y * self.y + self.z * self.z)) - - def length_squared(self) -> float: - """Compute the squared length of the vector (faster than length()).""" - return float(self.x * self.x + self.y * self.y + self.z * self.z) - - def normalize(self) -> Vector3: - """Return a normalized unit vector in the same direction.""" - length = self.length() - if length < 1e-10: # Avoid division by near-zero - return self.__class__(0.0, 0.0, 0.0) - return self.__class__(self.x / length, self.y / length, self.z / length) - - def to_2d(self) -> Vector3: - """Convert a vector to a 2D vector by taking only the x and y components (z=0).""" - return self.__class__(self.x, self.y, 0.0) - - def distance(self, other: VectorConvertable | Vector3) -> float: - """Compute Euclidean distance to another vector.""" - other_vector = to_vector(other) - dx = self.x - other_vector.x - dy = self.y - other_vector.y - dz = self.z - other_vector.z - return float(np.sqrt(dx * dx + dy * dy + dz * dz)) - - def distance_squared(self, other: VectorConvertable | Vector3) -> float: - """Compute squared Euclidean distance to another vector (faster than distance()).""" - other_vector = to_vector(other) - dx = self.x - other_vector.x - dy = self.y - other_vector.y - dz = self.z - other_vector.z - return float(dx * dx + dy * dy + dz * dz) - - def angle(self, other: VectorConvertable | Vector3) -> float: - """Compute the angle (in radians) between this vector and another.""" - other_vector = to_vector(other) - this_length = self.length() - other_length = other_vector.length() - - if this_length < 1e-10 or other_length < 1e-10: - return 0.0 - - cos_angle = np.clip( - self.dot(other_vector) / (this_length * other_length), - -1.0, - 1.0, - ) - return float(np.arccos(cos_angle)) - - def project(self, onto: VectorConvertable | Vector3) -> Vector3: - """Project this vector onto another vector.""" - onto_vector = to_vector(onto) - onto_length_sq = ( - onto_vector.x * onto_vector.x - + onto_vector.y * onto_vector.y - + onto_vector.z * onto_vector.z - ) - if onto_length_sq < 1e-10: - return self.__class__(0.0, 0.0, 0.0) - - scalar_projection = self.dot(onto_vector) / onto_length_sq - return self.__class__( - scalar_projection * onto_vector.x, - scalar_projection * onto_vector.y, - scalar_projection * onto_vector.z, - ) - - # this is here to test ros_observable_topic - # doesn't happen irl afaik that we want a vector from ros message - @classmethod - def from_msg(cls, msg) -> Vector3: - return cls(*msg) - - @classmethod - def zeros(cls) -> Vector3: - """Create a zero 3D vector.""" - return cls() - - @classmethod - def ones(cls) -> Vector3: - """Create a 3D vector of ones.""" - return cls(1.0, 1.0, 1.0) - - @classmethod - def unit_x(cls) -> Vector3: - """Create a unit vector in the x direction.""" - return cls(1.0, 0.0, 0.0) - - @classmethod - def unit_y(cls) -> Vector3: - """Create a unit vector in the y direction.""" - return cls(0.0, 1.0, 0.0) - - @classmethod - def unit_z(cls) -> Vector3: - """Create a unit vector in the z direction.""" - return cls(0.0, 0.0, 1.0) - - def to_list(self) -> list[float]: - """Convert the vector to a list.""" - return [self.x, self.y, self.z] - - def to_tuple(self) -> tuple[float, float, float]: - """Convert the vector to a tuple.""" - return (self.x, self.y, self.z) - - def to_numpy(self) -> np.ndarray: - """Convert the vector to a numpy array.""" - return np.array([self.x, self.y, self.z], dtype=float) - - def is_zero(self) -> bool: - """Check if this is a zero vector (all components are zero). - - Returns: - True if all components are zero, False otherwise - """ - return np.allclose([self.x, self.y, self.z], 0.0) - - @property - def quaternion(self): - return self.to_quaternion() - - def to_quaternion(self): - """Convert Vector3 representing Euler angles (roll, pitch, yaw) to a Quaternion. - - Assumes this Vector3 contains Euler angles in radians: - - x component: roll (rotation around x-axis) - - y component: pitch (rotation around y-axis) - - z component: yaw (rotation around z-axis) - - Returns: - Quaternion: The equivalent quaternion representation - """ - # Import here to avoid circular imports - from dimos.msgs.geometry_msgs.Quaternion import Quaternion - - # Extract Euler angles - roll = self.x - pitch = self.y - yaw = self.z - - # Convert Euler angles to quaternion using ZYX convention - # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles - - # Compute half angles - cy = np.cos(yaw * 0.5) - sy = np.sin(yaw * 0.5) - cp = np.cos(pitch * 0.5) - sp = np.sin(pitch * 0.5) - cr = np.cos(roll * 0.5) - sr = np.sin(roll * 0.5) - - # Compute quaternion components - w = cr * cp * cy + sr * sp * sy - x = sr * cp * cy - cr * sp * sy - y = cr * sp * cy + sr * cp * sy - z = cr * cp * sy - sr * sp * cy - - return Quaternion(x, y, z, w) - - def __bool__(self) -> bool: - """Boolean conversion for Vector. - - A Vector is considered False if it's a zero vector (all components are zero), - and True otherwise. - - Returns: - False if vector is zero, True otherwise - """ - return not self.is_zero() - - -@dispatch -def to_numpy(value: "Vector3") -> np.ndarray: - """Convert a Vector3 to a numpy array.""" - return value.to_numpy() - - -@dispatch -def to_numpy(value: np.ndarray) -> np.ndarray: - """Pass through numpy arrays.""" - return value - - -@dispatch -def to_numpy(value: Sequence[int | float]) -> np.ndarray: - """Convert a sequence to a numpy array.""" - return np.array(value, dtype=float) - - -@dispatch -def to_vector(value: "Vector3") -> Vector3: - """Pass through Vector3 objects.""" - return value - - -@dispatch -def to_vector(value: VectorConvertable | Vector3) -> Vector3: - """Convert a vector-compatible value to a Vector3 object.""" - return Vector3(value) - - -@dispatch -def to_tuple(value: Vector3) -> tuple[float, float, float]: - """Convert a Vector3 to a tuple.""" - return value.to_tuple() - - -@dispatch -def to_tuple(value: np.ndarray) -> tuple[float, ...]: - """Convert a numpy array to a tuple.""" - return tuple(value.tolist()) - - -@dispatch -def to_tuple(value: Sequence[int | float]) -> tuple[float, ...]: - """Convert a sequence to a tuple.""" - if isinstance(value, tuple): - return value - else: - return tuple(value) - - -@dispatch -def to_list(value: Vector3) -> list[float]: - """Convert a Vector3 to a list.""" - return value.to_list() - - -@dispatch -def to_list(value: np.ndarray) -> list[float]: - """Convert a numpy array to a list.""" - return value.tolist() - - -@dispatch -def to_list(value: Sequence[int | float]) -> list[float]: - """Convert a sequence to a list.""" - if isinstance(value, list): - return value - else: - return list(value) - - -VectorLike: TypeAlias = VectorConvertable | Vector3 diff --git a/build/lib/dimos/msgs/geometry_msgs/__init__.py b/build/lib/dimos/msgs/geometry_msgs/__init__.py deleted file mode 100644 index 2af44a7ff5..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from dimos.msgs.geometry_msgs.Pose import Pose -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Vector3 import Vector3 diff --git a/build/lib/dimos/msgs/geometry_msgs/test_Pose.py b/build/lib/dimos/msgs/geometry_msgs/test_Pose.py deleted file mode 100644 index 590a17549c..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/test_Pose.py +++ /dev/null @@ -1,555 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle - -import numpy as np -import pytest -from dimos_lcm.geometry_msgs import Pose as LCMPose - -from dimos.msgs.geometry_msgs.Pose import Pose, to_pose -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Vector3 import Vector3 - - -def test_pose_default_init(): - """Test that default initialization creates a pose at origin with identity orientation.""" - pose = Pose() - - # Position should be at origin - assert pose.position.x == 0.0 - assert pose.position.y == 0.0 - assert pose.position.z == 0.0 - - # Orientation should be identity quaternion - assert pose.orientation.x == 0.0 - assert pose.orientation.y == 0.0 - assert pose.orientation.z == 0.0 - assert pose.orientation.w == 1.0 - - # Test convenience properties - assert pose.x == 0.0 - assert pose.y == 0.0 - assert pose.z == 0.0 - - -def test_pose_position_init(): - """Test initialization with position coordinates only (identity orientation).""" - pose = Pose(1.0, 2.0, 3.0) - - # Position should be as specified - assert pose.position.x == 1.0 - assert pose.position.y == 2.0 - assert pose.position.z == 3.0 - - # Orientation should be identity quaternion - assert pose.orientation.x == 0.0 - assert pose.orientation.y == 0.0 - assert pose.orientation.z == 0.0 - assert pose.orientation.w == 1.0 - - # Test convenience properties - assert pose.x == 1.0 - assert pose.y == 2.0 - assert pose.z == 3.0 - - -def test_pose_full_init(): - """Test initialization with position and orientation coordinates.""" - pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) - - # Position should be as specified - assert pose.position.x == 1.0 - assert pose.position.y == 2.0 - assert pose.position.z == 3.0 - - # Orientation should be as specified - assert pose.orientation.x == 0.1 - assert pose.orientation.y == 0.2 - assert pose.orientation.z == 0.3 - assert pose.orientation.w == 0.9 - - # Test convenience properties - assert pose.x == 1.0 - assert pose.y == 2.0 - assert pose.z == 3.0 - - -def test_pose_vector_position_init(): - """Test initialization with Vector3 position (identity orientation).""" - position = Vector3(4.0, 5.0, 6.0) - pose = Pose(position) - - # Position should match the vector - assert pose.position.x == 4.0 - assert pose.position.y == 5.0 - assert pose.position.z == 6.0 - - # Orientation should be identity - assert pose.orientation.x == 0.0 - assert pose.orientation.y == 0.0 - assert pose.orientation.z == 0.0 - assert pose.orientation.w == 1.0 - - -def test_pose_vector_quaternion_init(): - """Test initialization with Vector3 position and Quaternion orientation.""" - position = Vector3(1.0, 2.0, 3.0) - orientation = Quaternion(0.1, 0.2, 0.3, 0.9) - pose = Pose(position, orientation) - - # Position should match the vector - assert pose.position.x == 1.0 - assert pose.position.y == 2.0 - assert pose.position.z == 3.0 - - # Orientation should match the quaternion - assert pose.orientation.x == 0.1 - assert pose.orientation.y == 0.2 - assert pose.orientation.z == 0.3 - assert pose.orientation.w == 0.9 - - -def test_pose_list_init(): - """Test initialization with lists for position and orientation.""" - position_list = [1.0, 2.0, 3.0] - orientation_list = [0.1, 0.2, 0.3, 0.9] - pose = Pose(position_list, orientation_list) - - # Position should match the list - assert pose.position.x == 1.0 - assert pose.position.y == 2.0 - assert pose.position.z == 3.0 - - # Orientation should match the list - assert pose.orientation.x == 0.1 - assert pose.orientation.y == 0.2 - assert pose.orientation.z == 0.3 - assert pose.orientation.w == 0.9 - - -def test_pose_tuple_init(): - """Test initialization from a tuple of (position, orientation).""" - position = [1.0, 2.0, 3.0] - orientation = [0.1, 0.2, 0.3, 0.9] - pose_tuple = (position, orientation) - pose = Pose(pose_tuple) - - # Position should match - assert pose.position.x == 1.0 - assert pose.position.y == 2.0 - assert pose.position.z == 3.0 - - # Orientation should match - assert pose.orientation.x == 0.1 - assert pose.orientation.y == 0.2 - assert pose.orientation.z == 0.3 - assert pose.orientation.w == 0.9 - - -def test_pose_dict_init(): - """Test initialization from a dictionary with 'position' and 'orientation' keys.""" - pose_dict = {"position": [1.0, 2.0, 3.0], "orientation": [0.1, 0.2, 0.3, 0.9]} - pose = Pose(pose_dict) - - # Position should match - assert pose.position.x == 1.0 - assert pose.position.y == 2.0 - assert pose.position.z == 3.0 - - # Orientation should match - assert pose.orientation.x == 0.1 - assert pose.orientation.y == 0.2 - assert pose.orientation.z == 0.3 - assert pose.orientation.w == 0.9 - - -def test_pose_copy_init(): - """Test initialization from another Pose (copy constructor).""" - original = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) - copy = Pose(original) - - # Position should match - assert copy.position.x == 1.0 - assert copy.position.y == 2.0 - assert copy.position.z == 3.0 - - # Orientation should match - assert copy.orientation.x == 0.1 - assert copy.orientation.y == 0.2 - assert copy.orientation.z == 0.3 - assert copy.orientation.w == 0.9 - - # Should be a copy, not the same object - assert copy is not original - assert copy == original - - -def test_pose_lcm_init(): - """Test initialization from an LCM Pose.""" - # Create LCM pose - lcm_pose = LCMPose() - lcm_pose.position.x = 1.0 - lcm_pose.position.y = 2.0 - lcm_pose.position.z = 3.0 - lcm_pose.orientation.x = 0.1 - lcm_pose.orientation.y = 0.2 - lcm_pose.orientation.z = 0.3 - lcm_pose.orientation.w = 0.9 - - pose = Pose(lcm_pose) - - # Position should match - assert pose.position.x == 1.0 - assert pose.position.y == 2.0 - assert pose.position.z == 3.0 - - # Orientation should match - assert pose.orientation.x == 0.1 - assert pose.orientation.y == 0.2 - assert pose.orientation.z == 0.3 - assert pose.orientation.w == 0.9 - - -def test_pose_properties(): - """Test pose property access.""" - pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) - - # Test position properties - assert pose.x == 1.0 - assert pose.y == 2.0 - assert pose.z == 3.0 - - # Test orientation properties (through quaternion's to_euler method) - euler = pose.orientation.to_euler() - assert pose.roll == euler.x - assert pose.pitch == euler.y - assert pose.yaw == euler.z - - -def test_pose_euler_properties_identity(): - """Test pose Euler angle properties with identity orientation.""" - pose = Pose(1.0, 2.0, 3.0) # Identity orientation - - # Identity quaternion should give zero Euler angles - assert np.isclose(pose.roll, 0.0, atol=1e-10) - assert np.isclose(pose.pitch, 0.0, atol=1e-10) - assert np.isclose(pose.yaw, 0.0, atol=1e-10) - - # Euler property should also be zeros - assert np.isclose(pose.orientation.euler.x, 0.0, atol=1e-10) - assert np.isclose(pose.orientation.euler.y, 0.0, atol=1e-10) - assert np.isclose(pose.orientation.euler.z, 0.0, atol=1e-10) - - -def test_pose_repr(): - """Test pose string representation.""" - pose = Pose(1.234, 2.567, 3.891, 0.1, 0.2, 0.3, 0.9) - - repr_str = repr(pose) - - # Should contain position and orientation info - assert "Pose" in repr_str - assert "position" in repr_str - assert "orientation" in repr_str - - # Should contain the actual values (approximately) - assert "1.234" in repr_str or "1.23" in repr_str - assert "2.567" in repr_str or "2.57" in repr_str - - -def test_pose_str(): - """Test pose string formatting.""" - pose = Pose(1.234, 2.567, 3.891, 0.1, 0.2, 0.3, 0.9) - - str_repr = str(pose) - - # Should contain position coordinates - assert "1.234" in str_repr - assert "2.567" in str_repr - assert "3.891" in str_repr - - # Should contain Euler angles - assert "euler" in str_repr - - # Should be formatted with specified precision - assert str_repr.count("Pose") == 1 - - -def test_pose_equality(): - """Test pose equality comparison.""" - pose1 = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) - pose2 = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) - pose3 = Pose(1.1, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) # Different position - pose4 = Pose(1.0, 2.0, 3.0, 0.11, 0.2, 0.3, 0.9) # Different orientation - - # Equal poses - assert pose1 == pose2 - assert pose2 == pose1 - - # Different poses - assert pose1 != pose3 - assert pose1 != pose4 - assert pose3 != pose4 - - # Different types - assert pose1 != "not a pose" - assert pose1 != [1.0, 2.0, 3.0] - assert pose1 != None - - -def test_pose_with_numpy_arrays(): - """Test pose initialization with numpy arrays.""" - position_array = np.array([1.0, 2.0, 3.0]) - orientation_array = np.array([0.1, 0.2, 0.3, 0.9]) - - pose = Pose(position_array, orientation_array) - - # Position should match - assert pose.position.x == 1.0 - assert pose.position.y == 2.0 - assert pose.position.z == 3.0 - - # Orientation should match - assert pose.orientation.x == 0.1 - assert pose.orientation.y == 0.2 - assert pose.orientation.z == 0.3 - assert pose.orientation.w == 0.9 - - -def test_pose_with_mixed_types(): - """Test pose initialization with mixed input types.""" - # Position as tuple, orientation as list - pose1 = Pose((1.0, 2.0, 3.0), [0.1, 0.2, 0.3, 0.9]) - - # Position as numpy array, orientation as Vector3/Quaternion - position = np.array([1.0, 2.0, 3.0]) - orientation = Quaternion(0.1, 0.2, 0.3, 0.9) - pose2 = Pose(position, orientation) - - # Both should result in the same pose - assert pose1.position.x == pose2.position.x - assert pose1.position.y == pose2.position.y - assert pose1.position.z == pose2.position.z - assert pose1.orientation.x == pose2.orientation.x - assert pose1.orientation.y == pose2.orientation.y - assert pose1.orientation.z == pose2.orientation.z - assert pose1.orientation.w == pose2.orientation.w - - -def test_to_pose_passthrough(): - """Test to_pose function with Pose input (passthrough).""" - original = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) - result = to_pose(original) - - # Should be the same object (passthrough) - assert result is original - - -def test_to_pose_conversion(): - """Test to_pose function with convertible inputs.""" - # Note: The to_pose conversion function has type checking issues in the current implementation - # Test direct construction instead to verify the intended functionality - - # Test the intended functionality by creating poses directly - pose_tuple = ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3, 0.9]) - result1 = Pose(pose_tuple) - - assert isinstance(result1, Pose) - assert result1.position.x == 1.0 - assert result1.position.y == 2.0 - assert result1.position.z == 3.0 - assert result1.orientation.x == 0.1 - assert result1.orientation.y == 0.2 - assert result1.orientation.z == 0.3 - assert result1.orientation.w == 0.9 - - # Test with dictionary - pose_dict = {"position": [1.0, 2.0, 3.0], "orientation": [0.1, 0.2, 0.3, 0.9]} - result2 = Pose(pose_dict) - - assert isinstance(result2, Pose) - assert result2.position.x == 1.0 - assert result2.position.y == 2.0 - assert result2.position.z == 3.0 - assert result2.orientation.x == 0.1 - assert result2.orientation.y == 0.2 - assert result2.orientation.z == 0.3 - assert result2.orientation.w == 0.9 - - -def test_pose_euler_roundtrip(): - """Test conversion from Euler angles to quaternion and back.""" - # Start with known Euler angles (small angles to avoid gimbal lock) - roll = 0.1 - pitch = 0.2 - yaw = 0.3 - - # Create quaternion from Euler angles - euler_vector = Vector3(roll, pitch, yaw) - quaternion = euler_vector.to_quaternion() - - # Create pose with this quaternion - pose = Pose(Vector3(0, 0, 0), quaternion) - - # Convert back to Euler angles - result_euler = pose.orientation.euler - - # Should get back the original Euler angles (within tolerance) - assert np.isclose(result_euler.x, roll, atol=1e-6) - assert np.isclose(result_euler.y, pitch, atol=1e-6) - assert np.isclose(result_euler.z, yaw, atol=1e-6) - - -def test_pose_zero_position(): - """Test pose with zero position vector.""" - # Use manual construction since Vector3.zeros has signature issues - pose = Pose(0.0, 0.0, 0.0) # Position at origin with identity orientation - - assert pose.x == 0.0 - assert pose.y == 0.0 - assert pose.z == 0.0 - assert np.isclose(pose.roll, 0.0, atol=1e-10) - assert np.isclose(pose.pitch, 0.0, atol=1e-10) - assert np.isclose(pose.yaw, 0.0, atol=1e-10) - - -def test_pose_unit_vectors(): - """Test pose with unit vector positions.""" - # Test unit x vector position - pose_x = Pose(Vector3.unit_x()) - assert pose_x.x == 1.0 - assert pose_x.y == 0.0 - assert pose_x.z == 0.0 - - # Test unit y vector position - pose_y = Pose(Vector3.unit_y()) - assert pose_y.x == 0.0 - assert pose_y.y == 1.0 - assert pose_y.z == 0.0 - - # Test unit z vector position - pose_z = Pose(Vector3.unit_z()) - assert pose_z.x == 0.0 - assert pose_z.y == 0.0 - assert pose_z.z == 1.0 - - -def test_pose_negative_coordinates(): - """Test pose with negative coordinates.""" - pose = Pose(-1.0, -2.0, -3.0, -0.1, -0.2, -0.3, 0.9) - - # Position should be negative - assert pose.x == -1.0 - assert pose.y == -2.0 - assert pose.z == -3.0 - - # Orientation should be as specified - assert pose.orientation.x == -0.1 - assert pose.orientation.y == -0.2 - assert pose.orientation.z == -0.3 - assert pose.orientation.w == 0.9 - - -def test_pose_large_coordinates(): - """Test pose with large coordinate values.""" - large_value = 1000.0 - pose = Pose(large_value, large_value, large_value) - - assert pose.x == large_value - assert pose.y == large_value - assert pose.z == large_value - - # Orientation should still be identity - assert pose.orientation.x == 0.0 - assert pose.orientation.y == 0.0 - assert pose.orientation.z == 0.0 - assert pose.orientation.w == 1.0 - - -@pytest.mark.parametrize( - "x,y,z", - [(0.0, 0.0, 0.0), (1.0, 2.0, 3.0), (-1.0, -2.0, -3.0), (0.5, -0.5, 1.5), (100.0, -100.0, 0.0)], -) -def test_pose_parametrized_positions(x, y, z): - """Parametrized test for various position values.""" - pose = Pose(x, y, z) - - assert pose.x == x - assert pose.y == y - assert pose.z == z - - # Should have identity orientation - assert pose.orientation.x == 0.0 - assert pose.orientation.y == 0.0 - assert pose.orientation.z == 0.0 - assert pose.orientation.w == 1.0 - - -@pytest.mark.parametrize( - "qx,qy,qz,qw", - [ - (0.0, 0.0, 0.0, 1.0), # Identity - (1.0, 0.0, 0.0, 0.0), # 180° around x - (0.0, 1.0, 0.0, 0.0), # 180° around y - (0.0, 0.0, 1.0, 0.0), # 180° around z - (0.5, 0.5, 0.5, 0.5), # Equal components - ], -) -def test_pose_parametrized_orientations(qx, qy, qz, qw): - """Parametrized test for various orientation values.""" - pose = Pose(0.0, 0.0, 0.0, qx, qy, qz, qw) - - # Position should be at origin - assert pose.x == 0.0 - assert pose.y == 0.0 - assert pose.z == 0.0 - - # Orientation should match - assert pose.orientation.x == qx - assert pose.orientation.y == qy - assert pose.orientation.z == qz - assert pose.orientation.w == qw - - -def test_lcm_encode_decode(): - """Test encoding and decoding of Pose to/from binary LCM format.""" - - def encodepass(): - pose_source = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) - binary_msg = pose_source.lcm_encode() - pose_dest = Pose.lcm_decode(binary_msg) - assert isinstance(pose_dest, Pose) - assert pose_dest is not pose_source - assert pose_dest == pose_source - - import timeit - - print(f"{timeit.timeit(encodepass, number=1000)} ms per cycle") - - -def test_pickle_encode_decode(): - """Test encoding and decoding of Pose to/from binary LCM format.""" - - def encodepass(): - pose_source = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) - binary_msg = pickle.dumps(pose_source) - pose_dest = pickle.loads(binary_msg) - assert isinstance(pose_dest, Pose) - assert pose_dest is not pose_source - assert pose_dest == pose_source - - import timeit - - print(f"{timeit.timeit(encodepass, number=1000)} ms per cycle") diff --git a/build/lib/dimos/msgs/geometry_msgs/test_Quaternion.py b/build/lib/dimos/msgs/geometry_msgs/test_Quaternion.py deleted file mode 100644 index ab049f809f..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/test_Quaternion.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest -from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion - -from dimos.msgs.geometry_msgs.Quaternion import Quaternion - - -def test_quaternion_default_init(): - """Test that default initialization creates an identity quaternion (w=1, x=y=z=0).""" - q = Quaternion() - assert q.x == 0.0 - assert q.y == 0.0 - assert q.z == 0.0 - assert q.w == 1.0 - assert q.to_tuple() == (0.0, 0.0, 0.0, 1.0) - - -def test_quaternion_component_init(): - """Test initialization with four float components (x, y, z, w).""" - q = Quaternion(0.5, 0.5, 0.5, 0.5) - assert q.x == 0.5 - assert q.y == 0.5 - assert q.z == 0.5 - assert q.w == 0.5 - - # Test with different values - q2 = Quaternion(1.0, 2.0, 3.0, 4.0) - assert q2.x == 1.0 - assert q2.y == 2.0 - assert q2.z == 3.0 - assert q2.w == 4.0 - - # Test with negative values - q3 = Quaternion(-1.0, -2.0, -3.0, -4.0) - assert q3.x == -1.0 - assert q3.y == -2.0 - assert q3.z == -3.0 - assert q3.w == -4.0 - - # Test with integers (should convert to float) - q4 = Quaternion(1, 2, 3, 4) - assert q4.x == 1.0 - assert q4.y == 2.0 - assert q4.z == 3.0 - assert q4.w == 4.0 - assert isinstance(q4.x, float) - - -def test_quaternion_sequence_init(): - """Test initialization from sequence (list, tuple) of 4 numbers.""" - # From list - q1 = Quaternion([0.1, 0.2, 0.3, 0.4]) - assert q1.x == 0.1 - assert q1.y == 0.2 - assert q1.z == 0.3 - assert q1.w == 0.4 - - # From tuple - q2 = Quaternion((0.5, 0.6, 0.7, 0.8)) - assert q2.x == 0.5 - assert q2.y == 0.6 - assert q2.z == 0.7 - assert q2.w == 0.8 - - # Test with integers in sequence - q3 = Quaternion([1, 2, 3, 4]) - assert q3.x == 1.0 - assert q3.y == 2.0 - assert q3.z == 3.0 - assert q3.w == 4.0 - - # Test error with wrong length - with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): - Quaternion([1, 2, 3]) # Only 3 components - - with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): - Quaternion([1, 2, 3, 4, 5]) # Too many components - - -def test_quaternion_numpy_init(): - """Test initialization from numpy array.""" - # From numpy array - arr = np.array([0.1, 0.2, 0.3, 0.4]) - q1 = Quaternion(arr) - assert q1.x == 0.1 - assert q1.y == 0.2 - assert q1.z == 0.3 - assert q1.w == 0.4 - - # Test with different dtypes - arr_int = np.array([1, 2, 3, 4], dtype=int) - q2 = Quaternion(arr_int) - assert q2.x == 1.0 - assert q2.y == 2.0 - assert q2.z == 3.0 - assert q2.w == 4.0 - - # Test error with wrong size - with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): - Quaternion(np.array([1, 2, 3])) # Only 3 elements - - with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): - Quaternion(np.array([1, 2, 3, 4, 5])) # Too many elements - - -def test_quaternion_copy_init(): - """Test initialization from another Quaternion (copy constructor).""" - original = Quaternion(0.1, 0.2, 0.3, 0.4) - copy = Quaternion(original) - - assert copy.x == 0.1 - assert copy.y == 0.2 - assert copy.z == 0.3 - assert copy.w == 0.4 - - # Verify it's a copy, not the same object - assert copy is not original - assert copy == original - - -def test_quaternion_lcm_init(): - """Test initialization from LCM Quaternion.""" - lcm_quat = LCMQuaternion() - lcm_quat.x = 0.1 - lcm_quat.y = 0.2 - lcm_quat.z = 0.3 - lcm_quat.w = 0.4 - - q = Quaternion(lcm_quat) - assert q.x == 0.1 - assert q.y == 0.2 - assert q.z == 0.3 - assert q.w == 0.4 - - -def test_quaternion_properties(): - """Test quaternion component properties.""" - q = Quaternion(1.0, 2.0, 3.0, 4.0) - - # Test property access - assert q.x == 1.0 - assert q.y == 2.0 - assert q.z == 3.0 - assert q.w == 4.0 - - # Test as_tuple property - assert q.to_tuple() == (1.0, 2.0, 3.0, 4.0) - - -def test_quaternion_indexing(): - """Test quaternion indexing support.""" - q = Quaternion(1.0, 2.0, 3.0, 4.0) - - # Test indexing - assert q[0] == 1.0 - assert q[1] == 2.0 - assert q[2] == 3.0 - assert q[3] == 4.0 - - -def test_quaternion_euler(): - """Test quaternion to Euler angles conversion.""" - - # Test identity quaternion (should give zero angles) - q_identity = Quaternion() - angles = q_identity.to_euler() - assert np.isclose(angles.x, 0.0, atol=1e-10) # roll - assert np.isclose(angles.y, 0.0, atol=1e-10) # pitch - assert np.isclose(angles.z, 0.0, atol=1e-10) # yaw - - # Test 90 degree rotation around Z-axis (yaw) - q_z90 = Quaternion(0, 0, np.sin(np.pi / 4), np.cos(np.pi / 4)) - angles_z90 = q_z90.to_euler() - assert np.isclose(angles_z90.roll, 0.0, atol=1e-10) # roll should be 0 - assert np.isclose(angles_z90.pitch, 0.0, atol=1e-10) # pitch should be 0 - assert np.isclose(angles_z90.yaw, np.pi / 2, atol=1e-10) # yaw should be π/2 (90 degrees) - - # Test 90 degree rotation around X-axis (roll) - q_x90 = Quaternion(np.sin(np.pi / 4), 0, 0, np.cos(np.pi / 4)) - angles_x90 = q_x90.to_euler() - assert np.isclose(angles_x90.x, np.pi / 2, atol=1e-10) # roll should be π/2 - assert np.isclose(angles_x90.y, 0.0, atol=1e-10) # pitch should be 0 - assert np.isclose(angles_x90.z, 0.0, atol=1e-10) # yaw should be 0 - - -def test_lcm_encode_decode(): - """Test encoding and decoding of Quaternion to/from binary LCM format.""" - q_source = Quaternion(1.0, 2.0, 3.0, 4.0) - - binary_msg = q_source.lcm_encode() - - q_dest = Quaternion.lcm_decode(binary_msg) - - assert isinstance(q_dest, Quaternion) - assert q_dest is not q_source - assert q_dest == q_source diff --git a/build/lib/dimos/msgs/geometry_msgs/test_Vector3.py b/build/lib/dimos/msgs/geometry_msgs/test_Vector3.py deleted file mode 100644 index 81325286f9..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/test_Vector3.py +++ /dev/null @@ -1,462 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest - -from dimos.msgs.geometry_msgs.Vector3 import Vector3 - - -def test_vector_default_init(): - """Test that default initialization of Vector() has x,y,z components all zero.""" - v = Vector3() - assert v.x == 0.0 - assert v.y == 0.0 - assert v.z == 0.0 - assert len(v.data) == 3 - assert v.to_list() == [0.0, 0.0, 0.0] - assert v.is_zero() == True # Zero vector should be considered zero - - -def test_vector_specific_init(): - """Test initialization with specific values and different input types.""" - - v1 = Vector3(1.0, 2.0) # 2D vector (now becomes 3D with z=0) - assert v1.x == 1.0 - assert v1.y == 2.0 - assert v1.z == 0.0 - - v2 = Vector3(3.0, 4.0, 5.0) # 3D vector - assert v2.x == 3.0 - assert v2.y == 4.0 - assert v2.z == 5.0 - - v3 = Vector3([6.0, 7.0, 8.0]) - assert v3.x == 6.0 - assert v3.y == 7.0 - assert v3.z == 8.0 - - v4 = Vector3((9.0, 10.0, 11.0)) - assert v4.x == 9.0 - assert v4.y == 10.0 - assert v4.z == 11.0 - - v5 = Vector3(np.array([12.0, 13.0, 14.0])) - assert v5.x == 12.0 - assert v5.y == 13.0 - assert v5.z == 14.0 - - original = Vector3([15.0, 16.0, 17.0]) - v6 = Vector3(original) - assert v6.x == 15.0 - assert v6.y == 16.0 - assert v6.z == 17.0 - - assert v6 is not original - assert v6 == original - - -def test_vector_addition(): - """Test vector addition.""" - v1 = Vector3(1.0, 2.0, 3.0) - v2 = Vector3(4.0, 5.0, 6.0) - - v_add = v1 + v2 - assert v_add.x == 5.0 - assert v_add.y == 7.0 - assert v_add.z == 9.0 - - -def test_vector_subtraction(): - """Test vector subtraction.""" - v1 = Vector3(1.0, 2.0, 3.0) - v2 = Vector3(4.0, 5.0, 6.0) - - v_sub = v2 - v1 - assert v_sub.x == 3.0 - assert v_sub.y == 3.0 - assert v_sub.z == 3.0 - - -def test_vector_scalar_multiplication(): - """Test vector multiplication by a scalar.""" - v1 = Vector3(1.0, 2.0, 3.0) - - v_mul = v1 * 2.0 - assert v_mul.x == 2.0 - assert v_mul.y == 4.0 - assert v_mul.z == 6.0 - - # Test right multiplication - v_rmul = 2.0 * v1 - assert v_rmul.x == 2.0 - assert v_rmul.y == 4.0 - assert v_rmul.z == 6.0 - - -def test_vector_scalar_division(): - """Test vector division by a scalar.""" - v2 = Vector3(4.0, 5.0, 6.0) - - v_div = v2 / 2.0 - assert v_div.x == 2.0 - assert v_div.y == 2.5 - assert v_div.z == 3.0 - - -def test_vector_dot_product(): - """Test vector dot product.""" - v1 = Vector3(1.0, 2.0, 3.0) - v2 = Vector3(4.0, 5.0, 6.0) - - dot = v1.dot(v2) - assert dot == 32.0 - - -def test_vector_length(): - """Test vector length calculation.""" - # 2D vector with length 5 (now 3D with z=0) - v1 = Vector3(3.0, 4.0) - assert v1.length() == 5.0 - - # 3D vector - v2 = Vector3(2.0, 3.0, 6.0) - assert v2.length() == pytest.approx(7.0, 0.001) - - # Test length_squared - assert v1.length_squared() == 25.0 - assert v2.length_squared() == 49.0 - - -def test_vector_normalize(): - """Test vector normalization.""" - v = Vector3(2.0, 3.0, 6.0) - assert v.is_zero() == False - - v_norm = v.normalize() - length = v.length() - expected_x = 2.0 / length - expected_y = 3.0 / length - expected_z = 6.0 / length - - assert np.isclose(v_norm.x, expected_x) - assert np.isclose(v_norm.y, expected_y) - assert np.isclose(v_norm.z, expected_z) - assert np.isclose(v_norm.length(), 1.0) - assert v_norm.is_zero() == False - - # Test normalizing a zero vector - v_zero = Vector3(0.0, 0.0, 0.0) - assert v_zero.is_zero() == True - v_zero_norm = v_zero.normalize() - assert v_zero_norm.x == 0.0 - assert v_zero_norm.y == 0.0 - assert v_zero_norm.z == 0.0 - assert v_zero_norm.is_zero() == True - - -def test_vector_to_2d(): - """Test conversion to 2D vector.""" - v = Vector3(2.0, 3.0, 6.0) - - v_2d = v.to_2d() - assert v_2d.x == 2.0 - assert v_2d.y == 3.0 - assert v_2d.z == 0.0 # z should be 0 for 2D conversion - - # Already 2D vector (z=0) - v2 = Vector3(4.0, 5.0) - v2_2d = v2.to_2d() - assert v2_2d.x == 4.0 - assert v2_2d.y == 5.0 - assert v2_2d.z == 0.0 - - -def test_vector_distance(): - """Test distance calculations between vectors.""" - v1 = Vector3(1.0, 2.0, 3.0) - v2 = Vector3(4.0, 6.0, 8.0) - - # Distance - dist = v1.distance(v2) - expected_dist = np.sqrt(9.0 + 16.0 + 25.0) # sqrt((4-1)² + (6-2)² + (8-3)²) - assert dist == pytest.approx(expected_dist) - - # Distance squared - dist_sq = v1.distance_squared(v2) - assert dist_sq == 50.0 # 9 + 16 + 25 - - -def test_vector_cross_product(): - """Test vector cross product.""" - v1 = Vector3(1.0, 0.0, 0.0) # Unit x vector - v2 = Vector3(0.0, 1.0, 0.0) # Unit y vector - - # v1 × v2 should be unit z vector - cross = v1.cross(v2) - assert cross.x == 0.0 - assert cross.y == 0.0 - assert cross.z == 1.0 - - # Test with more complex vectors - a = Vector3(2.0, 3.0, 4.0) - b = Vector3(5.0, 6.0, 7.0) - c = a.cross(b) - - # Cross product manually calculated: - # (3*7-4*6, 4*5-2*7, 2*6-3*5) - assert c.x == -3.0 - assert c.y == 6.0 - assert c.z == -3.0 - - # Test with vectors that have z=0 (still works as they're 3D) - v_2d1 = Vector3(1.0, 2.0) # (1, 2, 0) - v_2d2 = Vector3(3.0, 4.0) # (3, 4, 0) - cross_2d = v_2d1.cross(v_2d2) - # (2*0-0*4, 0*3-1*0, 1*4-2*3) = (0, 0, -2) - assert cross_2d.x == 0.0 - assert cross_2d.y == 0.0 - assert cross_2d.z == -2.0 - - -def test_vector_zeros(): - """Test Vector3.zeros class method.""" - # 3D zero vector - v_zeros = Vector3.zeros() - assert v_zeros.x == 0.0 - assert v_zeros.y == 0.0 - assert v_zeros.z == 0.0 - assert v_zeros.is_zero() == True - - -def test_vector_ones(): - """Test Vector3.ones class method.""" - # 3D ones vector - v_ones = Vector3.ones() - assert v_ones.x == 1.0 - assert v_ones.y == 1.0 - assert v_ones.z == 1.0 - - -def test_vector_conversion_methods(): - """Test vector conversion methods (to_list, to_tuple, to_numpy).""" - v = Vector3(1.0, 2.0, 3.0) - - # to_list - assert v.to_list() == [1.0, 2.0, 3.0] - - # to_tuple - assert v.to_tuple() == (1.0, 2.0, 3.0) - - # to_numpy - np_array = v.to_numpy() - assert isinstance(np_array, np.ndarray) - assert np.array_equal(np_array, np.array([1.0, 2.0, 3.0])) - - -def test_vector_equality(): - """Test vector equality.""" - v1 = Vector3(1, 2, 3) - v2 = Vector3(1, 2, 3) - v3 = Vector3(4, 5, 6) - - assert v1 == v2 - assert v1 != v3 - assert v1 != Vector3(1, 2) # Now (1, 2, 0) vs (1, 2, 3) - assert v1 != Vector3(1.1, 2, 3) # Different values - assert v1 != [1, 2, 3] - - -def test_vector_is_zero(): - """Test is_zero method for vectors.""" - # Default zero vector - v0 = Vector3() - assert v0.is_zero() == True - - # Explicit zero vector - v1 = Vector3(0.0, 0.0, 0.0) - assert v1.is_zero() == True - - # Zero vector with different initialization (now always 3D) - v2 = Vector3(0.0, 0.0) # Becomes (0, 0, 0) - assert v2.is_zero() == True - - # Non-zero vectors - v3 = Vector3(1.0, 0.0, 0.0) - assert v3.is_zero() == False - - v4 = Vector3(0.0, 2.0, 0.0) - assert v4.is_zero() == False - - v5 = Vector3(0.0, 0.0, 3.0) - assert v5.is_zero() == False - - # Almost zero (within tolerance) - v6 = Vector3(1e-10, 1e-10, 1e-10) - assert v6.is_zero() == True - - # Almost zero (outside tolerance) - v7 = Vector3(1e-6, 1e-6, 1e-6) - assert v7.is_zero() == False - - -def test_vector_bool_conversion(): - """Test boolean conversion of vectors.""" - # Zero vectors should be False - v0 = Vector3() - assert bool(v0) == False - - v1 = Vector3(0.0, 0.0, 0.0) - assert bool(v1) == False - - # Almost zero vectors should be False - v2 = Vector3(1e-10, 1e-10, 1e-10) - assert bool(v2) == False - - # Non-zero vectors should be True - v3 = Vector3(1.0, 0.0, 0.0) - assert bool(v3) == True - - v4 = Vector3(0.0, 2.0, 0.0) - assert bool(v4) == True - - v5 = Vector3(0.0, 0.0, 3.0) - assert bool(v5) == True - - # Direct use in if statements - if v0: - assert False, "Zero vector should be False in boolean context" - else: - pass # Expected path - - if v3: - pass # Expected path - else: - assert False, "Non-zero vector should be True in boolean context" - - -def test_vector_add(): - """Test vector addition operator.""" - v1 = Vector3(1.0, 2.0, 3.0) - v2 = Vector3(4.0, 5.0, 6.0) - - # Using __add__ method - v_add = v1.__add__(v2) - assert v_add.x == 5.0 - assert v_add.y == 7.0 - assert v_add.z == 9.0 - - # Using + operator - v_add_op = v1 + v2 - assert v_add_op.x == 5.0 - assert v_add_op.y == 7.0 - assert v_add_op.z == 9.0 - - # Adding zero vector should return original vector - v_zero = Vector3.zeros() - assert (v1 + v_zero) == v1 - - -def test_vector_add_dim_mismatch(): - """Test vector addition with different input dimensions (now all vectors are 3D).""" - v1 = Vector3(1.0, 2.0) # Becomes (1, 2, 0) - v2 = Vector3(4.0, 5.0, 6.0) # (4, 5, 6) - - # Using + operator - should work fine now since both are 3D - v_add_op = v1 + v2 - assert v_add_op.x == 5.0 # 1 + 4 - assert v_add_op.y == 7.0 # 2 + 5 - assert v_add_op.z == 6.0 # 0 + 6 - - -def test_yaw_pitch_roll_accessors(): - """Test yaw, pitch, and roll accessor properties.""" - # Test with a 3D vector - v = Vector3(1.0, 2.0, 3.0) - - # According to standard convention: - # roll = rotation around x-axis = x component - # pitch = rotation around y-axis = y component - # yaw = rotation around z-axis = z component - assert v.roll == 1.0 # Should return x component - assert v.pitch == 2.0 # Should return y component - assert v.yaw == 3.0 # Should return z component - - # Test with a 2D vector (z should be 0.0) - v_2d = Vector3(4.0, 5.0) - assert v_2d.roll == 4.0 # Should return x component - assert v_2d.pitch == 5.0 # Should return y component - assert v_2d.yaw == 0.0 # Should return z component (defaults to 0 for 2D) - - # Test with empty vector (all should be 0.0) - v_empty = Vector3() - assert v_empty.roll == 0.0 - assert v_empty.pitch == 0.0 - assert v_empty.yaw == 0.0 - - # Test with negative values - v_neg = Vector3(-1.5, -2.5, -3.5) - assert v_neg.roll == -1.5 - assert v_neg.pitch == -2.5 - assert v_neg.yaw == -3.5 - - -def test_vector_to_quaternion(): - """Test vector to quaternion conversion.""" - # Test with zero Euler angles (should produce identity quaternion) - v_zero = Vector3(0.0, 0.0, 0.0) - q_identity = v_zero.to_quaternion() - - # Identity quaternion should have w=1, x=y=z=0 - assert np.isclose(q_identity.x, 0.0, atol=1e-10) - assert np.isclose(q_identity.y, 0.0, atol=1e-10) - assert np.isclose(q_identity.z, 0.0, atol=1e-10) - assert np.isclose(q_identity.w, 1.0, atol=1e-10) - - # Test with small angles (to avoid gimbal lock issues) - v_small = Vector3(0.1, 0.2, 0.3) # Small roll, pitch, yaw - q_small = v_small.to_quaternion() - - # Quaternion should be normalized (magnitude = 1) - magnitude = np.sqrt(q_small.x**2 + q_small.y**2 + q_small.z**2 + q_small.w**2) - assert np.isclose(magnitude, 1.0, atol=1e-10) - - # Test conversion back to Euler (should be close to original) - v_back = q_small.to_euler() - assert np.isclose(v_back.x, 0.1, atol=1e-6) - assert np.isclose(v_back.y, 0.2, atol=1e-6) - assert np.isclose(v_back.z, 0.3, atol=1e-6) - - # Test with π/2 rotation around x-axis - v_x_90 = Vector3(np.pi / 2, 0.0, 0.0) - q_x_90 = v_x_90.to_quaternion() - - # Should be approximately (sin(π/4), 0, 0, cos(π/4)) = (√2/2, 0, 0, √2/2) - expected = np.sqrt(2) / 2 - assert np.isclose(q_x_90.x, expected, atol=1e-10) - assert np.isclose(q_x_90.y, 0.0, atol=1e-10) - assert np.isclose(q_x_90.z, 0.0, atol=1e-10) - assert np.isclose(q_x_90.w, expected, atol=1e-10) - - -def test_lcm_encode_decode(): - v_source = Vector3(1.0, 2.0, 3.0) - - binary_msg = v_source.lcm_encode() - - v_dest = Vector3.lcm_decode(binary_msg) - - assert isinstance(v_dest, Vector3) - assert v_dest is not v_source - assert v_dest == v_source diff --git a/build/lib/dimos/msgs/geometry_msgs/test_publish.py b/build/lib/dimos/msgs/geometry_msgs/test_publish.py deleted file mode 100644 index 4e364dc19a..0000000000 --- a/build/lib/dimos/msgs/geometry_msgs/test_publish.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import lcm -import pytest - -from dimos.msgs.geometry_msgs import Vector3 - - -@pytest.mark.tool -def test_runpublish(): - for i in range(10): - msg = Vector3(-5 + i, -5 + i, i) - lc = lcm.LCM() - lc.publish("thing1_vector3#geometry_msgs.Vector3", msg.encode()) - time.sleep(0.1) - print(f"Published: {msg}") - - -@pytest.mark.tool -def test_receive(): - lc = lcm.LCM() - - def receive(bla, msg): - # print("receive", bla, msg) - print(Vector3.decode(msg)) - - lc.subscribe("thing1_vector3#geometry_msgs.Vector3", receive) - - def _loop(): - while True: - """LCM message handling loop""" - try: - lc.handle() - # loop 10000 times - for _ in range(10000000): - 3 + 3 - except Exception as e: - print(f"Error in LCM handling: {e}") - - _loop() diff --git a/build/lib/dimos/msgs/sensor_msgs/Image.py b/build/lib/dimos/msgs/sensor_msgs/Image.py deleted file mode 100644 index 2ac53a2fd7..0000000000 --- a/build/lib/dimos/msgs/sensor_msgs/Image.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from dataclasses import dataclass, field -from enum import Enum -from typing import Optional, Tuple - -import cv2 -import numpy as np - -# Import LCM types -from dimos_lcm.sensor_msgs.Image import Image as LCMImage -from dimos_lcm.std_msgs.Header import Header - -from dimos.types.timestamped import Timestamped - - -class ImageFormat(Enum): - """Supported image formats.""" - - BGR = "bgr8" - RGB = "rgb8" - RGBA = "rgba8" - BGRA = "bgra8" - GRAY = "mono8" - GRAY16 = "mono16" - - -@dataclass -class Image(Timestamped): - """Standardized image type with LCM integration.""" - - msg_name = "sensor_msgs.Image" - data: np.ndarray - format: ImageFormat = field(default=ImageFormat.BGR) - frame_id: str = field(default="") - ts: float = field(default_factory=time.time) - - def __post_init__(self): - """Validate image data and format.""" - if self.data is None: - raise ValueError("Image data cannot be None") - - if not isinstance(self.data, np.ndarray): - raise ValueError("Image data must be a numpy array") - - if len(self.data.shape) < 2: - raise ValueError("Image data must be at least 2D") - - # Ensure data is contiguous for efficient operations - if not self.data.flags["C_CONTIGUOUS"]: - self.data = np.ascontiguousarray(self.data) - - @property - def height(self) -> int: - """Get image height.""" - return self.data.shape[0] - - @property - def width(self) -> int: - """Get image width.""" - return self.data.shape[1] - - @property - def channels(self) -> int: - """Get number of channels.""" - if len(self.data.shape) == 2: - return 1 - elif len(self.data.shape) == 3: - return self.data.shape[2] - else: - raise ValueError("Invalid image dimensions") - - @property - def shape(self) -> Tuple[int, ...]: - """Get image shape.""" - return self.data.shape - - @property - def dtype(self) -> np.dtype: - """Get image data type.""" - return self.data.dtype - - def copy(self) -> "Image": - """Create a deep copy of the image.""" - return self.__class__( - data=self.data.copy(), - format=self.format, - frame_id=self.frame_id, - ts=self.ts, - ) - - @classmethod - def from_opencv( - cls, cv_image: np.ndarray, format: ImageFormat = ImageFormat.BGR, **kwargs - ) -> "Image": - """Create Image from OpenCV image array.""" - return cls(data=cv_image, format=format, **kwargs) - - @classmethod - def from_numpy( - cls, np_image: np.ndarray, format: ImageFormat = ImageFormat.BGR, **kwargs - ) -> "Image": - """Create Image from numpy array.""" - return cls(data=np_image, format=format, **kwargs) - - @classmethod - def from_file(cls, filepath: str, format: ImageFormat = ImageFormat.BGR) -> "Image": - """Load image from file.""" - # OpenCV loads as BGR by default - cv_image = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) - if cv_image is None: - raise ValueError(f"Could not load image from {filepath}") - - # Detect format based on channels - if len(cv_image.shape) == 2: - detected_format = ImageFormat.GRAY - elif cv_image.shape[2] == 3: - detected_format = ImageFormat.BGR # OpenCV default - elif cv_image.shape[2] == 4: - detected_format = ImageFormat.BGRA - else: - detected_format = format - - return cls(data=cv_image, format=detected_format) - - def to_opencv(self) -> np.ndarray: - """Convert to OpenCV-compatible array (BGR format).""" - if self.format == ImageFormat.BGR: - return self.data - elif self.format == ImageFormat.RGB: - return cv2.cvtColor(self.data, cv2.COLOR_RGB2BGR) - elif self.format == ImageFormat.RGBA: - return cv2.cvtColor(self.data, cv2.COLOR_RGBA2BGR) - elif self.format == ImageFormat.BGRA: - return cv2.cvtColor(self.data, cv2.COLOR_BGRA2BGR) - elif self.format == ImageFormat.GRAY: - return self.data - elif self.format == ImageFormat.GRAY16: - return self.data - else: - raise ValueError(f"Unsupported format conversion: {self.format}") - - def to_rgb(self) -> "Image": - """Convert image to RGB format.""" - if self.format == ImageFormat.RGB: - return self.copy() - elif self.format == ImageFormat.BGR: - rgb_data = cv2.cvtColor(self.data, cv2.COLOR_BGR2RGB) - elif self.format == ImageFormat.RGBA: - return self.copy() # Already RGB with alpha - elif self.format == ImageFormat.BGRA: - rgb_data = cv2.cvtColor(self.data, cv2.COLOR_BGRA2RGBA) - elif self.format == ImageFormat.GRAY: - rgb_data = cv2.cvtColor(self.data, cv2.COLOR_GRAY2RGB) - elif self.format == ImageFormat.GRAY16: - # Convert 16-bit grayscale to 8-bit then to RGB - gray8 = (self.data / 256).astype(np.uint8) - rgb_data = cv2.cvtColor(gray8, cv2.COLOR_GRAY2RGB) - else: - raise ValueError(f"Unsupported format conversion from {self.format} to RGB") - - return self.__class__( - data=rgb_data, - format=ImageFormat.RGB if self.format != ImageFormat.BGRA else ImageFormat.RGBA, - frame_id=self.frame_id, - ts=self.ts, - ) - - def to_bgr(self) -> "Image": - """Convert image to BGR format.""" - if self.format == ImageFormat.BGR: - return self.copy() - elif self.format == ImageFormat.RGB: - bgr_data = cv2.cvtColor(self.data, cv2.COLOR_RGB2BGR) - elif self.format == ImageFormat.RGBA: - bgr_data = cv2.cvtColor(self.data, cv2.COLOR_RGBA2BGR) - elif self.format == ImageFormat.BGRA: - bgr_data = cv2.cvtColor(self.data, cv2.COLOR_BGRA2BGR) - elif self.format == ImageFormat.GRAY: - bgr_data = cv2.cvtColor(self.data, cv2.COLOR_GRAY2BGR) - elif self.format == ImageFormat.GRAY16: - # Convert 16-bit grayscale to 8-bit then to BGR - gray8 = (self.data / 256).astype(np.uint8) - bgr_data = cv2.cvtColor(gray8, cv2.COLOR_GRAY2BGR) - else: - raise ValueError(f"Unsupported format conversion from {self.format} to BGR") - - return self.__class__( - data=bgr_data, - format=ImageFormat.BGR, - frame_id=self.frame_id, - ts=self.ts, - ) - - def to_grayscale(self) -> "Image": - """Convert image to grayscale.""" - if self.format == ImageFormat.GRAY: - return self.copy() - elif self.format == ImageFormat.GRAY16: - return self.copy() - elif self.format == ImageFormat.BGR: - gray_data = cv2.cvtColor(self.data, cv2.COLOR_BGR2GRAY) - elif self.format == ImageFormat.RGB: - gray_data = cv2.cvtColor(self.data, cv2.COLOR_RGB2GRAY) - elif self.format == ImageFormat.RGBA: - gray_data = cv2.cvtColor(self.data, cv2.COLOR_RGBA2GRAY) - elif self.format == ImageFormat.BGRA: - gray_data = cv2.cvtColor(self.data, cv2.COLOR_BGRA2GRAY) - else: - raise ValueError(f"Unsupported format conversion from {self.format} to grayscale") - - return self.__class__( - data=gray_data, - format=ImageFormat.GRAY, - frame_id=self.frame_id, - ts=self.ts, - ) - - def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> "Image": - """Resize the image to the specified dimensions.""" - resized_data = cv2.resize(self.data, (width, height), interpolation=interpolation) - - return self.__class__( - data=resized_data, - format=self.format, - frame_id=self.frame_id, - ts=self.ts, - ) - - def crop(self, x: int, y: int, width: int, height: int) -> "Image": - """Crop the image to the specified region.""" - # Ensure crop region is within image bounds - x = max(0, min(x, self.width)) - y = max(0, min(y, self.height)) - x2 = min(x + width, self.width) - y2 = min(y + height, self.height) - - cropped_data = self.data[y:y2, x:x2] - - return self.__class__( - data=cropped_data, - format=self.format, - frame_id=self.frame_id, - ts=self.ts, - ) - - def save(self, filepath: str) -> bool: - """Save image to file.""" - # Convert to OpenCV format for saving - cv_image = self.to_opencv() - return cv2.imwrite(filepath, cv_image) - - def lcm_encode(self, frame_id: Optional[str] = None) -> LCMImage: - """Convert to LCM Image message.""" - msg = LCMImage() - - # Header - msg.header = Header() - msg.header.seq = 0 # Initialize sequence number - msg.header.frame_id = frame_id or self.frame_id - - # Set timestamp properly as Time object - if self.ts is not None: - msg.header.stamp.sec = int(self.ts) - msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) - else: - current_time = time.time() - msg.header.stamp.sec = int(current_time) - msg.header.stamp.nsec = int((current_time - int(current_time)) * 1e9) - - # Image properties - msg.height = self.height - msg.width = self.width - msg.encoding = self.format.value - msg.is_bigendian = False # Use little endian - msg.step = self._get_row_step() - - # Image data - image_bytes = self.data.tobytes() - msg.data_length = len(image_bytes) - msg.data = image_bytes - - return msg.encode() - - @classmethod - def lcm_decode(cls, data: bytes, **kwargs) -> "Image": - """Create Image from LCM Image message.""" - # Parse encoding to determine format and data type - msg = LCMImage.decode(data) - format_info = cls._parse_encoding(msg.encoding) - - # Convert bytes back to numpy array - data = np.frombuffer(msg.data, dtype=format_info["dtype"]) - - # Reshape to image dimensions - if format_info["channels"] == 1: - data = data.reshape((msg.height, msg.width)) - else: - data = data.reshape((msg.height, msg.width, format_info["channels"])) - - return cls( - data=data, - format=format_info["format"], - frame_id=msg.header.frame_id if hasattr(msg, "header") else "", - ts=msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 - if hasattr(msg, "header") and msg.header.stamp.sec > 0 - else time.time(), - **kwargs, - ) - - def _get_row_step(self) -> int: - """Calculate row step (bytes per row).""" - bytes_per_pixel = self._get_bytes_per_pixel() - return self.width * bytes_per_pixel - - def _get_bytes_per_pixel(self) -> int: - """Calculate bytes per pixel based on format and data type.""" - bytes_per_element = self.data.dtype.itemsize - return self.channels * bytes_per_element - - @staticmethod - def _parse_encoding(encoding: str) -> dict: - """Parse LCM image encoding string to determine format and data type.""" - encoding_map = { - "mono8": {"format": ImageFormat.GRAY, "dtype": np.uint8, "channels": 1}, - "mono16": {"format": ImageFormat.GRAY16, "dtype": np.uint16, "channels": 1}, - "rgb8": {"format": ImageFormat.RGB, "dtype": np.uint8, "channels": 3}, - "rgba8": {"format": ImageFormat.RGBA, "dtype": np.uint8, "channels": 4}, - "bgr8": {"format": ImageFormat.BGR, "dtype": np.uint8, "channels": 3}, - "bgra8": {"format": ImageFormat.BGRA, "dtype": np.uint8, "channels": 4}, - } - - if encoding not in encoding_map: - raise ValueError(f"Unsupported encoding: {encoding}") - - return encoding_map[encoding] - - def __repr__(self) -> str: - """String representation.""" - return ( - f"Image(shape={self.shape}, format={self.format.value}, " - f"dtype={self.dtype}, frame_id='{self.frame_id}', ts={self.ts})" - ) - - def __eq__(self, other) -> bool: - """Check equality with another Image.""" - if not isinstance(other, Image): - return False - - return ( - np.array_equal(self.data, other.data) - and self.format == other.format - and self.frame_id == other.frame_id - and abs(self.ts - other.ts) < 1e-6 - ) - - def __len__(self) -> int: - """Return total number of pixels.""" - return self.height * self.width diff --git a/build/lib/dimos/msgs/sensor_msgs/PointCloud2.py b/build/lib/dimos/msgs/sensor_msgs/PointCloud2.py deleted file mode 100644 index 4c4455a473..0000000000 --- a/build/lib/dimos/msgs/sensor_msgs/PointCloud2.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import struct -import time -from typing import Optional - -import numpy as np -import open3d as o3d - -# Import LCM types -from dimos_lcm.sensor_msgs.PointCloud2 import ( - PointCloud2 as LCMPointCloud2, -) -from dimos_lcm.sensor_msgs.PointField import PointField -from dimos_lcm.std_msgs.Header import Header - -from dimos.types.timestamped import Timestamped - - -# TODO: encode/decode need to be updated to work with full spectrum of pointcloud2 fields -class PointCloud2(Timestamped): - msg_name = "sensor_msgs.PointCloud2" - - def __init__( - self, - pointcloud: o3d.geometry.PointCloud = None, - frame_id: str = "", - ts: Optional[float] = None, - ): - self.ts = ts if ts is not None else time.time() - self.pointcloud = pointcloud if pointcloud is not None else o3d.geometry.PointCloud() - self.frame_id = frame_id - - # TODO what's the usual storage here? is it already numpy? - def as_numpy(self) -> np.ndarray: - """Get points as numpy array.""" - return np.asarray(self.pointcloud.points) - - def lcm_encode(self, frame_id: Optional[str] = None) -> bytes: - """Convert to LCM PointCloud2 message.""" - msg = LCMPointCloud2() - - # Header - msg.header = Header() - msg.header.seq = 0 # Initialize sequence number - msg.header.frame_id = frame_id or self.frame_id - - msg.header.stamp.sec = int(self.ts) - msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) - - points = self.as_numpy() - if len(points) == 0: - # Empty point cloud - msg.height = 0 - msg.width = 0 - msg.point_step = 16 # 4 floats * 4 bytes (x, y, z, intensity) - msg.row_step = 0 - msg.data_length = 0 - msg.data = b"" - msg.is_dense = True - msg.is_bigendian = False - msg.fields_length = 4 # x, y, z, intensity - msg.fields = self._create_xyz_field() - return msg.encode() - - # Point cloud dimensions - msg.height = 1 # Unorganized point cloud - msg.width = len(points) - - # Define fields (X, Y, Z, intensity as float32) - msg.fields_length = 4 # x, y, z, intensity - msg.fields = self._create_xyz_field() - - # Point step and row step - msg.point_step = 16 # 4 floats * 4 bytes each (x, y, z, intensity) - msg.row_step = msg.point_step * msg.width - - # Convert points to bytes with intensity padding (little endian float32) - # Add intensity column (zeros) to make it 4 columns: x, y, z, intensity - points_with_intensity = np.column_stack( - [ - points, # x, y, z columns - np.zeros(len(points), dtype=np.float32), # intensity column (padding) - ] - ) - data_bytes = points_with_intensity.astype(np.float32).tobytes() - msg.data_length = len(data_bytes) - msg.data = data_bytes - - # Properties - msg.is_dense = True # No invalid points - msg.is_bigendian = False # Little endian - - return msg.encode() - - @classmethod - def lcm_decode(cls, data: bytes) -> "PointCloud2": - msg = LCMPointCloud2.decode(data) - - if msg.width == 0 or msg.height == 0: - # Empty point cloud - pc = o3d.geometry.PointCloud() - return cls( - pointcloud=pc, - frame_id=msg.header.frame_id if hasattr(msg, "header") else "", - ts=msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 - if hasattr(msg, "header") and msg.header.stamp.sec > 0 - else None, - ) - - # Parse field information to find X, Y, Z offsets - x_offset = y_offset = z_offset = None - for msgfield in msg.fields: - if msgfield.name == "x": - x_offset = msgfield.offset - elif msgfield.name == "y": - y_offset = msgfield.offset - elif msgfield.name == "z": - z_offset = msgfield.offset - - if any(offset is None for offset in [x_offset, y_offset, z_offset]): - raise ValueError("PointCloud2 message missing X, Y, or Z msgfields") - - # Extract points from binary data - num_points = msg.width * msg.height - points = np.zeros((num_points, 3), dtype=np.float32) - - data = msg.data - point_step = msg.point_step - - for i in range(num_points): - base_offset = i * point_step - - # Extract X, Y, Z (assuming float32, little endian) - x_bytes = data[base_offset + x_offset : base_offset + x_offset + 4] - y_bytes = data[base_offset + y_offset : base_offset + y_offset + 4] - z_bytes = data[base_offset + z_offset : base_offset + z_offset + 4] - - points[i, 0] = struct.unpack(" 0 - else None, - ) - - def _create_xyz_field(self) -> list: - """Create standard X, Y, Z field definitions for LCM PointCloud2.""" - fields = [] - - # X field - x_field = PointField() - x_field.name = "x" - x_field.offset = 0 - x_field.datatype = 7 # FLOAT32 - x_field.count = 1 - fields.append(x_field) - - # Y field - y_field = PointField() - y_field.name = "y" - y_field.offset = 4 - y_field.datatype = 7 # FLOAT32 - y_field.count = 1 - fields.append(y_field) - - # Z field - z_field = PointField() - z_field.name = "z" - z_field.offset = 8 - z_field.datatype = 7 # FLOAT32 - z_field.count = 1 - fields.append(z_field) - - # I field - i_field = PointField() - i_field.name = "intensity" - i_field.offset = 12 - i_field.datatype = 7 # FLOAT32 - i_field.count = 1 - fields.append(i_field) - - return fields - - def __len__(self) -> int: - """Return number of points.""" - return len(self.pointcloud.points) - - def __repr__(self) -> str: - """String representation.""" - return f"PointCloud(points={len(self)}, frame_id='{self.frame_id}', ts={self.ts})" diff --git a/build/lib/dimos/msgs/sensor_msgs/__init__.py b/build/lib/dimos/msgs/sensor_msgs/__init__.py deleted file mode 100644 index 170587e286..0000000000 --- a/build/lib/dimos/msgs/sensor_msgs/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from dimos.msgs.sensor_msgs.Image import Image -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 diff --git a/build/lib/dimos/msgs/sensor_msgs/test_PointCloud2.py b/build/lib/dimos/msgs/sensor_msgs/test_PointCloud2.py deleted file mode 100644 index eee1778680..0000000000 --- a/build/lib/dimos/msgs/sensor_msgs/test_PointCloud2.py +++ /dev/null @@ -1,81 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.utils.testing import SensorReplay - - -def test_lcm_encode_decode(): - """Test LCM encode/decode preserves pointcloud data.""" - replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) - lidar_msg: LidarMessage = replay.load_one("lidar_data_021") - - binary_msg = lidar_msg.lcm_encode() - decoded = PointCloud2.lcm_decode(binary_msg) - - # 1. Check number of points - original_points = lidar_msg.as_numpy() - decoded_points = decoded.as_numpy() - - print(f"Original points: {len(original_points)}") - print(f"Decoded points: {len(decoded_points)}") - assert len(original_points) == len(decoded_points), ( - f"Point count mismatch: {len(original_points)} vs {len(decoded_points)}" - ) - - # 2. Check point coordinates are preserved (within floating point tolerance) - if len(original_points) > 0: - np.testing.assert_allclose( - original_points, - decoded_points, - rtol=1e-6, - atol=1e-6, - err_msg="Point coordinates don't match between original and decoded", - ) - print(f"✓ All {len(original_points)} point coordinates match within tolerance") - - # 3. Check frame_id is preserved - assert lidar_msg.frame_id == decoded.frame_id, ( - f"Frame ID mismatch: '{lidar_msg.frame_id}' vs '{decoded.frame_id}'" - ) - print(f"✓ Frame ID preserved: '{decoded.frame_id}'") - - # 4. Check timestamp is preserved (within reasonable tolerance for float precision) - if lidar_msg.ts is not None and decoded.ts is not None: - assert abs(lidar_msg.ts - decoded.ts) < 1e-6, ( - f"Timestamp mismatch: {lidar_msg.ts} vs {decoded.ts}" - ) - print(f"✓ Timestamp preserved: {decoded.ts}") - - # 5. Check pointcloud properties - assert len(lidar_msg.pointcloud.points) == len(decoded.pointcloud.points), ( - "Open3D pointcloud size mismatch" - ) - - # 6. Additional detailed checks - print("✓ Original pointcloud summary:") - print(f" - Points: {len(original_points)}") - print(f" - Bounds: {original_points.min(axis=0)} to {original_points.max(axis=0)}") - print(f" - Mean: {original_points.mean(axis=0)}") - - print("✓ Decoded pointcloud summary:") - print(f" - Points: {len(decoded_points)}") - print(f" - Bounds: {decoded_points.min(axis=0)} to {decoded_points.max(axis=0)}") - print(f" - Mean: {decoded_points.mean(axis=0)}") - - print("✓ LCM encode/decode test passed - all properties preserved!") diff --git a/build/lib/dimos/msgs/sensor_msgs/test_image.py b/build/lib/dimos/msgs/sensor_msgs/test_image.py deleted file mode 100644 index 8e4e0a413f..0000000000 --- a/build/lib/dimos/msgs/sensor_msgs/test_image.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest - -from dimos.msgs.sensor_msgs.Image import Image, ImageFormat -from dimos.utils.data import get_data - - -@pytest.fixture -def img(): - image_file_path = get_data("cafe.jpg") - return Image.from_file(str(image_file_path)) - - -def test_file_load(img: Image): - assert isinstance(img.data, np.ndarray) - assert img.width == 1024 - assert img.height == 771 - assert img.channels == 3 - assert img.shape == (771, 1024, 3) - assert img.data.dtype == np.uint8 - assert img.format == ImageFormat.BGR - assert img.frame_id == "" - assert isinstance(img.ts, float) - assert img.ts > 0 - assert img.data.flags["C_CONTIGUOUS"] - - -def test_lcm_encode_decode(img: Image): - binary_msg = img.lcm_encode() - decoded_img = Image.lcm_decode(binary_msg) - - assert isinstance(decoded_img, Image) - assert decoded_img is not img - assert decoded_img == img - - -def test_rgb_bgr_conversion(img: Image): - rgb = img.to_rgb() - assert not rgb == img - assert rgb.to_bgr() == img - - -def test_opencv_conversion(img: Image): - ocv = img.to_opencv() - decoded_img = Image.from_opencv(ocv) - - # artificially patch timestamp - decoded_img.ts = img.ts - assert decoded_img == img diff --git a/build/lib/dimos/perception/__init__.py b/build/lib/dimos/perception/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/perception/common/__init__.py b/build/lib/dimos/perception/common/__init__.py deleted file mode 100644 index ad815d3f46..0000000000 --- a/build/lib/dimos/perception/common/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .detection2d_tracker import target2dTracker, get_tracked_results -from .cuboid_fit import * -from .ibvs import * diff --git a/build/lib/dimos/perception/common/cuboid_fit.py b/build/lib/dimos/perception/common/cuboid_fit.py deleted file mode 100644 index 9848332c06..0000000000 --- a/build/lib/dimos/perception/common/cuboid_fit.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from sklearn.decomposition import PCA -import matplotlib.pyplot as plt -import cv2 - - -def depth_to_point_cloud(depth_image, camera_matrix, subsample_factor=4): - """ - Convert depth image to point cloud using camera intrinsics. - Subsamples points to reduce density. - - Args: - depth_image: HxW depth image in meters - camera_matrix: 3x3 camera intrinsic matrix - subsample_factor: Factor to subsample points (higher = fewer points) - - Returns: - Nx3 array of 3D points - """ - # Get focal length and principal point from camera matrix - fx = camera_matrix[0, 0] - fy = camera_matrix[1, 1] - cx = camera_matrix[0, 2] - cy = camera_matrix[1, 2] - - # Create pixel coordinate grid - rows, cols = depth_image.shape - x_grid, y_grid = np.meshgrid( - np.arange(0, cols, subsample_factor), np.arange(0, rows, subsample_factor) - ) - - # Flatten grid and depth - x = x_grid.flatten() - y = y_grid.flatten() - z = depth_image[y_grid, x_grid].flatten() - - # Remove points with invalid depth - valid = z > 0 - x = x[valid] - y = y[valid] - z = z[valid] - - # Convert to 3D points - X = (x - cx) * z / fx - Y = (y - cy) * z / fy - Z = z - - return np.column_stack([X, Y, Z]) - - -def fit_cuboid(points, n_iterations=5, inlier_thresh=2.0): - """ - Fit a cuboid to a point cloud using iteratively refined PCA. - - Args: - points: Nx3 array of points - n_iterations: Number of refinement iterations - inlier_thresh: Threshold for inlier detection in standard deviations - - Returns: - dict containing: - - center: 3D center point - - dimensions: 3D dimensions - - rotation: 3x3 rotation matrix - - error: fitting error - """ - points = np.asarray(points) - if len(points) < 4: - return None - - # Initial center estimate using median for robustness - best_error = float("inf") - best_params = None - center = np.median(points, axis=0) - current_points = points - center - - for iteration in range(n_iterations): - if len(current_points) < 4: # Need at least 4 points for PCA - break - - # Perform PCA - pca = PCA(n_components=3) - pca.fit(current_points) - - # Get rotation matrix from PCA - rotation = pca.components_ - - # Transform points to PCA space - local_points = current_points @ rotation.T - - # Initialize mask for this iteration - inlier_mask = np.ones(len(current_points), dtype=bool) - dimensions = np.zeros(3) - - # Filter points along each dimension - for dim in range(3): - points_1d = local_points[inlier_mask, dim] - if len(points_1d) < 4: - break - - median = np.median(points_1d) - mad = np.median(np.abs(points_1d - median)) - sigma = mad * 1.4826 # Convert MAD to standard deviation estimate - - # Avoid issues with constant values - if sigma < 1e-6: - continue - - # Update mask for this dimension - dim_inliers = np.abs(points_1d - median) < (inlier_thresh * sigma) - inlier_mask[inlier_mask] = dim_inliers - - # Calculate dimension based on robust statistics - valid_points = points_1d[dim_inliers] - if len(valid_points) > 0: - dimensions[dim] = np.max(valid_points) - np.min(valid_points) - - # Skip if we don't have enough inliers - if np.sum(inlier_mask) < 4: - continue - - # Calculate error for this iteration - # Mean squared distance from points to cuboid surface - half_dims = dimensions / 2 - dx = np.abs(local_points[:, 0]) - half_dims[0] - dy = np.abs(local_points[:, 1]) - half_dims[1] - dz = np.abs(local_points[:, 2]) - half_dims[2] - - outside_dist = np.sqrt( - np.maximum(dx, 0) ** 2 + np.maximum(dy, 0) ** 2 + np.maximum(dz, 0) ** 2 - ) - inside_dist = np.minimum(np.maximum(np.maximum(dx, dy), dz), 0) - distances = outside_dist + inside_dist - error = np.mean(distances**2) - - if error < best_error: - best_error = error - best_params = { - "center": center, - "rotation": rotation, - "dimensions": dimensions, - "error": error, - } - - # Update points for next iteration - current_points = current_points[inlier_mask] - - return best_params - - -def compute_fitting_error(local_points, dimensions): - """Compute mean squared distance from points to cuboid surface.""" - half_dims = dimensions / 2 - dx = np.abs(local_points[:, 0]) - half_dims[0] - dy = np.abs(local_points[:, 1]) - half_dims[1] - dz = np.abs(local_points[:, 2]) - half_dims[2] - - outside_dist = np.sqrt(np.maximum(dx, 0) ** 2 + np.maximum(dy, 0) ** 2 + np.maximum(dz, 0) ** 2) - inside_dist = np.minimum(np.maximum(np.maximum(dx, dy), dz), 0) - - distances = outside_dist + inside_dist - return np.mean(distances**2) - - -def get_cuboid_corners(center, dimensions, rotation): - """Get the 8 corners of a cuboid.""" - half_dims = dimensions / 2 - corners_local = ( - np.array( - [ - [-1, -1, -1], # 0: left bottom back - [-1, -1, 1], # 1: left bottom front - [-1, 1, -1], # 2: left top back - [-1, 1, 1], # 3: left top front - [1, -1, -1], # 4: right bottom back - [1, -1, 1], # 5: right bottom front - [1, 1, -1], # 6: right top back - [1, 1, 1], # 7: right top front - ] - ) - * half_dims - ) - - return corners_local @ rotation + center - - -def visualize_fit(image, cuboid_params, camera_matrix, R=None, t=None): - """ - Draw the fitted cuboid on the image. - """ - # Get corners in world coordinates - corners = get_cuboid_corners( - cuboid_params["center"], cuboid_params["dimensions"], cuboid_params["rotation"] - ) - - # Transform corners if R and t are provided - if R is not None and t is not None: - corners = (R @ corners.T).T + t - - # Project corners to image space - corners_img = ( - cv2.projectPoints( - corners, - np.zeros(3), - np.zeros(3), # Already in camera frame - camera_matrix, - None, - )[0] - .reshape(-1, 2) - .astype(int) - ) - - # Define edges for visualization - edges = [ - # Bottom face - (0, 1), - (1, 5), - (5, 4), - (4, 0), - # Top face - (2, 3), - (3, 7), - (7, 6), - (6, 2), - # Vertical edges - (0, 2), - (1, 3), - (5, 7), - (4, 6), - ] - - # Draw edges - vis_img = image.copy() - for i, j in edges: - cv2.line(vis_img, tuple(corners_img[i]), tuple(corners_img[j]), (0, 255, 0), 2) - - # Add text with dimensions - dims = cuboid_params["dimensions"] - dim_text = f"Dims: {dims[0]:.3f} x {dims[1]:.3f} x {dims[2]:.3f}" - cv2.putText(vis_img, dim_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) - - return vis_img - - -def plot_3d_fit(points, cuboid_params, title="3D Cuboid Fit"): - """Plot points and fitted cuboid in 3D.""" - fig = plt.figure(figsize=(10, 10)) - ax = fig.add_subplot(111, projection="3d") - - # Plot points - ax.scatter( - points[:, 0], points[:, 1], points[:, 2], c="b", marker=".", alpha=0.1, label="Points" - ) - - # Plot fitted cuboid - corners = get_cuboid_corners( - cuboid_params["center"], cuboid_params["dimensions"], cuboid_params["rotation"] - ) - - # Define edges - edges = [ - # Bottom face - (0, 1), - (1, 5), - (5, 4), - (4, 0), - # Top face - (2, 3), - (3, 7), - (7, 6), - (6, 2), - # Vertical edges - (0, 2), - (1, 3), - (5, 7), - (4, 6), - ] - - # Plot edges - for i, j in edges: - ax.plot3D( - [corners[i, 0], corners[j, 0]], - [corners[i, 1], corners[j, 1]], - [corners[i, 2], corners[j, 2]], - "r-", - ) - - # Set labels and title - ax.set_xlabel("X") - ax.set_ylabel("Y") - ax.set_zlabel("Z") - ax.set_title(title) - - # Make scaling uniform - all_points = np.vstack([points, corners]) - max_range = ( - np.array( - [ - all_points[:, 0].max() - all_points[:, 0].min(), - all_points[:, 1].max() - all_points[:, 1].min(), - all_points[:, 2].max() - all_points[:, 2].min(), - ] - ).max() - / 2.0 - ) - - mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5 - mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5 - mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5 - - ax.set_xlim(mid_x - max_range, mid_x + max_range) - ax.set_ylim(mid_y - max_range, mid_y + max_range) - ax.set_zlim(mid_z - max_range, mid_z + max_range) - - ax.set_box_aspect([1, 1, 1]) - plt.legend() - return fig, ax diff --git a/build/lib/dimos/perception/common/detection2d_tracker.py b/build/lib/dimos/perception/common/detection2d_tracker.py deleted file mode 100644 index 2e4582cc00..0000000000 --- a/build/lib/dimos/perception/common/detection2d_tracker.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from collections import deque - - -def compute_iou(bbox1, bbox2): - """ - Compute Intersection over Union (IoU) of two bounding boxes. - Each bbox is [x1, y1, x2, y2]. - """ - x1 = max(bbox1[0], bbox2[0]) - y1 = max(bbox1[1], bbox2[1]) - x2 = min(bbox1[2], bbox2[2]) - y2 = min(bbox1[3], bbox2[3]) - - inter_area = max(0, x2 - x1) * max(0, y2 - y1) - area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) - area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) - - union_area = area1 + area2 - inter_area - if union_area == 0: - return 0 - return inter_area / union_area - - -def get_tracked_results(tracked_targets): - """ - Extract tracked results from a list of target2d objects. - - Args: - tracked_targets (list[target2d]): List of target2d objects (published targets) - returned by the tracker's update() function. - - Returns: - tuple: (tracked_masks, tracked_bboxes, tracked_track_ids, tracked_probs, tracked_names) - where each is a list of the corresponding attribute from each target. - """ - tracked_masks = [] - tracked_bboxes = [] - tracked_track_ids = [] - tracked_probs = [] - tracked_names = [] - - for target in tracked_targets: - # Extract the latest values stored in each target. - tracked_masks.append(target.latest_mask) - tracked_bboxes.append(target.latest_bbox) - # Here we use the most recent detection's track ID. - tracked_track_ids.append(target.target_id) - # Use the latest probability from the history. - tracked_probs.append(target.score) - # Use the stored name (if any). If not available, you can use a default value. - tracked_names.append(target.name) - - return tracked_masks, tracked_bboxes, tracked_track_ids, tracked_probs, tracked_names - - -class target2d: - """ - Represents a tracked 2D target. - Stores the latest bounding box and mask along with a short history of track IDs, - detection probabilities, and computed texture values. - """ - - def __init__( - self, - initial_mask, - initial_bbox, - track_id, - prob, - name, - texture_value, - target_id, - history_size=10, - ): - """ - Args: - initial_mask (torch.Tensor): Latest segmentation mask. - initial_bbox (list): Bounding box in [x1, y1, x2, y2] format. - track_id (int): Detection’s track ID (may be -1 if not provided). - prob (float): Detection probability. - name (str): Object class name. - texture_value (float): Computed average texture value for this detection. - target_id (int): Unique identifier assigned by the tracker. - history_size (int): Maximum number of frames to keep in the history. - """ - self.target_id = target_id - self.latest_mask = initial_mask - self.latest_bbox = initial_bbox - self.name = name - self.score = 1.0 - - self.track_id = track_id - self.probs_history = deque(maxlen=history_size) - self.texture_history = deque(maxlen=history_size) - - self.frame_count = deque(maxlen=history_size) # Total frames this target has been seen. - self.missed_frames = 0 # Consecutive frames when no detection was assigned. - self.history_size = history_size - - def update(self, mask, bbox, track_id, prob, name, texture_value): - """ - Update the target with a new detection. - """ - self.latest_mask = mask - self.latest_bbox = bbox - self.name = name - - self.track_id = track_id - self.probs_history.append(prob) - self.texture_history.append(texture_value) - - self.frame_count.append(1) - self.missed_frames = 0 - - def mark_missed(self): - """ - Increment the count of consecutive frames where this target was not updated. - """ - self.missed_frames += 1 - self.frame_count.append(0) - - def compute_score( - self, - frame_shape, - min_area_ratio, - max_area_ratio, - texture_range=(0.0, 1.0), - border_safe_distance=50, - weights=None, - ): - """ - Compute a combined score for the target based on several factors. - - Factors: - - **Detection probability:** Average over recent frames. - - **Temporal stability:** How consistently the target has appeared. - - **Texture quality:** Normalized using the provided min and max values. - - **Border proximity:** Computed from the minimum distance from the bbox to the frame edges. - - **Size:** How the object's area (relative to the frame) compares to acceptable bounds. - - Args: - frame_shape (tuple): (height, width) of the frame. - min_area_ratio (float): Minimum acceptable ratio (bbox area / frame area). - max_area_ratio (float): Maximum acceptable ratio. - texture_range (tuple): (min_texture, max_texture) expected values. - border_safe_distance (float): Distance (in pixels) considered safe from the border. - weights (dict): Weights for each component. Expected keys: - 'prob', 'temporal', 'texture', 'border', and 'size'. - - Returns: - float: The combined (normalized) score in the range [0, 1]. - """ - # Default weights if none provided. - if weights is None: - weights = {"prob": 1.0, "temporal": 1.0, "texture": 1.0, "border": 1.0, "size": 1.0} - - h, w = frame_shape - x1, y1, x2, y2 = self.latest_bbox - bbox_area = (x2 - x1) * (y2 - y1) - frame_area = w * h - area_ratio = bbox_area / frame_area - - # Detection probability factor. - avg_prob = np.mean(self.probs_history) - # Temporal stability factor: normalized by history size. - temporal_stability = np.mean(self.frame_count) - # Texture factor: normalize average texture using the provided range. - avg_texture = np.mean(self.texture_history) if self.texture_history else 0.0 - min_texture, max_texture = texture_range - if max_texture == min_texture: - normalized_texture = avg_texture - else: - normalized_texture = (avg_texture - min_texture) / (max_texture - min_texture) - normalized_texture = max(0.0, min(normalized_texture, 1.0)) - - # Border factor: compute the minimum distance from the bbox to any frame edge. - left_dist = x1 - top_dist = y1 - right_dist = w - x2 - min_border_dist = min(left_dist, top_dist, right_dist) - # Normalize the border distance: full score (1.0) if at least border_safe_distance away. - border_factor = min(1.0, min_border_dist / border_safe_distance) - - # Size factor: penalize objects that are too small or too big. - if area_ratio < min_area_ratio: - size_factor = area_ratio / min_area_ratio - elif area_ratio > max_area_ratio: - # Here we compute a linear penalty if the area exceeds max_area_ratio. - if 1 - max_area_ratio > 0: - size_factor = max(0, (1 - area_ratio) / (1 - max_area_ratio)) - else: - size_factor = 0.0 - else: - size_factor = 1.0 - - # Combine factors using a weighted sum (each factor is assumed in [0, 1]). - w_prob = weights.get("prob", 1.0) - w_temporal = weights.get("temporal", 1.0) - w_texture = weights.get("texture", 1.0) - w_border = weights.get("border", 1.0) - w_size = weights.get("size", 1.0) - total_weight = w_prob + w_temporal + w_texture + w_border + w_size - - # print(f"track_id: {self.target_id}, avg_prob: {avg_prob:.2f}, temporal_stability: {temporal_stability:.2f}, normalized_texture: {normalized_texture:.2f}, border_factor: {border_factor:.2f}, size_factor: {size_factor:.2f}") - - final_score = ( - w_prob * avg_prob - + w_temporal * temporal_stability - + w_texture * normalized_texture - + w_border * border_factor - + w_size * size_factor - ) / total_weight - - self.score = final_score - - return final_score - - -class target2dTracker: - """ - Tracker that maintains a history of targets across frames. - New segmentation detections (frame, masks, bboxes, track_ids, probabilities, - and computed texture values) are matched to existing targets or used to create new ones. - - The tracker uses a scoring system that incorporates: - - **Detection probability** - - **Temporal stability** - - **Texture quality** (normalized within a specified range) - - **Proximity to image borders** (a continuous penalty based on the distance) - - **Object size** relative to the frame - - Targets are published if their score exceeds the start threshold and are removed if their score - falls below the stop threshold or if they are missed for too many consecutive frames. - """ - - def __init__( - self, - history_size=10, - score_threshold_start=0.5, - score_threshold_stop=0.3, - min_frame_count=10, - max_missed_frames=3, - min_area_ratio=0.001, - max_area_ratio=0.1, - texture_range=(0.0, 1.0), - border_safe_distance=50, - weights=None, - ): - """ - Args: - history_size (int): Maximum history length (number of frames) per target. - score_threshold_start (float): Minimum score for a target to be published. - score_threshold_stop (float): If a target’s score falls below this, it is removed. - min_frame_count (int): Minimum number of frames a target must be seen to be published. - max_missed_frames (int): Maximum consecutive frames a target can be missing before deletion. - min_area_ratio (float): Minimum acceptable bbox area relative to the frame. - max_area_ratio (float): Maximum acceptable bbox area relative to the frame. - texture_range (tuple): (min_texture, max_texture) expected values. - border_safe_distance (float): Distance (in pixels) considered safe from the border. - weights (dict): Weights for the scoring components (keys: 'prob', 'temporal', - 'texture', 'border', 'size'). - """ - self.history_size = history_size - self.score_threshold_start = score_threshold_start - self.score_threshold_stop = score_threshold_stop - self.min_frame_count = min_frame_count - self.max_missed_frames = max_missed_frames - self.min_area_ratio = min_area_ratio - self.max_area_ratio = max_area_ratio - self.texture_range = texture_range - self.border_safe_distance = border_safe_distance - # Default weights if none are provided. - if weights is None: - weights = {"prob": 1.0, "temporal": 1.0, "texture": 1.0, "border": 1.0, "size": 1.0} - self.weights = weights - - self.targets = {} # Dictionary mapping target_id -> target2d instance. - self.next_target_id = 0 - - def update(self, frame, masks, bboxes, track_ids, probs, names, texture_values): - """ - Update the tracker with new detections from the current frame. - - Args: - frame (np.ndarray): Current BGR frame. - masks (list[torch.Tensor]): List of segmentation masks. - bboxes (list): List of bounding boxes [x1, y1, x2, y2]. - track_ids (list): List of detection track IDs. - probs (list): List of detection probabilities. - names (list): List of class names. - texture_values (list): List of computed texture values. - - Returns: - published_targets (list[target2d]): Targets that are active and have scores above - the start threshold. - """ - updated_target_ids = set() - frame_shape = frame.shape[:2] # (height, width) - - # For each detection, try to match with an existing target. - for mask, bbox, det_tid, prob, name, texture in zip( - masks, bboxes, track_ids, probs, names, texture_values - ): - matched_target = None - - # First, try matching by detection track ID if valid. - if det_tid != -1: - for target in self.targets.values(): - if target.track_id == det_tid: - matched_target = target - break - - # Otherwise, try matching using IoU. - if matched_target is None: - best_iou = 0 - for target in self.targets.values(): - iou = compute_iou(bbox, target.latest_bbox) - if iou > 0.5 and iou > best_iou: - best_iou = iou - matched_target = target - - # Update existing target or create a new one. - if matched_target is not None: - matched_target.update(mask, bbox, det_tid, prob, name, texture) - updated_target_ids.add(matched_target.target_id) - else: - new_target = target2d( - mask, bbox, det_tid, prob, name, texture, self.next_target_id, self.history_size - ) - self.targets[self.next_target_id] = new_target - updated_target_ids.add(self.next_target_id) - self.next_target_id += 1 - - # Mark targets that were not updated. - for target_id, target in list(self.targets.items()): - if target_id not in updated_target_ids: - target.mark_missed() - if target.missed_frames > self.max_missed_frames: - del self.targets[target_id] - continue # Skip further checks for this target. - # Remove targets whose score falls below the stop threshold. - score = target.compute_score( - frame_shape, - self.min_area_ratio, - self.max_area_ratio, - texture_range=self.texture_range, - border_safe_distance=self.border_safe_distance, - weights=self.weights, - ) - if score < self.score_threshold_stop: - del self.targets[target_id] - - # Publish targets with scores above the start threshold. - published_targets = [] - for target in self.targets.values(): - score = target.compute_score( - frame_shape, - self.min_area_ratio, - self.max_area_ratio, - texture_range=self.texture_range, - border_safe_distance=self.border_safe_distance, - weights=self.weights, - ) - if ( - score >= self.score_threshold_start - and sum(target.frame_count) >= self.min_frame_count - and target.missed_frames <= 5 - ): - published_targets.append(target) - - return published_targets diff --git a/build/lib/dimos/perception/common/export_tensorrt.py b/build/lib/dimos/perception/common/export_tensorrt.py deleted file mode 100644 index 9c021eb0a0..0000000000 --- a/build/lib/dimos/perception/common/export_tensorrt.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -from ultralytics import YOLO, FastSAM - - -def parse_args(): - parser = argparse.ArgumentParser(description="Export YOLO/FastSAM models to different formats") - parser.add_argument("--model_path", type=str, required=True, help="Path to the model weights") - parser.add_argument( - "--model_type", - type=str, - choices=["yolo", "fastsam"], - required=True, - help="Type of model to export", - ) - parser.add_argument( - "--precision", - type=str, - choices=["fp32", "fp16", "int8"], - default="fp32", - help="Precision for export", - ) - parser.add_argument( - "--format", type=str, choices=["onnx", "engine"], default="onnx", help="Export format" - ) - return parser.parse_args() - - -def main(): - args = parse_args() - half = args.precision == "fp16" - int8 = args.precision == "int8" - # Load the appropriate model - if args.model_type == "yolo": - model = YOLO(args.model_path) - else: - model = FastSAM(args.model_path) - - # Export the model - model.export(format=args.format, half=half, int8=int8) - - -if __name__ == "__main__": - main() diff --git a/build/lib/dimos/perception/common/ibvs.py b/build/lib/dimos/perception/common/ibvs.py deleted file mode 100644 index d580c71b23..0000000000 --- a/build/lib/dimos/perception/common/ibvs.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - - -class PersonDistanceEstimator: - def __init__(self, K, camera_pitch, camera_height): - """ - Initialize the distance estimator using ground plane constraint. - - Args: - K: 3x3 Camera intrinsic matrix in OpenCV format - (Assumed to be already for an undistorted image) - camera_pitch: Upward pitch of the camera (in radians), in the robot frame - Positive means looking up, negative means looking down - camera_height: Height of the camera above the ground (in meters) - """ - self.K = K - self.camera_height = camera_height - - # Precompute the inverse intrinsic matrix - self.K_inv = np.linalg.inv(K) - - # Transform from camera to robot frame (z-forward to x-forward) - self.T = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) - - # Pitch rotation matrix (positive is upward) - theta = -camera_pitch # Negative since positive pitch is negative rotation about robot Y - self.R_pitch = np.array( - [[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]] - ) - - # Combined transform from camera to robot frame - self.A = self.R_pitch @ self.T - - # Store focal length and principal point for angle calculation - self.fx = K[0, 0] - self.cx = K[0, 2] - - def estimate_distance_angle(self, bbox: tuple, robot_pitch: float = None): - """ - Estimate distance and angle to person using ground plane constraint. - - Args: - bbox: tuple (x_min, y_min, x_max, y_max) - where y_max represents the feet position - robot_pitch: Current pitch of the robot body (in radians) - If provided, this will be combined with the camera's fixed pitch - - Returns: - depth: distance to person along camera's z-axis (meters) - angle: horizontal angle in camera frame (radians, positive right) - """ - x_min, _, x_max, y_max = bbox - - # Get center point of feet - u_c = (x_min + x_max) / 2.0 - v_feet = y_max - - # Create homogeneous feet point and get ray direction - p_feet = np.array([u_c, v_feet, 1.0]) - d_feet_cam = self.K_inv @ p_feet - - # If robot_pitch is provided, recalculate the transformation matrix - if robot_pitch is not None: - # Combined pitch (fixed camera pitch + current robot pitch) - total_pitch = -camera_pitch - robot_pitch # Both negated for correct rotation direction - R_total_pitch = np.array( - [ - [np.cos(total_pitch), 0, np.sin(total_pitch)], - [0, 1, 0], - [-np.sin(total_pitch), 0, np.cos(total_pitch)], - ] - ) - # Use the updated transformation matrix - A = R_total_pitch @ self.T - else: - # Use the precomputed transformation matrix - A = self.A - - # Convert ray to robot frame using appropriate transformation - d_feet_robot = A @ d_feet_cam - - # Ground plane intersection (z=0) - # camera_height + t * d_feet_robot[2] = 0 - if abs(d_feet_robot[2]) < 1e-6: - raise ValueError("Feet ray is parallel to ground plane") - - # Solve for scaling factor t - t = -self.camera_height / d_feet_robot[2] - - # Get 3D feet position in robot frame - p_feet_robot = t * d_feet_robot - - # Convert back to camera frame - p_feet_cam = self.A.T @ p_feet_robot - - # Extract depth (z-coordinate in camera frame) - depth = p_feet_cam[2] - - # Calculate horizontal angle from image center - angle = np.arctan((u_c - self.cx) / self.fx) - - return depth, angle - - -class ObjectDistanceEstimator: - """ - Estimate distance to an object using the ground plane constraint. - This class assumes the camera is mounted on a robot and uses the - camera's intrinsic parameters to estimate the distance to a detected object. - """ - - def __init__(self, K, camera_pitch, camera_height): - """ - Initialize the distance estimator using ground plane constraint. - - Args: - K: 3x3 Camera intrinsic matrix in OpenCV format - (Assumed to be already for an undistorted image) - camera_pitch: Upward pitch of the camera (in radians) - Positive means looking up, negative means looking down - camera_height: Height of the camera above the ground (in meters) - """ - self.K = K - self.camera_height = camera_height - - # Precompute the inverse intrinsic matrix - self.K_inv = np.linalg.inv(K) - - # Transform from camera to robot frame (z-forward to x-forward) - self.T = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) - - # Pitch rotation matrix (positive is upward) - theta = -camera_pitch # Negative since positive pitch is negative rotation about robot Y - self.R_pitch = np.array( - [[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]] - ) - - # Combined transform from camera to robot frame - self.A = self.R_pitch @ self.T - - # Store focal length and principal point for angle calculation - self.fx = K[0, 0] - self.fy = K[1, 1] - self.cx = K[0, 2] - self.estimated_object_size = None - - def estimate_object_size(self, bbox: tuple, distance: float): - """ - Estimate the physical size of an object based on its bbox and known distance. - - Args: - bbox: tuple (x_min, y_min, x_max, y_max) bounding box in the image - distance: Known distance to the object (in meters) - robot_pitch: Current pitch of the robot body (in radians), if any - - Returns: - estimated_size: Estimated physical height of the object (in meters) - """ - x_min, y_min, x_max, y_max = bbox - - # Calculate object height in pixels - object_height_px = y_max - y_min - - # Calculate the physical height using the known distance and focal length - estimated_size = object_height_px * distance / self.fy - self.estimated_object_size = estimated_size - - return estimated_size - - def set_estimated_object_size(self, size: float): - """ - Set the estimated object size for future distance calculations. - - Args: - size: Estimated physical size of the object (in meters) - """ - self.estimated_object_size = size - - def estimate_distance_angle(self, bbox: tuple): - """ - Estimate distance and angle to object using size-based estimation. - - Args: - bbox: tuple (x_min, y_min, x_max, y_max) - where y_max represents the bottom of the object - robot_pitch: Current pitch of the robot body (in radians) - If provided, this will be combined with the camera's fixed pitch - initial_distance: Initial distance estimate for the object (in meters) - Used to calibrate object size if not previously known - - Returns: - depth: distance to object along camera's z-axis (meters) - angle: horizontal angle in camera frame (radians, positive right) - or None, None if estimation not possible - """ - # If we don't have estimated object size and no initial distance is provided, - # we can't estimate the distance - if self.estimated_object_size is None: - return None, None - - x_min, y_min, x_max, y_max = bbox - - # Calculate center of the object for angle calculation - u_c = (x_min + x_max) / 2.0 - - # If we have an initial distance estimate and no object size yet, - # calculate and store the object size using the initial distance - object_height_px = y_max - y_min - depth = self.estimated_object_size * self.fy / object_height_px - - # Calculate horizontal angle from image center - angle = np.arctan((u_c - self.cx) / self.fx) - - return depth, angle - - -# Example usage: -if __name__ == "__main__": - # Example camera calibration - K = np.array([[600, 0, 320], [0, 600, 240], [0, 0, 1]], dtype=np.float32) - - # Camera mounted 1.2m high, pitched down 10 degrees - camera_pitch = np.deg2rad(0) # negative for downward pitch - camera_height = 1.0 # meters - - estimator = PersonDistanceEstimator(K, camera_pitch, camera_height) - object_estimator = ObjectDistanceEstimator(K, camera_pitch, camera_height) - - # Example detection - bbox = (300, 100, 380, 400) # x1, y1, x2, y2 - - depth, angle = estimator.estimate_distance_angle(bbox) - # Estimate object size based on the known distance - object_size = object_estimator.estimate_object_size(bbox, depth) - depth_obj, angle_obj = object_estimator.estimate_distance_angle(bbox) - - print(f"Estimated person depth: {depth:.2f} m") - print(f"Estimated person angle: {np.rad2deg(angle):.1f}°") - print(f"Estimated object depth: {depth_obj:.2f} m") - print(f"Estimated object angle: {np.rad2deg(angle_obj):.1f}°") - - # Shrink the bbox by 30 pixels while keeping the same center - x_min, y_min, x_max, y_max = bbox - width = x_max - x_min - height = y_max - y_min - center_x = (x_min + x_max) // 2 - center_y = (y_min + y_max) // 2 - - new_width = max(width - 20, 2) # Ensure width is at least 2 pixels - new_height = max(height - 20, 2) # Ensure height is at least 2 pixels - - x_min = center_x - new_width // 2 - x_max = center_x + new_width // 2 - y_min = center_y - new_height // 2 - y_max = center_y + new_height // 2 - - bbox = (x_min, y_min, x_max, y_max) - - # Re-estimate distance and angle with the new bbox - depth, angle = estimator.estimate_distance_angle(bbox) - depth_obj, angle_obj = object_estimator.estimate_distance_angle(bbox) - - print(f"New estimated person depth: {depth:.2f} m") - print(f"New estimated person angle: {np.rad2deg(angle):.1f}°") - print(f"New estimated object depth: {depth_obj:.2f} m") - print(f"New estimated object angle: {np.rad2deg(angle_obj):.1f}°") diff --git a/build/lib/dimos/perception/common/utils.py b/build/lib/dimos/perception/common/utils.py deleted file mode 100644 index fc50e042ad..0000000000 --- a/build/lib/dimos/perception/common/utils.py +++ /dev/null @@ -1,364 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -from typing import List, Tuple, Optional, Any -from dimos.types.manipulation import ObjectData -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger -import torch - -logger = setup_logger("dimos.perception.common.utils") - - -def colorize_depth(depth_img: np.ndarray, max_depth: float = 5.0) -> Optional[np.ndarray]: - """ - Normalize and colorize depth image using COLORMAP_JET. - - Args: - depth_img: Depth image (H, W) in meters - max_depth: Maximum depth value for normalization - - Returns: - Colorized depth image (H, W, 3) in RGB format, or None if input is None - """ - if depth_img is None: - return None - - valid_mask = np.isfinite(depth_img) & (depth_img > 0) - depth_norm = np.zeros_like(depth_img) - depth_norm[valid_mask] = np.clip(depth_img[valid_mask] / max_depth, 0, 1) - depth_colored = cv2.applyColorMap((depth_norm * 255).astype(np.uint8), cv2.COLORMAP_JET) - depth_rgb = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB) - - # Make the depth image less bright by scaling down the values - depth_rgb = (depth_rgb * 0.6).astype(np.uint8) - - return depth_rgb - - -def draw_bounding_box( - image: np.ndarray, - bbox: List[float], - color: Tuple[int, int, int] = (0, 255, 0), - thickness: int = 2, - label: Optional[str] = None, - confidence: Optional[float] = None, - object_id: Optional[int] = None, - font_scale: float = 0.6, -) -> np.ndarray: - """ - Draw a bounding box with optional label on an image. - - Args: - image: Image to draw on (H, W, 3) - bbox: Bounding box [x1, y1, x2, y2] - color: RGB color tuple for the box - thickness: Line thickness for the box - label: Optional class label - confidence: Optional confidence score - object_id: Optional object ID - font_scale: Font scale for text - - Returns: - Image with bounding box drawn - """ - x1, y1, x2, y2 = map(int, bbox) - - # Draw bounding box - cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness) - - # Create label text - text_parts = [] - if label is not None: - text_parts.append(str(label)) - if object_id is not None: - text_parts.append(f"ID: {object_id}") - if confidence is not None: - text_parts.append(f"({confidence:.2f})") - - if text_parts: - text = ", ".join(text_parts) - - # Draw text background - text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1)[0] - cv2.rectangle( - image, - (x1, y1 - text_size[1] - 5), - (x1 + text_size[0], y1), - (0, 0, 0), - -1, - ) - - # Draw text - cv2.putText( - image, - text, - (x1, y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, - font_scale, - (255, 255, 255), - 1, - ) - - return image - - -def draw_segmentation_mask( - image: np.ndarray, - mask: np.ndarray, - color: Tuple[int, int, int] = (0, 200, 200), - alpha: float = 0.5, - draw_contours: bool = True, - contour_thickness: int = 2, -) -> np.ndarray: - """ - Draw segmentation mask overlay on an image. - - Args: - image: Image to draw on (H, W, 3) - mask: Segmentation mask (H, W) - boolean or uint8 - color: RGB color for the mask - alpha: Transparency factor (0.0 = transparent, 1.0 = opaque) - draw_contours: Whether to draw mask contours - contour_thickness: Thickness of contour lines - - Returns: - Image with mask overlay drawn - """ - if mask is None: - return image - - try: - # Ensure mask is uint8 - mask = mask.astype(np.uint8) - - # Create colored mask overlay - colored_mask = np.zeros_like(image) - colored_mask[mask > 0] = color - - # Apply the mask with transparency - mask_area = mask > 0 - image[mask_area] = cv2.addWeighted( - image[mask_area], 1 - alpha, colored_mask[mask_area], alpha, 0 - ) - - # Draw mask contours if requested - if draw_contours: - contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - cv2.drawContours(image, contours, -1, color, contour_thickness) - - except Exception as e: - logger.warning(f"Error drawing segmentation mask: {e}") - - return image - - -def draw_object_detection_visualization( - image: np.ndarray, - objects: List[ObjectData], - draw_masks: bool = False, - bbox_color: Tuple[int, int, int] = (0, 255, 0), - mask_color: Tuple[int, int, int] = (0, 200, 200), - font_scale: float = 0.6, -) -> np.ndarray: - """ - Create object detection visualization with bounding boxes and optional masks. - - Args: - image: Base image to draw on (H, W, 3) - objects: List of ObjectData with detection information - draw_masks: Whether to draw segmentation masks - bbox_color: Default color for bounding boxes - mask_color: Default color for segmentation masks - font_scale: Font scale for text labels - - Returns: - Image with detection visualization - """ - viz_image = image.copy() - - for obj in objects: - try: - # Draw segmentation mask first (if enabled and available) - if draw_masks and "segmentation_mask" in obj and obj["segmentation_mask"] is not None: - viz_image = draw_segmentation_mask( - viz_image, obj["segmentation_mask"], color=mask_color, alpha=0.5 - ) - - # Draw bounding box - if "bbox" in obj and obj["bbox"] is not None: - # Use object's color if available, otherwise default - color = bbox_color - if "color" in obj and obj["color"] is not None: - obj_color = obj["color"] - if isinstance(obj_color, np.ndarray): - color = tuple(int(c) for c in obj_color) - elif isinstance(obj_color, (list, tuple)): - color = tuple(int(c) for c in obj_color[:3]) - - viz_image = draw_bounding_box( - viz_image, - obj["bbox"], - color=color, - label=obj.get("label"), - confidence=obj.get("confidence"), - object_id=obj.get("object_id"), - font_scale=font_scale, - ) - - except Exception as e: - logger.warning(f"Error drawing object visualization: {e}") - - return viz_image - - -def detection_results_to_object_data( - bboxes: List[List[float]], - track_ids: List[int], - class_ids: List[int], - confidences: List[float], - names: List[str], - masks: Optional[List[np.ndarray]] = None, - source: str = "detection", -) -> List[ObjectData]: - """ - Convert detection/segmentation results to ObjectData format. - - Args: - bboxes: List of bounding boxes [x1, y1, x2, y2] - track_ids: List of tracking IDs - class_ids: List of class indices - confidences: List of detection confidences - names: List of class names - masks: Optional list of segmentation masks - source: Source type ("detection" or "segmentation") - - Returns: - List of ObjectData dictionaries - """ - objects = [] - - for i in range(len(bboxes)): - # Calculate basic properties from bbox - bbox = bboxes[i] - width = bbox[2] - bbox[0] - height = bbox[3] - bbox[1] - center_x = bbox[0] + width / 2 - center_y = bbox[1] + height / 2 - - # Create ObjectData - object_data: ObjectData = { - "object_id": track_ids[i] if i < len(track_ids) else i, - "bbox": bbox, - "depth": -1.0, # Will be populated by depth estimation or point cloud processing - "confidence": confidences[i] if i < len(confidences) else 1.0, - "class_id": class_ids[i] if i < len(class_ids) else 0, - "label": names[i] if i < len(names) else f"{source}_object", - "movement_tolerance": 1.0, # Default to freely movable - "segmentation_mask": masks[i].cpu().numpy() - if masks and i < len(masks) and isinstance(masks[i], torch.Tensor) - else masks[i] - if masks and i < len(masks) - else None, - # Initialize 3D properties (will be populated by point cloud processing) - "position": Vector(0, 0, 0), - "rotation": Vector(0, 0, 0), - "size": { - "width": 0.0, - "height": 0.0, - "depth": 0.0, - }, - } - objects.append(object_data) - - return objects - - -def combine_object_data( - list1: List[ObjectData], list2: List[ObjectData], overlap_threshold: float = 0.8 -) -> List[ObjectData]: - """ - Combine two ObjectData lists, removing duplicates based on segmentation mask overlap. - """ - combined = list1.copy() - used_ids = set(obj.get("object_id", 0) for obj in list1) - next_id = max(used_ids) + 1 if used_ids else 1 - - for obj2 in list2: - obj_copy = obj2.copy() - - # Handle duplicate object_id - if obj_copy.get("object_id", 0) in used_ids: - obj_copy["object_id"] = next_id - next_id += 1 - used_ids.add(obj_copy["object_id"]) - - # Check mask overlap - mask2 = obj2.get("segmentation_mask") - if mask2 is None or np.sum(mask2 > 0) == 0: - combined.append(obj_copy) - continue - - mask2_area = np.sum(mask2 > 0) - is_duplicate = False - - for obj1 in list1: - mask1 = obj1.get("segmentation_mask") - if mask1 is None: - continue - - intersection = np.sum((mask1 > 0) & (mask2 > 0)) - if intersection / mask2_area >= overlap_threshold: - is_duplicate = True - break - - if not is_duplicate: - combined.append(obj_copy) - - return combined - - -def point_in_bbox(point: Tuple[int, int], bbox: List[float]) -> bool: - """ - Check if a point is inside a bounding box. - - Args: - point: (x, y) coordinates - bbox: Bounding box [x1, y1, x2, y2] - - Returns: - True if point is inside bbox - """ - x, y = point - x1, y1, x2, y2 = bbox - return x1 <= x <= x2 and y1 <= y <= y2 - - -def find_clicked_object(click_point: Tuple[int, int], objects: List[Any]) -> Optional[Any]: - """ - Find which object was clicked based on bounding boxes. - - Args: - click_point: (x, y) coordinates of mouse click - objects: List of objects with 'bbox' field - - Returns: - Clicked object or None - """ - for obj in objects: - if "bbox" in obj and point_in_bbox(click_point, obj["bbox"]): - return obj - return None diff --git a/build/lib/dimos/perception/detection2d/__init__.py b/build/lib/dimos/perception/detection2d/__init__.py deleted file mode 100644 index a43c5da6ce..0000000000 --- a/build/lib/dimos/perception/detection2d/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .utils import * -from .yolo_2d_det import * diff --git a/build/lib/dimos/perception/detection2d/detic_2d_det.py b/build/lib/dimos/perception/detection2d/detic_2d_det.py deleted file mode 100644 index fc81526ad2..0000000000 --- a/build/lib/dimos/perception/detection2d/detic_2d_det.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import os -import sys - -# Add Detic to Python path -detic_path = os.path.join(os.path.dirname(__file__), "..", "..", "models", "Detic") -if detic_path not in sys.path: - sys.path.append(detic_path) - sys.path.append(os.path.join(detic_path, "third_party/CenterNet2")) - -# PIL patch for compatibility -import PIL.Image - -if not hasattr(PIL.Image, "LINEAR") and hasattr(PIL.Image, "BILINEAR"): - PIL.Image.LINEAR = PIL.Image.BILINEAR - -# Detectron2 imports -from detectron2.config import get_cfg -from detectron2.data import MetadataCatalog - - -# Simple tracking implementation -class SimpleTracker: - """Simple IOU-based tracker implementation without external dependencies""" - - def __init__(self, iou_threshold=0.3, max_age=5): - self.iou_threshold = iou_threshold - self.max_age = max_age - self.next_id = 1 - self.tracks = {} # id -> {bbox, class_id, age, mask, etc} - - def _calculate_iou(self, bbox1, bbox2): - """Calculate IoU between two bboxes in format [x1,y1,x2,y2]""" - x1 = max(bbox1[0], bbox2[0]) - y1 = max(bbox1[1], bbox2[1]) - x2 = min(bbox1[2], bbox2[2]) - y2 = min(bbox1[3], bbox2[3]) - - if x2 < x1 or y2 < y1: - return 0.0 - - intersection = (x2 - x1) * (y2 - y1) - area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) - area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) - union = area1 + area2 - intersection - - return intersection / union if union > 0 else 0 - - def update(self, detections, masks): - """Update tracker with new detections - - Args: - detections: List of [x1,y1,x2,y2,score,class_id] - masks: List of segmentation masks corresponding to detections - - Returns: - List of [track_id, bbox, score, class_id, mask] - """ - if len(detections) == 0: - # Age existing tracks - for track_id in list(self.tracks.keys()): - self.tracks[track_id]["age"] += 1 - # Remove old tracks - if self.tracks[track_id]["age"] > self.max_age: - del self.tracks[track_id] - return [] - - # Convert to numpy for easier handling - if not isinstance(detections, np.ndarray): - detections = np.array(detections) - - result = [] - matched_indices = set() - - # Update existing tracks - for track_id, track in list(self.tracks.items()): - track["age"] += 1 - - if track["age"] > self.max_age: - del self.tracks[track_id] - continue - - # Find best matching detection for this track - best_iou = self.iou_threshold - best_idx = -1 - - for i, det in enumerate(detections): - if i in matched_indices: - continue - - # Check class match - if det[5] != track["class_id"]: - continue - - iou = self._calculate_iou(track["bbox"], det[:4]) - if iou > best_iou: - best_iou = iou - best_idx = i - - # If we found a match, update the track - if best_idx >= 0: - self.tracks[track_id]["bbox"] = detections[best_idx][:4] - self.tracks[track_id]["score"] = detections[best_idx][4] - self.tracks[track_id]["age"] = 0 - self.tracks[track_id]["mask"] = masks[best_idx] - matched_indices.add(best_idx) - - # Add to results with mask - result.append( - [ - track_id, - detections[best_idx][:4], - detections[best_idx][4], - int(detections[best_idx][5]), - self.tracks[track_id]["mask"], - ] - ) - - # Create new tracks for unmatched detections - for i, det in enumerate(detections): - if i in matched_indices: - continue - - # Create new track - new_id = self.next_id - self.next_id += 1 - - self.tracks[new_id] = { - "bbox": det[:4], - "score": det[4], - "class_id": int(det[5]), - "age": 0, - "mask": masks[i], - } - - # Add to results with mask directly from the track - result.append([new_id, det[:4], det[4], int(det[5]), masks[i]]) - - return result - - -class Detic2DDetector: - def __init__(self, model_path=None, device="cuda", vocabulary=None, threshold=0.5): - """ - Initialize the Detic detector with open vocabulary support. - - Args: - model_path (str): Path to a custom Detic model weights (optional) - device (str): Device to run inference on ('cuda' or 'cpu') - vocabulary (list): Custom vocabulary (list of class names) or 'lvis', 'objects365', 'openimages', 'coco' - threshold (float): Detection confidence threshold - """ - self.device = device - self.threshold = threshold - - # Set up Detic paths - already added to sys.path at module level - - # Import Detic modules - from centernet.config import add_centernet_config - from detic.config import add_detic_config - from detic.modeling.utils import reset_cls_test - from detic.modeling.text.text_encoder import build_text_encoder - - # Keep reference to these functions for later use - self.reset_cls_test = reset_cls_test - self.build_text_encoder = build_text_encoder - - # Setup model configuration - self.cfg = get_cfg() - add_centernet_config(self.cfg) - add_detic_config(self.cfg) - - # Use default Detic config - self.cfg.merge_from_file( - os.path.join( - detic_path, "configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml" - ) - ) - - # Set default weights if not provided - if model_path is None: - self.cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth" - else: - self.cfg.MODEL.WEIGHTS = model_path - - # Set device - if device == "cpu": - self.cfg.MODEL.DEVICE = "cpu" - - # Set detection threshold - self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold - self.cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand" - self.cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True - - # Built-in datasets for Detic - use absolute paths with detic_path - self.builtin_datasets = { - "lvis": { - "metadata": "lvis_v1_val", - "classifier": os.path.join( - detic_path, "datasets/metadata/lvis_v1_clip_a+cname.npy" - ), - }, - "objects365": { - "metadata": "objects365_v2_val", - "classifier": os.path.join( - detic_path, "datasets/metadata/o365_clip_a+cnamefix.npy" - ), - }, - "openimages": { - "metadata": "oid_val_expanded", - "classifier": os.path.join(detic_path, "datasets/metadata/oid_clip_a+cname.npy"), - }, - "coco": { - "metadata": "coco_2017_val", - "classifier": os.path.join(detic_path, "datasets/metadata/coco_clip_a+cname.npy"), - }, - } - - # Override config paths to use absolute paths - self.cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = os.path.join( - detic_path, "datasets/metadata/lvis_v1_train_cat_info.json" - ) - - # Initialize model - self.predictor = None - - # Setup with initial vocabulary - vocabulary = vocabulary or "lvis" - self.setup_vocabulary(vocabulary) - - # Initialize our simple tracker - self.tracker = SimpleTracker(iou_threshold=0.5, max_age=5) - - def setup_vocabulary(self, vocabulary): - """ - Setup the model's vocabulary. - - Args: - vocabulary: Either a string ('lvis', 'objects365', 'openimages', 'coco') - or a list of class names for custom vocabulary. - """ - if self.predictor is None: - # Initialize the model - from detectron2.engine import DefaultPredictor - - self.predictor = DefaultPredictor(self.cfg) - - if isinstance(vocabulary, str) and vocabulary in self.builtin_datasets: - # Use built-in dataset - dataset = vocabulary - metadata = MetadataCatalog.get(self.builtin_datasets[dataset]["metadata"]) - classifier = self.builtin_datasets[dataset]["classifier"] - num_classes = len(metadata.thing_classes) - self.class_names = metadata.thing_classes - else: - # Use custom vocabulary - if isinstance(vocabulary, str): - # If it's a string but not a built-in dataset, treat as a file - try: - with open(vocabulary, "r") as f: - class_names = [line.strip() for line in f if line.strip()] - except: - # Default to LVIS if there's an issue - print(f"Error loading vocabulary from {vocabulary}, using LVIS") - return self.setup_vocabulary("lvis") - else: - # Assume it's a list of class names - class_names = vocabulary - - # Create classifier from text embeddings - metadata = MetadataCatalog.get("__unused") - metadata.thing_classes = class_names - self.class_names = class_names - - # Generate CLIP embeddings for custom vocabulary - classifier = self._get_clip_embeddings(class_names) - num_classes = len(class_names) - - # Reset model with new vocabulary - self.reset_cls_test(self.predictor.model, classifier, num_classes) - return self.class_names - - def _get_clip_embeddings(self, vocabulary, prompt="a "): - """ - Generate CLIP embeddings for a vocabulary list. - - Args: - vocabulary (list): List of class names - prompt (str): Prompt prefix to use for CLIP - - Returns: - torch.Tensor: Tensor of embeddings - """ - text_encoder = self.build_text_encoder(pretrain=True) - text_encoder.eval() - texts = [prompt + x for x in vocabulary] - emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() - return emb - - def process_image(self, image): - """ - Process an image and return detection results. - - Args: - image: Input image in BGR format (OpenCV) - - Returns: - tuple: (bboxes, track_ids, class_ids, confidences, names, masks) - - bboxes: list of [x1, y1, x2, y2] coordinates - - track_ids: list of tracking IDs (or -1 if no tracking) - - class_ids: list of class indices - - confidences: list of detection confidences - - names: list of class names - - masks: list of segmentation masks (numpy arrays) - """ - # Run inference with Detic - outputs = self.predictor(image) - instances = outputs["instances"].to("cpu") - - # Extract bounding boxes, classes, scores, and masks - if len(instances) == 0: - return [], [], [], [], [], [] - - boxes = instances.pred_boxes.tensor.numpy() - class_ids = instances.pred_classes.numpy() - scores = instances.scores.numpy() - masks = instances.pred_masks.numpy() - - # Convert boxes to [x1, y1, x2, y2] format - bboxes = [] - for box in boxes: - x1, y1, x2, y2 = box.tolist() - bboxes.append([x1, y1, x2, y2]) - - # Get class names - names = [self.class_names[class_id] for class_id in class_ids] - - # Apply tracking - detections = [] - filtered_masks = [] - for i, bbox in enumerate(bboxes): - if scores[i] >= self.threshold: - # Format for tracker: [x1, y1, x2, y2, score, class_id] - detections.append(bbox + [scores[i], class_ids[i]]) - filtered_masks.append(masks[i]) - - if not detections: - return [], [], [], [], [], [] - - # Update tracker with detections and correctly aligned masks - track_results = self.tracker.update(detections, filtered_masks) - - # Process tracking results - track_ids = [] - tracked_bboxes = [] - tracked_class_ids = [] - tracked_scores = [] - tracked_names = [] - tracked_masks = [] - - for track_id, bbox, score, class_id, mask in track_results: - track_ids.append(int(track_id)) - tracked_bboxes.append(bbox.tolist() if isinstance(bbox, np.ndarray) else bbox) - tracked_class_ids.append(int(class_id)) - tracked_scores.append(score) - tracked_names.append(self.class_names[int(class_id)]) - tracked_masks.append(mask) - - return ( - tracked_bboxes, - track_ids, - tracked_class_ids, - tracked_scores, - tracked_names, - tracked_masks, - ) - - def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, names): - """ - Generate visualization of detection results. - - Args: - image: Original input image - bboxes: List of bounding boxes - track_ids: List of tracking IDs - class_ids: List of class indices - confidences: List of detection confidences - names: List of class names - - Returns: - Image with visualized detections - """ - from dimos.perception.detection2d.utils import plot_results - - return plot_results(image, bboxes, track_ids, class_ids, confidences, names) - - def cleanup(self): - """Clean up resources.""" - # Nothing specific to clean up for Detic - pass diff --git a/build/lib/dimos/perception/detection2d/test_yolo_2d_det.py b/build/lib/dimos/perception/detection2d/test_yolo_2d_det.py deleted file mode 100644 index 4240625744..0000000000 --- a/build/lib/dimos/perception/detection2d/test_yolo_2d_det.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time - -import cv2 -import numpy as np -import pytest -import reactivex as rx -from reactivex import operators as ops - -from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector -from dimos.stream.video_provider import VideoProvider - - -class TestYolo2DDetector: - def test_yolo_detector_initialization(self): - """Test YOLO detector initializes correctly with default model path.""" - try: - detector = Yolo2DDetector() - assert detector is not None - assert detector.model is not None - except Exception as e: - # If the model file doesn't exist, the test should still pass with a warning - pytest.skip(f"Skipping test due to model initialization error: {e}") - - def test_yolo_detector_process_image(self): - """Test YOLO detector can process video frames and return detection results.""" - try: - # Import data inside method to avoid pytest fixture confusion - from dimos.utils.data import get_data - - detector = Yolo2DDetector() - - video_path = get_data("assets") / "trimmed_video_office.mov" - - # Create video provider and directly get a video stream observable - assert os.path.exists(video_path), f"Test video not found: {video_path}" - video_provider = VideoProvider(dev_name="test_video", video_source=video_path) - # Process more frames for thorough testing - video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) - - # Use ReactiveX operators to process the stream - def process_frame(frame): - try: - # Process frame with YOLO - bboxes, track_ids, class_ids, confidences, names = detector.process_image(frame) - print( - f"YOLO results - boxes: {(bboxes)}, tracks: {len(track_ids)}, classes: {(class_ids)}, confidences: {(confidences)}, names: {(names)}" - ) - - return { - "frame": frame, - "bboxes": bboxes, - "track_ids": track_ids, - "class_ids": class_ids, - "confidences": confidences, - "names": names, - } - except Exception as e: - print(f"Exception in process_frame: {e}") - return {} - - # Create the detection stream using pipe and map operator - detection_stream = video_stream.pipe(ops.map(process_frame)) - - # Collect results from the stream - results = [] - - frames_processed = 0 - target_frames = 10 - - def on_next(result): - nonlocal frames_processed - if not result: - return - - results.append(result) - frames_processed += 1 - - # Stop after processing target number of frames - if frames_processed >= target_frames: - subscription.dispose() - - def on_error(error): - pytest.fail(f"Error in detection stream: {error}") - - def on_completed(): - pass - - # Subscribe and wait for results - subscription = detection_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - timeout = 10.0 - start_time = time.time() - while frames_processed < target_frames and time.time() - start_time < timeout: - time.sleep(0.5) - - # Clean up subscription - subscription.dispose() - video_provider.dispose_all() - # Check that we got detection results - if len(results) == 0: - pytest.skip("Skipping test due to error: Failed to get any detection results") - - # Verify we have detection results with expected properties - assert len(results) > 0, "No detection results were received" - - # Print statistics about detections - total_detections = sum(len(r["bboxes"]) for r in results if r.get("bboxes")) - avg_detections = total_detections / len(results) if results else 0 - print(f"Total detections: {total_detections}, Average per frame: {avg_detections:.2f}") - - # Print most common detected objects - object_counts = {} - for r in results: - if r.get("names"): - for name in r["names"]: - if name: - object_counts[name] = object_counts.get(name, 0) + 1 - - if object_counts: - print("Detected objects:") - for obj, count in sorted(object_counts.items(), key=lambda x: x[1], reverse=True)[ - :5 - ]: - print(f" - {obj}: {count} times") - - # Analyze the first result - result = results[0] - - # Check that we have a frame - assert "frame" in result, "Result doesn't contain a frame" - assert isinstance(result["frame"], np.ndarray), "Frame is not a numpy array" - - # Check that detection results are valid - assert isinstance(result["bboxes"], list) - assert isinstance(result["track_ids"], list) - assert isinstance(result["class_ids"], list) - assert isinstance(result["confidences"], list) - assert isinstance(result["names"], list) - - # All result lists should be the same length - assert ( - len(result["bboxes"]) - == len(result["track_ids"]) - == len(result["class_ids"]) - == len(result["confidences"]) - == len(result["names"]) - ) - - # If we have detections, check that bbox format is valid - if result["bboxes"]: - assert len(result["bboxes"][0]) == 4, ( - "Bounding boxes should be in [x1, y1, x2, y2] format" - ) - - except Exception as e: - pytest.skip(f"Skipping test due to error: {e}") - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/build/lib/dimos/perception/detection2d/utils.py b/build/lib/dimos/perception/detection2d/utils.py deleted file mode 100644 index dbe19baf30..0000000000 --- a/build/lib/dimos/perception/detection2d/utils.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import cv2 -from dimos.types.vector import Vector -from dimos.utils.transform_utils import distance_angle_to_goal_xy - - -def filter_detections( - bboxes, - track_ids, - class_ids, - confidences, - names, - class_filter=None, - name_filter=None, - track_id_filter=None, -): - """ - Filter detection results based on class IDs, names, and/or tracking IDs. - - Args: - bboxes: List of bounding boxes [x1, y1, x2, y2] - track_ids: List of tracking IDs - class_ids: List of class indices - confidences: List of detection confidences - names: List of class names - class_filter: List/set of class IDs to keep, or None to keep all - name_filter: List/set of class names to keep, or None to keep all - track_id_filter: List/set of track IDs to keep, or None to keep all - - Returns: - tuple: (filtered_bboxes, filtered_track_ids, filtered_class_ids, - filtered_confidences, filtered_names) - """ - # Convert filters to sets for efficient lookup - if class_filter is not None: - class_filter = set(class_filter) - if name_filter is not None: - name_filter = set(name_filter) - if track_id_filter is not None: - track_id_filter = set(track_id_filter) - - # Initialize lists for filtered results - filtered_bboxes = [] - filtered_track_ids = [] - filtered_class_ids = [] - filtered_confidences = [] - filtered_names = [] - - # Filter detections - for bbox, track_id, class_id, conf, name in zip( - bboxes, track_ids, class_ids, confidences, names - ): - # Check if detection passes all specified filters - keep = True - - if class_filter is not None: - keep = keep and (class_id in class_filter) - - if name_filter is not None: - keep = keep and (name in name_filter) - - if track_id_filter is not None: - keep = keep and (track_id in track_id_filter) - - # If detection passes all filters, add it to results - if keep: - filtered_bboxes.append(bbox) - filtered_track_ids.append(track_id) - filtered_class_ids.append(class_id) - filtered_confidences.append(conf) - filtered_names.append(name) - - return ( - filtered_bboxes, - filtered_track_ids, - filtered_class_ids, - filtered_confidences, - filtered_names, - ) - - -def extract_detection_results(result, class_filter=None, name_filter=None, track_id_filter=None): - """ - Extract and optionally filter detection information from a YOLO result object. - - Args: - result: Ultralytics result object - class_filter: List/set of class IDs to keep, or None to keep all - name_filter: List/set of class names to keep, or None to keep all - track_id_filter: List/set of track IDs to keep, or None to keep all - - Returns: - tuple: (bboxes, track_ids, class_ids, confidences, names) - - bboxes: list of [x1, y1, x2, y2] coordinates - - track_ids: list of tracking IDs - - class_ids: list of class indices - - confidences: list of detection confidences - - names: list of class names - """ - bboxes = [] - track_ids = [] - class_ids = [] - confidences = [] - names = [] - - if result.boxes is None: - return bboxes, track_ids, class_ids, confidences, names - - for box in result.boxes: - # Extract bounding box coordinates - x1, y1, x2, y2 = box.xyxy[0].tolist() - - # Extract tracking ID if available - track_id = -1 - if hasattr(box, "id") and box.id is not None: - track_id = int(box.id[0].item()) - - # Extract class information - cls_idx = int(box.cls[0]) - name = result.names[cls_idx] - - # Extract confidence - conf = float(box.conf[0]) - - # Check filters before adding to results - keep = True - if class_filter is not None: - keep = keep and (cls_idx in class_filter) - if name_filter is not None: - keep = keep and (name in name_filter) - if track_id_filter is not None: - keep = keep and (track_id in track_id_filter) - - if keep: - bboxes.append([x1, y1, x2, y2]) - track_ids.append(track_id) - class_ids.append(cls_idx) - confidences.append(conf) - names.append(name) - - return bboxes, track_ids, class_ids, confidences, names - - -def plot_results(image, bboxes, track_ids, class_ids, confidences, names, alpha=0.5): - """ - Draw bounding boxes and labels on the image. - - Args: - image: Original input image - bboxes: List of bounding boxes [x1, y1, x2, y2] - track_ids: List of tracking IDs - class_ids: List of class indices - confidences: List of detection confidences - names: List of class names - alpha: Transparency of the overlay - - Returns: - Image with visualized detections - """ - vis_img = image.copy() - - for bbox, track_id, conf, name in zip(bboxes, track_ids, confidences, names): - # Generate consistent color based on track_id or class name - if track_id != -1: - np.random.seed(track_id) - else: - np.random.seed(hash(name) % 100000) - color = np.random.randint(0, 255, (3,), dtype=np.uint8) - np.random.seed(None) - - # Draw bounding box - x1, y1, x2, y2 = map(int, bbox) - cv2.rectangle(vis_img, (x1, y1), (x2, y2), color.tolist(), 2) - - # Prepare label text - if track_id != -1: - label = f"ID:{track_id} {name} {conf:.2f}" - else: - label = f"{name} {conf:.2f}" - - # Calculate text size for background rectangle - (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) - - # Draw background rectangle for text - cv2.rectangle(vis_img, (x1, y1 - text_h - 8), (x1 + text_w + 4, y1), color.tolist(), -1) - - # Draw text with white color for better visibility - cv2.putText( - vis_img, label, (x1 + 2, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1 - ) - - return vis_img - - -def calculate_depth_from_bbox(depth_map, bbox): - """ - Calculate the average depth of an object within a bounding box. - Uses the 25th to 75th percentile range to filter outliers. - - Args: - depth_map: The depth map - bbox: Bounding box in format [x1, y1, x2, y2] - - Returns: - float: Average depth in meters, or None if depth estimation fails - """ - try: - # Extract region of interest from the depth map - x1, y1, x2, y2 = map(int, bbox) - roi_depth = depth_map[y1:y2, x1:x2] - - if roi_depth.size == 0: - return None - - # Calculate 25th and 75th percentile to filter outliers - p25 = np.percentile(roi_depth, 25) - p75 = np.percentile(roi_depth, 75) - - # Filter depth values within this range - filtered_depth = roi_depth[(roi_depth >= p25) & (roi_depth <= p75)] - - # Calculate average depth (convert to meters) - if filtered_depth.size > 0: - return np.mean(filtered_depth) / 1000.0 # Convert mm to meters - - return None - except Exception as e: - print(f"Error calculating depth from bbox: {e}") - return None - - -def calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics): - """ - Calculate distance and angle to object center based on bbox and depth. - - Args: - bbox: Bounding box [x1, y1, x2, y2] - depth: Depth value in meters - camera_intrinsics: List [fx, fy, cx, cy] with camera parameters - - Returns: - tuple: (distance, angle) in meters and radians - """ - if camera_intrinsics is None: - raise ValueError("Camera intrinsics required for distance calculation") - - # Extract camera parameters - fx, fy, cx, cy = camera_intrinsics - - # Calculate center of bounding box in pixels - x1, y1, x2, y2 = bbox - center_x = (x1 + x2) / 2 - center_y = (y1 + y2) / 2 - - # Calculate normalized image coordinates - x_norm = (center_x - cx) / fx - - # Calculate angle (positive to the right) - angle = np.arctan(x_norm) - - # Calculate distance using depth and angle - distance = depth / np.cos(angle) if np.cos(angle) != 0 else depth - - return distance, angle - - -def calculate_object_size_from_bbox(bbox, depth, camera_intrinsics): - """ - Estimate physical width and height of object in meters. - - Args: - bbox: Bounding box [x1, y1, x2, y2] - depth: Depth value in meters - camera_intrinsics: List [fx, fy, cx, cy] with camera parameters - - Returns: - tuple: (width, height) in meters - """ - if camera_intrinsics is None: - return 0.0, 0.0 - - fx, fy, _, _ = camera_intrinsics - - # Calculate bbox dimensions in pixels - x1, y1, x2, y2 = bbox - width_px = x2 - x1 - height_px = y2 - y1 - - # Convert to meters using similar triangles and depth - width_m = (width_px * depth) / fx - height_m = (height_px * depth) / fy - - return width_m, height_m - - -def calculate_position_rotation_from_bbox(bbox, depth, camera_intrinsics): - """ - Calculate position (xyz) and rotation (roll, pitch, yaw) for an object - based on its bounding box and depth. - - Args: - bbox: Bounding box [x1, y1, x2, y2] - depth: Depth value in meters - camera_intrinsics: List [fx, fy, cx, cy] with camera parameters - - Returns: - Vector: position - Vector: rotation - """ - # Calculate distance and angle to object - distance, angle = calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics) - - # Convert distance and angle to x,y coordinates (in camera frame) - # Note: We negate the angle since positive angle means object is to the right, - # but we want positive y to be to the left in the standard coordinate system - x, y = distance_angle_to_goal_xy(distance, -angle) - - # For now, rotation is only in yaw (around z-axis) - # We can use the negative of the angle as an estimate of the object's yaw - # assuming objects tend to face the camera - position = Vector([x, y, 0.0]) - rotation = Vector([0.0, 0.0, -angle]) - - return position, rotation diff --git a/build/lib/dimos/perception/detection2d/yolo_2d_det.py b/build/lib/dimos/perception/detection2d/yolo_2d_det.py deleted file mode 100644 index b9b04165cd..0000000000 --- a/build/lib/dimos/perception/detection2d/yolo_2d_det.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import cv2 -import onnxruntime -from ultralytics import YOLO - -from dimos.perception.detection2d.utils import ( - extract_detection_results, - filter_detections, - plot_results, -) -from dimos.utils.data import get_data -from dimos.utils.gpu_utils import is_cuda_available -from dimos.utils.logging_config import setup_logger -from dimos.utils.path_utils import get_project_root - -logger = setup_logger("dimos.perception.detection2d.yolo_2d_det") - - -class Yolo2DDetector: - def __init__(self, model_path="models_yolo", model_name="yolo11n.onnx", device="cpu"): - """ - Initialize the YOLO detector. - - Args: - model_path (str): Path to the YOLO model weights in tests/data LFS directory - model_name (str): Name of the YOLO model weights file - device (str): Device to run inference on ('cuda' or 'cpu') - """ - self.device = device - self.model = YOLO(get_data(model_path) / model_name) - - module_dir = os.path.dirname(__file__) - self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") - if is_cuda_available(): - if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18 - onnxruntime.preload_dlls(cuda=True, cudnn=True) - self.device = "cuda" - logger.info("Using CUDA for YOLO 2d detector") - else: - self.device = "cpu" - logger.info("Using CPU for YOLO 2d detector") - - def process_image(self, image): - """ - Process an image and return detection results. - - Args: - image: Input image in BGR format (OpenCV) - - Returns: - tuple: (bboxes, track_ids, class_ids, confidences, names) - - bboxes: list of [x1, y1, x2, y2] coordinates - - track_ids: list of tracking IDs (or -1 if no tracking) - - class_ids: list of class indices - - confidences: list of detection confidences - - names: list of class names - """ - results = self.model.track( - source=image, - device=self.device, - conf=0.5, - iou=0.6, - persist=True, - verbose=False, - tracker=self.tracker_config, - ) - - if len(results) > 0: - # Extract detection results - bboxes, track_ids, class_ids, confidences, names = extract_detection_results(results[0]) - return bboxes, track_ids, class_ids, confidences, names - - return [], [], [], [], [] - - def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, names): - """ - Generate visualization of detection results. - - Args: - image: Original input image - bboxes: List of bounding boxes - track_ids: List of tracking IDs - class_ids: List of class indices - confidences: List of detection confidences - names: List of class names - - Returns: - Image with visualized detections - """ - return plot_results(image, bboxes, track_ids, class_ids, confidences, names) - - -def main(): - """Example usage of the Yolo2DDetector class.""" - # Initialize video capture - cap = cv2.VideoCapture(0) - - # Initialize detector - detector = Yolo2DDetector() - - enable_person_filter = True - - try: - while cap.isOpened(): - ret, frame = cap.read() - if not ret: - break - - # Process frame - bboxes, track_ids, class_ids, confidences, names = detector.process_image(frame) - - # Apply person filtering if enabled - if enable_person_filter and len(bboxes) > 0: - # Person is class_id 0 in COCO dataset - bboxes, track_ids, class_ids, confidences, names = filter_detections( - bboxes, - track_ids, - class_ids, - confidences, - names, - class_filter=[0], # 0 is the class_id for person - name_filter=["person"], - ) - - # Visualize results - if len(bboxes) > 0: - frame = detector.visualize_results( - frame, bboxes, track_ids, class_ids, confidences, names - ) - - # Display results - cv2.imshow("YOLO Detection", frame) - if cv2.waitKey(1) & 0xFF == ord("q"): - break - - finally: - cap.release() - cv2.destroyAllWindows() - - -if __name__ == "__main__": - main() diff --git a/build/lib/dimos/perception/grasp_generation/__init__.py b/build/lib/dimos/perception/grasp_generation/__init__.py deleted file mode 100644 index 16281fe0b6..0000000000 --- a/build/lib/dimos/perception/grasp_generation/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import * diff --git a/build/lib/dimos/perception/grasp_generation/grasp_generation.py b/build/lib/dimos/perception/grasp_generation/grasp_generation.py deleted file mode 100644 index 947a3bcd96..0000000000 --- a/build/lib/dimos/perception/grasp_generation/grasp_generation.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -AnyGrasp-based grasp generation for manipulation pipeline. -""" - -import asyncio -import numpy as np -import open3d as o3d -from typing import Dict, List, Optional - -from dimos.types.manipulation import ObjectData -from dimos.utils.logging_config import setup_logger -from dimos.perception.grasp_generation.utils import parse_anygrasp_results - -logger = setup_logger("dimos.perception.grasp_generation") - - -class AnyGraspGenerator: - """ - AnyGrasp-based grasp generator using WebSocket communication. - """ - - def __init__(self, server_url: str): - """ - Initialize AnyGrasp generator. - - Args: - server_url: WebSocket URL for AnyGrasp server - """ - self.server_url = server_url - logger.info(f"Initialized AnyGrasp generator with server: {server_url}") - - def generate_grasps_from_objects( - self, objects: List[ObjectData], full_pcd: o3d.geometry.PointCloud - ) -> List[Dict]: - """ - Generate grasps from ObjectData objects using AnyGrasp. - - Args: - objects: List of ObjectData with point clouds - full_pcd: Open3D point cloud of full scene - - Returns: - Parsed grasp results as list of dictionaries - """ - try: - # Combine all point clouds - all_points = [] - all_colors = [] - valid_objects = 0 - - for obj in objects: - if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: - continue - - points = obj["point_cloud_numpy"] - if not isinstance(points, np.ndarray) or points.size == 0: - continue - - if len(points.shape) != 2 or points.shape[1] != 3: - continue - - colors = None - if "colors_numpy" in obj and obj["colors_numpy"] is not None: - colors = obj["colors_numpy"] - if isinstance(colors, np.ndarray) and colors.size > 0: - if ( - colors.shape[0] != points.shape[0] - or len(colors.shape) != 2 - or colors.shape[1] != 3 - ): - colors = None - - all_points.append(points) - if colors is not None: - all_colors.append(colors) - valid_objects += 1 - - if not all_points: - return [] - - # Combine point clouds - combined_points = np.vstack(all_points) - combined_colors = None - if len(all_colors) == valid_objects and len(all_colors) > 0: - combined_colors = np.vstack(all_colors) - - # Send grasp request - grasps = self._send_grasp_request_sync(combined_points, combined_colors) - - if not grasps: - return [] - - # Parse and return results in list of dictionaries format - return parse_anygrasp_results(grasps) - - except Exception as e: - logger.error(f"AnyGrasp generation failed: {e}") - return [] - - def _send_grasp_request_sync( - self, points: np.ndarray, colors: Optional[np.ndarray] - ) -> Optional[List[Dict]]: - """Send synchronous grasp request to AnyGrasp server.""" - - try: - # Prepare colors - colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 - - # Ensure correct data types - points = points.astype(np.float32) - colors = colors.astype(np.float32) - - # Validate ranges - if np.any(np.isnan(points)) or np.any(np.isinf(points)): - logger.error("Points contain NaN or Inf values") - return None - if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): - logger.error("Colors contain NaN or Inf values") - return None - - colors = np.clip(colors, 0.0, 1.0) - - # Run async request in sync context - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete(self._async_grasp_request(points, colors)) - return result - finally: - loop.close() - - except Exception as e: - logger.error(f"Error in synchronous grasp request: {e}") - return None - - async def _async_grasp_request( - self, points: np.ndarray, colors: np.ndarray - ) -> Optional[List[Dict]]: - """Async grasp request helper.""" - import json - import websockets - - try: - async with websockets.connect(self.server_url) as websocket: - request = { - "points": points.tolist(), - "colors": colors.tolist(), - "lims": [-1.0, 1.0, -1.0, 1.0, 0.0, 2.0], - } - - await websocket.send(json.dumps(request)) - response = await websocket.recv() - grasps = json.loads(response) - - if isinstance(grasps, dict) and "error" in grasps: - logger.error(f"Server returned error: {grasps['error']}") - return None - elif isinstance(grasps, (int, float)) and grasps == 0: - return None - elif not isinstance(grasps, list): - logger.error(f"Server returned unexpected response type: {type(grasps)}") - return None - elif len(grasps) == 0: - return None - - return self._convert_grasp_format(grasps) - - except Exception as e: - logger.error(f"Async grasp request failed: {e}") - return None - - def _convert_grasp_format(self, anygrasp_grasps: List[dict]) -> List[dict]: - """Convert AnyGrasp format to visualization format.""" - converted = [] - - for i, grasp in enumerate(anygrasp_grasps): - rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) - euler_angles = self._rotation_matrix_to_euler(rotation_matrix) - - converted_grasp = { - "id": f"grasp_{i}", - "score": grasp.get("score", 0.0), - "width": grasp.get("width", 0.0), - "height": grasp.get("height", 0.0), - "depth": grasp.get("depth", 0.0), - "translation": grasp.get("translation", [0, 0, 0]), - "rotation_matrix": rotation_matrix.tolist(), - "euler_angles": euler_angles, - } - converted.append(converted_grasp) - - converted.sort(key=lambda x: x["score"], reverse=True) - return converted - - def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: - """Convert rotation matrix to Euler angles (in radians).""" - sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) - - singular = sy < 1e-6 - - if not singular: - x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) - y = np.arctan2(-rotation_matrix[2, 0], sy) - z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) - else: - x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) - y = np.arctan2(-rotation_matrix[2, 0], sy) - z = 0 - - return {"roll": x, "pitch": y, "yaw": z} - - def cleanup(self): - """Clean up resources.""" - logger.info("AnyGrasp generator cleaned up") diff --git a/build/lib/dimos/perception/grasp_generation/utils.py b/build/lib/dimos/perception/grasp_generation/utils.py deleted file mode 100644 index ba461f9d90..0000000000 --- a/build/lib/dimos/perception/grasp_generation/utils.py +++ /dev/null @@ -1,621 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for grasp generation and visualization.""" - -import numpy as np -import open3d as o3d -import cv2 -from typing import List, Dict, Tuple, Optional, Union - - -def project_3d_points_to_2d( - points_3d: np.ndarray, camera_intrinsics: Union[List[float], np.ndarray] -) -> np.ndarray: - """ - Project 3D points to 2D image coordinates using camera intrinsics. - - Args: - points_3d: Nx3 array of 3D points (X, Y, Z) - camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix - - Returns: - Nx2 array of 2D image coordinates (u, v) - """ - if len(points_3d) == 0: - return np.zeros((0, 2), dtype=np.int32) - - # Filter out points with zero or negative depth - valid_mask = points_3d[:, 2] > 0 - if not np.any(valid_mask): - return np.zeros((0, 2), dtype=np.int32) - - valid_points = points_3d[valid_mask] - - # Extract camera parameters - if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: - fx, fy, cx, cy = camera_intrinsics - else: - camera_matrix = np.array(camera_intrinsics) - fx = camera_matrix[0, 0] - fy = camera_matrix[1, 1] - cx = camera_matrix[0, 2] - cy = camera_matrix[1, 2] - - # Project to image coordinates - u = (valid_points[:, 0] * fx / valid_points[:, 2]) + cx - v = (valid_points[:, 1] * fy / valid_points[:, 2]) + cy - - # Round to integer pixel coordinates - points_2d = np.column_stack([u, v]).astype(np.int32) - - return points_2d - - -def euler_to_rotation_matrix(roll: float, pitch: float, yaw: float) -> np.ndarray: - """ - Convert Euler angles to rotation matrix. - - Args: - roll: Roll angle in radians - pitch: Pitch angle in radians - yaw: Yaw angle in radians - - Returns: - 3x3 rotation matrix - """ - Rx = np.array([[1, 0, 0], [0, np.cos(roll), -np.sin(roll)], [0, np.sin(roll), np.cos(roll)]]) - - Ry = np.array( - [[np.cos(pitch), 0, np.sin(pitch)], [0, 1, 0], [-np.sin(pitch), 0, np.cos(pitch)]] - ) - - Rz = np.array([[np.cos(yaw), -np.sin(yaw), 0], [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]]) - - # Combined rotation matrix - R = Rz @ Ry @ Rx - - return R - - -def create_gripper_geometry( - grasp_data: dict, - finger_length: float = 0.08, - finger_thickness: float = 0.004, -) -> List[o3d.geometry.TriangleMesh]: - """ - Create a simple fork-like gripper geometry from grasp data. - - Args: - grasp_data: Dictionary containing grasp parameters - - translation: 3D position list - - rotation_matrix: 3x3 rotation matrix defining gripper coordinate system - * X-axis: gripper width direction (opening/closing) - * Y-axis: finger length direction - * Z-axis: approach direction (toward object) - - width: Gripper opening width - finger_length: Length of gripper fingers (longer) - finger_thickness: Thickness of gripper fingers - base_height: Height of gripper base (longer) - color: RGB color for the gripper (solid blue) - - Returns: - List of Open3D TriangleMesh geometries for the gripper - """ - - translation = np.array(grasp_data["translation"]) - rotation_matrix = np.array(grasp_data["rotation_matrix"]) - - width = grasp_data.get("width", 0.04) - - # Create transformation matrix - transform = np.eye(4) - transform[:3, :3] = rotation_matrix - transform[:3, 3] = translation - - geometries = [] - - # Gripper dimensions - finger_width = 0.006 # Thickness of each finger - handle_length = 0.05 # Length of handle extending backward - - # Build gripper in local coordinate system: - # X-axis = width direction (left/right finger separation) - # Y-axis = finger length direction (fingers extend along +Y) - # Z-axis = approach direction (toward object, handle extends along -Z) - # IMPORTANT: Fingertips should be at origin (translation point) - - # Create left finger extending along +Y, positioned at +X - left_finger = o3d.geometry.TriangleMesh.create_box( - width=finger_width, # Thin finger - height=finger_length, # Extends along Y (finger length direction) - depth=finger_thickness, # Thin in Z direction - ) - left_finger.translate( - [ - width / 2 - finger_width / 2, # Position at +X (half width from center) - -finger_length, # Shift so fingertips are at origin - -finger_thickness / 2, # Center in Z - ] - ) - - # Create right finger extending along +Y, positioned at -X - right_finger = o3d.geometry.TriangleMesh.create_box( - width=finger_width, # Thin finger - height=finger_length, # Extends along Y (finger length direction) - depth=finger_thickness, # Thin in Z direction - ) - right_finger.translate( - [ - -width / 2 - finger_width / 2, # Position at -X (half width from center) - -finger_length, # Shift so fingertips are at origin - -finger_thickness / 2, # Center in Z - ] - ) - - # Create base connecting fingers - flat like a stickman body - base = o3d.geometry.TriangleMesh.create_box( - width=width + finger_width, # Full width plus finger thickness - height=finger_thickness, # Flat like fingers (stickman style) - depth=finger_thickness, # Thin like fingers - ) - base.translate( - [ - -width / 2 - finger_width / 2, # Start from left finger position - -finger_length - finger_thickness, # Behind fingers, adjusted for fingertips at origin - -finger_thickness / 2, # Center in Z - ] - ) - - # Create handle extending backward - flat stick like stickman arm - handle = o3d.geometry.TriangleMesh.create_box( - width=finger_width, # Same width as fingers - height=handle_length, # Extends backward along Y direction (same plane) - depth=finger_thickness, # Thin like fingers (same plane) - ) - handle.translate( - [ - -finger_width / 2, # Center in X - -finger_length - - finger_thickness - - handle_length, # Extend backward from base, adjusted for fingertips at origin - -finger_thickness / 2, # Same Z plane as other components - ] - ) - - # Use solid red color for all parts (user changed to red) - solid_color = [1.0, 0.0, 0.0] # Red color - - left_finger.paint_uniform_color(solid_color) - right_finger.paint_uniform_color(solid_color) - base.paint_uniform_color(solid_color) - handle.paint_uniform_color(solid_color) - - # Apply transformation to all parts - left_finger.transform(transform) - right_finger.transform(transform) - base.transform(transform) - handle.transform(transform) - - geometries.extend([left_finger, right_finger, base, handle]) - - return geometries - - -def create_all_gripper_geometries( - grasp_list: List[dict], max_grasps: int = -1 -) -> List[o3d.geometry.TriangleMesh]: - """ - Create gripper geometries for multiple grasps. - - Args: - grasp_list: List of grasp dictionaries - max_grasps: Maximum number of grasps to visualize (-1 for all) - - Returns: - List of all gripper geometries - """ - all_geometries = [] - - grasps_to_show = grasp_list if max_grasps < 0 else grasp_list[:max_grasps] - - for grasp in grasps_to_show: - gripper_parts = create_gripper_geometry(grasp) - all_geometries.extend(gripper_parts) - - return all_geometries - - -def draw_grasps_on_image( - image: np.ndarray, - grasp_data: Union[dict, Dict[Union[int, str], List[dict]], List[dict]], - camera_intrinsics: Union[List[float], np.ndarray], # [fx, fy, cx, cy] or 3x3 matrix - max_grasps: int = -1, # -1 means show all grasps - finger_length: float = 0.08, # Match 3D gripper - finger_thickness: float = 0.004, # Match 3D gripper -) -> np.ndarray: - """ - Draw fork-like gripper visualizations on the image matching 3D gripper design. - - Args: - image: Base image to draw on - grasp_data: Can be: - - A single grasp dict - - A list of grasp dicts - - A dictionary mapping object IDs or "scene" to list of grasps - camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix - max_grasps: Maximum number of grasps to visualize (-1 for all) - finger_length: Length of gripper fingers (matches 3D design) - finger_thickness: Thickness of gripper fingers (matches 3D design) - - Returns: - Image with grasps drawn - """ - result = image.copy() - - # Convert camera intrinsics to 3x3 matrix if needed - if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: - fx, fy, cx, cy = camera_intrinsics - camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) - else: - camera_matrix = np.array(camera_intrinsics) - - # Convert input to standard format - if isinstance(grasp_data, dict) and not any( - key in grasp_data for key in ["scene", 0, 1, 2, 3, 4, 5] - ): - # Single grasp - grasps_to_draw = [(grasp_data, 0)] - elif isinstance(grasp_data, list): - # List of grasps - grasps_to_draw = [(grasp, i) for i, grasp in enumerate(grasp_data)] - else: - # Dictionary of grasps by object ID - grasps_to_draw = [] - for obj_id, grasps in grasp_data.items(): - for i, grasp in enumerate(grasps): - grasps_to_draw.append((grasp, i)) - - # Limit number of grasps if specified - if max_grasps > 0: - grasps_to_draw = grasps_to_draw[:max_grasps] - - # Define grasp colors (solid red to match 3D design) - def get_grasp_color(index: int) -> tuple: - # Use solid red color for all grasps to match 3D design - return (0, 0, 255) # Red in BGR format for OpenCV - - # Draw each grasp - for grasp, index in grasps_to_draw: - try: - color = get_grasp_color(index) - thickness = max(1, 4 - index // 3) - - # Extract grasp parameters (using translation and rotation_matrix) - if "translation" not in grasp or "rotation_matrix" not in grasp: - continue - - translation = np.array(grasp["translation"]) - rotation_matrix = np.array(grasp["rotation_matrix"]) - width = grasp.get("width", 0.04) - - # Match 3D gripper dimensions - finger_width = 0.006 # Thickness of each finger (matches 3D) - handle_length = 0.05 # Length of handle extending backward (matches 3D) - - # Create gripper geometry in local coordinate system matching 3D design: - # X-axis = width direction (left/right finger separation) - # Y-axis = finger length direction (fingers extend along +Y) - # Z-axis = approach direction (toward object, handle extends along -Z) - # IMPORTANT: Fingertips should be at origin (translation point) - - # Left finger extending along +Y, positioned at +X - left_finger_points = np.array( - [ - [ - width / 2 - finger_width / 2, - -finger_length, - -finger_thickness / 2, - ], # Back left - [ - width / 2 + finger_width / 2, - -finger_length, - -finger_thickness / 2, - ], # Back right - [ - width / 2 + finger_width / 2, - 0, - -finger_thickness / 2, - ], # Front right (at origin) - [ - width / 2 - finger_width / 2, - 0, - -finger_thickness / 2, - ], # Front left (at origin) - ] - ) - - # Right finger extending along +Y, positioned at -X - right_finger_points = np.array( - [ - [ - -width / 2 - finger_width / 2, - -finger_length, - -finger_thickness / 2, - ], # Back left - [ - -width / 2 + finger_width / 2, - -finger_length, - -finger_thickness / 2, - ], # Back right - [ - -width / 2 + finger_width / 2, - 0, - -finger_thickness / 2, - ], # Front right (at origin) - [ - -width / 2 - finger_width / 2, - 0, - -finger_thickness / 2, - ], # Front left (at origin) - ] - ) - - # Base connecting fingers - flat rectangle behind fingers - base_points = np.array( - [ - [ - -width / 2 - finger_width / 2, - -finger_length - finger_thickness, - -finger_thickness / 2, - ], # Back left - [ - width / 2 + finger_width / 2, - -finger_length - finger_thickness, - -finger_thickness / 2, - ], # Back right - [ - width / 2 + finger_width / 2, - -finger_length, - -finger_thickness / 2, - ], # Front right - [ - -width / 2 - finger_width / 2, - -finger_length, - -finger_thickness / 2, - ], # Front left - ] - ) - - # Handle extending backward - thin rectangle - handle_points = np.array( - [ - [ - -finger_width / 2, - -finger_length - finger_thickness - handle_length, - -finger_thickness / 2, - ], # Back left - [ - finger_width / 2, - -finger_length - finger_thickness - handle_length, - -finger_thickness / 2, - ], # Back right - [ - finger_width / 2, - -finger_length - finger_thickness, - -finger_thickness / 2, - ], # Front right - [ - -finger_width / 2, - -finger_length - finger_thickness, - -finger_thickness / 2, - ], # Front left - ] - ) - - # Transform all points to world frame - def transform_points(points): - # Apply rotation and translation - world_points = (rotation_matrix @ points.T).T + translation - return world_points - - left_finger_world = transform_points(left_finger_points) - right_finger_world = transform_points(right_finger_points) - base_world = transform_points(base_points) - handle_world = transform_points(handle_points) - - # Project to 2D - left_finger_2d = project_3d_points_to_2d(left_finger_world, camera_matrix) - right_finger_2d = project_3d_points_to_2d(right_finger_world, camera_matrix) - base_2d = project_3d_points_to_2d(base_world, camera_matrix) - handle_2d = project_3d_points_to_2d(handle_world, camera_matrix) - - # Draw left finger - pts = left_finger_2d.astype(np.int32) - cv2.polylines(result, [pts], True, color, thickness) - - # Draw right finger - pts = right_finger_2d.astype(np.int32) - cv2.polylines(result, [pts], True, color, thickness) - - # Draw base - pts = base_2d.astype(np.int32) - cv2.polylines(result, [pts], True, color, thickness) - - # Draw handle - pts = handle_2d.astype(np.int32) - cv2.polylines(result, [pts], True, color, thickness) - - # Draw grasp center (fingertips at origin) - center_2d = project_3d_points_to_2d(translation.reshape(1, -1), camera_matrix)[0] - cv2.circle(result, tuple(center_2d.astype(int)), 3, color, -1) - - except Exception as e: - # Skip this grasp if there's an error - continue - - return result - - -def get_standard_coordinate_transform(): - """ - Get a standard coordinate transformation matrix for consistent visualization. - - This transformation ensures that: - - X (red) axis points right - - Y (green) axis points up - - Z (blue) axis points toward viewer - - Returns: - 4x4 transformation matrix - """ - # Standard transformation matrix to ensure consistent coordinate frame orientation - transform = np.array( - [ - [1, 0, 0, 0], # X points right - [0, -1, 0, 0], # Y points up (flip from OpenCV to standard) - [0, 0, -1, 0], # Z points toward viewer (flip depth) - [0, 0, 0, 1], - ] - ) - return transform - - -def visualize_grasps_3d( - point_cloud: o3d.geometry.PointCloud, - grasp_list: List[dict], - max_grasps: int = -1, -): - """ - Visualize grasps in 3D with point cloud. - - Args: - point_cloud: Open3D point cloud - grasp_list: List of grasp dictionaries - max_grasps: Maximum number of grasps to visualize - """ - # Apply standard coordinate transformation - transform = get_standard_coordinate_transform() - - # Transform point cloud - pc_copy = o3d.geometry.PointCloud(point_cloud) - pc_copy.transform(transform) - geometries = [pc_copy] - - # Transform gripper geometries - gripper_geometries = create_all_gripper_geometries(grasp_list, max_grasps) - for geom in gripper_geometries: - geom.transform(transform) - geometries.extend(gripper_geometries) - - # Add transformed coordinate frame - origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) - origin_frame.transform(transform) - geometries.append(origin_frame) - - o3d.visualization.draw_geometries(geometries, window_name="3D Grasp Visualization") - - -def rotation_matrix_to_euler(rotation_matrix: np.ndarray) -> Tuple[float, float, float]: - """ - Convert 3x3 rotation matrix to Euler angles (roll, pitch, yaw). - - Args: - rotation_matrix: 3x3 rotation matrix - - Returns: - Tuple of (roll, pitch, yaw) in radians - """ - sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) - singular = sy < 1e-6 - - if not singular: - x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) # roll - y = np.arctan2(-rotation_matrix[2, 0], sy) # pitch - z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) # yaw - else: - x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) # roll - y = np.arctan2(-rotation_matrix[2, 0], sy) # pitch - z = 0 # yaw - - return x, y, z - - -def parse_anygrasp_results(grasps: List[Dict]) -> List[Dict]: - """ - Parse AnyGrasp results into visualization format. - - Args: - grasps: List of AnyGrasp grasp dictionaries - - Returns: - List of dictionaries containing: - - id: Unique grasp identifier - - score: Confidence score (float) - - width: Gripper opening width (float) - - translation: 3D position [x, y, z] - - rotation_matrix: 3x3 rotation matrix as nested list - """ - if not grasps: - return [] - - parsed_grasps = [] - - for i, grasp in enumerate(grasps): - # Extract data from each grasp - translation = grasp.get("translation", [0, 0, 0]) - rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) - score = float(grasp.get("score", 0.0)) - width = float(grasp.get("width", 0.08)) - - parsed_grasp = { - "id": f"grasp_{i}", - "score": score, - "width": width, - "translation": translation, - "rotation_matrix": rotation_matrix.tolist(), - } - parsed_grasps.append(parsed_grasp) - - return parsed_grasps - - -def create_grasp_overlay( - rgb_image: np.ndarray, - grasps: List[Dict], - camera_intrinsics: Union[List[float], np.ndarray], -) -> np.ndarray: - """ - Create grasp visualization overlay on RGB image. - - Args: - rgb_image: RGB input image - grasps: List of grasp dictionaries in viz format - camera_intrinsics: Camera parameters - - Returns: - RGB image with grasp overlay - """ - try: - bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) - - result_bgr = draw_grasps_on_image( - bgr_image, - grasps, - camera_intrinsics, - max_grasps=-1, - ) - return cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) - except Exception as e: - return rgb_image.copy() diff --git a/build/lib/dimos/perception/object_detection_stream.py b/build/lib/dimos/perception/object_detection_stream.py deleted file mode 100644 index 3284d99f8b..0000000000 --- a/build/lib/dimos/perception/object_detection_stream.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import time -import numpy as np -from reactivex import Observable -from reactivex import operators as ops - -from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector - -try: - from dimos.perception.detection2d.detic_2d_det import Detic2DDetector - - DETIC_AVAILABLE = True -except (ModuleNotFoundError, ImportError): - DETIC_AVAILABLE = False - Detic2DDetector = None -from dimos.models.depth.metric3d import Metric3D -from dimos.perception.detection2d.utils import ( - calculate_depth_from_bbox, - calculate_object_size_from_bbox, - calculate_position_rotation_from_bbox, -) -from dimos.types.vector import Vector -from typing import Optional, Union, Callable -from dimos.types.manipulation import ObjectData -from dimos.utils.transform_utils import transform_robot_to_map - -from dimos.utils.logging_config import setup_logger - -# Initialize logger for the ObjectDetectionStream -logger = setup_logger("dimos.perception.object_detection_stream") - - -class ObjectDetectionStream: - """ - A stream processor that: - 1. Detects objects using a Detector (Detic or Yolo) - 2. Estimates depth using Metric3D - 3. Calculates 3D position and dimensions using camera intrinsics - 4. Transforms coordinates to map frame - 5. Draws bounding boxes and segmentation masks on the frame - - Provides a stream of structured object data with position and rotation information. - """ - - def __init__( - self, - camera_intrinsics=None, # [fx, fy, cx, cy] - device="cuda", - gt_depth_scale=1000.0, - min_confidence=0.7, - class_filter=None, # Optional list of class names to filter (e.g., ["person", "car"]) - get_pose: Callable = None, # Optional function to transform coordinates to map frame - detector: Optional[Union[Detic2DDetector, Yolo2DDetector]] = None, - video_stream: Observable = None, - disable_depth: bool = False, # Flag to disable monocular Metric3D depth estimation - draw_masks: bool = False, # Flag to enable drawing segmentation masks - ): - """ - Initialize the ObjectDetectionStream. - - Args: - camera_intrinsics: List [fx, fy, cx, cy] with camera parameters - device: Device to run inference on ("cuda" or "cpu") - gt_depth_scale: Ground truth depth scale for Metric3D - min_confidence: Minimum confidence for detections - class_filter: Optional list of class names to filter - get_pose: Optional function to transform pose to map coordinates - detector: Optional detector instance (Detic or Yolo) - video_stream: Observable of video frames to process (if provided, returns a stream immediately) - disable_depth: Flag to disable monocular Metric3D depth estimation - draw_masks: Flag to enable drawing segmentation masks - """ - self.min_confidence = min_confidence - self.class_filter = class_filter - self.get_pose = get_pose - self.disable_depth = disable_depth - self.draw_masks = draw_masks - # Initialize object detector - if detector is not None: - self.detector = detector - else: - if DETIC_AVAILABLE: - try: - self.detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) - logger.info("Using Detic2DDetector") - except Exception as e: - logger.warning( - f"Failed to initialize Detic2DDetector: {e}. Falling back to Yolo2DDetector." - ) - self.detector = Yolo2DDetector() - else: - logger.info("Detic not available. Using Yolo2DDetector.") - self.detector = Yolo2DDetector() - # Set up camera intrinsics - self.camera_intrinsics = camera_intrinsics - - # Initialize depth estimation model - self.depth_model = None - if not disable_depth: - try: - self.depth_model = Metric3D(gt_depth_scale) - - if camera_intrinsics is not None: - self.depth_model.update_intrinsic(camera_intrinsics) - - # Create 3x3 camera matrix for calculations - fx, fy, cx, cy = camera_intrinsics - self.camera_matrix = np.array( - [[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32 - ) - else: - raise ValueError("camera_intrinsics must be provided") - - logger.info("Depth estimation enabled with Metric3D") - except Exception as e: - logger.warning(f"Failed to initialize Metric3D depth model: {e}") - logger.warning("Falling back to disable_depth=True mode") - self.disable_depth = True - self.depth_model = None - else: - logger.info("Depth estimation disabled") - - # If video_stream is provided, create and store the stream immediately - self.stream = None - if video_stream is not None: - self.stream = self.create_stream(video_stream) - - def create_stream(self, video_stream: Observable) -> Observable: - """ - Create an Observable stream of object data from a video stream. - - Args: - video_stream: Observable that emits video frames - - Returns: - Observable that emits dictionaries containing object data - with position and rotation information - """ - - def process_frame(frame): - # TODO: More modular detector output interface - bboxes, track_ids, class_ids, confidences, names, *mask_data = ( - self.detector.process_image(frame) + ([],) - ) - - masks = ( - mask_data[0] - if mask_data and len(mask_data[0]) == len(bboxes) - else [None] * len(bboxes) - ) - - # Create visualization - viz_frame = frame.copy() - - # Process detections - objects = [] - if not self.disable_depth: - depth_map = self.depth_model.infer_depth(frame) - depth_map = np.array(depth_map) - else: - depth_map = None - - for i, bbox in enumerate(bboxes): - # Skip if confidence is too low - if i < len(confidences) and confidences[i] < self.min_confidence: - continue - - # Skip if class filter is active and class not in filter - class_name = names[i] if i < len(names) else None - if self.class_filter and class_name not in self.class_filter: - continue - - if not self.disable_depth and depth_map is not None: - # Get depth for this object - depth = calculate_depth_from_bbox(depth_map, bbox) - if depth is None: - # Skip objects with invalid depth - continue - # Calculate object position and rotation - position, rotation = calculate_position_rotation_from_bbox( - bbox, depth, self.camera_intrinsics - ) - # Get object dimensions - width, height = calculate_object_size_from_bbox( - bbox, depth, self.camera_intrinsics - ) - - # Transform to map frame if a transform function is provided - try: - if self.get_pose: - # position and rotation are already Vector objects, no need to convert - robot_pose = self.get_pose() - position, rotation = transform_robot_to_map( - robot_pose["position"], robot_pose["rotation"], position, rotation - ) - except Exception as e: - logger.error(f"Error transforming to map frame: {e}") - position, rotation = position, rotation - - else: - depth = -1 - position = Vector(0, 0, 0) - rotation = Vector(0, 0, 0) - width = -1 - height = -1 - - # Create a properly typed ObjectData instance - object_data: ObjectData = { - "object_id": track_ids[i] if i < len(track_ids) else -1, - "bbox": bbox, - "depth": depth, - "confidence": confidences[i] if i < len(confidences) else None, - "class_id": class_ids[i] if i < len(class_ids) else None, - "label": class_name, - "position": position, - "rotation": rotation, - "size": {"width": width, "height": height}, - "segmentation_mask": masks[i], - } - - objects.append(object_data) - - # Add visualization - x1, y1, x2, y2 = map(int, bbox) - color = (0, 255, 0) # Green for detected objects - mask_color = (0, 200, 200) # Yellow-green for masks - - # Draw segmentation mask if available and valid - try: - if self.draw_masks and object_data["segmentation_mask"] is not None: - # Create a colored mask overlay - mask = object_data["segmentation_mask"].astype(np.uint8) - colored_mask = np.zeros_like(viz_frame) - colored_mask[mask > 0] = mask_color - - # Apply the mask with transparency - alpha = 0.5 # transparency factor - mask_area = mask > 0 - viz_frame[mask_area] = cv2.addWeighted( - viz_frame[mask_area], 1 - alpha, colored_mask[mask_area], alpha, 0 - ) - - # Draw mask contour - contours, _ = cv2.findContours( - mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE - ) - cv2.drawContours(viz_frame, contours, -1, mask_color, 2) - except Exception as e: - logger.warning(f"Error drawing segmentation mask: {e}") - - # Draw bounding box with metadata - try: - cv2.rectangle(viz_frame, (x1, y1), (x2, y2), color, 1) - - # Add text for class only (removed position data) - # Handle possible None values for class_name or track_ids[i] - class_text = class_name if class_name is not None else "Unknown" - id_text = ( - track_ids[i] if i < len(track_ids) and track_ids[i] is not None else "?" - ) - text = f"{class_text}, ID: {id_text}" - - # Draw text background with smaller font - text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.3, 1)[0] - cv2.rectangle( - viz_frame, - (x1, y1 - text_size[1] - 5), - (x1 + text_size[0], y1), - (0, 0, 0), - -1, - ) - - # Draw text with smaller font - cv2.putText( - viz_frame, - text, - (x1, y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.3, - (255, 255, 255), - 1, - ) - except Exception as e: - logger.warning(f"Error drawing bounding box or text: {e}") - - return {"frame": frame, "viz_frame": viz_frame, "objects": objects} - - self.stream = video_stream.pipe(ops.map(process_frame)) - - return self.stream - - def get_stream(self): - """ - Returns the current detection stream if available. - Creates a new one with the provided video_stream if not already created. - - Returns: - Observable: The reactive stream of detection results - """ - if self.stream is None: - raise ValueError( - "Stream not initialized. Either provide a video_stream during initialization or call create_stream first." - ) - return self.stream - - def get_formatted_stream(self): - """ - Returns a formatted stream of object detection data for better readability. - This is especially useful for LLMs like Claude that need structured text input. - - Returns: - Observable: A stream of formatted string representations of object data - """ - if self.stream is None: - raise ValueError( - "Stream not initialized. Either provide a video_stream during initialization or call create_stream first." - ) - - def format_detection_data(result): - # Extract objects from result - objects = result.get("objects", []) - - if not objects: - return "No objects detected." - - formatted_data = "[DETECTED OBJECTS]\n" - try: - for i, obj in enumerate(objects): - pos = obj["position"] - rot = obj["rotation"] - size = obj["size"] - bbox = obj["bbox"] - - # Format each object with a multiline f-string for better readability - bbox_str = f"[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]" - formatted_data += ( - f"Object {i + 1}: {obj['label']}\n" - f" ID: {obj['object_id']}\n" - f" Confidence: {obj['confidence']:.2f}\n" - f" Position: x={pos.x:.2f}m, y={pos.y:.2f}m, z={pos.z:.2f}m\n" - f" Rotation: yaw={rot.z:.2f} rad\n" - f" Size: width={size['width']:.2f}m, height={size['height']:.2f}m\n" - f" Depth: {obj['depth']:.2f}m\n" - f" Bounding box: {bbox_str}\n" - "----------------------------------\n" - ) - except Exception as e: - logger.warning(f"Error formatting object {i}: {e}") - formatted_data += f"Object {i + 1}: [Error formatting data]" - formatted_data += "\n----------------------------------\n" - - return formatted_data - - # Return a new stream with the formatter applied - return self.stream.pipe(ops.map(format_detection_data)) - - def cleanup(self): - """Clean up resources.""" - pass diff --git a/build/lib/dimos/perception/object_tracker.py b/build/lib/dimos/perception/object_tracker.py deleted file mode 100644 index 010dbb9f3e..0000000000 --- a/build/lib/dimos/perception/object_tracker.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -from reactivex import Observable -from reactivex import operators as ops -import numpy as np -from dimos.perception.common.ibvs import ObjectDistanceEstimator -from dimos.models.depth.metric3d import Metric3D -from dimos.perception.detection2d.utils import calculate_depth_from_bbox - - -class ObjectTrackingStream: - def __init__( - self, - camera_intrinsics=None, - camera_pitch=0.0, - camera_height=1.0, - reid_threshold=5, - reid_fail_tolerance=10, - gt_depth_scale=1000.0, - ): - """ - Initialize an object tracking stream using OpenCV's CSRT tracker with ORB re-ID. - - Args: - camera_intrinsics: List in format [fx, fy, cx, cy] where: - - fx: Focal length in x direction (pixels) - - fy: Focal length in y direction (pixels) - - cx: Principal point x-coordinate (pixels) - - cy: Principal point y-coordinate (pixels) - camera_pitch: Camera pitch angle in radians (positive is up) - camera_height: Height of the camera from the ground in meters - reid_threshold: Minimum good feature matches needed to confirm re-ID. - reid_fail_tolerance: Number of consecutive frames Re-ID can fail before - tracking is stopped. - gt_depth_scale: Ground truth depth scale factor for Metric3D model - """ - self.tracker = None - self.tracking_bbox = None # Stores (x, y, w, h) for tracker initialization - self.tracking_initialized = False - self.orb = cv2.ORB_create() - self.bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False) - self.original_des = None # Store original ORB descriptors - self.reid_threshold = reid_threshold - self.reid_fail_tolerance = reid_fail_tolerance - self.reid_fail_count = 0 # Counter for consecutive re-id failures - - # Initialize distance estimator if camera parameters are provided - self.distance_estimator = None - if camera_intrinsics is not None: - # Convert [fx, fy, cx, cy] to 3x3 camera matrix - fx, fy, cx, cy = camera_intrinsics - K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) - - self.distance_estimator = ObjectDistanceEstimator( - K=K, camera_pitch=camera_pitch, camera_height=camera_height - ) - - # Initialize depth model - self.depth_model = Metric3D(gt_depth_scale) - if camera_intrinsics is not None: - self.depth_model.update_intrinsic(camera_intrinsics) - - def track(self, bbox, frame=None, distance=None, size=None): - """ - Set the initial bounding box for tracking. Features are extracted later. - - Args: - bbox: Bounding box in format [x1, y1, x2, y2] - frame: Optional - Current frame for depth estimation and feature extraction - distance: Optional - Known distance to object (meters) - size: Optional - Known size of object (meters) - - Returns: - bool: True if intention to track is set (bbox is valid) - """ - x1, y1, x2, y2 = map(int, bbox) - w, h = x2 - x1, y2 - y1 - if w <= 0 or h <= 0: - print(f"Warning: Invalid initial bbox provided: {bbox}. Tracking not started.") - self.stop_track() # Ensure clean state - return False - - self.tracking_bbox = (x1, y1, w, h) # Store in (x, y, w, h) format - self.tracker = cv2.legacy.TrackerCSRT_create() - self.tracking_initialized = False # Reset flag - self.original_des = None # Clear previous descriptors - self.reid_fail_count = 0 # Reset counter on new track - print(f"Tracking target set with bbox: {self.tracking_bbox}") - - # Calculate depth only if distance and size not provided - if frame is not None and distance is None and size is None: - depth_map = self.depth_model.infer_depth(frame) - depth_map = np.array(depth_map) - depth_estimate = calculate_depth_from_bbox(depth_map, bbox) - if depth_estimate is not None: - print(f"Estimated depth for object: {depth_estimate:.2f}m") - - # Update distance estimator if needed - if self.distance_estimator is not None: - if size is not None: - self.distance_estimator.set_estimated_object_size(size) - elif distance is not None: - self.distance_estimator.estimate_object_size(bbox, distance) - elif depth_estimate is not None: - self.distance_estimator.estimate_object_size(bbox, depth_estimate) - else: - print("No distance or size provided. Cannot estimate object size.") - - return True # Indicate intention to track is set - - def calculate_depth_from_bbox(self, frame, bbox): - """ - Calculate the average depth of an object within a bounding box. - Uses the 25th to 75th percentile range to filter outliers. - - Args: - frame: The image frame - bbox: Bounding box in format [x1, y1, x2, y2] - - Returns: - float: Average depth in meters, or None if depth estimation fails - """ - try: - # Get depth map for the entire frame - depth_map = self.depth_model.infer_depth(frame) - depth_map = np.array(depth_map) - - # Extract region of interest from the depth map - x1, y1, x2, y2 = map(int, bbox) - roi_depth = depth_map[y1:y2, x1:x2] - - if roi_depth.size == 0: - return None - - # Calculate 25th and 75th percentile to filter outliers - p25 = np.percentile(roi_depth, 25) - p75 = np.percentile(roi_depth, 75) - - # Filter depth values within this range - filtered_depth = roi_depth[(roi_depth >= p25) & (roi_depth <= p75)] - - # Calculate average depth (convert to meters) - if filtered_depth.size > 0: - return np.mean(filtered_depth) / 1000.0 # Convert mm to meters - - return None - except Exception as e: - print(f"Error calculating depth from bbox: {e}") - return None - - def reid(self, frame, current_bbox) -> bool: - """Check if features in current_bbox match stored original features.""" - if self.original_des is None: - return True # Cannot re-id if no original features - x1, y1, x2, y2 = map(int, current_bbox) - roi = frame[y1:y2, x1:x2] - if roi.size == 0: - return False # Empty ROI cannot match - - _, des_current = self.orb.detectAndCompute(roi, None) - if des_current is None or len(des_current) < 2: - return False # Need at least 2 descriptors for knnMatch - - # Handle case where original_des has only 1 descriptor (cannot use knnMatch with k=2) - if len(self.original_des) < 2: - matches = self.bf.match(self.original_des, des_current) - good_matches = len(matches) - else: - matches = self.bf.knnMatch(self.original_des, des_current, k=2) - # Apply Lowe's ratio test robustly - good_matches = 0 - for match_pair in matches: - if len(match_pair) == 2: - m, n = match_pair - if m.distance < 0.75 * n.distance: - good_matches += 1 - - # print(f"ReID: Good Matches={good_matches}, Threshold={self.reid_threshold}") # Debug - return good_matches >= self.reid_threshold - - def stop_track(self): - """ - Stop tracking the current object. - This resets the tracker and all tracking state. - - Returns: - bool: True if tracking was successfully stopped - """ - self.tracker = None - self.tracking_bbox = None - self.tracking_initialized = False - self.original_des = None - self.reid_fail_count = 0 # Reset counter - return True - - def create_stream(self, video_stream: Observable) -> Observable: - """ - Create an Observable stream of object tracking results from a video stream. - - Args: - video_stream: Observable that emits video frames - - Returns: - Observable that emits dictionaries containing tracking results and visualizations - """ - - def process_frame(frame): - viz_frame = frame.copy() - tracker_succeeded = False # Success from tracker.update() - reid_confirmed_this_frame = False # Result of reid() call for this frame - final_success = False # Overall success considering re-id tolerance - target_data = None - current_bbox_x1y1x2y2 = None # Store current bbox if tracking succeeds - - if self.tracker is not None and self.tracking_bbox is not None: - if not self.tracking_initialized: - # Extract initial features and initialize tracker on first frame - x_init, y_init, w_init, h_init = self.tracking_bbox - roi = frame[y_init : y_init + h_init, x_init : x_init + w_init] - - if roi.size > 0: - _, self.original_des = self.orb.detectAndCompute(roi, None) - if self.original_des is None: - print( - "Warning: No ORB features found in initial ROI during stream processing." - ) - else: - print(f"Initial ORB features extracted: {len(self.original_des)}") - - # Initialize the tracker - init_success = self.tracker.init(frame, self.tracking_bbox) - if init_success: - self.tracking_initialized = True - tracker_succeeded = True - reid_confirmed_this_frame = True # Assume re-id true on init - current_bbox_x1y1x2y2 = [ - x_init, - y_init, - x_init + w_init, - y_init + h_init, - ] - print("Tracker initialized successfully.") - else: - print("Error: Tracker initialization failed in stream.") - self.stop_track() # Reset if init fails - else: - print("Error: Empty ROI during tracker initialization in stream.") - self.stop_track() # Reset if ROI is bad - - else: # Tracker already initialized, perform update and re-id - tracker_succeeded, bbox_cv = self.tracker.update(frame) - if tracker_succeeded: - x, y, w, h = map(int, bbox_cv) - current_bbox_x1y1x2y2 = [x, y, x + w, y + h] - # Perform re-ID check - reid_confirmed_this_frame = self.reid(frame, current_bbox_x1y1x2y2) - - if reid_confirmed_this_frame: - self.reid_fail_count = 0 # Reset counter on success - else: - self.reid_fail_count += 1 # Increment counter on failure - print( - f"Re-ID failed ({self.reid_fail_count}/{self.reid_fail_tolerance}). Continuing track..." - ) - - # --- Determine final success and stop tracking if needed --- - if tracker_succeeded: - if self.reid_fail_count >= self.reid_fail_tolerance: - print(f"Re-ID failed consecutively {self.reid_fail_count} times. Target lost.") - final_success = False # Stop tracking - else: - final_success = True # Tracker ok, Re-ID ok or within tolerance - else: # Tracker update failed - final_success = False - if self.tracking_initialized: - print("Tracker update failed. Stopping track.") - - # --- Post-processing based on final_success --- - if final_success and current_bbox_x1y1x2y2 is not None: - # Tracking is considered successful (tracker ok, re-id ok or within tolerance) - x1, y1, x2, y2 = current_bbox_x1y1x2y2 - # Visualize based on *this frame's* re-id result - viz_color = ( - (0, 255, 0) if reid_confirmed_this_frame else (0, 165, 255) - ) # Green if confirmed, Orange if failed but tolerated - cv2.rectangle(viz_frame, (x1, y1), (x2, y2), viz_color, 2) - - target_data = { - "target_id": 0, - "bbox": current_bbox_x1y1x2y2, - "confidence": 1.0, - "reid_confirmed": reid_confirmed_this_frame, # Report actual re-id status - } - - dist_text = "Object Tracking" - if not reid_confirmed_this_frame: - dist_text += " (Re-ID Failed - Tolerated)" - - if ( - self.distance_estimator is not None - and self.distance_estimator.estimated_object_size is not None - ): - distance, angle = self.distance_estimator.estimate_distance_angle( - current_bbox_x1y1x2y2 - ) - if distance is not None: - target_data["distance"] = distance - target_data["angle"] = angle - dist_text = f"Object: {distance:.2f}m, {np.rad2deg(angle):.1f} deg" - if not reid_confirmed_this_frame: - dist_text += " (Re-ID Failed - Tolerated)" - - text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] - label_bg_y = max(y1 - text_size[1] - 5, 0) - cv2.rectangle(viz_frame, (x1, label_bg_y), (x1 + text_size[0], y1), (0, 0, 0), -1) - cv2.putText( - viz_frame, - dist_text, - (x1, y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 255), - 1, - ) - - elif ( - self.tracking_initialized - ): # Tracking stopped this frame (either tracker fail or re-id tolerance exceeded) - self.stop_track() # Reset tracker state and counter - - # else: # Not tracking or initialization failed, do nothing, return empty result - # pass - - return { - "frame": frame, - "viz_frame": viz_frame, - "targets": [target_data] if target_data else [], - } - - return video_stream.pipe(ops.map(process_frame)) - - def cleanup(self): - """Clean up resources.""" - self.stop_track() diff --git a/build/lib/dimos/perception/person_tracker.py b/build/lib/dimos/perception/person_tracker.py deleted file mode 100644 index 0a2f9cc7b7..0000000000 --- a/build/lib/dimos/perception/person_tracker.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector -from dimos.perception.detection2d.utils import filter_detections -from dimos.perception.common.ibvs import PersonDistanceEstimator -from reactivex import Observable -from reactivex import operators as ops -import numpy as np -import cv2 - - -class PersonTrackingStream: - def __init__( - self, - camera_intrinsics=None, - camera_pitch=0.0, - camera_height=1.0, - ): - """ - Initialize a person tracking stream using Yolo2DDetector and PersonDistanceEstimator. - - Args: - camera_intrinsics: List in format [fx, fy, cx, cy] where: - - fx: Focal length in x direction (pixels) - - fy: Focal length in y direction (pixels) - - cx: Principal point x-coordinate (pixels) - - cy: Principal point y-coordinate (pixels) - camera_pitch: Camera pitch angle in radians (positive is up) - camera_height: Height of the camera from the ground in meters - """ - self.detector = Yolo2DDetector() - - # Initialize distance estimator - if camera_intrinsics is None: - raise ValueError("Camera intrinsics are required for distance estimation") - - # Validate camera intrinsics format [fx, fy, cx, cy] - if ( - not isinstance(camera_intrinsics, (list, tuple, np.ndarray)) - or len(camera_intrinsics) != 4 - ): - raise ValueError("Camera intrinsics must be provided as [fx, fy, cx, cy]") - - # Convert [fx, fy, cx, cy] to 3x3 camera matrix - fx, fy, cx, cy = camera_intrinsics - K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) - - self.distance_estimator = PersonDistanceEstimator( - K=K, camera_pitch=camera_pitch, camera_height=camera_height - ) - - def create_stream(self, video_stream: Observable) -> Observable: - """ - Create an Observable stream of person tracking results from a video stream. - - Args: - video_stream: Observable that emits video frames - - Returns: - Observable that emits dictionaries containing tracking results and visualizations - """ - - def process_frame(frame): - # Detect people in the frame - bboxes, track_ids, class_ids, confidences, names = self.detector.process_image(frame) - - # Filter to keep only person detections using filter_detections - ( - filtered_bboxes, - filtered_track_ids, - filtered_class_ids, - filtered_confidences, - filtered_names, - ) = filter_detections( - bboxes, - track_ids, - class_ids, - confidences, - names, - class_filter=[0], # 0 is the class_id for person - name_filter=["person"], - ) - - # Create visualization - viz_frame = self.detector.visualize_results( - frame, - filtered_bboxes, - filtered_track_ids, - filtered_class_ids, - filtered_confidences, - filtered_names, - ) - - # Calculate distance and angle for each person - targets = [] - for i, bbox in enumerate(filtered_bboxes): - target_data = { - "target_id": filtered_track_ids[i] if i < len(filtered_track_ids) else -1, - "bbox": bbox, - "confidence": filtered_confidences[i] - if i < len(filtered_confidences) - else None, - } - - distance, angle = self.distance_estimator.estimate_distance_angle(bbox) - target_data["distance"] = distance - target_data["angle"] = angle - - # Add text to visualization - x1, y1, x2, y2 = map(int, bbox) - dist_text = f"{distance:.2f}m, {np.rad2deg(angle):.1f} deg" - - # Add black background for better visibility - text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] - # Position at top-right corner - cv2.rectangle( - viz_frame, (x2 - text_size[0], y1 - text_size[1] - 5), (x2, y1), (0, 0, 0), -1 - ) - - # Draw text in white at top-right - cv2.putText( - viz_frame, - dist_text, - (x2 - text_size[0], y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 255), - 2, - ) - - targets.append(target_data) - - # Create the result dictionary - result = {"frame": frame, "viz_frame": viz_frame, "targets": targets} - - return result - - return video_stream.pipe(ops.map(process_frame)) - - def cleanup(self): - """Clean up resources.""" - pass # No specific cleanup needed for now diff --git a/build/lib/dimos/perception/pointcloud/__init__.py b/build/lib/dimos/perception/pointcloud/__init__.py deleted file mode 100644 index 1f282bb738..0000000000 --- a/build/lib/dimos/perception/pointcloud/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .utils import * -from .cuboid_fit import * -from .pointcloud_filtering import * diff --git a/build/lib/dimos/perception/pointcloud/cuboid_fit.py b/build/lib/dimos/perception/pointcloud/cuboid_fit.py deleted file mode 100644 index d567f40395..0000000000 --- a/build/lib/dimos/perception/pointcloud/cuboid_fit.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import open3d as o3d -import cv2 -from typing import Dict, Optional, Union, Tuple - - -def fit_cuboid( - points: Union[np.ndarray, o3d.geometry.PointCloud], method: str = "minimal" -) -> Optional[Dict]: - """ - Fit a cuboid to a point cloud using Open3D's built-in methods. - - Args: - points: Nx3 array of points or Open3D PointCloud - method: Fitting method: - - 'minimal': Minimal oriented bounding box (best fit) - - 'oriented': PCA-based oriented bounding box - - 'axis_aligned': Axis-aligned bounding box - - Returns: - Dictionary containing: - - center: 3D center point - - dimensions: 3D dimensions (extent) - - rotation: 3x3 rotation matrix - - error: Fitting error - - bounding_box: Open3D OrientedBoundingBox object - Returns None if insufficient points or fitting fails. - - Raises: - ValueError: If method is invalid or inputs are malformed - """ - # Validate method - valid_methods = ["minimal", "oriented", "axis_aligned"] - if method not in valid_methods: - raise ValueError(f"method must be one of {valid_methods}, got '{method}'") - - # Convert to point cloud if needed - if isinstance(points, np.ndarray): - points = np.asarray(points) - if len(points.shape) != 2 or points.shape[1] != 3: - raise ValueError(f"points array must be Nx3, got shape {points.shape}") - if len(points) < 4: - return None - - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(points) - elif isinstance(points, o3d.geometry.PointCloud): - pcd = points - points = np.asarray(pcd.points) - if len(points) < 4: - return None - else: - raise ValueError(f"points must be numpy array or Open3D PointCloud, got {type(points)}") - - try: - # Get bounding box based on method - if method == "minimal": - obb = pcd.get_minimal_oriented_bounding_box(robust=True) - elif method == "oriented": - obb = pcd.get_oriented_bounding_box(robust=True) - elif method == "axis_aligned": - # Convert axis-aligned to oriented format for consistency - aabb = pcd.get_axis_aligned_bounding_box() - obb = o3d.geometry.OrientedBoundingBox() - obb.center = aabb.get_center() - obb.extent = aabb.get_extent() - obb.R = np.eye(3) # Identity rotation for axis-aligned - - # Extract parameters - center = np.asarray(obb.center) - dimensions = np.asarray(obb.extent) - rotation = np.asarray(obb.R) - - # Calculate fitting error - error = _compute_fitting_error(points, center, dimensions, rotation) - - return { - "center": center, - "dimensions": dimensions, - "rotation": rotation, - "error": error, - "bounding_box": obb, - "method": method, - } - - except Exception as e: - # Log error but don't crash - return None for graceful handling - print(f"Warning: Cuboid fitting failed with method '{method}': {e}") - return None - - -def fit_cuboid_simple(points: Union[np.ndarray, o3d.geometry.PointCloud]) -> Optional[Dict]: - """ - Simple wrapper for minimal oriented bounding box fitting. - - Args: - points: Nx3 array of points or Open3D PointCloud - - Returns: - Dictionary with center, dimensions, rotation, and bounding_box, - or None if insufficient points - """ - return fit_cuboid(points, method="minimal") - - -def _compute_fitting_error( - points: np.ndarray, center: np.ndarray, dimensions: np.ndarray, rotation: np.ndarray -) -> float: - """ - Compute fitting error as mean squared distance from points to cuboid surface. - - Args: - points: Nx3 array of points - center: 3D center point - dimensions: 3D dimensions - rotation: 3x3 rotation matrix - - Returns: - Mean squared error - """ - if len(points) == 0: - return 0.0 - - # Transform points to local coordinates - local_points = (points - center) @ rotation - half_dims = dimensions / 2 - - # Calculate distance to cuboid surface - dx = np.abs(local_points[:, 0]) - half_dims[0] - dy = np.abs(local_points[:, 1]) - half_dims[1] - dz = np.abs(local_points[:, 2]) - half_dims[2] - - # Points outside: distance to nearest face - # Points inside: negative distance to nearest face - outside_dist = np.sqrt(np.maximum(dx, 0) ** 2 + np.maximum(dy, 0) ** 2 + np.maximum(dz, 0) ** 2) - inside_dist = np.minimum(np.minimum(dx, dy), dz) - distances = np.where((dx > 0) | (dy > 0) | (dz > 0), outside_dist, -inside_dist) - - return float(np.mean(distances**2)) - - -def get_cuboid_corners( - center: np.ndarray, dimensions: np.ndarray, rotation: np.ndarray -) -> np.ndarray: - """ - Get the 8 corners of a cuboid. - - Args: - center: 3D center point - dimensions: 3D dimensions - rotation: 3x3 rotation matrix - - Returns: - 8x3 array of corner coordinates - """ - half_dims = dimensions / 2 - corners_local = ( - np.array( - [ - [-1, -1, -1], # 0: left bottom back - [-1, -1, 1], # 1: left bottom front - [-1, 1, -1], # 2: left top back - [-1, 1, 1], # 3: left top front - [1, -1, -1], # 4: right bottom back - [1, -1, 1], # 5: right bottom front - [1, 1, -1], # 6: right top back - [1, 1, 1], # 7: right top front - ] - ) - * half_dims - ) - - # Apply rotation and translation - return corners_local @ rotation.T + center - - -def visualize_cuboid_on_image( - image: np.ndarray, - cuboid_params: Dict, - camera_matrix: np.ndarray, - extrinsic_rotation: Optional[np.ndarray] = None, - extrinsic_translation: Optional[np.ndarray] = None, - color: Tuple[int, int, int] = (0, 255, 0), - thickness: int = 2, - show_dimensions: bool = True, -) -> np.ndarray: - """ - Draw a fitted cuboid on an image using camera projection. - - Args: - image: Input image to draw on - cuboid_params: Dictionary containing cuboid parameters - camera_matrix: Camera intrinsic matrix (3x3) - extrinsic_rotation: Optional external rotation (3x3) - extrinsic_translation: Optional external translation (3x1) - color: Line color as (B, G, R) tuple - thickness: Line thickness - show_dimensions: Whether to display dimension text - - Returns: - Image with cuboid visualization - - Raises: - ValueError: If required parameters are missing or invalid - """ - # Validate inputs - required_keys = ["center", "dimensions", "rotation"] - if not all(key in cuboid_params for key in required_keys): - raise ValueError(f"cuboid_params must contain keys: {required_keys}") - - if camera_matrix.shape != (3, 3): - raise ValueError(f"camera_matrix must be 3x3, got {camera_matrix.shape}") - - # Get corners in world coordinates - corners = get_cuboid_corners( - cuboid_params["center"], cuboid_params["dimensions"], cuboid_params["rotation"] - ) - - # Transform corners if extrinsic parameters are provided - if extrinsic_rotation is not None and extrinsic_translation is not None: - if extrinsic_rotation.shape != (3, 3): - raise ValueError(f"extrinsic_rotation must be 3x3, got {extrinsic_rotation.shape}") - if extrinsic_translation.shape not in [(3,), (3, 1)]: - raise ValueError( - f"extrinsic_translation must be (3,) or (3,1), got {extrinsic_translation.shape}" - ) - - extrinsic_translation = extrinsic_translation.flatten() - corners = (extrinsic_rotation @ corners.T).T + extrinsic_translation - - try: - # Project 3D corners to image coordinates - corners_img, _ = cv2.projectPoints( - corners.astype(np.float32), - np.zeros(3), - np.zeros(3), # No additional rotation/translation - camera_matrix.astype(np.float32), - None, # No distortion - ) - corners_img = corners_img.reshape(-1, 2).astype(int) - - # Check if corners are within image bounds - h, w = image.shape[:2] - valid_corners = ( - (corners_img[:, 0] >= 0) - & (corners_img[:, 0] < w) - & (corners_img[:, 1] >= 0) - & (corners_img[:, 1] < h) - ) - - if not np.any(valid_corners): - print("Warning: All cuboid corners are outside image bounds") - return image.copy() - - except Exception as e: - print(f"Warning: Failed to project cuboid corners: {e}") - return image.copy() - - # Define edges for wireframe visualization - edges = [ - # Bottom face - (0, 1), - (1, 5), - (5, 4), - (4, 0), - # Top face - (2, 3), - (3, 7), - (7, 6), - (6, 2), - # Vertical edges - (0, 2), - (1, 3), - (5, 7), - (4, 6), - ] - - # Draw edges - vis_img = image.copy() - for i, j in edges: - # Only draw edge if both corners are valid - if valid_corners[i] and valid_corners[j]: - cv2.line(vis_img, tuple(corners_img[i]), tuple(corners_img[j]), color, thickness) - - # Add dimension text if requested - if show_dimensions and np.any(valid_corners): - dims = cuboid_params["dimensions"] - dim_text = f"Dims: {dims[0]:.3f} x {dims[1]:.3f} x {dims[2]:.3f}" - - # Find a good position for text (top-left of image) - text_pos = (10, 30) - font_scale = 0.7 - - # Add background rectangle for better readability - text_size = cv2.getTextSize(dim_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2)[0] - cv2.rectangle( - vis_img, - (text_pos[0] - 5, text_pos[1] - text_size[1] - 5), - (text_pos[0] + text_size[0] + 5, text_pos[1] + 5), - (0, 0, 0), - -1, - ) - - cv2.putText(vis_img, dim_text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, 2) - - return vis_img - - -def compute_cuboid_volume(cuboid_params: Dict) -> float: - """ - Compute the volume of a cuboid. - - Args: - cuboid_params: Dictionary containing cuboid parameters - - Returns: - Volume in cubic units - """ - if "dimensions" not in cuboid_params: - raise ValueError("cuboid_params must contain 'dimensions' key") - - dims = cuboid_params["dimensions"] - return float(np.prod(dims)) - - -def compute_cuboid_surface_area(cuboid_params: Dict) -> float: - """ - Compute the surface area of a cuboid. - - Args: - cuboid_params: Dictionary containing cuboid parameters - - Returns: - Surface area in square units - """ - if "dimensions" not in cuboid_params: - raise ValueError("cuboid_params must contain 'dimensions' key") - - dims = cuboid_params["dimensions"] - return 2.0 * (dims[0] * dims[1] + dims[1] * dims[2] + dims[2] * dims[0]) - - -def check_cuboid_quality(cuboid_params: Dict, points: np.ndarray) -> Dict: - """ - Assess the quality of a cuboid fit. - - Args: - cuboid_params: Dictionary containing cuboid parameters - points: Original points used for fitting - - Returns: - Dictionary with quality metrics - """ - if len(points) == 0: - return {"error": "No points provided"} - - # Basic metrics - volume = compute_cuboid_volume(cuboid_params) - surface_area = compute_cuboid_surface_area(cuboid_params) - error = cuboid_params.get("error", 0.0) - - # Aspect ratio analysis - dims = cuboid_params["dimensions"] - aspect_ratios = [ - dims[0] / dims[1] if dims[1] > 0 else float("inf"), - dims[1] / dims[2] if dims[2] > 0 else float("inf"), - dims[2] / dims[0] if dims[0] > 0 else float("inf"), - ] - max_aspect_ratio = max(aspect_ratios) - - # Volume ratio (cuboid volume vs convex hull volume) - try: - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(points) - hull, _ = pcd.compute_convex_hull() - hull_volume = hull.get_volume() - volume_ratio = volume / hull_volume if hull_volume > 0 else float("inf") - except: - volume_ratio = None - - return { - "fitting_error": error, - "volume": volume, - "surface_area": surface_area, - "max_aspect_ratio": max_aspect_ratio, - "volume_ratio": volume_ratio, - "num_points": len(points), - "method": cuboid_params.get("method", "unknown"), - } - - -# Backward compatibility -def visualize_fit(image, cuboid_params, camera_matrix, R=None, t=None): - """ - Legacy function for backward compatibility. - Use visualize_cuboid_on_image instead. - """ - return visualize_cuboid_on_image( - image, cuboid_params, camera_matrix, R, t, show_dimensions=True - ) diff --git a/build/lib/dimos/perception/pointcloud/pointcloud_filtering.py b/build/lib/dimos/perception/pointcloud/pointcloud_filtering.py deleted file mode 100644 index ef033bff3f..0000000000 --- a/build/lib/dimos/perception/pointcloud/pointcloud_filtering.py +++ /dev/null @@ -1,674 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import cv2 -import os -import torch -import open3d as o3d -import argparse -import pickle -from typing import Dict, List, Optional, Union -import time -from dimos.types.manipulation import ObjectData -from dimos.types.vector import Vector -from dimos.perception.pointcloud.utils import ( - load_camera_matrix_from_yaml, - create_point_cloud_and_extract_masks, - o3d_point_cloud_to_numpy, -) -from dimos.perception.pointcloud.cuboid_fit import fit_cuboid - - -class PointcloudFiltering: - """ - A production-ready point cloud filtering pipeline for segmented objects. - - This class takes segmentation results and produces clean, filtered point clouds - for each object with consistent coloring and optional outlier removal. - """ - - def __init__( - self, - color_intrinsics: Optional[Union[str, List[float], np.ndarray]] = None, - depth_intrinsics: Optional[Union[str, List[float], np.ndarray]] = None, - color_weight: float = 0.3, - enable_statistical_filtering: bool = True, - statistical_neighbors: int = 20, - statistical_std_ratio: float = 1.5, - enable_radius_filtering: bool = True, - radius_filtering_radius: float = 0.015, - radius_filtering_min_neighbors: int = 25, - enable_subsampling: bool = True, - voxel_size: float = 0.005, - max_num_objects: int = 10, - min_points_for_cuboid: int = 10, - cuboid_method: str = "oriented", - max_bbox_size_percent: float = 30.0, - ): - """ - Initialize the point cloud filtering pipeline. - - Args: - color_intrinsics: Camera intrinsics for color image - depth_intrinsics: Camera intrinsics for depth image - color_weight: Weight for blending generated color with original (0.0-1.0) - enable_statistical_filtering: Enable/disable statistical outlier filtering - statistical_neighbors: Number of neighbors for statistical filtering - statistical_std_ratio: Standard deviation ratio for statistical filtering - enable_radius_filtering: Enable/disable radius outlier filtering - radius_filtering_radius: Search radius for radius filtering (meters) - radius_filtering_min_neighbors: Min neighbors within radius - enable_subsampling: Enable/disable point cloud subsampling - voxel_size: Voxel size for downsampling (meters, when subsampling enabled) - max_num_objects: Maximum number of objects to process (top N by confidence) - min_points_for_cuboid: Minimum points required for cuboid fitting - cuboid_method: Method for cuboid fitting ('minimal', 'oriented', 'axis_aligned') - max_bbox_size_percent: Maximum percentage of image size for object bboxes (0-100) - - Raises: - ValueError: If invalid parameters are provided - """ - # Validate parameters - if not 0.0 <= color_weight <= 1.0: - raise ValueError(f"color_weight must be between 0.0 and 1.0, got {color_weight}") - if not 0.0 <= max_bbox_size_percent <= 100.0: - raise ValueError( - f"max_bbox_size_percent must be between 0.0 and 100.0, got {max_bbox_size_percent}" - ) - - # Store settings - self.color_weight = color_weight - self.enable_statistical_filtering = enable_statistical_filtering - self.statistical_neighbors = statistical_neighbors - self.statistical_std_ratio = statistical_std_ratio - self.enable_radius_filtering = enable_radius_filtering - self.radius_filtering_radius = radius_filtering_radius - self.radius_filtering_min_neighbors = radius_filtering_min_neighbors - self.enable_subsampling = enable_subsampling - self.voxel_size = voxel_size - self.max_num_objects = max_num_objects - self.min_points_for_cuboid = min_points_for_cuboid - self.cuboid_method = cuboid_method - self.max_bbox_size_percent = max_bbox_size_percent - - # Load camera matrices - self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) - self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) - - # Store the full point cloud - self.full_pcd = None - - def generate_color_from_id(self, object_id: int) -> np.ndarray: - """Generate a consistent color for a given object ID.""" - np.random.seed(object_id) - color = np.random.randint(0, 255, 3, dtype=np.uint8) - np.random.seed(None) - return color - - def _validate_inputs( - self, color_img: np.ndarray, depth_img: np.ndarray, objects: List[ObjectData] - ): - """Validate input parameters.""" - if color_img.shape[:2] != depth_img.shape: - raise ValueError("Color and depth image dimensions don't match") - - def _prepare_masks(self, masks: List[np.ndarray], target_shape: tuple) -> List[np.ndarray]: - """Prepare and validate masks to match target shape.""" - processed_masks = [] - for mask in masks: - # Convert mask to numpy if it's a tensor - if hasattr(mask, "cpu"): - mask = mask.cpu().numpy() - - mask = mask.astype(bool) - - # Handle shape mismatches - if mask.shape != target_shape: - if len(mask.shape) > 2: - mask = mask[:, :, 0] - - if mask.shape != target_shape: - mask = cv2.resize( - mask.astype(np.uint8), - (target_shape[1], target_shape[0]), - interpolation=cv2.INTER_NEAREST, - ).astype(bool) - - processed_masks.append(mask) - - return processed_masks - - def _apply_color_mask( - self, pcd: o3d.geometry.PointCloud, rgb_color: np.ndarray - ) -> o3d.geometry.PointCloud: - """Apply weighted color mask to point cloud.""" - if len(np.asarray(pcd.colors)) > 0: - original_colors = np.asarray(pcd.colors) - generated_color = rgb_color.astype(np.float32) / 255.0 - colored_mask = ( - 1.0 - self.color_weight - ) * original_colors + self.color_weight * generated_color - colored_mask = np.clip(colored_mask, 0.0, 1.0) - pcd.colors = o3d.utility.Vector3dVector(colored_mask) - return pcd - - def _apply_filtering(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.PointCloud: - """Apply optional filtering to point cloud based on enabled flags.""" - current_pcd = pcd - - # Apply statistical filtering if enabled - if self.enable_statistical_filtering: - current_pcd, _ = current_pcd.remove_statistical_outlier( - nb_neighbors=self.statistical_neighbors, std_ratio=self.statistical_std_ratio - ) - - # Apply radius filtering if enabled - if self.enable_radius_filtering: - current_pcd, _ = current_pcd.remove_radius_outlier( - nb_points=self.radius_filtering_min_neighbors, radius=self.radius_filtering_radius - ) - - return current_pcd - - def _apply_subsampling(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.PointCloud: - """Apply subsampling to limit point cloud size using Open3D's voxel downsampling.""" - if self.enable_subsampling: - return pcd.voxel_down_sample(self.voxel_size) - return pcd - - def _extract_masks_from_objects(self, objects: List[ObjectData]) -> List[np.ndarray]: - """Extract segmentation masks from ObjectData objects.""" - return [obj["segmentation_mask"] for obj in objects] - - def get_full_point_cloud(self) -> o3d.geometry.PointCloud: - """Get the full point cloud.""" - return self._apply_subsampling(self.full_pcd) - - def process_images( - self, color_img: np.ndarray, depth_img: np.ndarray, objects: List[ObjectData] - ) -> List[ObjectData]: - """ - Process color and depth images with object detection results to create filtered point clouds. - - Args: - color_img: RGB image as numpy array (H, W, 3) - depth_img: Depth image as numpy array (H, W) in meters - objects: List of ObjectData from object detection stream - - Returns: - List of updated ObjectData with pointcloud and 3D information. Each ObjectData - dictionary is enhanced with the following new fields: - - **3D Spatial Information** (added when sufficient points for cuboid fitting): - - "position": Vector(x, y, z) - 3D center position in world coordinates (meters) - - "rotation": Vector(roll, pitch, yaw) - 3D orientation as Euler angles (radians) - - "size": {"width": float, "height": float, "depth": float} - 3D bounding box dimensions (meters) - - **Point Cloud Data**: - - "point_cloud": o3d.geometry.PointCloud - Filtered Open3D point cloud with colors - - "color": np.ndarray - Consistent RGB color [R,G,B] (0-255) generated from object_id - - **Grasp Generation Arrays** (AnyGrasp format): - - "point_cloud_numpy": np.ndarray - Nx3 XYZ coordinates as float32 (meters) - - "colors_numpy": np.ndarray - Nx3 RGB colors as float32 (0.0-1.0 range) - - Raises: - ValueError: If inputs are invalid - RuntimeError: If processing fails - """ - # Validate inputs - self._validate_inputs(color_img, depth_img, objects) - - if not objects: - return [] - - # Filter to top N objects by confidence - if len(objects) > self.max_num_objects: - # Sort objects by confidence (highest first), handle None confidences - sorted_objects = sorted( - objects, - key=lambda obj: obj.get("confidence", 0.0) - if obj.get("confidence") is not None - else 0.0, - reverse=True, - ) - objects = sorted_objects[: self.max_num_objects] - - # Filter out objects with bboxes too large - image_area = color_img.shape[0] * color_img.shape[1] - max_bbox_area = image_area * (self.max_bbox_size_percent / 100.0) - - filtered_objects = [] - for obj in objects: - if "bbox" in obj and obj["bbox"] is not None: - bbox = obj["bbox"] - # Calculate bbox area (assuming bbox format [x1, y1, x2, y2]) - bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - if bbox_area <= max_bbox_area: - filtered_objects.append(obj) - else: - filtered_objects.append(obj) - - objects = filtered_objects - - # Extract masks from ObjectData - masks = self._extract_masks_from_objects(objects) - - # Prepare masks - processed_masks = self._prepare_masks(masks, depth_img.shape) - - # Create point clouds efficiently - self.full_pcd, masked_pcds = create_point_cloud_and_extract_masks( - color_img, depth_img, processed_masks, self.depth_camera_matrix, depth_scale=1.0 - ) - - # Process each object and update ObjectData - updated_objects = [] - - for i, (obj, mask, pcd) in enumerate(zip(objects, processed_masks, masked_pcds)): - # Skip empty point clouds - if len(np.asarray(pcd.points)) == 0: - continue - - # Create a copy of the object data to avoid modifying the original - updated_obj = obj.copy() - - # Generate consistent color - object_id = obj.get("object_id", i) - rgb_color = self.generate_color_from_id(object_id) - - # Apply color mask - pcd = self._apply_color_mask(pcd, rgb_color) - - # Apply subsampling to control point cloud size - pcd = self._apply_subsampling(pcd) - - # Apply filtering (optional based on flags) - pcd_filtered = self._apply_filtering(pcd) - - # Fit cuboid and extract 3D information - points = np.asarray(pcd_filtered.points) - if len(points) >= self.min_points_for_cuboid: - cuboid_params = fit_cuboid(points, method=self.cuboid_method) - if cuboid_params is not None: - # Update position, rotation, and size from cuboid - center = cuboid_params["center"] - dimensions = cuboid_params["dimensions"] - rotation_matrix = cuboid_params["rotation"] - - # Convert rotation matrix to euler angles (roll, pitch, yaw) - sy = np.sqrt( - rotation_matrix[0, 0] * rotation_matrix[0, 0] - + rotation_matrix[1, 0] * rotation_matrix[1, 0] - ) - singular = sy < 1e-6 - - if not singular: - roll = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) - pitch = np.arctan2(-rotation_matrix[2, 0], sy) - yaw = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) - else: - roll = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) - pitch = np.arctan2(-rotation_matrix[2, 0], sy) - yaw = 0 - - # Update position, rotation, and size from cuboid - updated_obj["position"] = Vector(center[0], center[1], center[2]) - updated_obj["rotation"] = Vector(roll, pitch, yaw) - updated_obj["size"] = { - "width": float(dimensions[0]), - "height": float(dimensions[1]), - "depth": float(dimensions[2]), - } - - # Add point cloud data to ObjectData - updated_obj["point_cloud"] = pcd_filtered - updated_obj["color"] = rgb_color - - # Extract numpy arrays for grasp generation (anygrasp format) - points_array = np.asarray(pcd_filtered.points).astype(np.float32) # Nx3 XYZ coordinates - if pcd_filtered.has_colors(): - colors_array = np.asarray(pcd_filtered.colors).astype( - np.float32 - ) # Nx3 RGB (0-1 range) - else: - # If no colors, create array of zeros - colors_array = np.zeros((len(points_array), 3), dtype=np.float32) - - updated_obj["point_cloud_numpy"] = points_array - updated_obj["colors_numpy"] = colors_array - - updated_objects.append(updated_obj) - - return updated_objects - - def cleanup(self): - """Clean up resources.""" - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def create_test_pipeline(data_dir: str) -> tuple: - """ - Create a test pipeline with default settings. - - Args: - data_dir: Directory containing camera info files - - Returns: - Tuple of (filter_pipeline, color_info_path, depth_info_path) - """ - color_info_path = os.path.join(data_dir, "color_camera_info.yaml") - depth_info_path = os.path.join(data_dir, "depth_camera_info.yaml") - - # Default pipeline with subsampling disabled by default - filter_pipeline = PointcloudFiltering( - color_intrinsics=color_info_path, - depth_intrinsics=depth_info_path, - ) - - return filter_pipeline, color_info_path, depth_info_path - - -def load_test_images(data_dir: str) -> tuple: - """ - Load the first available test images from data directory. - - Args: - data_dir: Directory containing color and depth subdirectories - - Returns: - Tuple of (color_img, depth_img) or raises FileNotFoundError - """ - - def find_first_image(directory): - """Find the first image file in the given directory.""" - if not os.path.exists(directory): - return None - - image_extensions = [".jpg", ".jpeg", ".png", ".bmp"] - for filename in sorted(os.listdir(directory)): - if any(filename.lower().endswith(ext) for ext in image_extensions): - return os.path.join(directory, filename) - return None - - color_dir = os.path.join(data_dir, "color") - depth_dir = os.path.join(data_dir, "depth") - - color_img_path = find_first_image(color_dir) - depth_img_path = find_first_image(depth_dir) - - if not color_img_path or not depth_img_path: - raise FileNotFoundError(f"Could not find color or depth images in {data_dir}") - - # Load color image - color_img = cv2.imread(color_img_path) - if color_img is None: - raise FileNotFoundError(f"Could not load color image from {color_img_path}") - color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) - - # Load depth image - depth_img = cv2.imread(depth_img_path, cv2.IMREAD_UNCHANGED) - if depth_img is None: - raise FileNotFoundError(f"Could not load depth image from {depth_img_path}") - - # Convert depth to meters if needed - if depth_img.dtype == np.uint16: - depth_img = depth_img.astype(np.float32) / 1000.0 - - return color_img, depth_img - - -def run_segmentation(color_img: np.ndarray, device: str = "auto") -> List[ObjectData]: - """ - Run segmentation on color image and return ObjectData objects. - - Args: - color_img: RGB color image - device: Device to use ('auto', 'cuda', or 'cpu') - - Returns: - List of ObjectData objects with segmentation results - """ - if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - - # Import here to avoid circular imports - from dimos.perception.segmentation import Sam2DSegmenter - - segmenter = Sam2DSegmenter( - model_path="FastSAM-s.pt", device=device, use_tracker=False, use_analyzer=False - ) - - try: - masks, bboxes, target_ids, probs, names = segmenter.process_image(np.array(color_img)) - - # Create ObjectData objects - objects = [] - for i in range(len(bboxes)): - obj_data: ObjectData = { - "object_id": target_ids[i] if i < len(target_ids) else i, - "bbox": bboxes[i], - "depth": -1.0, # Will be populated by pointcloud filtering - "confidence": probs[i] if i < len(probs) else 1.0, - "class_id": i, - "label": names[i] if i < len(names) else "", - "segmentation_mask": masks[i].cpu().numpy() - if hasattr(masks[i], "cpu") - else masks[i], - "position": Vector(0, 0, 0), # Will be populated by pointcloud filtering - "rotation": Vector(0, 0, 0), # Will be populated by pointcloud filtering - "size": { - "width": 0.0, - "height": 0.0, - "depth": 0.0, - }, # Will be populated by pointcloud filtering - } - objects.append(obj_data) - - return objects - - finally: - segmenter.cleanup() - - -def visualize_results(objects: List[ObjectData]): - """ - Visualize point cloud filtering results with 3D bounding boxes. - - Args: - objects: List of ObjectData with point clouds - """ - all_pcds = [] - - for obj in objects: - if "point_cloud" in obj and obj["point_cloud"] is not None: - pcd = obj["point_cloud"] - all_pcds.append(pcd) - - # Draw 3D bounding box if position, rotation, and size are available - if ( - "position" in obj - and "rotation" in obj - and "size" in obj - and obj["position"] is not None - and obj["rotation"] is not None - and obj["size"] is not None - ): - try: - position = obj["position"] - rotation = obj["rotation"] - size = obj["size"] - - # Convert position to numpy array - if hasattr(position, "x"): # Vector object - center = np.array([position.x, position.y, position.z]) - else: # Dictionary - center = np.array([position["x"], position["y"], position["z"]]) - - # Convert rotation (euler angles) to rotation matrix - if hasattr(rotation, "x"): # Vector object (roll, pitch, yaw) - roll, pitch, yaw = rotation.x, rotation.y, rotation.z - else: # Dictionary - roll, pitch, yaw = rotation["roll"], rotation["pitch"], rotation["yaw"] - - # Create rotation matrix from euler angles (ZYX order) - # Roll (X), Pitch (Y), Yaw (Z) - cos_r, sin_r = np.cos(roll), np.sin(roll) - cos_p, sin_p = np.cos(pitch), np.sin(pitch) - cos_y, sin_y = np.cos(yaw), np.sin(yaw) - - # Rotation matrix for ZYX euler angles - R = np.array( - [ - [ - cos_y * cos_p, - cos_y * sin_p * sin_r - sin_y * cos_r, - cos_y * sin_p * cos_r + sin_y * sin_r, - ], - [ - sin_y * cos_p, - sin_y * sin_p * sin_r + cos_y * cos_r, - sin_y * sin_p * cos_r - cos_y * sin_r, - ], - [-sin_p, cos_p * sin_r, cos_p * cos_r], - ] - ) - - # Get dimensions - width = size.get("width", 0.1) - height = size.get("height", 0.1) - depth = size.get("depth", 0.1) - extent = np.array([width, height, depth]) - - # Create oriented bounding box - obb = o3d.geometry.OrientedBoundingBox(center=center, R=R, extent=extent) - obb.color = [1, 0, 0] # Red bounding boxes - all_pcds.append(obb) - - except Exception as e: - print( - f"Warning: Failed to create bounding box for object {obj.get('object_id', 'unknown')}: {e}" - ) - - # Add coordinate frame - coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) - all_pcds.append(coordinate_frame) - - # Visualize - if all_pcds: - o3d.visualization.draw_geometries( - all_pcds, - window_name="Filtered Point Clouds with 3D Bounding Boxes", - width=1280, - height=720, - ) - - -def main(): - """Main function to demonstrate the PointcloudFiltering pipeline.""" - parser = argparse.ArgumentParser(description="Point cloud filtering pipeline demonstration") - parser.add_argument( - "--save-pickle", - type=str, - help="Save generated ObjectData to pickle file (provide filename)", - ) - parser.add_argument( - "--data-dir", type=str, help="Directory containing RGBD data (default: auto-detect)" - ) - args = parser.parse_args() - - try: - # Setup paths - if args.data_dir: - data_dir = args.data_dir - else: - script_dir = os.path.dirname(os.path.abspath(__file__)) - dimos_dir = os.path.abspath(os.path.join(script_dir, "../../../")) - data_dir = os.path.join(dimos_dir, "assets/rgbd_data") - - # Load test data - print("Loading test images...") - color_img, depth_img = load_test_images(data_dir) - print(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") - - # Run segmentation - print("Running segmentation...") - objects = run_segmentation(color_img) - print(f"Found {len(objects)} objects") - - # Create filtering pipeline - print("Creating filtering pipeline...") - filter_pipeline, _, _ = create_test_pipeline(data_dir) - - # Process images - print("Processing point clouds...") - updated_objects = filter_pipeline.process_images(color_img, depth_img, objects) - - # Print results - print(f"Processing complete:") - print(f" Objects processed: {len(updated_objects)}/{len(objects)}") - - # Print per-object stats - for i, obj in enumerate(updated_objects): - if "point_cloud" in obj and obj["point_cloud"] is not None: - num_points = len(np.asarray(obj["point_cloud"].points)) - position = obj.get("position", Vector(0, 0, 0)) - size = obj.get("size", {}) - print(f" Object {i + 1} (ID: {obj['object_id']}): {num_points} points") - print(f" Position: ({position.x:.2f}, {position.y:.2f}, {position.z:.2f})") - print( - f" Size: {size.get('width', 0):.3f} x {size.get('height', 0):.3f} x {size.get('depth', 0):.3f}" - ) - - # Save to pickle file if requested - if args.save_pickle: - pickle_path = args.save_pickle - if not pickle_path.endswith(".pkl"): - pickle_path += ".pkl" - - print(f"Saving ObjectData to {pickle_path}...") - - # Create serializable objects (exclude Open3D point clouds) - serializable_objects = [] - for obj in updated_objects: - obj_copy = obj.copy() - # Remove the Open3D point cloud object (can't be pickled) - if "point_cloud" in obj_copy: - del obj_copy["point_cloud"] - serializable_objects.append(obj_copy) - - with open(pickle_path, "wb") as f: - pickle.dump(serializable_objects, f) - - print(f"Successfully saved {len(serializable_objects)} objects to {pickle_path}") - print("To load: objects = pickle.load(open('filename.pkl', 'rb'))") - print( - "Note: Open3D point clouds excluded - use point_cloud_numpy and colors_numpy for processing" - ) - - # Visualize results - print("Visualizing results...") - visualize_results(updated_objects) - - except Exception as e: - print(f"Error: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - main() diff --git a/build/lib/dimos/perception/pointcloud/utils.py b/build/lib/dimos/perception/pointcloud/utils.py deleted file mode 100644 index b1174253e3..0000000000 --- a/build/lib/dimos/perception/pointcloud/utils.py +++ /dev/null @@ -1,1451 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Point cloud utilities for RGBD data processing. - -This module provides efficient utilities for creating and manipulating point clouds -from RGBD images using Open3D. -""" - -import numpy as np -import yaml -import os -import cv2 -import open3d as o3d -from typing import List, Optional, Tuple, Union, Dict, Any -from scipy.spatial import cKDTree - - -def depth_to_point_cloud(depth_image, camera_intrinsics, subsample_factor=4): - """ - Convert depth image to point cloud using camera intrinsics. - Subsamples points to reduce density. - - Args: - depth_image: HxW depth image in meters - camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix - subsample_factor: Factor to subsample points (higher = fewer points) - - Returns: - Nx3 array of 3D points - """ - # Filter out inf and nan values from depth image - depth_filtered = depth_image.copy() - - # Create mask for valid depth values (finite, positive, non-zero) - valid_mask = np.isfinite(depth_filtered) & (depth_filtered > 0) - - # Set invalid values to 0 - depth_filtered[~valid_mask] = 0.0 - - # Extract camera parameters - if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: - fx, fy, cx, cy = camera_intrinsics - else: - fx = camera_intrinsics[0, 0] - fy = camera_intrinsics[1, 1] - cx = camera_intrinsics[0, 2] - cy = camera_intrinsics[1, 2] - - # Create pixel coordinate grid - rows, cols = depth_filtered.shape - x_grid, y_grid = np.meshgrid( - np.arange(0, cols, subsample_factor), np.arange(0, rows, subsample_factor) - ) - - # Flatten grid and depth - x = x_grid.flatten() - y = y_grid.flatten() - z = depth_filtered[y_grid, x_grid].flatten() - - # Remove points with invalid depth (after filtering, this catches zeros) - valid = z > 0 - x = x[valid] - y = y[valid] - z = z[valid] - - # Convert to 3D points - X = (x - cx) * z / fx - Y = (y - cy) * z / fy - Z = z - - return np.column_stack([X, Y, Z]) - - -def load_camera_matrix_from_yaml( - camera_info: Optional[Union[str, List[float], np.ndarray, dict]], -) -> Optional[np.ndarray]: - """ - Load camera intrinsic matrix from various input formats. - - Args: - camera_info: Can be: - - Path to YAML file containing camera parameters - - List of [fx, fy, cx, cy] - - 3x3 numpy array (returned as-is) - - Dict with camera parameters - - None (returns None) - - Returns: - 3x3 camera intrinsic matrix or None if input is None - - Raises: - ValueError: If camera_info format is invalid or file cannot be read - FileNotFoundError: If YAML file path doesn't exist - """ - if camera_info is None: - return None - - # Handle case where camera_info is already a matrix - if isinstance(camera_info, np.ndarray) and camera_info.shape == (3, 3): - return camera_info.astype(np.float32) - - # Handle case where camera_info is [fx, fy, cx, cy] format - if isinstance(camera_info, list) and len(camera_info) == 4: - fx, fy, cx, cy = camera_info - return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) - - # Handle case where camera_info is a dict - if isinstance(camera_info, dict): - return _extract_matrix_from_dict(camera_info) - - # Handle case where camera_info is a path to a YAML file - if isinstance(camera_info, str): - if not os.path.isfile(camera_info): - raise FileNotFoundError(f"Camera info file not found: {camera_info}") - - try: - with open(camera_info, "r") as f: - data = yaml.safe_load(f) - return _extract_matrix_from_dict(data) - except Exception as e: - raise ValueError(f"Failed to read camera info from {camera_info}: {e}") - - raise ValueError( - f"Invalid camera_info format. Expected str, list, dict, or numpy array, got {type(camera_info)}" - ) - - -def _extract_matrix_from_dict(data: dict) -> np.ndarray: - """Extract camera matrix from dictionary with various formats.""" - # ROS format with 'K' field (most common) - if "K" in data: - k_data = data["K"] - if len(k_data) == 9: - return np.array(k_data, dtype=np.float32).reshape(3, 3) - - # Standard format with 'camera_matrix' - if "camera_matrix" in data: - if "data" in data["camera_matrix"]: - matrix_data = data["camera_matrix"]["data"] - if len(matrix_data) == 9: - return np.array(matrix_data, dtype=np.float32).reshape(3, 3) - - # Explicit intrinsics format - if all(k in data for k in ["fx", "fy", "cx", "cy"]): - fx, fy = float(data["fx"]), float(data["fy"]) - cx, cy = float(data["cx"]), float(data["cy"]) - return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) - - # Error case - provide helpful debug info - available_keys = list(data.keys()) - if "K" in data: - k_info = f"K field length: {len(data['K']) if hasattr(data['K'], '__len__') else 'unknown'}" - else: - k_info = "K field not found" - - raise ValueError( - f"Cannot extract camera matrix from data. " - f"Available keys: {available_keys}. {k_info}. " - f"Expected formats: 'K' (9 elements), 'camera_matrix.data' (9 elements), " - f"or individual 'fx', 'fy', 'cx', 'cy' fields." - ) - - -def create_o3d_point_cloud_from_rgbd( - color_img: np.ndarray, - depth_img: np.ndarray, - intrinsic: np.ndarray, - depth_scale: float = 1.0, - depth_trunc: float = 3.0, -) -> o3d.geometry.PointCloud: - """ - Create an Open3D point cloud from RGB and depth images. - - Args: - color_img: RGB image as numpy array (H, W, 3) - depth_img: Depth image as numpy array (H, W) - intrinsic: Camera intrinsic matrix (3x3 numpy array) - depth_scale: Scale factor to convert depth to meters - depth_trunc: Maximum depth in meters - - Returns: - Open3D point cloud object - - Raises: - ValueError: If input dimensions are invalid - """ - # Validate inputs - if len(color_img.shape) != 3 or color_img.shape[2] != 3: - raise ValueError(f"color_img must be (H, W, 3), got {color_img.shape}") - if len(depth_img.shape) != 2: - raise ValueError(f"depth_img must be (H, W), got {depth_img.shape}") - if color_img.shape[:2] != depth_img.shape: - raise ValueError( - f"Color and depth image dimensions don't match: {color_img.shape[:2]} vs {depth_img.shape}" - ) - if intrinsic.shape != (3, 3): - raise ValueError(f"intrinsic must be (3, 3), got {intrinsic.shape}") - - # Convert to Open3D format - color_o3d = o3d.geometry.Image(color_img.astype(np.uint8)) - - # Filter out inf and nan values from depth image - depth_filtered = depth_img.copy() - - # Create mask for valid depth values (finite, positive, non-zero) - valid_mask = np.isfinite(depth_filtered) & (depth_filtered > 0) - - # Set invalid values to 0 (which Open3D treats as no depth) - depth_filtered[~valid_mask] = 0.0 - - depth_o3d = o3d.geometry.Image(depth_filtered.astype(np.float32)) - - # Create Open3D intrinsic object - height, width = color_img.shape[:2] - fx, fy = intrinsic[0, 0], intrinsic[1, 1] - cx, cy = intrinsic[0, 2], intrinsic[1, 2] - intrinsic_o3d = o3d.camera.PinholeCameraIntrinsic( - width, - height, - fx, - fy, # fx, fy - cx, - cy, # cx, cy - ) - - # Create RGBD image - rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( - color_o3d, - depth_o3d, - depth_scale=depth_scale, - depth_trunc=depth_trunc, - convert_rgb_to_intensity=False, - ) - - # Create point cloud - pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, intrinsic_o3d) - - return pcd - - -def o3d_point_cloud_to_numpy(pcd: o3d.geometry.PointCloud) -> np.ndarray: - """ - Convert Open3D point cloud to numpy array of XYZRGB points. - - Args: - pcd: Open3D point cloud object - - Returns: - Nx6 array of XYZRGB points (empty array if no points) - """ - points = np.asarray(pcd.points) - if len(points) == 0: - return np.zeros((0, 6), dtype=np.float32) - - # Get colors if available - if pcd.has_colors(): - colors = np.asarray(pcd.colors) * 255.0 # Convert from [0,1] to [0,255] - return np.column_stack([points, colors]).astype(np.float32) - else: - # No colors available, return points with zero colors - zeros = np.zeros((len(points), 3), dtype=np.float32) - return np.column_stack([points, zeros]).astype(np.float32) - - -def numpy_to_o3d_point_cloud(points_rgb: np.ndarray) -> o3d.geometry.PointCloud: - """ - Convert numpy array of XYZRGB points to Open3D point cloud. - - Args: - points_rgb: Nx6 array of XYZRGB points or Nx3 array of XYZ points - - Returns: - Open3D point cloud object - - Raises: - ValueError: If array shape is invalid - """ - if len(points_rgb) == 0: - return o3d.geometry.PointCloud() - - if points_rgb.shape[1] < 3: - raise ValueError( - f"points_rgb must have at least 3 columns (XYZ), got {points_rgb.shape[1]}" - ) - - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(points_rgb[:, :3]) - - # Add colors if available - if points_rgb.shape[1] >= 6: - colors = points_rgb[:, 3:6] / 255.0 # Convert from [0,255] to [0,1] - colors = np.clip(colors, 0.0, 1.0) # Ensure valid range - pcd.colors = o3d.utility.Vector3dVector(colors) - - return pcd - - -def create_masked_point_cloud(color_img, depth_img, mask, intrinsic, depth_scale=1.0): - """ - Create a point cloud for a masked region of RGBD data using Open3D. - - Args: - color_img: RGB image (H, W, 3) - depth_img: Depth image (H, W) - mask: Boolean mask of the same size as color_img and depth_img - intrinsic: Camera intrinsic matrix (3x3 numpy array) - depth_scale: Scale factor to convert depth to meters - - Returns: - Open3D point cloud object for the masked region - """ - # Filter out inf and nan values from depth image - depth_filtered = depth_img.copy() - - # Create mask for valid depth values (finite, positive, non-zero) - valid_depth_mask = np.isfinite(depth_filtered) & (depth_filtered > 0) - - # Set invalid values to 0 - depth_filtered[~valid_depth_mask] = 0.0 - - # Create masked color and depth images - masked_color = color_img.copy() - masked_depth = depth_filtered.copy() - - # Apply mask - if not mask.shape[:2] == color_img.shape[:2]: - raise ValueError(f"Mask shape {mask.shape} doesn't match image shape {color_img.shape[:2]}") - - # Create a boolean mask that is properly expanded for the RGB channels - # For RGB image, we need to properly broadcast the mask to all 3 channels - if len(color_img.shape) == 3 and color_img.shape[2] == 3: - # Properly broadcast mask to match the RGB dimensions - mask_rgb = np.broadcast_to(mask[:, :, np.newaxis], color_img.shape) - masked_color[~mask_rgb] = 0 - else: - # For grayscale images - masked_color[~mask] = 0 - - # Apply mask to depth image - masked_depth[~mask] = 0 - - # Create point cloud - pcd = create_o3d_point_cloud_from_rgbd(masked_color, masked_depth, intrinsic, depth_scale) - - # Remove points with coordinates at origin (0,0,0) which are likely from masked out regions - points = np.asarray(pcd.points) - if len(points) > 0: - # Find points that are not at origin - dist_from_origin = np.sum(points**2, axis=1) - valid_indices = dist_from_origin > 1e-6 - - # Filter points and colors - pcd = pcd.select_by_index(np.where(valid_indices)[0]) - - return pcd - - -def create_point_cloud_and_extract_masks( - color_img: np.ndarray, - depth_img: np.ndarray, - masks: List[np.ndarray], - intrinsic: np.ndarray, - depth_scale: float = 1.0, - depth_trunc: float = 3.0, -) -> Tuple[o3d.geometry.PointCloud, List[o3d.geometry.PointCloud]]: - """ - Efficiently create a point cloud once and extract multiple masked regions. - - Args: - color_img: RGB image (H, W, 3) - depth_img: Depth image (H, W) - masks: List of boolean masks, each of shape (H, W) - intrinsic: Camera intrinsic matrix (3x3 numpy array) - depth_scale: Scale factor to convert depth to meters - depth_trunc: Maximum depth in meters - - Returns: - Tuple of (full_point_cloud, list_of_masked_point_clouds) - """ - if not masks: - return o3d.geometry.PointCloud(), [] - - # Create the full point cloud - full_pcd = create_o3d_point_cloud_from_rgbd( - color_img, depth_img, intrinsic, depth_scale, depth_trunc - ) - - if len(np.asarray(full_pcd.points)) == 0: - return full_pcd, [o3d.geometry.PointCloud() for _ in masks] - - # Create pixel-to-point mapping - valid_depth_mask = np.isfinite(depth_img) & (depth_img > 0) & (depth_img <= depth_trunc) - - valid_depth = valid_depth_mask.flatten() - if not np.any(valid_depth): - return full_pcd, [o3d.geometry.PointCloud() for _ in masks] - - pixel_to_point = np.full(len(valid_depth), -1, dtype=np.int32) - pixel_to_point[valid_depth] = np.arange(np.sum(valid_depth)) - - # Extract point clouds for each mask - masked_pcds = [] - max_points = len(np.asarray(full_pcd.points)) - - for mask in masks: - if mask.shape != depth_img.shape: - masked_pcds.append(o3d.geometry.PointCloud()) - continue - - mask_flat = mask.flatten() - valid_mask_indices = mask_flat & valid_depth - point_indices = pixel_to_point[valid_mask_indices] - valid_point_indices = point_indices[point_indices >= 0] - - if len(valid_point_indices) > 0: - valid_point_indices = np.clip(valid_point_indices, 0, max_points - 1) - valid_point_indices = np.unique(valid_point_indices) - masked_pcd = full_pcd.select_by_index(valid_point_indices.tolist()) - else: - masked_pcd = o3d.geometry.PointCloud() - - masked_pcds.append(masked_pcd) - - return full_pcd, masked_pcds - - -def extract_masked_point_cloud_efficient( - full_pcd: o3d.geometry.PointCloud, depth_img: np.ndarray, mask: np.ndarray -) -> o3d.geometry.PointCloud: - """ - Extract a masked region from an existing point cloud efficiently. - - This assumes the point cloud was created from the given depth image. - Use this when you have a pre-computed full point cloud and want to extract - individual masked regions. - - Args: - full_pcd: Complete Open3D point cloud - depth_img: Depth image used to create the point cloud (H, W) - mask: Boolean mask (H, W) - - Returns: - Open3D point cloud for the masked region - - Raises: - ValueError: If mask shape doesn't match depth image - """ - if mask.shape != depth_img.shape: - raise ValueError( - f"Mask shape {mask.shape} doesn't match depth image shape {depth_img.shape}" - ) - - # Early return if no points in full point cloud - if len(np.asarray(full_pcd.points)) == 0: - return o3d.geometry.PointCloud() - - # Get valid depth mask - valid_depth = depth_img.flatten() > 0 - mask_flat = mask.flatten() - - # Find pixels that are both valid and in the mask - valid_mask_indices = mask_flat & valid_depth - - # Get indices of valid points - point_indices = np.where(valid_mask_indices[valid_depth])[0] - - # Extract the masked point cloud - if len(point_indices) > 0: - return full_pcd.select_by_index(point_indices) - else: - return o3d.geometry.PointCloud() - - -def segment_and_remove_plane(pcd, distance_threshold=0.02, ransac_n=3, num_iterations=1000): - """ - Segment the dominant plane from a point cloud using RANSAC and remove it. - Often used to remove table tops, floors, walls, or other planar surfaces. - - Args: - pcd: Open3D point cloud object - distance_threshold: Maximum distance a point can be from the plane to be considered an inlier (in meters) - ransac_n: Number of points to sample for each RANSAC iteration - num_iterations: Number of RANSAC iterations - - Returns: - Open3D point cloud with the dominant plane removed - """ - # Make a copy of the input point cloud to avoid modifying the original - pcd_filtered = o3d.geometry.PointCloud() - pcd_filtered.points = o3d.utility.Vector3dVector(np.asarray(pcd.points)) - if pcd.has_colors(): - pcd_filtered.colors = o3d.utility.Vector3dVector(np.asarray(pcd.colors)) - if pcd.has_normals(): - pcd_filtered.normals = o3d.utility.Vector3dVector(np.asarray(pcd.normals)) - - # Check if point cloud has enough points - if len(pcd_filtered.points) < ransac_n: - return pcd_filtered - - # Run RANSAC to find the largest plane - _, inliers = pcd_filtered.segment_plane( - distance_threshold=distance_threshold, ransac_n=ransac_n, num_iterations=num_iterations - ) - - # Remove the dominant plane (regardless of orientation) - pcd_without_dominant_plane = pcd_filtered.select_by_index(inliers, invert=True) - return pcd_without_dominant_plane - - -def filter_point_cloud_statistical( - pcd: o3d.geometry.PointCloud, nb_neighbors: int = 20, std_ratio: float = 2.0 -) -> Tuple[o3d.geometry.PointCloud, np.ndarray]: - """ - Apply statistical outlier filtering to point cloud. - - Args: - pcd: Input point cloud - nb_neighbors: Number of neighbors to analyze for each point - std_ratio: Threshold level based on standard deviation - - Returns: - Tuple of (filtered_point_cloud, outlier_indices) - """ - if len(np.asarray(pcd.points)) == 0: - return pcd, np.array([]) - - return pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) - - -def filter_point_cloud_radius( - pcd: o3d.geometry.PointCloud, nb_points: int = 16, radius: float = 0.05 -) -> Tuple[o3d.geometry.PointCloud, np.ndarray]: - """ - Apply radius-based outlier filtering to point cloud. - - Args: - pcd: Input point cloud - nb_points: Minimum number of points within radius - radius: Search radius in meters - - Returns: - Tuple of (filtered_point_cloud, outlier_indices) - """ - if len(np.asarray(pcd.points)) == 0: - return pcd, np.array([]) - - return pcd.remove_radius_outlier(nb_points=nb_points, radius=radius) - - -def compute_point_cloud_bounds(pcd: o3d.geometry.PointCloud) -> dict: - """ - Compute bounding box information for a point cloud. - - Args: - pcd: Input point cloud - - Returns: - Dictionary with bounds information - """ - points = np.asarray(pcd.points) - if len(points) == 0: - return { - "min": np.array([0, 0, 0]), - "max": np.array([0, 0, 0]), - "center": np.array([0, 0, 0]), - "size": np.array([0, 0, 0]), - "volume": 0.0, - } - - min_bound = points.min(axis=0) - max_bound = points.max(axis=0) - center = (min_bound + max_bound) / 2 - size = max_bound - min_bound - volume = np.prod(size) - - return {"min": min_bound, "max": max_bound, "center": center, "size": size, "volume": volume} - - -def project_3d_points_to_2d( - points_3d: np.ndarray, camera_intrinsics: Union[List[float], np.ndarray] -) -> np.ndarray: - """ - Project 3D points to 2D image coordinates using camera intrinsics. - - Args: - points_3d: Nx3 array of 3D points (X, Y, Z) - camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix - - Returns: - Nx2 array of 2D image coordinates (u, v) - """ - if len(points_3d) == 0: - return np.zeros((0, 2), dtype=np.int32) - - # Filter out points with zero or negative depth - valid_mask = points_3d[:, 2] > 0 - if not np.any(valid_mask): - return np.zeros((0, 2), dtype=np.int32) - - valid_points = points_3d[valid_mask] - - # Extract camera parameters - if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: - fx, fy, cx, cy = camera_intrinsics - else: - fx = camera_intrinsics[0, 0] - fy = camera_intrinsics[1, 1] - cx = camera_intrinsics[0, 2] - cy = camera_intrinsics[1, 2] - - # Project to image coordinates - u = (valid_points[:, 0] * fx / valid_points[:, 2]) + cx - v = (valid_points[:, 1] * fy / valid_points[:, 2]) + cy - - # Round to integer pixel coordinates - points_2d = np.column_stack([u, v]).astype(np.int32) - - return points_2d - - -def overlay_point_clouds_on_image( - base_image: np.ndarray, - point_clouds: List[o3d.geometry.PointCloud], - camera_intrinsics: Union[List[float], np.ndarray], - colors: List[Tuple[int, int, int]], - point_size: int = 2, - alpha: float = 0.7, -) -> np.ndarray: - """ - Overlay multiple colored point clouds onto an image. - - Args: - base_image: Base image to overlay onto (H, W, 3) - assumed to be RGB - point_clouds: List of Open3D point cloud objects - camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix - colors: List of RGB color tuples for each point cloud. If None, generates distinct colors. - point_size: Size of points to draw (in pixels) - alpha: Blending factor for overlay (0.0 = fully transparent, 1.0 = fully opaque) - - Returns: - Image with overlaid point clouds (H, W, 3) - """ - if len(point_clouds) == 0: - return base_image.copy() - - # Create overlay image - overlay = base_image.copy() - height, width = base_image.shape[:2] - - # Process each point cloud - for i, pcd in enumerate(point_clouds): - if pcd is None: - continue - - points_3d = np.asarray(pcd.points) - if len(points_3d) == 0: - continue - - # Project 3D points to 2D - points_2d = project_3d_points_to_2d(points_3d, camera_intrinsics) - - if len(points_2d) == 0: - continue - - # Filter points within image bounds - valid_mask = ( - (points_2d[:, 0] >= 0) - & (points_2d[:, 0] < width) - & (points_2d[:, 1] >= 0) - & (points_2d[:, 1] < height) - ) - valid_points_2d = points_2d[valid_mask] - - if len(valid_points_2d) == 0: - continue - - # Get color for this point cloud - color = colors[i % len(colors)] - - # Ensure color is a tuple of integers for OpenCV - if isinstance(color, (list, tuple, np.ndarray)): - color = tuple(int(c) for c in color[:3]) - else: - color = (255, 255, 255) - - # Draw points on overlay - for point in valid_points_2d: - u, v = point - # Draw a small filled circle for each point - cv2.circle(overlay, (u, v), point_size, color, -1) - - # Blend overlay with base image - result = cv2.addWeighted(base_image, 1 - alpha, overlay, alpha, 0) - - return result - - -def create_point_cloud_overlay_visualization( - base_image: np.ndarray, - objects: List[dict], - intrinsics: np.ndarray, -) -> np.ndarray: - """ - Create a visualization showing object point clouds and bounding boxes overlaid on a base image. - - Args: - base_image: Base image to overlay onto (H, W, 3) - objects: List of object dictionaries containing 'point_cloud', 'color', 'position', 'rotation', 'size' keys - intrinsics: Camera intrinsics as [fx, fy, cx, cy] or 3x3 matrix - - Returns: - Visualization image with overlaid point clouds and bounding boxes (H, W, 3) - """ - # Extract point clouds and colors from objects - point_clouds = [] - colors = [] - for obj in objects: - if "point_cloud" in obj and obj["point_cloud"] is not None: - point_clouds.append(obj["point_cloud"]) - - # Convert color to tuple - color = obj["color"] - if isinstance(color, np.ndarray): - color = tuple(int(c) for c in color) - elif isinstance(color, (list, tuple)): - color = tuple(int(c) for c in color[:3]) - colors.append(color) - - # Create visualization - if point_clouds: - result = overlay_point_clouds_on_image( - base_image=base_image, - point_clouds=point_clouds, - camera_intrinsics=intrinsics, - colors=colors, - point_size=3, - alpha=0.8, - ) - else: - result = base_image.copy() - - # Draw 3D bounding boxes - height_img, width_img = result.shape[:2] - for i, obj in enumerate(objects): - if all(key in obj and obj[key] is not None for key in ["position", "rotation", "size"]): - try: - # Create and project 3D bounding box - corners_3d = create_3d_bounding_box_corners( - obj["position"], obj["rotation"], obj["size"] - ) - corners_2d = project_3d_points_to_2d(corners_3d, intrinsics) - - # Check if any corners are visible - valid_mask = ( - (corners_2d[:, 0] >= 0) - & (corners_2d[:, 0] < width_img) - & (corners_2d[:, 1] >= 0) - & (corners_2d[:, 1] < height_img) - ) - - if np.any(valid_mask): - # Get color - bbox_color = colors[i] if i < len(colors) else (255, 255, 255) - draw_3d_bounding_box_on_image(result, corners_2d, bbox_color, thickness=2) - except: - continue - - return result - - -def create_3d_bounding_box_corners(position, rotation, size): - """ - Create 8 corners of a 3D bounding box from position, rotation, and size. - - Args: - position: Vector or dict with x, y, z coordinates - rotation: Vector or dict with roll, pitch, yaw angles - size: Dict with width, height, depth - - Returns: - 8x3 numpy array of corner coordinates - """ - # Convert position to numpy array - if hasattr(position, "x"): # Vector object - center = np.array([position.x, position.y, position.z]) - else: # Dictionary - center = np.array([position["x"], position["y"], position["z"]]) - - # Convert rotation (euler angles) to rotation matrix - if hasattr(rotation, "x"): # Vector object (roll, pitch, yaw) - roll, pitch, yaw = rotation.x, rotation.y, rotation.z - else: # Dictionary - roll, pitch, yaw = rotation["roll"], rotation["pitch"], rotation["yaw"] - - # Create rotation matrix from euler angles (ZYX order) - cos_r, sin_r = np.cos(roll), np.sin(roll) - cos_p, sin_p = np.cos(pitch), np.sin(pitch) - cos_y, sin_y = np.cos(yaw), np.sin(yaw) - - # Rotation matrix for ZYX euler angles - R = np.array( - [ - [ - cos_y * cos_p, - cos_y * sin_p * sin_r - sin_y * cos_r, - cos_y * sin_p * cos_r + sin_y * sin_r, - ], - [ - sin_y * cos_p, - sin_y * sin_p * sin_r + cos_y * cos_r, - sin_y * sin_p * cos_r - cos_y * sin_r, - ], - [-sin_p, cos_p * sin_r, cos_p * cos_r], - ] - ) - - # Get dimensions - width = size.get("width", 0.1) - height = size.get("height", 0.1) - depth = size.get("depth", 0.1) - - # Create 8 corners of the bounding box (before rotation) - corners = np.array( - [ - [-width / 2, -height / 2, -depth / 2], # 0 - [width / 2, -height / 2, -depth / 2], # 1 - [width / 2, height / 2, -depth / 2], # 2 - [-width / 2, height / 2, -depth / 2], # 3 - [-width / 2, -height / 2, depth / 2], # 4 - [width / 2, -height / 2, depth / 2], # 5 - [width / 2, height / 2, depth / 2], # 6 - [-width / 2, height / 2, depth / 2], # 7 - ] - ) - - # Apply rotation and translation - rotated_corners = corners @ R.T + center - - return rotated_corners - - -def draw_3d_bounding_box_on_image(image, corners_2d, color, thickness=2): - """ - Draw a 3D bounding box on an image using projected 2D corners. - - Args: - image: Image to draw on - corners_2d: 8x2 array of 2D corner coordinates - color: RGB color tuple - thickness: Line thickness - """ - # Define the 12 edges of a cube (connecting corner indices) - edges = [ - (0, 1), - (1, 2), - (2, 3), - (3, 0), # Bottom face - (4, 5), - (5, 6), - (6, 7), - (7, 4), # Top face - (0, 4), - (1, 5), - (2, 6), - (3, 7), # Vertical edges - ] - - # Draw each edge - for start_idx, end_idx in edges: - start_point = tuple(corners_2d[start_idx].astype(int)) - end_point = tuple(corners_2d[end_idx].astype(int)) - cv2.line(image, start_point, end_point, color, thickness) - - -def extract_and_cluster_misc_points( - full_pcd: o3d.geometry.PointCloud, - all_objects: List[dict], - eps: float = 0.03, - min_points: int = 100, - enable_filtering: bool = True, - voxel_size: float = 0.02, -) -> Tuple[List[o3d.geometry.PointCloud], o3d.geometry.VoxelGrid]: - """ - Extract miscellaneous/background points and cluster them using DBSCAN. - - Args: - full_pcd: Complete scene point cloud - all_objects: List of objects with point clouds to subtract - eps: DBSCAN epsilon parameter (max distance between points in cluster) - min_points: DBSCAN min_samples parameter (min points to form cluster) - enable_filtering: Whether to apply statistical and radius filtering - voxel_size: Size of voxels for voxel grid generation - - Returns: - Tuple of (clustered_point_clouds, voxel_grid) - """ - if full_pcd is None or len(np.asarray(full_pcd.points)) == 0: - return [], o3d.geometry.VoxelGrid() - - if not all_objects: - # If no objects detected, cluster the full point cloud - clusters = _cluster_point_cloud_dbscan(full_pcd, eps, min_points) - voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) - return clusters, voxel_grid - - try: - # Start with a copy of the full point cloud - misc_pcd = o3d.geometry.PointCloud(full_pcd) - - # Remove object points by combining all object point clouds - all_object_points = [] - for obj in all_objects: - if "point_cloud" in obj and obj["point_cloud"] is not None: - obj_points = np.asarray(obj["point_cloud"].points) - if len(obj_points) > 0: - all_object_points.append(obj_points) - - if not all_object_points: - # No object points to remove, cluster full point cloud - clusters = _cluster_point_cloud_dbscan(misc_pcd, eps, min_points) - voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) - return clusters, voxel_grid - - # Combine all object points - combined_obj_points = np.vstack(all_object_points) - - # For efficiency, downsample both point clouds - misc_downsampled = misc_pcd.voxel_down_sample(voxel_size=0.005) - - # Create object point cloud for efficient operations - obj_pcd = o3d.geometry.PointCloud() - obj_pcd.points = o3d.utility.Vector3dVector(combined_obj_points) - obj_downsampled = obj_pcd.voxel_down_sample(voxel_size=0.005) - - misc_points = np.asarray(misc_downsampled.points) - obj_points_down = np.asarray(obj_downsampled.points) - - if len(misc_points) == 0 or len(obj_points_down) == 0: - clusters = _cluster_point_cloud_dbscan(misc_downsampled, eps, min_points) - voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) - return clusters, voxel_grid - - # Build tree for object points - obj_tree = cKDTree(obj_points_down) - - # Find distances from misc points to nearest object points - distances, _ = obj_tree.query(misc_points, k=1) - - # Keep points that are far enough from any object point - threshold = 0.015 # 1.5cm threshold - keep_mask = distances > threshold - - if not np.any(keep_mask): - return [], o3d.geometry.VoxelGrid() - - # Filter misc points - misc_indices = np.where(keep_mask)[0] - final_misc_pcd = misc_downsampled.select_by_index(misc_indices) - - if len(np.asarray(final_misc_pcd.points)) == 0: - return [], o3d.geometry.VoxelGrid() - - # Apply additional filtering if enabled - if enable_filtering: - # Apply statistical outlier filtering - filtered_misc_pcd, _ = filter_point_cloud_statistical( - final_misc_pcd, nb_neighbors=30, std_ratio=2.0 - ) - - if len(np.asarray(filtered_misc_pcd.points)) == 0: - return [], o3d.geometry.VoxelGrid() - - # Apply radius outlier filtering - final_filtered_misc_pcd, _ = filter_point_cloud_radius( - filtered_misc_pcd, - nb_points=20, - radius=0.03, # 3cm radius - ) - - if len(np.asarray(final_filtered_misc_pcd.points)) == 0: - return [], o3d.geometry.VoxelGrid() - - final_misc_pcd = final_filtered_misc_pcd - - # Cluster the misc points using DBSCAN - clusters = _cluster_point_cloud_dbscan(final_misc_pcd, eps, min_points) - - # Create voxel grid from all misc points (before clustering) - voxel_grid = _create_voxel_grid_from_point_cloud(final_misc_pcd, voxel_size) - - return clusters, voxel_grid - - except Exception as e: - print(f"Error in misc point extraction and clustering: {e}") - # Fallback: return downsampled full point cloud as single cluster - try: - downsampled = full_pcd.voxel_down_sample(voxel_size=0.02) - if len(np.asarray(downsampled.points)) > 0: - voxel_grid = _create_voxel_grid_from_point_cloud(downsampled, voxel_size) - return [downsampled], voxel_grid - else: - return [], o3d.geometry.VoxelGrid() - except: - return [], o3d.geometry.VoxelGrid() - - -def _create_voxel_grid_from_point_cloud( - pcd: o3d.geometry.PointCloud, voxel_size: float = 0.02 -) -> o3d.geometry.VoxelGrid: - """ - Create a voxel grid from a point cloud. - - Args: - pcd: Input point cloud - voxel_size: Size of each voxel - - Returns: - Open3D VoxelGrid object - """ - if len(np.asarray(pcd.points)) == 0: - return o3d.geometry.VoxelGrid() - - try: - # Create voxel grid from point cloud - voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size) - - # Color the voxels with a semi-transparent gray - for voxel in voxel_grid.get_voxels(): - voxel.color = [0.5, 0.5, 0.5] # Gray color - - print( - f"Created voxel grid with {len(voxel_grid.get_voxels())} voxels (voxel_size={voxel_size})" - ) - return voxel_grid - - except Exception as e: - print(f"Error creating voxel grid: {e}") - return o3d.geometry.VoxelGrid() - - -def _create_voxel_grid_from_clusters( - clusters: List[o3d.geometry.PointCloud], voxel_size: float = 0.02 -) -> o3d.geometry.VoxelGrid: - """ - Create a voxel grid from multiple clustered point clouds. - - Args: - clusters: List of clustered point clouds - voxel_size: Size of each voxel - - Returns: - Open3D VoxelGrid object - """ - if not clusters: - return o3d.geometry.VoxelGrid() - - # Combine all clusters into one point cloud - combined_points = [] - for cluster in clusters: - points = np.asarray(cluster.points) - if len(points) > 0: - combined_points.append(points) - - if not combined_points: - return o3d.geometry.VoxelGrid() - - # Create combined point cloud - all_points = np.vstack(combined_points) - combined_pcd = o3d.geometry.PointCloud() - combined_pcd.points = o3d.utility.Vector3dVector(all_points) - - return _create_voxel_grid_from_point_cloud(combined_pcd, voxel_size) - - -def _cluster_point_cloud_dbscan( - pcd: o3d.geometry.PointCloud, eps: float = 0.05, min_points: int = 50 -) -> List[o3d.geometry.PointCloud]: - """ - Cluster a point cloud using DBSCAN and return list of clustered point clouds. - - Args: - pcd: Point cloud to cluster - eps: DBSCAN epsilon parameter - min_points: DBSCAN min_samples parameter - - Returns: - List of point clouds, one for each cluster - """ - if len(np.asarray(pcd.points)) == 0: - return [] - - try: - # Apply DBSCAN clustering - labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_points)) - - # Get unique cluster labels (excluding noise points labeled as -1) - unique_labels = np.unique(labels) - cluster_pcds = [] - - for label in unique_labels: - if label == -1: # Skip noise points - continue - - # Get indices for this cluster - cluster_indices = np.where(labels == label)[0] - - if len(cluster_indices) > 0: - # Create point cloud for this cluster - cluster_pcd = pcd.select_by_index(cluster_indices) - - # Assign a random color to this cluster - cluster_color = np.random.rand(3) # Random RGB color - cluster_pcd.paint_uniform_color(cluster_color) - - cluster_pcds.append(cluster_pcd) - - print( - f"DBSCAN clustering found {len(cluster_pcds)} clusters from {len(np.asarray(pcd.points))} points" - ) - return cluster_pcds - - except Exception as e: - print(f"Error in DBSCAN clustering: {e}") - return [pcd] # Return original point cloud as fallback - - -def get_standard_coordinate_transform(): - """ - Get a standard coordinate transformation matrix for consistent visualization. - - This transformation ensures that: - - X (red) axis points right - - Y (green) axis points up - - Z (blue) axis points toward viewer - - Returns: - 4x4 transformation matrix - """ - # Standard transformation matrix to ensure consistent coordinate frame orientation - transform = np.array( - [ - [1, 0, 0, 0], # X points right - [0, -1, 0, 0], # Y points up (flip from OpenCV to standard) - [0, 0, -1, 0], # Z points toward viewer (flip depth) - [0, 0, 0, 1], - ] - ) - return transform - - -def visualize_clustered_point_clouds( - clustered_pcds: List[o3d.geometry.PointCloud], - window_name: str = "Clustered Point Clouds", - point_size: float = 2.0, - show_coordinate_frame: bool = True, - coordinate_frame_size: float = 0.1, -) -> None: - """ - Visualize multiple clustered point clouds with different colors. - - Args: - clustered_pcds: List of point clouds (already colored) - window_name: Name of the visualization window - point_size: Size of points in the visualization - show_coordinate_frame: Whether to show coordinate frame - coordinate_frame_size: Size of the coordinate frame - """ - if not clustered_pcds: - print("Warning: No clustered point clouds to visualize") - return - - # Apply standard coordinate transformation - transform = get_standard_coordinate_transform() - geometries = [] - for pcd in clustered_pcds: - pcd_copy = o3d.geometry.PointCloud(pcd) - pcd_copy.transform(transform) - geometries.append(pcd_copy) - - # Add coordinate frame - if show_coordinate_frame: - coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( - size=coordinate_frame_size - ) - coordinate_frame.transform(transform) - geometries.append(coordinate_frame) - - total_points = sum(len(np.asarray(pcd.points)) for pcd in clustered_pcds) - print(f"Visualizing {len(clustered_pcds)} clusters with {total_points} total points") - - try: - vis = o3d.visualization.Visualizer() - vis.create_window(window_name=window_name, width=1280, height=720) - for geom in geometries: - vis.add_geometry(geom) - render_option = vis.get_render_option() - render_option.point_size = point_size - vis.run() - vis.destroy_window() - except Exception as e: - print(f"Failed to create interactive visualization: {e}") - o3d.visualization.draw_geometries( - geometries, window_name=window_name, width=1280, height=720 - ) - - -def visualize_pcd( - pcd: o3d.geometry.PointCloud, - window_name: str = "Point Cloud Visualization", - point_size: float = 1.0, - show_coordinate_frame: bool = True, - coordinate_frame_size: float = 0.1, -) -> None: - """ - Visualize an Open3D point cloud using Open3D's visualization window. - - Args: - pcd: Open3D point cloud to visualize - window_name: Name of the visualization window - point_size: Size of points in the visualization - show_coordinate_frame: Whether to show coordinate frame - coordinate_frame_size: Size of the coordinate frame - """ - if pcd is None: - print("Warning: Point cloud is None, nothing to visualize") - return - - if len(np.asarray(pcd.points)) == 0: - print("Warning: Point cloud is empty, nothing to visualize") - return - - # Apply standard coordinate transformation - transform = get_standard_coordinate_transform() - pcd_copy = o3d.geometry.PointCloud(pcd) - pcd_copy.transform(transform) - geometries = [pcd_copy] - - # Add coordinate frame - if show_coordinate_frame: - coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( - size=coordinate_frame_size - ) - coordinate_frame.transform(transform) - geometries.append(coordinate_frame) - - print(f"Visualizing point cloud with {len(np.asarray(pcd.points))} points") - - try: - vis = o3d.visualization.Visualizer() - vis.create_window(window_name=window_name, width=1280, height=720) - for geom in geometries: - vis.add_geometry(geom) - render_option = vis.get_render_option() - render_option.point_size = point_size - vis.run() - vis.destroy_window() - except Exception as e: - print(f"Failed to create interactive visualization: {e}") - o3d.visualization.draw_geometries( - geometries, window_name=window_name, width=1280, height=720 - ) - - -def visualize_voxel_grid( - voxel_grid: o3d.geometry.VoxelGrid, - window_name: str = "Voxel Grid Visualization", - show_coordinate_frame: bool = True, - coordinate_frame_size: float = 0.1, -) -> None: - """ - Visualize an Open3D voxel grid using Open3D's visualization window. - - Args: - voxel_grid: Open3D voxel grid to visualize - window_name: Name of the visualization window - show_coordinate_frame: Whether to show coordinate frame - coordinate_frame_size: Size of the coordinate frame - """ - if voxel_grid is None: - print("Warning: Voxel grid is None, nothing to visualize") - return - - if len(voxel_grid.get_voxels()) == 0: - print("Warning: Voxel grid is empty, nothing to visualize") - return - - # VoxelGrid doesn't support transform, so we need to transform the source points instead - # For now, just visualize as-is with transformed coordinate frame - geometries = [voxel_grid] - - # Add coordinate frame - if show_coordinate_frame: - coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( - size=coordinate_frame_size - ) - coordinate_frame.transform(get_standard_coordinate_transform()) - geometries.append(coordinate_frame) - - print(f"Visualizing voxel grid with {len(voxel_grid.get_voxels())} voxels") - - try: - vis = o3d.visualization.Visualizer() - vis.create_window(window_name=window_name, width=1280, height=720) - for geom in geometries: - vis.add_geometry(geom) - vis.run() - vis.destroy_window() - except Exception as e: - print(f"Failed to create interactive visualization: {e}") - o3d.visualization.draw_geometries( - geometries, window_name=window_name, width=1280, height=720 - ) - - -def combine_object_pointclouds( - point_clouds: Union[List[np.ndarray], List[o3d.geometry.PointCloud]], - colors: Optional[List[np.ndarray]] = None, -) -> o3d.geometry.PointCloud: - """ - Combine multiple point clouds into a single Open3D point cloud. - - Args: - point_clouds: List of point clouds as numpy arrays or Open3D point clouds - colors: List of colors as numpy arrays - Returns: - Combined Open3D point cloud - """ - all_points = [] - all_colors = [] - - for i, pcd in enumerate(point_clouds): - if isinstance(pcd, np.ndarray): - points = pcd[:, :3] - all_points.append(points) - if colors: - all_colors.append(colors[i]) - - elif isinstance(pcd, o3d.geometry.PointCloud): - points = np.asarray(pcd.points) - all_points.append(points) - if pcd.has_colors(): - colors = np.asarray(pcd.colors) - all_colors.append(colors) - - if not all_points: - return o3d.geometry.PointCloud() - - combined_pcd = o3d.geometry.PointCloud() - combined_pcd.points = o3d.utility.Vector3dVector(np.vstack(all_points)) - - if all_colors: - combined_pcd.colors = o3d.utility.Vector3dVector(np.vstack(all_colors)) - - return combined_pcd - - -def extract_centroids_from_masks( - rgb_image: np.ndarray, - depth_image: np.ndarray, - masks: List[np.ndarray], - camera_intrinsics: Union[List[float], np.ndarray], - min_points: int = 10, - max_depth: float = 10.0, -) -> List[Dict[str, Any]]: - """ - Extract 3D centroids and orientations from segmentation masks. - - Args: - rgb_image: RGB image (H, W, 3) - depth_image: Depth image (H, W) in meters - masks: List of boolean masks (H, W) - camera_intrinsics: Camera parameters as [fx, fy, cx, cy] or 3x3 matrix - min_points: Minimum number of valid 3D points required for a detection - max_depth: Maximum valid depth in meters - - Returns: - List of dictionaries containing: - - centroid: 3D centroid position [x, y, z] in camera frame - - orientation: Normalized direction vector from camera to centroid - - num_points: Number of valid 3D points - - mask_idx: Index of the mask in the input list - """ - # Extract camera parameters - if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: - fx, fy, cx, cy = camera_intrinsics - else: - fx = camera_intrinsics[0, 0] - fy = camera_intrinsics[1, 1] - cx = camera_intrinsics[0, 2] - cy = camera_intrinsics[1, 2] - - results = [] - - for mask_idx, mask in enumerate(masks): - if mask is None or mask.sum() == 0: - continue - - # Get pixel coordinates where mask is True - y_coords, x_coords = np.where(mask) - - # Get depth values at mask locations - depths = depth_image[y_coords, x_coords] - - # Filter valid depths - valid_mask = (depths > 0) & (depths < max_depth) & np.isfinite(depths) - if valid_mask.sum() < min_points: - continue - - # Get valid coordinates and depths - valid_x = x_coords[valid_mask] - valid_y = y_coords[valid_mask] - valid_z = depths[valid_mask] - - # Convert to 3D points in camera frame - X = (valid_x - cx) * valid_z / fx - Y = (valid_y - cy) * valid_z / fy - Z = valid_z - - # Calculate centroid - centroid_x = np.mean(X) - centroid_y = np.mean(Y) - centroid_z = np.mean(Z) - centroid = np.array([centroid_x, centroid_y, centroid_z]) - - # Calculate orientation as normalized direction from camera origin to centroid - # Camera origin is at (0, 0, 0) - orientation = centroid / np.linalg.norm(centroid) - - results.append( - { - "centroid": centroid, - "orientation": orientation, - "num_points": int(valid_mask.sum()), - "mask_idx": mask_idx, - } - ) - - return results diff --git a/build/lib/dimos/perception/segmentation/__init__.py b/build/lib/dimos/perception/segmentation/__init__.py deleted file mode 100644 index a8f9a291ce..0000000000 --- a/build/lib/dimos/perception/segmentation/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .utils import * -from .sam_2d_seg import * diff --git a/build/lib/dimos/perception/segmentation/image_analyzer.py b/build/lib/dimos/perception/segmentation/image_analyzer.py deleted file mode 100644 index 1260e41fe7..0000000000 --- a/build/lib/dimos/perception/segmentation/image_analyzer.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -from openai import OpenAI -import cv2 -import os - -NORMAL_PROMPT = "What are in these images? Give a short word answer with at most two words, \ - if not sure, give a description of its shape or color like 'small tube', 'blue item'. \" \ - if does not look like an object, say 'unknown'. Export objects as a list of strings \ - in this exact format '['object 1', 'object 2', '...']'." - -RICH_PROMPT = ( - "What are in these images? Give a detailed description of each item, the first n images will be \ - cropped patches of the original image detected by the object detection model. \ - The last image will be the original image. Use the last image only for context, \ - do not describe objects in the last image. \ - Export the objects as a list of strings in this exact format, '['description of object 1', '...', '...']', \ - don't include anything else. " -) - - -class ImageAnalyzer: - def __init__(self): - """ - Initializes the ImageAnalyzer with OpenAI API credentials. - """ - self.client = OpenAI() - - def encode_image(self, image): - """ - Encodes an image to Base64. - - Parameters: - image (numpy array): Image array (BGR format). - - Returns: - str: Base64 encoded string of the image. - """ - _, buffer = cv2.imencode(".jpg", image) - return base64.b64encode(buffer).decode("utf-8") - - def analyze_images(self, images, detail="auto", prompt_type="normal"): - """ - Takes a list of cropped images and returns descriptions from OpenAI's Vision model. - - Parameters: - images (list of numpy arrays): Cropped images from the original frame. - detail (str): "low", "high", or "auto" to set image processing detail. - prompt_type (str): "normal" or "rich" to set the prompt type. - - Returns: - list of str: Descriptions of objects in each image. - """ - image_data = [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{self.encode_image(img)}", - "detail": detail, - }, - } - for img in images - ] - - if prompt_type == "normal": - prompt = NORMAL_PROMPT - elif prompt_type == "rich": - prompt = RICH_PROMPT - else: - raise ValueError(f"Invalid prompt type: {prompt_type}") - - response = self.client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - { - "role": "user", - "content": [{"type": "text", "text": prompt}] + image_data, - } - ], - max_tokens=300, - timeout=5, - ) - - # Accessing the content of the response using dot notation - return [choice.message.content for choice in response.choices][0] - - -def main(): - # Define the directory containing cropped images - cropped_images_dir = "cropped_images" - if not os.path.exists(cropped_images_dir): - print(f"Directory '{cropped_images_dir}' does not exist.") - return - - # Load all images from the directory - images = [] - for filename in os.listdir(cropped_images_dir): - if filename.endswith(".jpg") or filename.endswith(".png"): - image_path = os.path.join(cropped_images_dir, filename) - image = cv2.imread(image_path) - if image is not None: - images.append(image) - else: - print(f"Warning: Could not read image {image_path}") - - if not images: - print("No valid images found in the directory.") - return - - # Initialize ImageAnalyzer - analyzer = ImageAnalyzer() - - # Analyze images - results = analyzer.analyze_images(images) - - # Split results into a list of items - object_list = [item.strip()[2:] for item in results.split("\n")] - - # Overlay text on images and display them - for i, (img, obj) in enumerate(zip(images, object_list)): - if obj: # Only process non-empty lines - # Add text to image - font = cv2.FONT_HERSHEY_SIMPLEX - font_scale = 0.5 - thickness = 2 - text = obj.strip() - - # Get text size - (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) - - # Position text at top of image - x = 10 - y = text_height + 10 - - # Add white background for text - cv2.rectangle( - img, (x - 5, y - text_height - 5), (x + text_width + 5, y + 5), (255, 255, 255), -1 - ) - # Add text - cv2.putText(img, text, (x, y), font, font_scale, (0, 0, 0), thickness) - - # Save or display the image - cv2.imwrite(f"annotated_image_{i}.jpg", img) - print(f"Detected object: {obj}") - - -if __name__ == "__main__": - main() diff --git a/build/lib/dimos/perception/segmentation/sam_2d_seg.py b/build/lib/dimos/perception/segmentation/sam_2d_seg.py deleted file mode 100644 index d33c7faa0d..0000000000 --- a/build/lib/dimos/perception/segmentation/sam_2d_seg.py +++ /dev/null @@ -1,335 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -from collections import deque -from concurrent.futures import ThreadPoolExecutor - -import cv2 -import onnxruntime -from ultralytics import FastSAM - -from dimos.perception.common.detection2d_tracker import get_tracked_results, target2dTracker -from dimos.perception.segmentation.image_analyzer import ImageAnalyzer -from dimos.perception.segmentation.utils import ( - crop_images_from_bboxes, - extract_masks_bboxes_probs_names, - filter_segmentation_results, - plot_results, -) -from dimos.utils.data import get_data -from dimos.utils.gpu_utils import is_cuda_available -from dimos.utils.logging_config import setup_logger -from dimos.utils.path_utils import get_project_root - -logger = setup_logger("dimos.perception.segmentation.sam_2d_seg") - - -class Sam2DSegmenter: - def __init__( - self, - model_path="models_fastsam", - model_name="FastSAM-s.onnx", - device="cpu", - min_analysis_interval=5.0, - use_tracker=True, - use_analyzer=True, - use_rich_labeling=False, - ): - self.device = device - if is_cuda_available(): - logger.info("Using CUDA for SAM 2d segmenter") - if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18 - onnxruntime.preload_dlls(cuda=True, cudnn=True) - self.device = "cuda" - else: - logger.info("Using CPU for SAM 2d segmenter") - self.device = "cpu" - # Core components - self.model = FastSAM(get_data(model_path) / model_name) - self.use_tracker = use_tracker - self.use_analyzer = use_analyzer - self.use_rich_labeling = use_rich_labeling - - module_dir = os.path.dirname(__file__) - self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") - - # Initialize tracker if enabled - if self.use_tracker: - self.tracker = target2dTracker( - history_size=80, - score_threshold_start=0.7, - score_threshold_stop=0.05, - min_frame_count=10, - max_missed_frames=50, - min_area_ratio=0.05, - max_area_ratio=0.4, - texture_range=(0.0, 0.35), - border_safe_distance=100, - weights={"prob": 1.0, "temporal": 3.0, "texture": 2.0, "border": 3.0, "size": 1.0}, - ) - - # Initialize analyzer components if enabled - if self.use_analyzer: - self.image_analyzer = ImageAnalyzer() - self.min_analysis_interval = min_analysis_interval - self.last_analysis_time = 0 - self.to_be_analyzed = deque() - self.object_names = {} - self.analysis_executor = ThreadPoolExecutor(max_workers=1) - self.current_future = None - self.current_queue_ids = None - - def process_image(self, image): - """Process an image and return segmentation results.""" - results = self.model.track( - source=image, - device=self.device, - retina_masks=True, - conf=0.6, - iou=0.9, - persist=True, - verbose=False, - tracker=self.tracker_config, - ) - - if len(results) > 0: - # Get initial segmentation results - masks, bboxes, track_ids, probs, names, areas = extract_masks_bboxes_probs_names( - results[0] - ) - - # Filter results - ( - filtered_masks, - filtered_bboxes, - filtered_track_ids, - filtered_probs, - filtered_names, - filtered_texture_values, - ) = filter_segmentation_results(image, masks, bboxes, track_ids, probs, names, areas) - - if self.use_tracker: - # Update tracker with filtered results - tracked_targets = self.tracker.update( - image, - filtered_masks, - filtered_bboxes, - filtered_track_ids, - filtered_probs, - filtered_names, - filtered_texture_values, - ) - - # Get tracked results - tracked_masks, tracked_bboxes, tracked_target_ids, tracked_probs, tracked_names = ( - get_tracked_results(tracked_targets) - ) - - if self.use_analyzer: - # Update analysis queue with tracked IDs - target_id_set = set(tracked_target_ids) - - # Remove untracked objects from object_names - all_target_ids = list(self.tracker.targets.keys()) - self.object_names = { - track_id: name - for track_id, name in self.object_names.items() - if track_id in all_target_ids - } - - # Remove untracked objects from queue and results - self.to_be_analyzed = deque( - [track_id for track_id in self.to_be_analyzed if track_id in target_id_set] - ) - - # Filter out any IDs being analyzed from the to_be_analyzed queue - if self.current_queue_ids: - self.to_be_analyzed = deque( - [ - tid - for tid in self.to_be_analyzed - if tid not in self.current_queue_ids - ] - ) - - # Add new track_ids to analysis queue - for track_id in tracked_target_ids: - if ( - track_id not in self.object_names - and track_id not in self.to_be_analyzed - ): - self.to_be_analyzed.append(track_id) - - return ( - tracked_masks, - tracked_bboxes, - tracked_target_ids, - tracked_probs, - tracked_names, - ) - else: - # Return filtered results directly if tracker is disabled - return ( - filtered_masks, - filtered_bboxes, - filtered_track_ids, - filtered_probs, - filtered_names, - ) - return [], [], [], [], [] - - def check_analysis_status(self, tracked_target_ids): - """Check if analysis is complete and prepare new queue if needed.""" - if not self.use_analyzer: - return None, None - - current_time = time.time() - - # Check if current queue analysis is complete - if self.current_future and self.current_future.done(): - try: - results = self.current_future.result() - if results is not None: - # Map results to track IDs - object_list = eval(results) - for track_id, result in zip(self.current_queue_ids, object_list): - self.object_names[track_id] = result - except Exception as e: - print(f"Queue analysis failed: {e}") - self.current_future = None - self.current_queue_ids = None - self.last_analysis_time = current_time - - # If enough time has passed and we have items to analyze, start new analysis - if ( - not self.current_future - and self.to_be_analyzed - and current_time - self.last_analysis_time >= self.min_analysis_interval - ): - queue_indices = [] - queue_ids = [] - - # Collect all valid track IDs from the queue - while self.to_be_analyzed: - track_id = self.to_be_analyzed[0] - if track_id in tracked_target_ids: - bbox_idx = tracked_target_ids.index(track_id) - queue_indices.append(bbox_idx) - queue_ids.append(track_id) - self.to_be_analyzed.popleft() - - if queue_indices: - return queue_indices, queue_ids - return None, None - - def run_analysis(self, frame, tracked_bboxes, tracked_target_ids): - """Run queue image analysis in background.""" - if not self.use_analyzer: - return - - queue_indices, queue_ids = self.check_analysis_status(tracked_target_ids) - if queue_indices: - selected_bboxes = [tracked_bboxes[i] for i in queue_indices] - cropped_images = crop_images_from_bboxes(frame, selected_bboxes) - if cropped_images: - self.current_queue_ids = queue_ids - print(f"Analyzing objects with track_ids: {queue_ids}") - - if self.use_rich_labeling: - prompt_type = "rich" - cropped_images.append(frame) - else: - prompt_type = "normal" - - self.current_future = self.analysis_executor.submit( - self.image_analyzer.analyze_images, cropped_images, prompt_type=prompt_type - ) - - def get_object_names(self, track_ids, tracked_names): - """Get object names for the given track IDs, falling back to tracked names.""" - if not self.use_analyzer: - return tracked_names - - return [ - self.object_names.get(track_id, tracked_name) - for track_id, tracked_name in zip(track_ids, tracked_names) - ] - - def visualize_results(self, image, masks, bboxes, track_ids, probs, names): - """Generate an overlay visualization with segmentation results and object names.""" - return plot_results(image, masks, bboxes, track_ids, probs, names) - - def cleanup(self): - """Cleanup resources.""" - if self.use_analyzer: - self.analysis_executor.shutdown() - - -def main(): - # Example usage with different configurations - cap = cv2.VideoCapture(0) - - # Example 1: Full functionality with rich labeling - segmenter = Sam2DSegmenter( - min_analysis_interval=4.0, - use_tracker=True, - use_analyzer=True, - use_rich_labeling=True, # Enable rich labeling - ) - - # Example 2: Full functionality with normal labeling - # segmenter = Sam2DSegmenter(min_analysis_interval=4.0, use_tracker=True, use_analyzer=True) - - # Example 3: Tracker only (analyzer disabled) - # segmenter = Sam2DSegmenter(use_analyzer=False) - - # Example 4: Basic segmentation only (both tracker and analyzer disabled) - # segmenter = Sam2DSegmenter(use_tracker=False, use_analyzer=False) - - try: - while cap.isOpened(): - ret, frame = cap.read() - if not ret: - break - - start_time = time.time() - - # Process image and get results - masks, bboxes, target_ids, probs, names = segmenter.process_image(frame) - - # Run analysis if enabled - if segmenter.use_tracker and segmenter.use_analyzer: - segmenter.run_analysis(frame, bboxes, target_ids) - names = segmenter.get_object_names(target_ids, names) - - # processing_time = time.time() - start_time - # print(f"Processing time: {processing_time:.2f}s") - - overlay = segmenter.visualize_results(frame, masks, bboxes, target_ids, probs, names) - - cv2.imshow("Segmentation", overlay) - key = cv2.waitKey(1) - if key & 0xFF == ord("q"): - break - - finally: - segmenter.cleanup() - cap.release() - cv2.destroyAllWindows() - - -if __name__ == "__main__": - main() diff --git a/build/lib/dimos/perception/segmentation/test_sam_2d_seg.py b/build/lib/dimos/perception/segmentation/test_sam_2d_seg.py deleted file mode 100644 index 297b265415..0000000000 --- a/build/lib/dimos/perception/segmentation/test_sam_2d_seg.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time - -import cv2 -import numpy as np -import pytest -import reactivex as rx -from reactivex import operators as ops - -from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter -from dimos.perception.segmentation.utils import extract_masks_bboxes_probs_names -from dimos.stream import video_provider -from dimos.stream.video_provider import VideoProvider - - -@pytest.mark.heavy -class TestSam2DSegmenter: - def test_sam_segmenter_initialization(self): - """Test FastSAM segmenter initializes correctly with default model path.""" - try: - # Try to initialize with the default model path and existing device setting - segmenter = Sam2DSegmenter(use_analyzer=False) - assert segmenter is not None - assert segmenter.model is not None - except Exception as e: - # If the model file doesn't exist, the test should still pass with a warning - pytest.skip(f"Skipping test due to model initialization error: {e}") - - def test_sam_segmenter_process_image(self): - """Test FastSAM segmenter can process video frames and return segmentation masks.""" - # Import get data inside method to avoid pytest fixture confusion - from dimos.utils.data import get_data - - # Get test video path directly - video_path = get_data("assets") / "trimmed_video_office.mov" - try: - # Initialize segmenter without analyzer for faster testing - segmenter = Sam2DSegmenter(use_analyzer=False) - - # Note: conf and iou are parameters for process_image, not constructor - # We'll monkey patch the process_image method to use lower thresholds - original_process_image = segmenter.process_image - - def patched_process_image(image): - results = segmenter.model.track( - source=image, - device=segmenter.device, - retina_masks=True, - conf=0.1, # Lower confidence threshold for testing - iou=0.5, # Lower IoU threshold - persist=True, - verbose=False, - tracker=segmenter.tracker_config - if hasattr(segmenter, "tracker_config") - else None, - ) - - if len(results) > 0: - masks, bboxes, track_ids, probs, names, areas = ( - extract_masks_bboxes_probs_names(results[0]) - ) - return masks, bboxes, track_ids, probs, names - return [], [], [], [], [] - - # Replace the method - segmenter.process_image = patched_process_image - - # Create video provider and directly get a video stream observable - assert os.path.exists(video_path), f"Test video not found: {video_path}" - video_provider = VideoProvider(dev_name="test_video", video_source=video_path) - - video_stream = video_provider.capture_video_as_observable(realtime=False, fps=1) - - # Use ReactiveX operators to process the stream - def process_frame(frame): - try: - # Process frame with FastSAM - masks, bboxes, track_ids, probs, names = segmenter.process_image(frame) - print( - f"SAM results - masks: {len(masks)}, bboxes: {len(bboxes)}, track_ids: {len(track_ids)}, names: {len(names)}" - ) - - return { - "frame": frame, - "masks": masks, - "bboxes": bboxes, - "track_ids": track_ids, - "probs": probs, - "names": names, - } - except Exception as e: - print(f"Error in process_frame: {e}") - return {} - - # Create the segmentation stream using pipe and map operator - segmentation_stream = video_stream.pipe(ops.map(process_frame)) - - # Collect results from the stream - results = [] - frames_processed = 0 - target_frames = 5 - - def on_next(result): - nonlocal frames_processed, results - if not result: - return - - results.append(result) - frames_processed += 1 - - # Stop processing after target frames - if frames_processed >= target_frames: - subscription.dispose() - - def on_error(error): - pytest.fail(f"Error in segmentation stream: {error}") - - def on_completed(): - pass - - # Subscribe and wait for results - subscription = segmentation_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - # Wait for frames to be processed - timeout = 30.0 # seconds - start_time = time.time() - while frames_processed < target_frames and time.time() - start_time < timeout: - time.sleep(0.5) - - # Clean up subscription - subscription.dispose() - video_provider.dispose_all() - - # Check if we have results - if len(results) == 0: - pytest.skip( - "No segmentation results found, but test connection established correctly" - ) - return - - print(f"Processed {len(results)} frames with segmentation results") - - # Analyze the first result - result = results[0] - - # Check that we have a frame - assert "frame" in result, "Result doesn't contain a frame" - assert isinstance(result["frame"], np.ndarray), "Frame is not a numpy array" - - # Check that segmentation results are valid - assert isinstance(result["masks"], list) - assert isinstance(result["bboxes"], list) - assert isinstance(result["track_ids"], list) - assert isinstance(result["probs"], list) - assert isinstance(result["names"], list) - - # All result lists should be the same length - assert ( - len(result["masks"]) - == len(result["bboxes"]) - == len(result["track_ids"]) - == len(result["probs"]) - == len(result["names"]) - ) - - # If we have masks, check that they have valid shape - if result.get("masks") and len(result["masks"]) > 0: - assert result["masks"][0].shape == ( - result["frame"].shape[0], - result["frame"].shape[1], - ), "Mask shape should match image dimensions" - print(f"Found {len(result['masks'])} masks in first frame") - else: - print("No masks found in first frame, but test connection established correctly") - - # Test visualization function - if result["masks"]: - vis_frame = segmenter.visualize_results( - result["frame"], - result["masks"], - result["bboxes"], - result["track_ids"], - result["probs"], - result["names"], - ) - assert isinstance(vis_frame, np.ndarray), "Visualization output should be an image" - assert vis_frame.shape == result["frame"].shape, ( - "Visualization should have same dimensions as input frame" - ) - - # We've already tested visualization above, so no need for a duplicate test - - except Exception as e: - pytest.skip(f"Skipping test due to error: {e}") - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/build/lib/dimos/perception/segmentation/utils.py b/build/lib/dimos/perception/segmentation/utils.py deleted file mode 100644 index c96a7d4a64..0000000000 --- a/build/lib/dimos/perception/segmentation/utils.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import cv2 -import torch - - -class SimpleTracker: - def __init__(self, history_size=100, min_count=10, count_window=20): - """ - Simple temporal tracker that counts appearances in a fixed window. - :param history_size: Number of past frames to remember - :param min_count: Minimum number of appearances required - :param count_window: Number of latest frames to consider for counting - """ - self.history = [] - self.history_size = history_size - self.min_count = min_count - self.count_window = count_window - self.total_counts = {} - - def update(self, track_ids): - # Add new frame's track IDs to history - self.history.append(track_ids) - if len(self.history) > self.history_size: - self.history.pop(0) - - # Consider only the latest `count_window` frames for counting - recent_history = self.history[-self.count_window :] - all_tracks = np.concatenate(recent_history) if recent_history else np.array([]) - - # Compute occurrences efficiently using numpy - unique_ids, counts = np.unique(all_tracks, return_counts=True) - id_counts = dict(zip(unique_ids, counts)) - - # Update total counts but ensure it only contains IDs within the history size - total_tracked_ids = np.concatenate(self.history) if self.history else np.array([]) - unique_total_ids, total_counts = np.unique(total_tracked_ids, return_counts=True) - self.total_counts = dict(zip(unique_total_ids, total_counts)) - - # Return IDs that appear often enough - return [track_id for track_id, count in id_counts.items() if count >= self.min_count] - - def get_total_counts(self): - """Returns the total count of each tracking ID seen over time, limited to history size.""" - return self.total_counts - - -def extract_masks_bboxes_probs_names(result, max_size=0.7): - """ - Extracts masks, bounding boxes, probabilities, and class names from one Ultralytics result object. - - Parameters: - result: Ultralytics result object - max_size: float, maximum allowed size of object relative to image (0-1) - - Returns: - tuple: (masks, bboxes, track_ids, probs, names, areas) - """ - masks = [] - bboxes = [] - track_ids = [] - probs = [] - names = [] - areas = [] - - if result.masks is None: - return masks, bboxes, track_ids, probs, names, areas - - total_area = result.masks.orig_shape[0] * result.masks.orig_shape[1] - - for box, mask_data in zip(result.boxes, result.masks.data): - mask_numpy = mask_data - - # Extract bounding box - x1, y1, x2, y2 = box.xyxy[0].tolist() - - # Extract track_id if available - track_id = -1 # default if no tracking - if hasattr(box, "id") and box.id is not None: - track_id = int(box.id[0].item()) - - # Extract probability and class index - conf = float(box.conf[0]) - cls_idx = int(box.cls[0]) - area = (x2 - x1) * (y2 - y1) - - if area / total_area > max_size: - continue - - masks.append(mask_numpy) - bboxes.append([x1, y1, x2, y2]) - track_ids.append(track_id) - probs.append(conf) - names.append(result.names[cls_idx]) - areas.append(area) - - return masks, bboxes, track_ids, probs, names, areas - - -def compute_texture_map(frame, blur_size=3): - """ - Compute texture map using gradient statistics. - Returns high values for textured regions and low values for smooth regions. - - Parameters: - frame: BGR image - blur_size: Size of Gaussian blur kernel for pre-processing - - Returns: - numpy array: Texture map with values normalized to [0,1] - """ - # Convert to grayscale - if len(frame.shape) == 3: - gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - else: - gray = frame - - # Pre-process with slight blur to reduce noise - if blur_size > 0: - gray = cv2.GaussianBlur(gray, (blur_size, blur_size), 0) - - # Compute gradients in x and y directions - grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) - grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) - - # Compute gradient magnitude and direction - magnitude = np.sqrt(grad_x**2 + grad_y**2) - - # Compute local standard deviation of gradient magnitude - texture_map = cv2.GaussianBlur(magnitude, (15, 15), 0) - - # Normalize to [0,1] - texture_map = (texture_map - texture_map.min()) / (texture_map.max() - texture_map.min() + 1e-8) - - return texture_map - - -def filter_segmentation_results( - frame, masks, bboxes, track_ids, probs, names, areas, texture_threshold=0.07, size_filter=800 -): - """ - Filters segmentation results using both overlap and saliency detection. - Uses mask_sum tensor for efficient overlap detection. - - Parameters: - masks: list of torch.Tensor containing mask data - bboxes: list of bounding boxes [x1, y1, x2, y2] - track_ids: list of tracking IDs - probs: list of confidence scores - names: list of class names - areas: list of object areas - frame: BGR image for computing saliency - texture_threshold: Average texture value required for mask to be kept - size_filter: Minimum size of the object to be kept - - Returns: - tuple: (filtered_masks, filtered_bboxes, filtered_track_ids, filtered_probs, filtered_names, filtered_texture_values, texture_map) - """ - if len(masks) <= 1: - return masks, bboxes, track_ids, probs, names, [] - - # Compute texture map once and convert to tensor - texture_map = compute_texture_map(frame) - - # Sort by area (smallest to largest) - sorted_indices = torch.tensor(areas).argsort(descending=False) - - device = masks[0].device # Get the device of the first mask - - # Create mask_sum tensor where each pixel stores the index of the mask that claims it - mask_sum = torch.zeros_like(masks[0], dtype=torch.int32) - - texture_map = torch.from_numpy(texture_map).to( - device - ) # Convert texture_map to tensor and move to device - - filtered_texture_values = [] # List to store texture values of filtered masks - - for i, idx in enumerate(sorted_indices): - mask = masks[idx] - # Compute average texture value within mask - texture_value = torch.mean(texture_map[mask > 0]) if torch.any(mask > 0) else 0 - - # Only claim pixels if mask passes texture threshold - if texture_value >= texture_threshold: - mask_sum[mask > 0] = i - filtered_texture_values.append( - texture_value.item() - ) # Store the texture value as a Python float - - # Get indices that appear in mask_sum (these are the masks we want to keep) - keep_indices, counts = torch.unique(mask_sum[mask_sum > 0], return_counts=True) - size_indices = counts > size_filter - keep_indices = keep_indices[size_indices] - - sorted_indices = sorted_indices.cpu() - keep_indices = keep_indices.cpu() - - # Map back to original indices and filter - final_indices = sorted_indices[keep_indices].tolist() - - filtered_masks = [masks[i] for i in final_indices] - filtered_bboxes = [bboxes[i] for i in final_indices] - filtered_track_ids = [track_ids[i] for i in final_indices] - filtered_probs = [probs[i] for i in final_indices] - filtered_names = [names[i] for i in final_indices] - - return ( - filtered_masks, - filtered_bboxes, - filtered_track_ids, - filtered_probs, - filtered_names, - filtered_texture_values, - ) - - -def plot_results(image, masks, bboxes, track_ids, probs, names, alpha=0.5): - """ - Draws bounding boxes, masks, and labels on the given image with enhanced visualization. - Includes object names in the overlay and improved text visibility. - """ - h, w = image.shape[:2] - overlay = image.copy() - - for mask, bbox, track_id, prob, name in zip(masks, bboxes, track_ids, probs, names): - # Convert mask tensor to numpy if needed - if isinstance(mask, torch.Tensor): - mask = mask.cpu().numpy() - - mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR) - - # Generate consistent color based on track_id - if track_id != -1: - np.random.seed(track_id) - color = np.random.randint(0, 255, (3,), dtype=np.uint8) - np.random.seed(None) - else: - color = np.random.randint(0, 255, (3,), dtype=np.uint8) - - # Apply mask color - overlay[mask_resized > 0.5] = color - - # Draw bounding box - x1, y1, x2, y2 = map(int, bbox) - cv2.rectangle(overlay, (x1, y1), (x2, y2), color.tolist(), 2) - - # Prepare label text - label = f"ID:{track_id} {prob:.2f}" - if name: # Add object name if available - label += f" {name}" - - # Calculate text size for background rectangle - (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) - - # Draw background rectangle for text - cv2.rectangle(overlay, (x1, y1 - text_h - 8), (x1 + text_w + 4, y1), color.tolist(), -1) - - # Draw text with white color for better visibility - cv2.putText( - overlay, - label, - (x1 + 2, y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 255), # White text - 1, - ) - - # Blend overlay with original image - result = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) - return result - - -def crop_images_from_bboxes(image, bboxes, buffer=0): - """ - Crops regions from an image based on bounding boxes with an optional buffer. - - Parameters: - image (numpy array): Input image. - bboxes (list of lists): List of bounding boxes [x1, y1, x2, y2]. - buffer (int): Number of pixels to expand each bounding box. - - Returns: - list of numpy arrays: Cropped image regions. - """ - height, width, _ = image.shape - cropped_images = [] - - for bbox in bboxes: - x1, y1, x2, y2 = bbox - - # Apply buffer - x1 = max(0, x1 - buffer) - y1 = max(0, y1 - buffer) - x2 = min(width, x2 + buffer) - y2 = min(height, y2 + buffer) - - cropped_image = image[int(y1) : int(y2), int(x1) : int(x2)] - cropped_images.append(cropped_image) - - return cropped_images diff --git a/build/lib/dimos/perception/semantic_seg.py b/build/lib/dimos/perception/semantic_seg.py deleted file mode 100644 index a07e69c279..0000000000 --- a/build/lib/dimos/perception/semantic_seg.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.perception.segmentation import Sam2DSegmenter -from dimos.models.depth.metric3d import Metric3D -from dimos.hardware.camera import Camera -from reactivex import Observable -from reactivex import operators as ops -from dimos.types.segmentation import SegmentationType -import numpy as np -import cv2 - - -class SemanticSegmentationStream: - def __init__( - self, - device: str = "cuda", - enable_mono_depth: bool = True, - enable_rich_labeling: bool = True, - camera_params: dict = None, - gt_depth_scale=256.0, - ): - """ - Initialize a semantic segmentation stream using Sam2DSegmenter. - - Args: - device: Computation device ("cuda" or "cpu") - enable_mono_depth: Whether to enable monocular depth processing - enable_rich_labeling: Whether to enable rich labeling - camera_params: Dictionary containing either: - - Direct intrinsics: [fx, fy, cx, cy] - - Physical parameters: resolution, focal_length, sensor_size - """ - self.segmenter = Sam2DSegmenter( - device=device, - min_analysis_interval=5.0, - use_tracker=True, - use_analyzer=True, - use_rich_labeling=enable_rich_labeling, - ) - - self.enable_mono_depth = enable_mono_depth - if enable_mono_depth: - self.depth_model = Metric3D(gt_depth_scale) - - if camera_params: - # Check if direct intrinsics are provided - if "intrinsics" in camera_params: - intrinsics = camera_params["intrinsics"] - if len(intrinsics) != 4: - raise ValueError("Intrinsics must be a list of 4 values: [fx, fy, cx, cy]") - self.depth_model.update_intrinsic(intrinsics) - else: - # Create camera object and calculate intrinsics from physical parameters - self.camera = Camera( - resolution=camera_params.get("resolution"), - focal_length=camera_params.get("focal_length"), - sensor_size=camera_params.get("sensor_size"), - ) - intrinsics = self.camera.calculate_intrinsics() - self.depth_model.update_intrinsic( - [ - intrinsics["focal_length_x"], - intrinsics["focal_length_y"], - intrinsics["principal_point_x"], - intrinsics["principal_point_y"], - ] - ) - else: - raise ValueError("Camera parameters are required for monocular depth processing.") - - def create_stream(self, video_stream: Observable) -> Observable[SegmentationType]: - """ - Create an Observable stream of segmentation results from a video stream. - - Args: - video_stream: Observable that emits video frames - - Returns: - Observable that emits SegmentationType objects containing masks and metadata - """ - - def process_frame(frame): - # Process image and get results - masks, bboxes, target_ids, probs, names = self.segmenter.process_image(frame) - - # Run analysis if enabled - if self.segmenter.use_analyzer: - self.segmenter.run_analysis(frame, bboxes, target_ids) - names = self.segmenter.get_object_names(target_ids, names) - - viz_frame = self.segmenter.visualize_results( - frame, masks, bboxes, target_ids, probs, names - ) - - # Process depth if enabled - depth_viz = None - object_depths = [] - if self.enable_mono_depth: - # Get depth map - depth_map = self.depth_model.infer_depth(frame) - depth_map = np.array(depth_map) - - # Calculate average depth for each object - object_depths = [] - for mask in masks: - # Convert mask to numpy if needed - mask_np = mask.cpu().numpy() if hasattr(mask, "cpu") else mask - # Get depth values where mask is True - object_depth = depth_map[mask_np > 0.5] - # Calculate average depth (in meters) - avg_depth = np.mean(object_depth) if len(object_depth) > 0 else 0 - object_depths.append(avg_depth / 1000) - - # Create colorized depth visualization - depth_viz = self._create_depth_visualization(depth_map) - - # Overlay depth values on the visualization frame - for bbox, depth in zip(bboxes, object_depths): - x1, y1, x2, y2 = map(int, bbox) - # Draw depth text at bottom left of bounding box - depth_text = f"{depth:.2f}mm" - # Add black background for better visibility - text_size = cv2.getTextSize(depth_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] - cv2.rectangle( - viz_frame, - (x1, y2 - text_size[1] - 5), - (x1 + text_size[0], y2), - (0, 0, 0), - -1, - ) - # Draw text in white - cv2.putText( - viz_frame, - depth_text, - (x1, y2 - 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 255), - 2, - ) - - # Create metadata in the new requested format - objects = [] - for i in range(len(bboxes)): - obj_data = { - "object_id": target_ids[i] if i < len(target_ids) else None, - "bbox": bboxes[i], - "prob": probs[i] if i < len(probs) else None, - "label": names[i] if i < len(names) else None, - } - - # Add depth if available - if self.enable_mono_depth and i < len(object_depths): - obj_data["depth"] = object_depths[i] - - objects.append(obj_data) - - # Create the new metadata dictionary - metadata = {"frame": frame, "viz_frame": viz_frame, "objects": objects} - - # Add depth visualization if available - if depth_viz is not None: - metadata["depth_viz"] = depth_viz - - # Convert masks to numpy arrays if they aren't already - numpy_masks = [mask.cpu().numpy() if hasattr(mask, "cpu") else mask for mask in masks] - - return SegmentationType(masks=numpy_masks, metadata=metadata) - - return video_stream.pipe(ops.map(process_frame)) - - def _create_depth_visualization(self, depth_map): - """ - Create a colorized visualization of the depth map. - - Args: - depth_map: Raw depth map in meters - - Returns: - Colorized depth map visualization - """ - # Normalize depth map to 0-255 range for visualization - depth_min = np.min(depth_map) - depth_max = np.max(depth_map) - depth_normalized = ((depth_map - depth_min) / (depth_max - depth_min) * 255).astype( - np.uint8 - ) - - # Apply colormap (using JET colormap for better depth perception) - depth_colored = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET) - - # Add depth scale bar - scale_height = 30 - scale_width = depth_map.shape[1] # Match width with depth map - scale_bar = np.zeros((scale_height, scale_width, 3), dtype=np.uint8) - - # Create gradient for scale bar - for i in range(scale_width): - color = cv2.applyColorMap( - np.array([[i * 255 // scale_width]], dtype=np.uint8), cv2.COLORMAP_JET - ) - scale_bar[:, i] = color[0, 0] - - # Add depth values to scale bar - cv2.putText( - scale_bar, - f"{depth_min:.1f}mm", - (5, 20), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 255), - 1, - ) - cv2.putText( - scale_bar, - f"{depth_max:.1f}mm", - (scale_width - 60, 20), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 255), - 1, - ) - - # Combine depth map and scale bar - combined_viz = np.vstack((depth_colored, scale_bar)) - - return combined_viz - - def cleanup(self): - """Clean up resources.""" - self.segmenter.cleanup() - if self.enable_mono_depth: - del self.depth_model diff --git a/build/lib/dimos/perception/spatial_perception.py b/build/lib/dimos/perception/spatial_perception.py deleted file mode 100644 index b994b52bc4..0000000000 --- a/build/lib/dimos/perception/spatial_perception.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Spatial Memory module for creating a semantic map of the environment. -""" - -import uuid -import time -import os -from typing import Dict, List, Optional, Any - -import numpy as np -from reactivex import Observable, disposable -from reactivex import operators as ops -from datetime import datetime - -from dimos.utils.logging_config import setup_logger -from dimos.agents.memory.spatial_vector_db import SpatialVectorDB -from dimos.agents.memory.image_embedding import ImageEmbeddingProvider -from dimos.agents.memory.visual_memory import VisualMemory -from dimos.types.vector import Vector -from dimos.types.robot_location import RobotLocation - -logger = setup_logger("dimos.perception.spatial_memory") - - -class SpatialMemory: - """ - A class for building and querying Robot spatial memory. - - This class processes video frames from ROSControl, associates them with - XY locations, and stores them in a vector database for later retrieval. - It also maintains a list of named robot locations that can be queried by name. - """ - - def __init__( - self, - collection_name: str = "spatial_memory", - embedding_model: str = "clip", - embedding_dimensions: int = 512, - min_distance_threshold: float = 0.01, # Min distance in meters to store a new frame - min_time_threshold: float = 1.0, # Min time in seconds to record a new frame - db_path: Optional[str] = None, # Path for ChromaDB persistence - visual_memory_path: Optional[str] = None, # Path for saving/loading visual memory - new_memory: bool = False, # Whether to create a new memory from scratch - output_dir: Optional[str] = None, # Directory for storing visual memory data - chroma_client: Any = None, # Optional ChromaDB client for persistence - visual_memory: Optional[ - "VisualMemory" - ] = None, # Optional VisualMemory instance for storing images - video_stream: Optional[Observable] = None, # Video stream to process - get_pose: Optional[callable] = None, # Function that returns position and rotation - ): - """ - Initialize the spatial perception system. - - Args: - collection_name: Name of the vector database collection - embedding_model: Model to use for image embeddings ("clip", "resnet", etc.) - embedding_dimensions: Dimensions of the embedding vectors - min_distance_threshold: Minimum distance in meters to record a new frame - min_time_threshold: Minimum time in seconds to record a new frame - chroma_client: Optional ChromaDB client for persistent storage - visual_memory: Optional VisualMemory instance for storing images - output_dir: Directory for storing visual memory data if visual_memory is not provided - """ - self.collection_name = collection_name - self.embedding_model = embedding_model - self.embedding_dimensions = embedding_dimensions - self.min_distance_threshold = min_distance_threshold - self.min_time_threshold = min_time_threshold - - # Set up paths for persistence - self.db_path = db_path - self.visual_memory_path = visual_memory_path - self.output_dir = output_dir - - # Setup ChromaDB client if not provided - self._chroma_client = chroma_client - if chroma_client is None and db_path is not None: - # Create db directory if needed - os.makedirs(db_path, exist_ok=True) - - # Clean up existing DB if creating new memory - if new_memory and os.path.exists(db_path): - try: - logger.info("Creating new ChromaDB database (new_memory=True)") - # Try to delete any existing database files - import shutil - - for item in os.listdir(db_path): - item_path = os.path.join(db_path, item) - if os.path.isfile(item_path): - os.unlink(item_path) - elif os.path.isdir(item_path): - shutil.rmtree(item_path) - logger.info(f"Removed existing ChromaDB files from {db_path}") - except Exception as e: - logger.error(f"Error clearing ChromaDB directory: {e}") - - from chromadb.config import Settings - import chromadb - - self._chroma_client = chromadb.PersistentClient( - path=db_path, settings=Settings(anonymized_telemetry=False) - ) - - # Initialize or load visual memory - self._visual_memory = visual_memory - if visual_memory is None: - if new_memory or not os.path.exists(visual_memory_path or ""): - logger.info("Creating new visual memory") - self._visual_memory = VisualMemory(output_dir=output_dir) - else: - try: - logger.info(f"Loading existing visual memory from {visual_memory_path}...") - self._visual_memory = VisualMemory.load( - visual_memory_path, output_dir=output_dir - ) - logger.info(f"Loaded {self._visual_memory.count()} images from previous runs") - except Exception as e: - logger.error(f"Error loading visual memory: {e}") - self._visual_memory = VisualMemory(output_dir=output_dir) - - # Initialize vector database - self.vector_db: SpatialVectorDB = SpatialVectorDB( - collection_name=collection_name, - chroma_client=self._chroma_client, - visual_memory=self._visual_memory, - ) - - self.embedding_provider: ImageEmbeddingProvider = ImageEmbeddingProvider( - model_name=embedding_model, dimensions=embedding_dimensions - ) - - self.last_position: Optional[Vector] = None - self.last_record_time: Optional[float] = None - - self.frame_count: int = 0 - self.stored_frame_count: int = 0 - - # For tracking stream subscription - self._subscription = None - - # List to store robot locations - self.robot_locations: List[RobotLocation] = [] - - logger.info(f"SpatialMemory initialized with model {embedding_model}") - - # Start processing video stream if provided - if video_stream is not None and get_pose is not None: - self.start_continuous_processing(video_stream, get_pose) - - def query_by_location( - self, x: float, y: float, radius: float = 2.0, limit: int = 5 - ) -> List[Dict]: - """ - Query the vector database for images near the specified location. - - Args: - x: X coordinate - y: Y coordinate - radius: Search radius in meters - limit: Maximum number of results to return - - Returns: - List of results, each containing the image and its metadata - """ - return self.vector_db.query_by_location(x, y, radius, limit) - - def start_continuous_processing( - self, video_stream: Observable, get_pose: callable - ) -> disposable.Disposable: - """ - Start continuous processing of video frames from an Observable stream. - - Args: - video_stream: Observable of video frames - get_pose: Callable that returns position and rotation for each frame - - Returns: - Disposable subscription that can be used to stop processing - """ - # Stop any existing subscription - self.stop_continuous_processing() - - # Map each video frame to include transform data - combined_stream = video_stream.pipe( - ops.map(lambda video_frame: {"frame": video_frame, **get_pose()}), - # Filter out bad transforms - ops.filter( - lambda data: data.get("position") is not None and data.get("rotation") is not None - ), - ) - - # Process with spatial memory - result_stream = self.process_stream(combined_stream) - - # Subscribe to the result stream - self._subscription = result_stream.subscribe( - on_next=self._on_frame_processed, - on_error=lambda e: logger.error(f"Error in spatial memory stream: {e}"), - on_completed=lambda: logger.info("Spatial memory stream completed"), - ) - - logger.info("Continuous spatial memory processing started") - return self._subscription - - def stop_continuous_processing(self) -> None: - """ - Stop continuous processing of video frames. - """ - if self._subscription is not None: - try: - self._subscription.dispose() - self._subscription = None - logger.info("Stopped continuous spatial memory processing") - except Exception as e: - logger.error(f"Error stopping spatial memory processing: {e}") - - def _on_frame_processed(self, result: Dict[str, Any]) -> None: - """ - Handle updates from the spatial memory processing stream. - """ - # Log successful frame storage (if stored) - position = result.get("position") - if position is not None: - logger.debug( - f"Spatial memory updated with frame at ({position[0]:.2f}, {position[1]:.2f}, {position[2]:.2f})" - ) - - # Periodically save visual memory to disk (e.g., every 100 frames) - if self._visual_memory is not None and self.visual_memory_path is not None: - if self.stored_frame_count % 100 == 0: - self.save() - - def save(self) -> bool: - """ - Save the visual memory component to disk. - - Returns: - True if memory was saved successfully, False otherwise - """ - if self._visual_memory is not None and self.visual_memory_path is not None: - try: - saved_path = self._visual_memory.save(self.visual_memory_path) - logger.info(f"Saved {self._visual_memory.count()} images to {saved_path}") - return True - except Exception as e: - logger.error(f"Failed to save visual memory: {e}") - return False - - def process_stream(self, combined_stream: Observable) -> Observable: - """ - Process a combined stream of video frames and positions. - - This method handles a stream where each item already contains both the frame and position, - such as the stream created by combining video and transform streams with the - with_latest_from operator. - - Args: - combined_stream: Observable stream of dictionaries containing 'frame' and 'position' - - Returns: - Observable of processing results, including the stored frame and its metadata - """ - self.last_position = None - self.last_record_time = None - - def process_combined_data(data): - self.frame_count += 1 - - frame = data.get("frame") - position_vec = data.get("position") # Use .get() for consistency - rotation_vec = data.get("rotation") # Get rotation data if available - - if not position_vec or not rotation_vec: - logger.info("No position or rotation data available, skipping frame") - return None - - if ( - self.last_position is not None - and (self.last_position - position_vec).length() < self.min_distance_threshold - ): - logger.debug("Position has not moved, skipping frame") - return None - - if ( - self.last_record_time is not None - and (time.time() - self.last_record_time) < self.min_time_threshold - ): - logger.debug("Time since last record too short, skipping frame") - return None - - current_time = time.time() - - frame_embedding = self.embedding_provider.get_embedding(frame) - - frame_id = f"frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" - - # Create metadata dictionary with primitive types only - metadata = { - "pos_x": float(position_vec.x), - "pos_y": float(position_vec.y), - "pos_z": float(position_vec.z), - "rot_x": float(rotation_vec.x), - "rot_y": float(rotation_vec.y), - "rot_z": float(rotation_vec.z), - "timestamp": current_time, - "frame_id": frame_id, - } - - self.vector_db.add_image_vector( - vector_id=frame_id, image=frame, embedding=frame_embedding, metadata=metadata - ) - - self.last_position = position_vec - self.last_record_time = current_time - self.stored_frame_count += 1 - - logger.info( - f"Stored frame at position {position_vec}, rotation {rotation_vec})" - f" stored {self.stored_frame_count}/{self.frame_count} frames" - ) - - # Create return dictionary with primitive-compatible values - return { - "frame": frame, - "position": (position_vec.x, position_vec.y, position_vec.z), - "rotation": (rotation_vec.x, rotation_vec.y, rotation_vec.z), - "frame_id": frame_id, - "timestamp": current_time, - } - - return combined_stream.pipe( - ops.map(process_combined_data), ops.filter(lambda result: result is not None) - ) - - def query_by_image(self, image: np.ndarray, limit: int = 5) -> List[Dict]: - """ - Query the vector database for images similar to the provided image. - - Args: - image: Query image - limit: Maximum number of results to return - - Returns: - List of results, each containing the image and its metadata - """ - embedding = self.embedding_provider.get_embedding(image) - return self.vector_db.query_by_embedding(embedding, limit) - - def query_by_text(self, text: str, limit: int = 5) -> List[Dict]: - """ - Query the vector database for images matching the provided text description. - - This method uses CLIP's text-to-image matching capability to find images - that semantically match the text query (e.g., "where is the kitchen"). - - Args: - text: Text query to search for - limit: Maximum number of results to return - - Returns: - List of results, each containing the image, its metadata, and similarity score - """ - logger.info(f"Querying spatial memory with text: '{text}'") - return self.vector_db.query_by_text(text, limit) - - def add_robot_location(self, location: RobotLocation) -> bool: - """ - Add a named robot location to spatial memory. - - Args: - location: The RobotLocation object to add - - Returns: - True if successfully added, False otherwise - """ - try: - # Add to our list of robot locations - self.robot_locations.append(location) - logger.info(f"Added robot location '{location.name}' at position {location.position}") - return True - - except Exception as e: - logger.error(f"Error adding robot location: {e}") - return False - - def get_robot_locations(self) -> List[RobotLocation]: - """ - Get all stored robot locations. - - Returns: - List of RobotLocation objects - """ - return self.robot_locations - - def find_robot_location(self, name: str) -> Optional[RobotLocation]: - """ - Find a robot location by name. - - Args: - name: Name of the location to find - - Returns: - RobotLocation object if found, None otherwise - """ - # Simple search through our list of locations - for location in self.robot_locations: - if location.name.lower() == name.lower(): - return location - - return None - - def cleanup(self): - """Clean up resources.""" - # Stop any ongoing processing - self.stop_continuous_processing() - - # Save data if possible - self.save() - - # Log cleanup - if self.vector_db: - logger.info(f"Cleaning up SpatialMemory, stored {self.stored_frame_count} frames") diff --git a/build/lib/dimos/perception/test_spatial_memory.py b/build/lib/dimos/perception/test_spatial_memory.py deleted file mode 100644 index 9a519fe59c..0000000000 --- a/build/lib/dimos/perception/test_spatial_memory.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import shutil -import tempfile -import time - -import cv2 -import numpy as np -import pytest -import reactivex as rx -from reactivex import Observable -from reactivex import operators as ops -from reactivex.subject import Subject - -from dimos.perception.spatial_perception import SpatialMemory -from dimos.stream.video_provider import VideoProvider -from dimos.types.pose import Pose -from dimos.types.vector import Vector - - -@pytest.mark.heavy -class TestSpatialMemory: - @pytest.fixture(scope="function") - def temp_dir(self): - # Create a temporary directory for storing spatial memory data - temp_dir = tempfile.mkdtemp() - yield temp_dir - # Clean up - shutil.rmtree(temp_dir) - - def test_spatial_memory_initialization(self): - """Test SpatialMemory initializes correctly with CLIP model.""" - try: - # Initialize spatial memory with default CLIP model - memory = SpatialMemory( - collection_name="test_collection", embedding_model="clip", new_memory=True - ) - assert memory is not None - assert memory.embedding_model == "clip" - assert memory.embedding_provider is not None - except Exception as e: - # If the model doesn't initialize, skip the test - pytest.fail(f"Failed to initialize model: {e}") - - def test_image_embedding(self): - """Test generating image embeddings using CLIP.""" - try: - # Initialize spatial memory with CLIP model - memory = SpatialMemory( - collection_name="test_collection", embedding_model="clip", new_memory=True - ) - - # Create a test image - use a simple colored square - test_image = np.zeros((224, 224, 3), dtype=np.uint8) - test_image[50:150, 50:150] = [0, 0, 255] # Blue square - - # Generate embedding - embedding = memory.embedding_provider.get_embedding(test_image) - - # Check embedding shape and characteristics - assert embedding is not None - assert isinstance(embedding, np.ndarray) - assert embedding.shape[0] == memory.embedding_dimensions - - # Check that embedding is normalized (unit vector) - assert np.isclose(np.linalg.norm(embedding), 1.0, atol=1e-5) - - # Test text embedding - text_embedding = memory.embedding_provider.get_text_embedding("a blue square") - assert text_embedding is not None - assert isinstance(text_embedding, np.ndarray) - assert text_embedding.shape[0] == memory.embedding_dimensions - assert np.isclose(np.linalg.norm(text_embedding), 1.0, atol=1e-5) - except Exception as e: - pytest.fail(f"Error in test: {e}") - - def test_spatial_memory_processing(self, temp_dir): - """Test processing video frames and building spatial memory with CLIP embeddings.""" - try: - # Initialize spatial memory with temporary storage - memory = SpatialMemory( - collection_name="test_collection", - embedding_model="clip", - new_memory=True, - db_path=os.path.join(temp_dir, "chroma_db"), - visual_memory_path=os.path.join(temp_dir, "visual_memory.pkl"), - output_dir=os.path.join(temp_dir, "images"), - min_distance_threshold=0.01, - min_time_threshold=0.01, - ) - - from dimos.utils.data import get_data - - video_path = get_data("assets") / "trimmed_video_office.mov" - assert os.path.exists(video_path), f"Test video not found: {video_path}" - video_provider = VideoProvider(dev_name="test_video", video_source=video_path) - video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) - - # Create a frame counter for position generation - frame_counter = 0 - - # Process each video frame directly - def process_frame(frame): - nonlocal frame_counter - - # Generate a unique position for this frame to ensure minimum distance threshold is met - pos = Pose(frame_counter * 0.5, frame_counter * 0.5, 0) - transform = {"position": pos, "timestamp": time.time()} - frame_counter += 1 - - # Create a dictionary with frame, position and rotation for SpatialMemory.process_stream - return { - "frame": frame, - "position": transform["position"], - "rotation": transform["position"], # Using position as rotation for testing - } - - # Create a stream that processes each frame - formatted_stream = video_stream.pipe(ops.map(process_frame)) - - # Process the stream using SpatialMemory's built-in processing - print("Creating spatial memory stream...") - spatial_stream = memory.process_stream(formatted_stream) - - # Stream is now created above using memory.process_stream() - - # Collect results from the stream - results = [] - - frames_processed = 0 - target_frames = 100 # Process more frames for thorough testing - - def on_next(result): - nonlocal results, frames_processed - if not result: # Skip None results - return - - results.append(result) - frames_processed += 1 - - # Stop processing after target frames - if frames_processed >= target_frames: - subscription.dispose() - - def on_error(error): - pytest.fail(f"Error in spatial stream: {error}") - - def on_completed(): - pass - - # Subscribe and wait for results - subscription = spatial_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - # Wait for frames to be processed - timeout = 30.0 # seconds - start_time = time.time() - while frames_processed < target_frames and time.time() - start_time < timeout: - time.sleep(0.5) - - subscription.dispose() - - assert len(results) > 0, "Failed to process any frames with spatial memory" - - relevant_queries = ["office", "room with furniture"] - irrelevant_query = "star wars" - - for query in relevant_queries: - results = memory.query_by_text(query, limit=2) - print(f"\nResults for query: '{query}'") - - assert len(results) > 0, f"No results found for relevant query: {query}" - - similarities = [1 - r.get("distance") for r in results] - print(f"Similarities: {similarities}") - - assert any(d > 0.24 for d in similarities), ( - f"Expected at least one result with similarity > 0.24 for query '{query}'" - ) - - results = memory.query_by_text(irrelevant_query, limit=2) - print(f"\nResults for query: '{irrelevant_query}'") - - if results: - similarities = [1 - r.get("distance") for r in results] - print(f"Similarities: {similarities}") - - assert all(d < 0.25 for d in similarities), ( - f"Expected all results to have similarity < 0.25 for irrelevant query '{irrelevant_query}'" - ) - - except Exception as e: - pytest.fail(f"Error in test: {e}") - finally: - memory.cleanup() - video_provider.dispose_all() - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/build/lib/dimos/perception/visual_servoing.py b/build/lib/dimos/perception/visual_servoing.py deleted file mode 100644 index 40cee7c60c..0000000000 --- a/build/lib/dimos/perception/visual_servoing.py +++ /dev/null @@ -1,500 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import threading -from typing import Dict, Optional, List, Tuple -import logging -import numpy as np - -from dimos.utils.simple_controller import VisualServoingController - -# Configure logging -logger = logging.getLogger(__name__) - - -def calculate_iou(box1, box2): - """Calculate Intersection over Union between two bounding boxes.""" - x1 = max(box1[0], box2[0]) - y1 = max(box1[1], box2[1]) - x2 = min(box1[2], box2[2]) - y2 = min(box1[3], box2[3]) - - intersection = max(0, x2 - x1) * max(0, y2 - y1) - area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) - area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) - union = area1 + area2 - intersection - - return intersection / union if union > 0 else 0 - - -class VisualServoing: - """ - A class that performs visual servoing to track and follow a human target. - - The class will use the provided tracking stream to detect people and estimate - their distance and angle, then use a VisualServoingController to generate - appropriate velocity commands to track the target. - """ - - def __init__( - self, - tracking_stream=None, - max_linear_speed=0.8, - max_angular_speed=1.5, - desired_distance=1.5, - max_lost_frames=10000, - iou_threshold=0.6, - ): - """Initialize the visual servoing. - - Args: - tracking_stream: Observable tracking stream (must be already set up) - max_linear_speed: Maximum linear speed in m/s - max_angular_speed: Maximum angular speed in rad/s - desired_distance: Desired distance to maintain from target in meters - max_lost_frames: Maximum number of frames target can be lost before stopping tracking - iou_threshold: Minimum IOU threshold to consider bounding boxes as matching - """ - self.tracking_stream = tracking_stream - self.max_linear_speed = max_linear_speed - self.max_angular_speed = max_angular_speed - self.desired_distance = desired_distance - self.max_lost_frames = max_lost_frames - self.iou_threshold = iou_threshold - - # Initialize the controller with PID parameters tuned for slow-moving robot - # Distance PID: (kp, ki, kd, output_limits, integral_limit, deadband, output_deadband) - distance_pid_params = ( - 1.0, # kp: Moderate proportional gain for smooth approach - 0.2, # ki: Small integral gain to eliminate steady-state error - 0.1, # kd: Some damping for smooth motion - (-self.max_linear_speed, self.max_linear_speed), # output_limits - 0.5, # integral_limit: Prevent windup - 0.1, # deadband: Small deadband for distance control - 0.05, # output_deadband: Minimum output to overcome friction - ) - - # Angle PID: (kp, ki, kd, output_limits, integral_limit, deadband, output_deadband) - angle_pid_params = ( - 1.4, # kp: Higher proportional gain for responsive turning - 0.1, # ki: Small integral gain - 0.05, # kd: Light damping to prevent oscillation - (-self.max_angular_speed, self.max_angular_speed), # output_limits - 0.3, # integral_limit: Prevent windup - 0.1, # deadband: Small deadband for angle control - 0.1, # output_deadband: Minimum output to overcome friction - True, # Invert output for angular control - ) - - # Initialize the visual servoing controller - self.controller = VisualServoingController( - distance_pid_params=distance_pid_params, angle_pid_params=angle_pid_params - ) - - # Initialize tracking state - self.last_control_time = time.time() - self.running = False - self.current_target = None # (target_id, bbox) - self.target_lost_frames = 0 - - # Add variables to track current distance and angle - self.current_distance = None - self.current_angle = None - - # Stream subscription management - self.subscription = None - self.latest_result = None - self.result_lock = threading.Lock() - self.stop_event = threading.Event() - - # Subscribe to the tracking stream - self._subscribe_to_tracking_stream() - - def start_tracking( - self, - desired_distance: int = None, - point: Tuple[int, int] = None, - timeout_wait_for_target: float = 20.0, - ) -> bool: - """ - Start tracking a human target using visual servoing. - - Args: - point: Optional tuple of (x, y) coordinates in image space. If provided, - will find the target whose bounding box contains this point. - If None, will track the closest person. - - Returns: - bool: True if tracking was successfully started, False otherwise - """ - if desired_distance is not None: - self.desired_distance = desired_distance - - if self.tracking_stream is None: - self.running = False - return False - - # Get the latest frame and targets from person tracker - try: - # Try getting the result multiple times with delays - for attempt in range(10): - result = self._get_current_tracking_result() - - if result is not None: - break - - logger.warning( - f"Attempt {attempt + 1}: No tracking result, retrying in 1 second..." - ) - time.sleep(3) # Wait 1 second between attempts - - if result is None: - logger.warning("Stream error, no targets found after multiple attempts") - return False - - targets = result.get("targets") - - # If bbox is provided, find matching target based on IOU - if point is not None and not self.running: - # Find the target with highest IOU to the provided bbox - best_target = self._find_target_by_point(point, targets) - # If no bbox is provided, find the closest person - elif not self.running: - if timeout_wait_for_target > 0.0 and len(targets) == 0: - # Wait for target to appear - start_time = time.time() - while time.time() - start_time < timeout_wait_for_target: - time.sleep(0.2) - result = self._get_current_tracking_result() - targets = result.get("targets") - if len(targets) > 0: - break - best_target = self._find_closest_target(targets) - else: - # Already tracking - return True - - if best_target: - # Set as current target and reset lost counter - target_id = best_target.get("target_id") - target_bbox = best_target.get("bbox") - self.current_target = (target_id, target_bbox) - self.target_lost_frames = 0 - self.running = True - logger.info(f"Started tracking target ID: {target_id}") - - # Get distance and angle and compute control (store as initial control values) - distance = best_target.get("distance") - angle = best_target.get("angle") - self._compute_control(distance, angle) - return True - else: - if point is not None: - logger.warning("No matching target found") - else: - logger.warning("No suitable target found for tracking") - self.running = False - return False - except Exception as e: - logger.error(f"Error starting tracking: {e}") - self.running = False - return False - - def _find_target_by_point(self, point, targets): - """Find the target whose bounding box contains the given point. - - Args: - point: Tuple of (x, y) coordinates in image space - targets: List of target dictionaries - - Returns: - dict: The target whose bbox contains the point, or None if no match - """ - x, y = point - for target in targets: - bbox = target.get("bbox") - if not bbox: - continue - - x1, y1, x2, y2 = bbox - if x1 <= x <= x2 and y1 <= y <= y2: - return target - return None - - def updateTracking(self) -> Dict[str, any]: - """ - Update tracking of current target. - - Returns: - Dict with linear_vel, angular_vel, and running state - """ - if not self.running or self.current_target is None: - self.running = False - self.current_distance = None - self.current_angle = None - return {"linear_vel": 0.0, "angular_vel": 0.0} - - # Get the latest tracking result - result = self._get_current_tracking_result() - - # Get targets from result - targets = result.get("targets") - - # Try to find current target by ID or IOU - current_target_id, current_bbox = self.current_target - target_found = False - - # First try to find by ID - for target in targets: - if target.get("target_id") == current_target_id: - # Found by ID, update bbox - self.current_target = (current_target_id, target.get("bbox")) - self.target_lost_frames = 0 - target_found = True - - # Store current distance and angle - self.current_distance = target.get("distance") - self.current_angle = target.get("angle") - - # Compute control - control = self._compute_control(self.current_distance, self.current_angle) - return control - - # If not found by ID, try to find by IOU - if not target_found and current_bbox is not None: - best_target = self._find_best_target_by_iou(current_bbox, targets) - if best_target: - # Update target - new_id = best_target.get("target_id") - new_bbox = best_target.get("bbox") - self.current_target = (new_id, new_bbox) - self.target_lost_frames = 0 - logger.info(f"Target ID updated: {current_target_id} -> {new_id}") - - # Store current distance and angle - self.current_distance = best_target.get("distance") - self.current_angle = best_target.get("angle") - - # Compute control - control = self._compute_control(self.current_distance, self.current_angle) - return control - - # Target not found, increment lost counter - if not target_found: - self.target_lost_frames += 1 - logger.warning(f"Target lost: frame {self.target_lost_frames}/{self.max_lost_frames}") - - # Check if target is lost for too many frames - if self.target_lost_frames >= self.max_lost_frames: - logger.info("Target lost for too many frames, stopping tracking") - self.stop_tracking() - return {"linear_vel": 0.0, "angular_vel": 0.0, "running": False} - - return {"linear_vel": 0.0, "angular_vel": 0.0} - - def _compute_control(self, distance: float, angle: float) -> Dict[str, float]: - """ - Compute control commands based on measured distance and angle. - - Args: - distance: Measured distance to target in meters - angle: Measured angle to target in radians - - Returns: - Dict with linear_vel and angular_vel keys - """ - current_time = time.time() - dt = current_time - self.last_control_time - self.last_control_time = current_time - - # Compute control with visual servoing controller - linear_vel, angular_vel = self.controller.compute_control( - measured_distance=distance, - measured_angle=angle, - desired_distance=self.desired_distance, - desired_angle=0.0, # Keep target centered - dt=dt, - ) - - # Log control values for debugging - logger.debug(f"Distance: {distance:.2f}m, Angle: {np.rad2deg(angle):.1f}°") - logger.debug(f"Control: linear={linear_vel:.2f}m/s, angular={angular_vel:.2f}rad/s") - - return {"linear_vel": linear_vel, "angular_vel": angular_vel} - - def _find_best_target_by_iou(self, bbox: List[float], targets: List[Dict]) -> Optional[Dict]: - """ - Find the target with highest IOU to the given bbox. - - Args: - bbox: Bounding box to match [x1, y1, x2, y2] - targets: List of target dictionaries - - Returns: - Best matching target or None if no match found - """ - if not targets: - return None - - best_iou = self.iou_threshold - best_target = None - - for target in targets: - target_bbox = target.get("bbox") - if target_bbox is None: - continue - - iou = calculate_iou(bbox, target_bbox) - if iou > best_iou: - best_iou = iou - best_target = target - - return best_target - - def _find_closest_target(self, targets: List[Dict]) -> Optional[Dict]: - """ - Find the target with shortest distance to the camera. - - Args: - targets: List of target dictionaries - - Returns: - The closest target or None if no targets available - """ - if not targets: - return None - - closest_target = None - min_distance = float("inf") - - for target in targets: - distance = target.get("distance") - if distance is not None and distance < min_distance: - min_distance = distance - closest_target = target - - return closest_target - - def _subscribe_to_tracking_stream(self): - """ - Subscribe to the already set up tracking stream. - """ - if self.tracking_stream is None: - logger.warning("No tracking stream provided to subscribe to") - return - - try: - # Set up subscription to process frames - self.subscription = self.tracking_stream.subscribe( - on_next=self._on_tracking_result, - on_error=self._on_tracking_error, - on_completed=self._on_tracking_completed, - ) - - logger.info("Subscribed to tracking stream successfully") - except Exception as e: - logger.error(f"Error subscribing to tracking stream: {e}") - - def _on_tracking_result(self, result): - """ - Callback for tracking stream results. - - This updates the latest result for use by _get_current_tracking_result. - - Args: - result: The result from the tracking stream - """ - if self.stop_event.is_set(): - return - - # Update the latest result - with self.result_lock: - self.latest_result = result - - def _on_tracking_error(self, error): - """ - Callback for tracking stream errors. - - Args: - error: The error from the tracking stream - """ - logger.error(f"Tracking stream error: {error}") - self.stop_event.set() - - def _on_tracking_completed(self): - """Callback for tracking stream completion.""" - logger.info("Tracking stream completed") - self.stop_event.set() - - def _get_current_tracking_result(self) -> Optional[Dict]: - """ - Get the current tracking result. - - Returns the latest result cached from the tracking stream subscription. - - Returns: - Dict with 'frame' and 'targets' or None if not available - """ - # Return the latest cached result - with self.result_lock: - return self.latest_result - - def stop_tracking(self): - """Stop tracking and reset controller state.""" - self.running = False - self.current_target = None - self.target_lost_frames = 0 - self.current_distance = None - self.current_angle = None - return {"linear_vel": 0.0, "angular_vel": 0.0, "running": False} - - def is_goal_reached(self, distance_threshold=0.2, angle_threshold=0.1) -> bool: - """ - Check if the robot has reached the tracking goal (desired distance and angle). - - Args: - distance_threshold: Maximum allowed difference between current and desired distance (meters) - angle_threshold: Maximum allowed difference between current and desired angle (radians) - - Returns: - bool: True if both distance and angle are within threshold of desired values - """ - if not self.running or self.current_target is None: - return False - - # Use the stored distance and angle values - if self.current_distance is None or self.current_angle is None: - return False - - # Check if within thresholds - distance_error = abs(self.current_distance - self.desired_distance) - angle_error = abs(self.current_angle) # Desired angle is always 0 (centered) - - logger.debug( - f"Goal check - Distance error: {distance_error:.2f}m, Angle error: {angle_error:.2f}rad" - ) - - return (distance_error <= distance_threshold) and (angle_error <= angle_threshold) - - def cleanup(self): - """Clean up all resources used by the visual servoing.""" - self.stop_event.set() - if self.subscription: - self.subscription.dispose() - self.subscription = None - - def __del__(self): - """Destructor to ensure cleanup on object deletion.""" - self.cleanup() diff --git a/build/lib/dimos/robot/__init__.py b/build/lib/dimos/robot/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/robot/connection_interface.py b/build/lib/dimos/robot/connection_interface.py deleted file mode 100644 index 1f327a7939..0000000000 --- a/build/lib/dimos/robot/connection_interface.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Optional -from reactivex.observable import Observable -from dimos.types.vector import Vector - -__all__ = ["ConnectionInterface"] - - -class ConnectionInterface(ABC): - """Abstract base class for robot connection interfaces. - - This class defines the minimal interface that all connection types (ROS, WebRTC, etc.) - must implement to provide robot control and data streaming capabilities. - """ - - @abstractmethod - def move(self, velocity: Vector, duration: float = 0.0) -> bool: - """Send movement command to the robot using velocity commands. - - Args: - velocity: Velocity vector [x, y, yaw] where: - x: Forward/backward velocity (m/s) - y: Left/right velocity (m/s) - yaw: Rotational velocity (rad/s) - duration: How long to move (seconds). If 0, command is continuous - - Returns: - bool: True if command was sent successfully - """ - pass - - @abstractmethod - def get_video_stream(self, fps: int = 30) -> Optional[Observable]: - """Get the video stream from the robot's camera. - - Args: - fps: Frames per second for the video stream - - Returns: - Observable: An observable stream of video frames or None if not available - """ - pass - - @abstractmethod - def stop(self) -> bool: - """Stop the robot's movement. - - Returns: - bool: True if stop command was sent successfully - """ - pass - - @abstractmethod - def disconnect(self) -> None: - """Disconnect from the robot and clean up resources.""" - pass diff --git a/build/lib/dimos/robot/foxglove_bridge.py b/build/lib/dimos/robot/foxglove_bridge.py deleted file mode 100644 index a0374fc251..0000000000 --- a/build/lib/dimos/robot/foxglove_bridge.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import threading - -# this is missing, I'm just trying to import lcm_foxglove_bridge.py from dimos_lcm -import dimos_lcm.lcm_foxglove_bridge as bridge - -from dimos.core import Module, rpc - - -class FoxgloveBridge(Module): - _thread: threading.Thread - _loop: asyncio.AbstractEventLoop - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.start() - - @rpc - def start(self): - def run_bridge(): - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - try: - self._loop.run_until_complete(bridge.main()) - except Exception as e: - print(f"Foxglove bridge error: {e}") - - self._thread = threading.Thread(target=run_bridge, daemon=True) - self._thread.start() - - @rpc - def stop(self): - if self._loop and self._loop.is_running(): - self._loop.call_soon_threadsafe(self._loop.stop) - self._thread.join(timeout=2) diff --git a/build/lib/dimos/robot/frontier_exploration/__init__.py b/build/lib/dimos/robot/frontier_exploration/__init__.py deleted file mode 100644 index 2b69011a9f..0000000000 --- a/build/lib/dimos/robot/frontier_exploration/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from utils import * diff --git a/build/lib/dimos/robot/frontier_exploration/qwen_frontier_predictor.py b/build/lib/dimos/robot/frontier_exploration/qwen_frontier_predictor.py deleted file mode 100644 index 10a1d8a265..0000000000 --- a/build/lib/dimos/robot/frontier_exploration/qwen_frontier_predictor.py +++ /dev/null @@ -1,368 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Qwen-based frontier exploration goal predictor using vision language model. - -This module provides a frontier goal detector that uses Qwen's vision capabilities -to analyze costmap images and predict optimal exploration goals. -""" - -import os -import glob -import json -import re -from typing import Optional, List, Tuple - -import numpy as np -from PIL import Image, ImageDraw - -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector -from dimos.models.qwen.video_query import query_single_frame -from dimos.robot.frontier_exploration.utils import ( - costmap_to_pil_image, - smooth_costmap_for_frontiers, -) - - -class QwenFrontierPredictor: - """ - Qwen-based frontier exploration goal predictor. - - Uses Qwen's vision capabilities to analyze costmap images and predict - optimal exploration goals based on visual understanding of the map structure. - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "qwen2.5-vl-72b-instruct", - use_smoothed_costmap: bool = True, - image_scale_factor: int = 4, - ): - """ - Initialize the Qwen frontier predictor. - - Args: - api_key: Alibaba API key for Qwen access - model_name: Qwen model to use for predictions - image_scale_factor: Scale factor for image processing - """ - self.api_key = api_key or os.getenv("ALIBABA_API_KEY") - if not self.api_key: - raise ValueError( - "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" - ) - - self.model_name = model_name - self.image_scale_factor = image_scale_factor - self.use_smoothed_costmap = use_smoothed_costmap - - # Storage for previously explored goals - self.explored_goals: List[Vector] = [] - - def _world_to_image_coords(self, world_pos: Vector, costmap: Costmap) -> Tuple[int, int]: - """Convert world coordinates to image pixel coordinates.""" - grid_pos = costmap.world_to_grid(world_pos) - img_x = int(grid_pos.x * self.image_scale_factor) - img_y = int((costmap.height - grid_pos.y) * self.image_scale_factor) # Flip Y - return img_x, img_y - - def _image_to_world_coords(self, img_x: int, img_y: int, costmap: Costmap) -> Vector: - """Convert image pixel coordinates to world coordinates.""" - # Unscale and flip Y coordinate - grid_x = img_x / self.image_scale_factor - grid_y = costmap.height - (img_y / self.image_scale_factor) - - # Convert grid to world coordinates - world_pos = costmap.grid_to_world(Vector([grid_x, grid_y])) - return world_pos - - def _draw_goals_on_image( - self, - image: Image.Image, - robot_pose: Vector, - costmap: Costmap, - latest_goal: Optional[Vector] = None, - ) -> Image.Image: - """ - Draw explored goals and robot position on the costmap image. - - Args: - image: PIL Image to draw on - robot_pose: Current robot position - costmap: Costmap for coordinate conversion - latest_goal: Latest predicted goal to highlight in red - - Returns: - PIL Image with goals drawn - """ - img_copy = image.copy() - draw = ImageDraw.Draw(img_copy) - - # Draw previously explored goals as green dots - for explored_goal in self.explored_goals: - x, y = self._world_to_image_coords(explored_goal, costmap) - radius = 8 - draw.ellipse( - [x - radius, y - radius, x + radius, y + radius], - fill=(0, 255, 0), - outline=(0, 128, 0), - width=2, - ) - - # Draw robot position as blue dot - robot_x, robot_y = self._world_to_image_coords(robot_pose, costmap) - robot_radius = 10 - draw.ellipse( - [ - robot_x - robot_radius, - robot_y - robot_radius, - robot_x + robot_radius, - robot_y + robot_radius, - ], - fill=(0, 0, 255), - outline=(0, 0, 128), - width=3, - ) - - # Draw latest predicted goal as red dot - if latest_goal: - goal_x, goal_y = self._world_to_image_coords(latest_goal, costmap) - goal_radius = 12 - draw.ellipse( - [ - goal_x - goal_radius, - goal_y - goal_radius, - goal_x + goal_radius, - goal_y + goal_radius, - ], - fill=(255, 0, 0), - outline=(128, 0, 0), - width=3, - ) - - return img_copy - - def _create_vision_prompt(self) -> str: - """Create the vision prompt for Qwen model.""" - prompt = """You are an expert robot navigation system analyzing a costmap for frontier exploration. - -COSTMAP LEGEND: -- Light gray pixels (205,205,205): FREE SPACE - areas the robot can navigate -- Dark gray pixels (128,128,128): UNKNOWN SPACE - unexplored areas that need exploration -- Black pixels (0,0,0): OBSTACLES - walls, furniture, blocked areas -- Blue dot: CURRENT ROBOT POSITION -- Green dots: PREVIOUSLY EXPLORED GOALS - avoid these areas - -TASK: Find the best frontier exploration goal by identifying the optimal point where: -1. It's at the boundary between FREE SPACE (light gray) and UNKNOWN SPACE (dark gray) (HIGHEST Priority) -2. It's reasonably far from the robot position (blue dot) (MEDIUM Priority) -3. It's reasonably far from previously explored goals (green dots) (MEDIUM Priority) -4. It leads to a large area of unknown space to explore (HIGH Priority) -5. It's accessible from the robot's current position through free space (MEDIUM Priority) -6. It's not near or on obstacles (HIGHEST Priority) - -RESPONSE FORMAT: Return ONLY the pixel coordinates as a JSON object: -{"x": pixel_x_coordinate, "y": pixel_y_coordinate, "reasoning": "brief explanation"} - -Example: {"x": 245, "y": 187, "reasoning": "Large unknown area to the north, good distance from robot and previous goals"} - -Analyze the image and identify the single best frontier exploration goal.""" - - return prompt - - def _parse_prediction_response(self, response: str) -> Optional[Tuple[int, int, str]]: - """ - Parse the model's response to extract coordinates and reasoning. - - Args: - response: Raw response from Qwen model - - Returns: - Tuple of (x, y, reasoning) or None if parsing failed - """ - try: - # Try to find JSON object in response - json_match = re.search(r"\{[^}]*\}", response) - if json_match: - json_str = json_match.group() - data = json.loads(json_str) - - if "x" in data and "y" in data: - x = int(data["x"]) - y = int(data["y"]) - reasoning = data.get("reasoning", "No reasoning provided") - return (x, y, reasoning) - - # Fallback: try to extract coordinates with regex - coord_match = re.search(r"[^\d]*(\d+)[^\d]+(\d+)", response) - if coord_match: - x = int(coord_match.group(1)) - y = int(coord_match.group(2)) - return (x, y, "Coordinates extracted from response") - - except (json.JSONDecodeError, ValueError, KeyError) as e: - print(f"DEBUG: Failed to parse prediction response: {e}") - - return None - - def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional[Vector]: - """ - Get the best exploration goal using Qwen vision analysis. - - Args: - robot_pose: Current robot position in world coordinates - costmap: Current costmap for analysis - - Returns: - Single best frontier goal in world coordinates, or None if no suitable goal found - """ - print( - f"DEBUG: Qwen frontier prediction starting with {len(self.explored_goals)} explored goals" - ) - - # Create costmap image - if self.use_smoothed_costmap: - costmap = smooth_costmap_for_frontiers(costmap, alpha=4.0) - - base_image = costmap_to_pil_image(costmap, self.image_scale_factor) - - # Draw goals on image (without latest goal initially) - annotated_image = self._draw_goals_on_image(base_image, robot_pose, costmap) - - # Query Qwen model for frontier prediction - try: - prompt = self._create_vision_prompt() - response = query_single_frame( - annotated_image, prompt, api_key=self.api_key, model_name=self.model_name - ) - - print(f"DEBUG: Qwen response: {response}") - - # Parse response to get coordinates - parsed_result = self._parse_prediction_response(response) - if not parsed_result: - print("DEBUG: Failed to parse Qwen response") - return None - - img_x, img_y, reasoning = parsed_result - print(f"DEBUG: Parsed coordinates: ({img_x}, {img_y}), Reasoning: {reasoning}") - - # Convert image coordinates to world coordinates - predicted_goal = self._image_to_world_coords(img_x, img_y, costmap) - print( - f"DEBUG: Predicted goal in world coordinates: ({predicted_goal.x:.2f}, {predicted_goal.y:.2f})" - ) - - # Store the goal in explored_goals for future reference - self.explored_goals.append(predicted_goal) - - print(f"DEBUG: Successfully predicted frontier goal: {predicted_goal}") - return predicted_goal - - except Exception as e: - print(f"DEBUG: Error during Qwen prediction: {e}") - return None - - -def test_qwen_frontier_detection(): - """ - Visual test for Qwen frontier detection using saved costmaps. - Shows frontier detection results with Qwen predictions. - """ - - # Path to saved costmaps - saved_maps_dir = os.path.join(os.getcwd(), "assets", "saved_maps") - - if not os.path.exists(saved_maps_dir): - print(f"Error: Saved maps directory not found: {saved_maps_dir}") - return - - # Get all pickle files - pickle_files = sorted(glob.glob(os.path.join(saved_maps_dir, "*.pickle"))) - - if not pickle_files: - print(f"No pickle files found in {saved_maps_dir}") - return - - print(f"Found {len(pickle_files)} costmap files for Qwen testing") - - # Initialize Qwen frontier predictor - predictor = QwenFrontierPredictor(image_scale_factor=4, use_smoothed_costmap=False) - - # Track the robot pose across iterations - robot_pose = None - - # Process each costmap - for i, pickle_file in enumerate(pickle_files): - print( - f"\n--- Processing costmap {i + 1}/{len(pickle_files)}: {os.path.basename(pickle_file)} ---" - ) - - try: - # Load the costmap - costmap = Costmap.from_pickle(pickle_file) - print( - f"Loaded costmap: {costmap.width}x{costmap.height}, resolution: {costmap.resolution}" - ) - - # Set robot pose: first iteration uses center, subsequent use last predicted goal - if robot_pose is None: - # First iteration: use center of costmap as robot position - center_world = costmap.grid_to_world( - Vector([costmap.width / 2, costmap.height / 2]) - ) - robot_pose = Vector([center_world.x, center_world.y]) - # else: robot_pose remains the last predicted goal - - print(f"Using robot position: {robot_pose}") - - # Get frontier prediction from Qwen - print("Getting Qwen frontier prediction...") - predicted_goal = predictor.get_exploration_goal(robot_pose, costmap) - - if predicted_goal: - distance = np.sqrt( - (predicted_goal.x - robot_pose.x) ** 2 + (predicted_goal.y - robot_pose.y) ** 2 - ) - print(f"Predicted goal: {predicted_goal}, Distance: {distance:.2f}m") - - # Show the final visualization - base_image = costmap_to_pil_image(costmap, predictor.image_scale_factor) - final_image = predictor._draw_goals_on_image( - base_image, robot_pose, costmap, predicted_goal - ) - - # Display image - title = f"Qwen Frontier Prediction {i + 1:04d}" - final_image.show(title=title) - - # Update robot pose for next iteration - robot_pose = predicted_goal - - else: - print("No suitable frontier goal predicted by Qwen") - - except Exception as e: - print(f"Error processing {pickle_file}: {e}") - continue - - print(f"\n=== Qwen Frontier Detection Test Complete ===") - print(f"Final explored goals count: {len(predictor.explored_goals)}") - - -if __name__ == "__main__": - test_qwen_frontier_detection() diff --git a/build/lib/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py b/build/lib/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py deleted file mode 100644 index c9b75b28d8..0000000000 --- a/build/lib/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional - -import numpy as np -import pytest -from PIL import Image, ImageDraw -from reactivex import operators as ops - -from dimos.robot.frontier_exploration.utils import costmap_to_pil_image -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( - WavefrontFrontierExplorer, -) -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.types.vector import Vector -from dimos.utils.testing import SensorReplay - - -def get_office_lidar_costmap(take_frames: int = 1, voxel_size: float = 0.5) -> tuple: - """ - Get a costmap from office_lidar data using SensorReplay. - - Args: - take_frames: Number of lidar frames to take (default 1) - voxel_size: Voxel size for map construction - - Returns: - Tuple of (costmap, first_lidar_message) for testing - """ - # Load office lidar data using SensorReplay as documented - lidar_stream = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) - - # Create map with specified voxel size - map_obj = Map(voxel_size=voxel_size) - - # Take only the specified number of frames and build map - limited_stream = lidar_stream.stream().pipe(ops.take(take_frames)) - - # Store the first lidar message for reference - first_lidar = None - - def capture_first_and_add(lidar_msg): - nonlocal first_lidar - if first_lidar is None: - first_lidar = lidar_msg - return map_obj.add_frame(lidar_msg) - - # Process the stream - limited_stream.pipe(ops.map(capture_first_and_add)).run() - - # Get the resulting costmap - costmap = map_obj.costmap() - - return costmap, first_lidar - - -def test_frontier_detection_with_office_lidar(): - """Test frontier detection using a single frame from office_lidar data.""" - # Get costmap from office lidar data - costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) - - # Verify we have a valid costmap - assert costmap is not None, "Costmap should not be None" - assert costmap.width > 0 and costmap.height > 0, "Costmap should have valid dimensions" - - print(f"Costmap dimensions: {costmap.width}x{costmap.height}") - print(f"Costmap resolution: {costmap.resolution}") - print(f"Unknown percent: {costmap.unknown_percent:.1f}%") - print(f"Free percent: {costmap.free_percent:.1f}%") - print(f"Occupied percent: {costmap.occupied_percent:.1f}%") - - # Initialize frontier explorer with default parameters - explorer = WavefrontFrontierExplorer() - - # Set robot pose near the center of free space in the costmap - # We'll use the lidar origin as a reasonable robot position - robot_pose = first_lidar.origin - print(f"Robot pose: {robot_pose}") - - # Detect frontiers - frontiers = explorer.detect_frontiers(robot_pose, costmap) - - # Verify frontier detection results - assert isinstance(frontiers, list), "Frontiers should be returned as a list" - print(f"Detected {len(frontiers)} frontiers") - - # Test that we get some frontiers (office environment should have unexplored areas) - if len(frontiers) > 0: - print("Frontier detection successful - found unexplored areas") - - # Verify frontiers are Vector objects with valid coordinates - for i, frontier in enumerate(frontiers[:5]): # Check first 5 - assert isinstance(frontier, Vector), f"Frontier {i} should be a Vector" - assert hasattr(frontier, "x") and hasattr(frontier, "y"), ( - f"Frontier {i} should have x,y coordinates" - ) - print(f" Frontier {i}: ({frontier.x:.2f}, {frontier.y:.2f})") - else: - print("No frontiers detected - map may be fully explored or parameters too restrictive") - - -def test_exploration_goal_selection(): - """Test the complete exploration goal selection pipeline.""" - # Get costmap from office lidar data - costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) - - # Initialize frontier explorer with default parameters - explorer = WavefrontFrontierExplorer() - - # Use lidar origin as robot position - robot_pose = first_lidar.origin - - # Get exploration goal - goal = explorer.get_exploration_goal(robot_pose, costmap) - - if goal is not None: - assert isinstance(goal, Vector), "Goal should be a Vector" - print(f"Selected exploration goal: ({goal.x:.2f}, {goal.y:.2f})") - - # Verify goal is at reasonable distance from robot - distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) - print(f"Goal distance from robot: {distance:.2f}m") - assert distance >= explorer.min_distance_from_robot, ( - "Goal should respect minimum distance from robot" - ) - - # Test that goal gets marked as explored - assert len(explorer.explored_goals) == 1, "Goal should be marked as explored" - assert explorer.explored_goals[0] == goal, "Explored goal should match selected goal" - - else: - print("No exploration goal selected - map may be fully explored") - - -def test_exploration_session_reset(): - """Test exploration session reset functionality.""" - # Get costmap - costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.3) - - # Initialize explorer and select a goal - explorer = WavefrontFrontierExplorer() - robot_pose = first_lidar.origin - - # Select a goal to populate exploration state - goal = explorer.get_exploration_goal(robot_pose, costmap) - - # Verify state is populated - initial_explored_count = len(explorer.explored_goals) - initial_direction = explorer.exploration_direction - - # Reset exploration session - explorer.reset_exploration_session() - - # Verify state is cleared - assert len(explorer.explored_goals) == 0, "Explored goals should be cleared after reset" - assert explorer.exploration_direction.x == 0.0 and explorer.exploration_direction.y == 0.0, ( - "Exploration direction should be reset" - ) - assert explorer.last_costmap is None, "Last costmap should be cleared" - assert explorer.num_no_gain_attempts == 0, "No-gain attempts should be reset" - - print("Exploration session reset successfully") - - -@pytest.mark.vis -def test_frontier_detection_visualization(): - """Test frontier detection with visualization (marked with @pytest.mark.vis).""" - # Get costmap from office lidar data - costmap, first_lidar = get_office_lidar_costmap(take_frames=1, voxel_size=0.2) - - # Initialize frontier explorer with default parameters - explorer = WavefrontFrontierExplorer() - - # Use lidar origin as robot position - robot_pose = first_lidar.origin - - # Detect all frontiers for visualization - all_frontiers = explorer.detect_frontiers(robot_pose, costmap) - - # Get selected goal - selected_goal = explorer.get_exploration_goal(robot_pose, costmap) - - print(f"Visualizing {len(all_frontiers)} frontier candidates") - if selected_goal: - print(f"Selected goal: ({selected_goal.x:.2f}, {selected_goal.y:.2f})") - - # Create visualization - image_scale_factor = 4 - base_image = costmap_to_pil_image(costmap, image_scale_factor) - - # Helper function to convert world coordinates to image coordinates - def world_to_image_coords(world_pos: Vector) -> tuple[int, int]: - grid_pos = costmap.world_to_grid(world_pos) - img_x = int(grid_pos.x * image_scale_factor) - img_y = int((costmap.height - grid_pos.y) * image_scale_factor) # Flip Y - return img_x, img_y - - # Draw visualization - draw = ImageDraw.Draw(base_image) - - # Draw frontier candidates as gray dots - for frontier in all_frontiers[:20]: # Limit to top 20 - x, y = world_to_image_coords(frontier) - radius = 6 - draw.ellipse( - [x - radius, y - radius, x + radius, y + radius], - fill=(128, 128, 128), # Gray - outline=(64, 64, 64), - width=1, - ) - - # Draw robot position as blue dot - robot_x, robot_y = world_to_image_coords(robot_pose) - robot_radius = 10 - draw.ellipse( - [ - robot_x - robot_radius, - robot_y - robot_radius, - robot_x + robot_radius, - robot_y + robot_radius, - ], - fill=(0, 0, 255), # Blue - outline=(0, 0, 128), - width=3, - ) - - # Draw selected goal as red dot - if selected_goal: - goal_x, goal_y = world_to_image_coords(selected_goal) - goal_radius = 12 - draw.ellipse( - [ - goal_x - goal_radius, - goal_y - goal_radius, - goal_x + goal_radius, - goal_y + goal_radius, - ], - fill=(255, 0, 0), # Red - outline=(128, 0, 0), - width=3, - ) - - # Display the image - base_image.show(title="Frontier Detection - Office Lidar") - - print("Visualization displayed. Close the image window to continue.") - - -def test_multi_frame_exploration(): - """Tool test for multi-frame exploration analysis.""" - print("=== Multi-Frame Exploration Analysis ===") - - # Test with different numbers of frames - frame_counts = [1, 3, 5] - - for frame_count in frame_counts: - print(f"\n--- Testing with {frame_count} lidar frame(s) ---") - - # Get costmap with multiple frames - costmap, first_lidar = get_office_lidar_costmap(take_frames=frame_count, voxel_size=0.3) - - print( - f"Costmap: {costmap.width}x{costmap.height}, " - f"unknown: {costmap.unknown_percent:.1f}%, " - f"free: {costmap.free_percent:.1f}%, " - f"occupied: {costmap.occupied_percent:.1f}%" - ) - - # Initialize explorer with default parameters - explorer = WavefrontFrontierExplorer() - - # Detect frontiers - robot_pose = first_lidar.origin - frontiers = explorer.detect_frontiers(robot_pose, costmap) - - print(f"Detected {len(frontiers)} frontiers") - - # Get exploration goal - goal = explorer.get_exploration_goal(robot_pose, costmap) - if goal: - distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) - print(f"Selected goal at distance {distance:.2f}m") - else: - print("No exploration goal selected") diff --git a/build/lib/dimos/robot/frontier_exploration/utils.py b/build/lib/dimos/robot/frontier_exploration/utils.py deleted file mode 100644 index 746f72e2f5..0000000000 --- a/build/lib/dimos/robot/frontier_exploration/utils.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Utility functions for frontier exploration visualization and testing. -""" - -import numpy as np -from PIL import Image, ImageDraw -from typing import List, Tuple -from dimos.types.costmap import Costmap, CostValues -from dimos.types.vector import Vector -import os -import pickle -import cv2 - - -def costmap_to_pil_image(costmap: Costmap, scale_factor: int = 2) -> Image.Image: - """ - Convert costmap to PIL Image with ROS-style coloring and optional scaling. - - Args: - costmap: Costmap to convert - scale_factor: Factor to scale up the image for better visibility - - Returns: - PIL Image with ROS-style colors - """ - # Create image array (height, width, 3 for RGB) - img_array = np.zeros((costmap.height, costmap.width, 3), dtype=np.uint8) - - # Apply ROS-style coloring based on costmap values - for i in range(costmap.height): - for j in range(costmap.width): - value = costmap.grid[i, j] - if value == CostValues.FREE: # Free space = light grey - img_array[i, j] = [205, 205, 205] - elif value == CostValues.UNKNOWN: # Unknown = dark gray - img_array[i, j] = [128, 128, 128] - elif value >= CostValues.OCCUPIED: # Occupied/obstacles = black - img_array[i, j] = [0, 0, 0] - else: # Any other values (low cost) = light grey - img_array[i, j] = [205, 205, 205] - - # Flip vertically to match ROS convention (origin at bottom-left) - img_array = np.flipud(img_array) - - # Create PIL image - img = Image.fromarray(img_array, "RGB") - - # Scale up if requested - if scale_factor > 1: - new_size = (img.width * scale_factor, img.height * scale_factor) - img = img.resize(new_size, Image.NEAREST) # Use NEAREST to keep sharp pixels - - return img - - -def draw_frontiers_on_image( - image: Image.Image, - costmap: Costmap, - frontiers: List[Vector], - scale_factor: int = 2, - unfiltered_frontiers: List[Vector] = None, -) -> Image.Image: - """ - Draw frontier points on the costmap image. - - Args: - image: PIL Image to draw on - costmap: Original costmap for coordinate conversion - frontiers: List of frontier centroids (top 5) - scale_factor: Scaling factor used for the image - unfiltered_frontiers: All unfiltered frontier results (light green) - - Returns: - PIL Image with frontiers drawn - """ - img_copy = image.copy() - draw = ImageDraw.Draw(img_copy) - - def world_to_image_coords(world_pos: Vector) -> Tuple[int, int]: - """Convert world coordinates to image pixel coordinates.""" - grid_pos = costmap.world_to_grid(world_pos) - # Flip Y coordinate and apply scaling - img_x = int(grid_pos.x * scale_factor) - img_y = int((costmap.height - grid_pos.y) * scale_factor) # Flip Y - return img_x, img_y - - # Draw all unfiltered frontiers as light green circles - if unfiltered_frontiers: - for frontier in unfiltered_frontiers: - x, y = world_to_image_coords(frontier) - radius = 3 * scale_factor - draw.ellipse( - [x - radius, y - radius, x + radius, y + radius], - fill=(144, 238, 144), - outline=(144, 238, 144), - ) # Light green - - # Draw top 5 frontiers as green circles - for i, frontier in enumerate(frontiers[1:]): # Skip the best one for now - x, y = world_to_image_coords(frontier) - radius = 4 * scale_factor - draw.ellipse( - [x - radius, y - radius, x + radius, y + radius], - fill=(0, 255, 0), - outline=(0, 128, 0), - width=2, - ) # Green - - # Add number label - draw.text((x + radius + 2, y - radius), str(i + 2), fill=(0, 255, 0)) - - # Draw best frontier as red circle - if frontiers: - best_frontier = frontiers[0] - x, y = world_to_image_coords(best_frontier) - radius = 6 * scale_factor - draw.ellipse( - [x - radius, y - radius, x + radius, y + radius], - fill=(255, 0, 0), - outline=(128, 0, 0), - width=3, - ) # Red - - # Add "BEST" label - draw.text((x + radius + 2, y - radius), "BEST", fill=(255, 0, 0)) - - return img_copy - - -def smooth_costmap_for_frontiers( - costmap: Costmap, -) -> Costmap: - """ - Smooth a costmap using morphological operations for frontier exploration. - - This function applies OpenCV morphological operations to smooth free space - areas and improve connectivity for better frontier detection. It's designed - specifically for frontier exploration. - - Args: - costmap: Input Costmap object - - Returns: - Smoothed Costmap object with enhanced free space connectivity - """ - # Extract grid data and metadata from costmap - grid = costmap.grid - resolution = costmap.resolution - - # Work with a copy to avoid modifying input - filtered_grid = grid.copy() - - # 1. Create binary mask for free space - free_mask = (grid == CostValues.FREE).astype(np.uint8) * 255 - - # 2. Apply morphological operations for smoothing - kernel_size = 7 - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) - - # Dilate free space to connect nearby areas - dilated = cv2.dilate(free_mask, kernel, iterations=1) - - # Morphological closing to fill small gaps - closed = cv2.morphologyEx(dilated, cv2.MORPH_CLOSE, kernel, iterations=1) - - eroded = cv2.erode(closed, kernel, iterations=1) - - # Apply the smoothed free space back to costmap - # Only change unknown areas to free, don't override obstacles - smoothed_free = eroded == 255 - unknown_mask = grid == CostValues.UNKNOWN - filtered_grid[smoothed_free & unknown_mask] = CostValues.FREE - - return Costmap(grid=filtered_grid, origin=costmap.origin, resolution=resolution) diff --git a/build/lib/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py b/build/lib/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py deleted file mode 100644 index 76f2ddbb0a..0000000000 --- a/build/lib/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py +++ /dev/null @@ -1,665 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Simple wavefront frontier exploration algorithm implementation using dimos types. - -This module provides frontier detection and exploration goal selection -for autonomous navigation using the dimos Costmap and Vector types. -""" - -from typing import List, Tuple, Optional, Callable -from collections import deque -import numpy as np -from dataclasses import dataclass -from enum import IntFlag -import threading -from dimos.utils.logging_config import setup_logger - -from dimos.types.costmap import Costmap, CostValues -from dimos.types.vector import Vector -from dimos.robot.frontier_exploration.utils import smooth_costmap_for_frontiers - -logger = setup_logger("dimos.robot.unitree.frontier_exploration") - - -class PointClassification(IntFlag): - """Point classification flags for frontier detection algorithm.""" - - NoInformation = 0 - MapOpen = 1 - MapClosed = 2 - FrontierOpen = 4 - FrontierClosed = 8 - - -@dataclass -class GridPoint: - """Represents a point in the grid map with classification.""" - - x: int - y: int - classification: int = PointClassification.NoInformation - - -class FrontierCache: - """Cache for grid points to avoid duplicate point creation.""" - - def __init__(self): - self.points = {} - - def get_point(self, x: int, y: int) -> GridPoint: - """Get or create a grid point at the given coordinates.""" - key = (x, y) - if key not in self.points: - self.points[key] = GridPoint(x, y) - return self.points[key] - - def clear(self): - """Clear the point cache.""" - self.points.clear() - - -class WavefrontFrontierExplorer: - """ - Wavefront frontier exploration algorithm implementation. - - This class encapsulates the frontier detection and exploration goal selection - functionality using the wavefront algorithm with BFS exploration. - """ - - def __init__( - self, - min_frontier_size: int = 10, - occupancy_threshold: int = 65, - subsample_resolution: int = 2, - min_distance_from_robot: float = 0.5, - explored_area_buffer: float = 0.5, - min_distance_from_obstacles: float = 0.6, - info_gain_threshold: float = 0.03, - num_no_gain_attempts: int = 4, - set_goal: Optional[Callable] = None, - get_costmap: Optional[Callable] = None, - get_robot_pos: Optional[Callable] = None, - ): - """ - Initialize the frontier explorer. - - Args: - min_frontier_size: Minimum number of points to consider a valid frontier - occupancy_threshold: Cost threshold above which a cell is considered occupied (0-255) - subsample_resolution: Factor by which to subsample the costmap for faster processing (1=no subsampling, 2=half resolution, 4=quarter resolution) - min_distance_from_robot: Minimum distance frontier must be from robot (meters) - explored_area_buffer: Buffer distance around free areas to consider as explored (meters) - min_distance_from_obstacles: Minimum distance frontier must be from obstacles (meters) - info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) - num_no_gain_attempts: Maximum number of consecutive attempts with no information gain - set_goal: Callable to set navigation goal, signature: (goal: Vector, stop_event: Optional[threading.Event]) -> bool - get_costmap: Callable to get current costmap, signature: () -> Costmap - get_robot_pos: Callable to get current robot position, signature: () -> Vector - """ - self.min_frontier_size = min_frontier_size - self.occupancy_threshold = occupancy_threshold - self.subsample_resolution = subsample_resolution - self.min_distance_from_robot = min_distance_from_robot - self.explored_area_buffer = explored_area_buffer - self.min_distance_from_obstacles = min_distance_from_obstacles - self.info_gain_threshold = info_gain_threshold - self.num_no_gain_attempts = num_no_gain_attempts - self.set_goal = set_goal - self.get_costmap = get_costmap - self.get_robot_pos = get_robot_pos - self._cache = FrontierCache() - self.explored_goals = [] # list of explored goals - self.exploration_direction = Vector([0.0, 0.0]) # current exploration direction - self.last_costmap = None # store last costmap for information comparison - - def _count_costmap_information(self, costmap: Costmap) -> int: - """ - Count the amount of information in a costmap (free space + obstacles). - - Args: - costmap: Costmap to analyze - - Returns: - Number of cells that are free space or obstacles (not unknown) - """ - free_count = np.sum(costmap.grid == CostValues.FREE) - obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) - return int(free_count + obstacle_count) - - def _get_neighbors(self, point: GridPoint, costmap: Costmap) -> List[GridPoint]: - """Get valid neighboring points for a given grid point.""" - neighbors = [] - - # 8-connected neighbors - for dx in [-1, 0, 1]: - for dy in [-1, 0, 1]: - if dx == 0 and dy == 0: - continue - - nx, ny = point.x + dx, point.y + dy - - # Check bounds - if 0 <= nx < costmap.width and 0 <= ny < costmap.height: - neighbors.append(self._cache.get_point(nx, ny)) - - return neighbors - - def _is_frontier_point(self, point: GridPoint, costmap: Costmap) -> bool: - """ - Check if a point is a frontier point. - A frontier point is an unknown cell adjacent to at least one free cell - and not adjacent to any occupied cells. - """ - # Point must be unknown - world_pos = costmap.grid_to_world(Vector([float(point.x), float(point.y)])) - cost = costmap.get_value(world_pos) - if cost != CostValues.UNKNOWN: - return False - - has_free = False - - for neighbor in self._get_neighbors(point, costmap): - neighbor_world = costmap.grid_to_world(Vector([float(neighbor.x), float(neighbor.y)])) - neighbor_cost = costmap.get_value(neighbor_world) - - # If adjacent to occupied space, not a frontier - if neighbor_cost and neighbor_cost > self.occupancy_threshold: - return False - - # Check if adjacent to free space - if neighbor_cost == CostValues.FREE: - has_free = True - - return has_free - - def _find_free_space(self, start_x: int, start_y: int, costmap: Costmap) -> Tuple[int, int]: - """ - Find the nearest free space point using BFS from the starting position. - """ - queue = deque([self._cache.get_point(start_x, start_y)]) - visited = set() - - while queue: - point = queue.popleft() - - if (point.x, point.y) in visited: - continue - visited.add((point.x, point.y)) - - # Check if this point is free space - world_pos = costmap.grid_to_world(Vector([float(point.x), float(point.y)])) - if costmap.get_value(world_pos) == CostValues.FREE: - return (point.x, point.y) - - # Add neighbors to search - for neighbor in self._get_neighbors(point, costmap): - if (neighbor.x, neighbor.y) not in visited: - queue.append(neighbor) - - # If no free space found, return original position - return (start_x, start_y) - - def _compute_centroid(self, frontier_points: List[Vector]) -> Vector: - """Compute the centroid of a list of frontier points.""" - if not frontier_points: - return Vector([0.0, 0.0]) - - # Vectorized approach using numpy - points_array = np.array([[point.x, point.y] for point in frontier_points]) - centroid = np.mean(points_array, axis=0) - - return Vector([centroid[0], centroid[1]]) - - def detect_frontiers(self, robot_pose: Vector, costmap: Costmap) -> List[Vector]: - """ - Main frontier detection algorithm using wavefront exploration. - - Args: - robot_pose: Current robot position in world coordinates (Vector with x, y) - costmap: Costmap for additional analysis - - Returns: - List of frontier centroids in world coordinates - """ - self._cache.clear() - - # Apply filtered costmap (now default) - working_costmap = smooth_costmap_for_frontiers(costmap) - - # Subsample the costmap for faster processing - if self.subsample_resolution > 1: - subsampled_costmap = working_costmap.subsample(self.subsample_resolution) - else: - subsampled_costmap = working_costmap - - # Convert robot pose to subsampled grid coordinates - subsampled_grid_pos = subsampled_costmap.world_to_grid(robot_pose) - grid_x, grid_y = int(subsampled_grid_pos.x), int(subsampled_grid_pos.y) - - # Find nearest free space to start exploration - free_x, free_y = self._find_free_space(grid_x, grid_y, subsampled_costmap) - start_point = self._cache.get_point(free_x, free_y) - start_point.classification = PointClassification.MapOpen - - # Main exploration queue - explore ALL reachable free space - map_queue = deque([start_point]) - frontiers = [] - frontier_sizes = [] - - points_checked = 0 - frontier_candidates = 0 - - while map_queue: - current_point = map_queue.popleft() - points_checked += 1 - - # Skip if already processed - if current_point.classification & PointClassification.MapClosed: - continue - - # Mark as processed - current_point.classification |= PointClassification.MapClosed - - # Check if this point starts a new frontier - if self._is_frontier_point(current_point, subsampled_costmap): - frontier_candidates += 1 - current_point.classification |= PointClassification.FrontierOpen - frontier_queue = deque([current_point]) - new_frontier = [] - - # Explore this frontier region using BFS - while frontier_queue: - frontier_point = frontier_queue.popleft() - - # Skip if already processed - if frontier_point.classification & PointClassification.FrontierClosed: - continue - - # If this is still a frontier point, add to current frontier - if self._is_frontier_point(frontier_point, subsampled_costmap): - new_frontier.append(frontier_point) - - # Add neighbors to frontier queue - for neighbor in self._get_neighbors(frontier_point, subsampled_costmap): - if not ( - neighbor.classification - & ( - PointClassification.FrontierOpen - | PointClassification.FrontierClosed - ) - ): - neighbor.classification |= PointClassification.FrontierOpen - frontier_queue.append(neighbor) - - frontier_point.classification |= PointClassification.FrontierClosed - - # Check if we found a large enough frontier - if len(new_frontier) >= self.min_frontier_size: - world_points = [] - for point in new_frontier: - world_pos = subsampled_costmap.grid_to_world( - Vector([float(point.x), float(point.y)]) - ) - world_points.append(world_pos) - - # Compute centroid in world coordinates (already correctly scaled) - centroid = self._compute_centroid(world_points) - frontiers.append(centroid) # Store centroid - frontier_sizes.append(len(new_frontier)) # Store frontier size - - # Add ALL neighbors to main exploration queue to explore entire free space - for neighbor in self._get_neighbors(current_point, subsampled_costmap): - if not ( - neighbor.classification - & (PointClassification.MapOpen | PointClassification.MapClosed) - ): - # Check if neighbor is free space or unknown (explorable) - neighbor_world = subsampled_costmap.grid_to_world( - Vector([float(neighbor.x), float(neighbor.y)]) - ) - neighbor_cost = subsampled_costmap.get_value(neighbor_world) - - # Add free space and unknown space to exploration queue - if neighbor_cost is not None and ( - neighbor_cost == CostValues.FREE or neighbor_cost == CostValues.UNKNOWN - ): - neighbor.classification |= PointClassification.MapOpen - map_queue.append(neighbor) - - # Extract just the centroids for ranking - frontier_centroids = frontiers - - if not frontier_centroids: - return [] - - # Rank frontiers using original costmap for proper filtering - ranked_frontiers = self._rank_frontiers( - frontier_centroids, frontier_sizes, robot_pose, costmap - ) - - return ranked_frontiers - - def _update_exploration_direction(self, robot_pose: Vector, goal_pose: Optional[Vector] = None): - """Update the current exploration direction based on robot movement or selected goal.""" - if goal_pose is not None: - # Calculate direction from robot to goal - direction = Vector([goal_pose.x - robot_pose.x, goal_pose.y - robot_pose.y]) - magnitude = np.sqrt(direction.x**2 + direction.y**2) - if magnitude > 0.1: # Avoid division by zero for very close goals - self.exploration_direction = Vector( - [direction.x / magnitude, direction.y / magnitude] - ) - - def _compute_direction_momentum_score(self, frontier: Vector, robot_pose: Vector) -> float: - """Compute direction momentum score for a frontier.""" - if self.exploration_direction.x == 0 and self.exploration_direction.y == 0: - return 0.0 # No momentum if no previous direction - - # Calculate direction from robot to frontier - frontier_direction = Vector([frontier.x - robot_pose.x, frontier.y - robot_pose.y]) - magnitude = np.sqrt(frontier_direction.x**2 + frontier_direction.y**2) - - if magnitude < 0.1: - return 0.0 # Too close to calculate meaningful direction - - # Normalize frontier direction - frontier_direction = Vector( - [frontier_direction.x / magnitude, frontier_direction.y / magnitude] - ) - - # Calculate dot product for directional alignment - dot_product = ( - self.exploration_direction.x * frontier_direction.x - + self.exploration_direction.y * frontier_direction.y - ) - - # Return momentum score (higher for same direction, lower for opposite) - return max(0.0, dot_product) # Only positive momentum, no penalty for different directions - - def _compute_distance_to_explored_goals(self, frontier: Vector) -> float: - """Compute distance from frontier to the nearest explored goal.""" - if not self.explored_goals: - return 5.0 # Default consistent value when no explored goals - # Calculate distance to nearest explored goal - min_distance = float("inf") - for goal in self.explored_goals: - distance = np.sqrt((frontier.x - goal.x) ** 2 + (frontier.y - goal.y) ** 2) - min_distance = min(min_distance, distance) - - return min_distance - - def _compute_distance_to_obstacles(self, frontier: Vector, costmap: Costmap) -> float: - """ - Compute the minimum distance from a frontier point to the nearest obstacle. - - Args: - frontier: Frontier point in world coordinates - costmap: Costmap to check for obstacles - - Returns: - Minimum distance to nearest obstacle in meters - """ - # Convert frontier to grid coordinates - grid_pos = costmap.world_to_grid(frontier) - grid_x, grid_y = int(grid_pos.x), int(grid_pos.y) - - # Check if frontier is within costmap bounds - if grid_x < 0 or grid_x >= costmap.width or grid_y < 0 or grid_y >= costmap.height: - return 0.0 # Consider out-of-bounds as obstacle - - min_distance = float("inf") - search_radius = ( - int(self.min_distance_from_obstacles / costmap.resolution) + 5 - ) # Search a bit beyond minimum - - # Search in a square around the frontier point - for dy in range(-search_radius, search_radius + 1): - for dx in range(-search_radius, search_radius + 1): - check_x = grid_x + dx - check_y = grid_y + dy - - # Skip if out of bounds - if ( - check_x < 0 - or check_x >= costmap.width - or check_y < 0 - or check_y >= costmap.height - ): - continue - - # Check if this cell is an obstacle - if costmap.grid[check_y, check_x] >= self.occupancy_threshold: - # Calculate distance in meters - distance = np.sqrt(dx**2 + dy**2) * costmap.resolution - min_distance = min(min_distance, distance) - - return min_distance if min_distance != float("inf") else float("inf") - - def _compute_comprehensive_frontier_score( - self, frontier: Vector, frontier_size: int, robot_pose: Vector, costmap: Costmap - ) -> float: - """Compute comprehensive score considering multiple criteria.""" - - # 1. Distance from robot (preference for moderate distances) - robot_distance = np.sqrt( - (frontier.x - robot_pose.x) ** 2 + (frontier.y - robot_pose.y) ** 2 - ) - - # Distance score: prefer moderate distances (not too close, not too far) - optimal_distance = 4.0 # meters - distance_score = 1.0 / (1.0 + abs(robot_distance - optimal_distance)) - - # 2. Information gain (frontier size) - info_gain_score = frontier_size - - # 3. Distance to explored goals (bonus for being far from explored areas) - explored_goals_distance = self._compute_distance_to_explored_goals(frontier) - explored_goals_score = explored_goals_distance - - # 4. Distance to obstacles (penalty for being too close) - obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) - obstacles_score = obstacles_distance - - # 5. Direction momentum (if we have a current direction) - momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) - - # Combine scores with consistent scaling (no arbitrary multipliers) - total_score = ( - 0.3 * info_gain_score # 30% information gain - + 0.3 * explored_goals_score # 30% distance from explored goals - + 0.2 * distance_score # 20% distance optimization - + 0.15 * obstacles_score # 15% distance from obstacles - + 0.05 * momentum_score # 5% direction momentum - ) - - return total_score - - def _rank_frontiers( - self, - frontier_centroids: List[Vector], - frontier_sizes: List[int], - robot_pose: Vector, - costmap: Costmap, - ) -> List[Vector]: - """ - Find the single best frontier using comprehensive scoring and filtering. - - Args: - frontier_centroids: List of frontier centroids - frontier_sizes: List of frontier sizes - robot_pose: Current robot position - costmap: Costmap for additional analysis - - Returns: - List containing single best frontier, or empty list if none suitable - """ - if not frontier_centroids: - return [] - - valid_frontiers = [] - - for i, frontier in enumerate(frontier_centroids): - robot_distance = np.sqrt( - (frontier.x - robot_pose.x) ** 2 + (frontier.y - robot_pose.y) ** 2 - ) - - # Filter 1: Skip frontiers too close to robot - if robot_distance < self.min_distance_from_robot: - continue - - # Filter 2: Skip frontiers too close to obstacles - obstacle_distance = self._compute_distance_to_obstacles(frontier, costmap) - if obstacle_distance < self.min_distance_from_obstacles: - continue - - # Compute comprehensive score - frontier_size = frontier_sizes[i] if i < len(frontier_sizes) else 1 - score = self._compute_comprehensive_frontier_score( - frontier, frontier_size, robot_pose, costmap - ) - - valid_frontiers.append((frontier, score)) - - logger.info(f"Valid frontiers: {len(valid_frontiers)}") - - if not valid_frontiers: - return [] - - # Sort by score and return all valid frontiers (highest scores first) - valid_frontiers.sort(key=lambda x: x[1], reverse=True) - - # Extract just the frontiers (remove scores) and return as list - return [frontier for frontier, _ in valid_frontiers] - - def get_exploration_goal(self, robot_pose: Vector, costmap: Costmap) -> Optional[Vector]: - """ - Get the single best exploration goal using comprehensive frontier scoring. - - Args: - robot_pose: Current robot position in world coordinates (Vector with x, y) - costmap: Costmap for additional analysis - - Returns: - Single best frontier goal in world coordinates, or None if no suitable frontiers found - """ - # Check if we should compare costmaps for information gain - if len(self.explored_goals) > 5 and self.last_costmap is not None: - current_info = self._count_costmap_information(costmap) - last_info = self._count_costmap_information(self.last_costmap) - - # Check if information increase meets minimum percentage threshold - if last_info > 0: # Avoid division by zero - info_increase_percent = (current_info - last_info) / last_info - if info_increase_percent < self.info_gain_threshold: - logger.info( - f"Information increase ({info_increase_percent:.2f}) below threshold ({self.info_gain_threshold:.2f})" - ) - logger.info( - f"Current information: {current_info}, Last information: {last_info}" - ) - self.num_no_gain_attempts += 1 - if self.num_no_gain_attempts >= self.num_no_gain_attempts: - logger.info( - "No information gain for {} consecutive attempts, skipping frontier selection".format( - self.num_no_gain_attempts - ) - ) - self.reset_exploration_session() - return None - - # Always detect new frontiers to get most up-to-date information - # The new algorithm filters out explored areas and returns only the best frontier - frontiers = self.detect_frontiers(robot_pose, costmap) - - if not frontiers: - # Store current costmap before returning - self.last_costmap = costmap - self.reset_exploration_session() - return None - - # Update exploration direction based on best goal selection - if frontiers: - self._update_exploration_direction(robot_pose, frontiers[0]) - - # Store the selected goal as explored - selected_goal = frontiers[0] - self.mark_explored_goal(selected_goal) - - # Store current costmap for next comparison - self.last_costmap = costmap - - return selected_goal - - # Store current costmap before returning - self.last_costmap = costmap - return None - - def mark_explored_goal(self, goal: Vector): - """Mark a goal as explored.""" - self.explored_goals.append(goal) - - def reset_exploration_session(self): - """ - Reset all exploration state variables for a new exploration session. - - Call this method when starting a new exploration or when the robot - needs to forget its previous exploration history. - """ - self.explored_goals.clear() # Clear all previously explored goals - self.exploration_direction = Vector([0.0, 0.0]) # Reset exploration direction - self.last_costmap = None # Clear last costmap comparison - self.num_no_gain_attempts = 0 # Reset no-gain attempt counter - self._cache.clear() # Clear frontier point cache - - logger.info("Exploration session reset - all state variables cleared") - - def explore(self, stop_event: Optional[threading.Event] = None) -> bool: - """ - Perform autonomous frontier exploration by continuously finding and navigating to frontiers. - - Args: - stop_event: Optional threading.Event to signal when exploration should stop - - Returns: - bool: True if exploration completed successfully, False if stopped or failed - """ - - logger.info("Starting autonomous frontier exploration") - - while True: - # Check if stop event is set - if stop_event and stop_event.is_set(): - logger.info("Exploration stopped by stop event") - return False - - # Get fresh robot position and costmap data - robot_pose = self.get_robot_pos() - costmap = self.get_costmap() - - # Get the next frontier goal - next_goal = self.get_exploration_goal(robot_pose, costmap) - if not next_goal: - logger.info("No more frontiers found, exploration complete") - return True - - # Navigate to the frontier - logger.info(f"Navigating to frontier at {next_goal}") - navigation_successful = self.set_goal(next_goal, stop_event=stop_event) - - if not navigation_successful: - logger.warning("Failed to navigate to frontier, continuing exploration") - # Continue to try other frontiers instead of stopping - continue diff --git a/build/lib/dimos/robot/global_planner/__init__.py b/build/lib/dimos/robot/global_planner/__init__.py deleted file mode 100644 index f26a5e8f7c..0000000000 --- a/build/lib/dimos/robot/global_planner/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from dimos.robot.global_planner.planner import AstarPlanner, Planner diff --git a/build/lib/dimos/robot/global_planner/algo.py b/build/lib/dimos/robot/global_planner/algo.py deleted file mode 100644 index 236725ce05..0000000000 --- a/build/lib/dimos/robot/global_planner/algo.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import heapq -from typing import Optional, Tuple -from collections import deque -from dimos.types.path import Path -from dimos.types.vector import VectorLike, Vector -from dimos.types.costmap import Costmap - - -def find_nearest_free_cell( - costmap: Costmap, position: VectorLike, cost_threshold: int = 90, max_search_radius: int = 20 -) -> Tuple[int, int]: - """ - Find the nearest unoccupied cell in the costmap using BFS. - - Args: - costmap: Costmap object containing the environment - position: Position to find nearest free cell from - cost_threshold: Cost threshold above which a cell is considered an obstacle - max_search_radius: Maximum search radius in cells - - Returns: - Tuple of (x, y) in grid coordinates of the nearest free cell, - or the original position if no free cell is found within max_search_radius - """ - # Convert world coordinates to grid coordinates - grid_pos = costmap.world_to_grid(position) - start_x, start_y = int(grid_pos.x), int(grid_pos.y) - - # If the cell is already free, return it - if 0 <= start_x < costmap.width and 0 <= start_y < costmap.height: - if costmap.grid[start_y, start_x] < cost_threshold: - return (start_x, start_y) - - # BFS to find nearest free cell - queue = deque([(start_x, start_y, 0)]) # (x, y, distance) - visited = set([(start_x, start_y)]) - - # Possible movements (8-connected grid) - directions = [ - (0, 1), - (1, 0), - (0, -1), - (-1, 0), # horizontal/vertical - (1, 1), - (1, -1), - (-1, 1), - (-1, -1), # diagonal - ] - - while queue: - x, y, dist = queue.popleft() - - # Check if we've reached the maximum search radius - if dist > max_search_radius: - print( - f"Could not find free cell within {max_search_radius} cells of ({start_x}, {start_y})" - ) - return (start_x, start_y) # Return original position if no free cell found - - # Check if this cell is valid and free - if 0 <= x < costmap.width and 0 <= y < costmap.height: - if costmap.grid[y, x] < cost_threshold: - print( - f"Found free cell at ({x}, {y}), {dist} cells away from ({start_x}, {start_y})" - ) - return (x, y) - - # Add neighbors to the queue - for dx, dy in directions: - nx, ny = x + dx, y + dy - if (nx, ny) not in visited: - visited.add((nx, ny)) - queue.append((nx, ny, dist + 1)) - - # If the queue is empty and no free cell is found, return the original position - return (start_x, start_y) - - -def astar( - costmap: Costmap, - goal: VectorLike, - start: VectorLike = (0.0, 0.0), - cost_threshold: int = 90, - allow_diagonal: bool = True, -) -> Optional[Path]: - """ - A* path planning algorithm from start to goal position. - - Args: - costmap: Costmap object containing the environment - goal: Goal position as any vector-like object - start: Start position as any vector-like object (default: origin [0,0]) - cost_threshold: Cost threshold above which a cell is considered an obstacle - allow_diagonal: Whether to allow diagonal movements - - Returns: - Path object containing waypoints, or None if no path found - """ - # Convert world coordinates to grid coordinates directly using vector-like inputs - start_vector = costmap.world_to_grid(start) - goal_vector = costmap.world_to_grid(goal) - - # Store original positions for reference - original_start = (int(start_vector.x), int(start_vector.y)) - original_goal = (int(goal_vector.x), int(goal_vector.y)) - - adjusted_start = original_start - adjusted_goal = original_goal - - # Check if start is out of bounds or in an obstacle - start_valid = 0 <= start_vector.x < costmap.width and 0 <= start_vector.y < costmap.height - - start_in_obstacle = False - if start_valid: - start_in_obstacle = costmap.grid[int(start_vector.y), int(start_vector.x)] >= cost_threshold - - if not start_valid or start_in_obstacle: - print("Start position is out of bounds or in an obstacle, finding nearest free cell") - adjusted_start = find_nearest_free_cell(costmap, start, cost_threshold) - # Update start_vector for later use - start_vector = Vector(adjusted_start[0], adjusted_start[1]) - - # Check if goal is out of bounds or in an obstacle - goal_valid = 0 <= goal_vector.x < costmap.width and 0 <= goal_vector.y < costmap.height - - goal_in_obstacle = False - if goal_valid: - goal_in_obstacle = costmap.grid[int(goal_vector.y), int(goal_vector.x)] >= cost_threshold - - if not goal_valid or goal_in_obstacle: - print("Goal position is out of bounds or in an obstacle, finding nearest free cell") - adjusted_goal = find_nearest_free_cell(costmap, goal, cost_threshold) - # Update goal_vector for later use - goal_vector = Vector(adjusted_goal[0], adjusted_goal[1]) - - # Define possible movements (8-connected grid) - if allow_diagonal: - # 8-connected grid: horizontal, vertical, and diagonal movements - directions = [ - (0, 1), - (1, 0), - (0, -1), - (-1, 0), - (1, 1), - (1, -1), - (-1, 1), - (-1, -1), - ] - else: - # 4-connected grid: only horizontal and vertical ts - directions = [(0, 1), (1, 0), (0, -1), (-1, 0)] - - # Cost for each movement (straight vs diagonal) - sc = 1.0 - dc = 1.42 - movement_costs = [sc, sc, sc, sc, dc, dc, dc, dc] if allow_diagonal else [sc, sc, sc, sc] - - # A* algorithm implementation - open_set = [] # Priority queue for nodes to explore - closed_set = set() # Set of explored nodes - - # Use adjusted positions as tuples for dictionary keys - start_tuple = adjusted_start - goal_tuple = adjusted_goal - - # Dictionary to store cost from start and parents for each node - g_score = {start_tuple: 0} - parents = {} - - # Heuristic function (Euclidean distance) - def heuristic(x1, y1, x2, y2): - return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) - - # Start with the starting node - f_score = g_score[start_tuple] + heuristic( - start_tuple[0], start_tuple[1], goal_tuple[0], goal_tuple[1] - ) - heapq.heappush(open_set, (f_score, start_tuple)) - - while open_set: - # Get the node with the lowest f_score - _, current = heapq.heappop(open_set) - current_x, current_y = current - - # Check if we've reached the goal - if current == goal_tuple: - # Reconstruct the path - waypoints = [] - while current in parents: - world_point = costmap.grid_to_world(current) - waypoints.append(world_point) - current = parents[current] - - # Add the start position - start_world_point = costmap.grid_to_world(start_tuple) - waypoints.append(start_world_point) - - # Reverse the path (start to goal) - waypoints.reverse() - - # Add the goal position if it's not already included - goal_point = costmap.grid_to_world(goal_tuple) - - if not waypoints or waypoints[-1].distance(goal_point) > 1e-5: - waypoints.append(goal_point) - - # If we adjusted the goal, add the original goal as the final point - if adjusted_goal != original_goal and goal_valid: - original_goal_point = costmap.grid_to_world(original_goal) - waypoints.append(original_goal_point) - - return Path(waypoints) - - # Add current node to closed set - closed_set.add(current) - - # Explore neighbors - for i, (dx, dy) in enumerate(directions): - neighbor_x, neighbor_y = current_x + dx, current_y + dy - neighbor = (neighbor_x, neighbor_y) - - # Check if the neighbor is valid - if not (0 <= neighbor_x < costmap.width and 0 <= neighbor_y < costmap.height): - continue - - # Check if the neighbor is already explored - if neighbor in closed_set: - continue - - # Check if the neighbor is an obstacle - neighbor_val = costmap.grid[neighbor_y, neighbor_x] - if neighbor_val >= cost_threshold: # or neighbor_val < 0: - continue - - obstacle_proximity_penalty = costmap.grid[neighbor_y, neighbor_x] / 25 - tentative_g_score = ( - g_score[current] - + movement_costs[i] - + (obstacle_proximity_penalty * movement_costs[i]) - ) - - # Get the current g_score for the neighbor or set to infinity if not yet explored - neighbor_g_score = g_score.get(neighbor, float("inf")) - - # If this path to the neighbor is better than any previous one - if tentative_g_score < neighbor_g_score: - # Update the neighbor's scores and parent - parents[neighbor] = current - g_score[neighbor] = tentative_g_score - f_score = tentative_g_score + heuristic( - neighbor_x, neighbor_y, goal_tuple[0], goal_tuple[1] - ) - - # Add the neighbor to the open set with its f_score - heapq.heappush(open_set, (f_score, neighbor)) - - # If we get here, no path was found - return None diff --git a/build/lib/dimos/robot/global_planner/planner.py b/build/lib/dimos/robot/global_planner/planner.py deleted file mode 100644 index 55eea616a0..0000000000 --- a/build/lib/dimos/robot/global_planner/planner.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -from abc import abstractmethod -from dataclasses import dataclass -from typing import Callable, Optional - -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Vector3 -from dimos.robot.global_planner.algo import astar -from dimos.types.costmap import Costmap -from dimos.types.path import Path -from dimos.types.vector import Vector, VectorLike, to_vector -from dimos.utils.logging_config import setup_logger -from dimos.web.websocket_vis.helpers import Visualizable - -logger = setup_logger("dimos.robot.unitree.global_planner") - - -@dataclass -class Planner(Visualizable, Module): - target: In[Vector3] = None - path: Out[Path] = None - - def __init__(self): - Module.__init__(self) - Visualizable.__init__(self) - - # def set_goal( - # self, - # goal: VectorLike, - # goal_theta: Optional[float] = None, - # stop_event: Optional[threading.Event] = None, - # ): - # path = self.plan(goal) - # if not path: - # logger.warning("No path found to the goal.") - # return False - - # print("pathing success", path) - # return self.set_local_nav(path, stop_event=stop_event, goal_theta=goal_theta) - - -class AstarPlanner(Planner): - target: In[Vector3] = None - path: Out[Path] = None - - get_costmap: Callable[[], Costmap] - get_robot_pos: Callable[[], Vector3] - - conservativism: int = 8 - - def __init__( - self, - get_costmap: Callable[[], Costmap], - get_robot_pos: Callable[[], Vector3], - ): - super().__init__() - self.get_costmap = get_costmap - self.get_robot_pos = get_robot_pos - - @rpc - def start(self): - self.target.subscribe(self.plan) - - def plan(self, goal: VectorLike) -> Path: - print("planning path to goal", goal) - goal = to_vector(goal).to_2d() - pos = self.get_robot_pos() - print("current pos", pos) - costmap = self.get_costmap().smudge() - - print("current costmap", costmap) - self.vis("target", goal) - - print("ASTAR ", costmap, goal, pos) - path = astar(costmap, goal, pos) - - if path: - path = path.resample(0.1) - self.vis("a*", path) - self.path.publish(path) - return path - logger.warning("No path found to the goal.") diff --git a/build/lib/dimos/robot/local_planner/__init__.py b/build/lib/dimos/robot/local_planner/__init__.py deleted file mode 100644 index 472b58dcd2..0000000000 --- a/build/lib/dimos/robot/local_planner/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from dimos.robot.local_planner.local_planner import ( - BaseLocalPlanner, - navigate_to_goal_local, - navigate_path_local, -) - -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner diff --git a/build/lib/dimos/robot/local_planner/local_planner.py b/build/lib/dimos/robot/local_planner/local_planner.py deleted file mode 100644 index 286ee94f2b..0000000000 --- a/build/lib/dimos/robot/local_planner/local_planner.py +++ /dev/null @@ -1,1442 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import numpy as np -from typing import Dict, Tuple, Optional, Callable, Any -from abc import ABC, abstractmethod -import cv2 -from reactivex import Observable -from reactivex.subject import Subject -import threading -import time -import logging -from collections import deque -from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import normalize_angle, distance_angle_to_goal_xy - -from dimos.types.vector import VectorLike, Vector, to_tuple -from dimos.types.path import Path -from dimos.types.costmap import Costmap - -logger = setup_logger("dimos.robot.unitree.local_planner", level=logging.DEBUG) - - -class BaseLocalPlanner(ABC): - """ - Abstract base class for local planners that handle obstacle avoidance and path following. - - This class defines the common interface and shared functionality that all local planners - must implement, regardless of the specific algorithm used. - - Args: - get_costmap: Function to get the latest local costmap - get_robot_pose: Function to get the latest robot pose (returning odom object) - move: Function to send velocity commands - safety_threshold: Distance to maintain from obstacles (meters) - max_linear_vel: Maximum linear velocity (m/s) - max_angular_vel: Maximum angular velocity (rad/s) - lookahead_distance: Lookahead distance for path following (meters) - goal_tolerance: Distance at which the goal is considered reached (meters) - angle_tolerance: Angle at which the goal orientation is considered reached (radians) - robot_width: Width of the robot for visualization (meters) - robot_length: Length of the robot for visualization (meters) - visualization_size: Size of the visualization image in pixels - control_frequency: Frequency at which the planner is called (Hz) - safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) - max_recovery_attempts: Maximum number of recovery attempts before failing navigation. - If the robot gets stuck and cannot recover within this many attempts, navigation will fail. - global_planner_plan: Optional callable to plan a global path to the goal. - If provided, this will be used to generate a path to the goal before local planning. - """ - - def __init__( - self, - get_costmap: Callable[[], Optional[Costmap]], - get_robot_pose: Callable[[], Any], - move: Callable[[Vector], None], - safety_threshold: float = 0.5, - max_linear_vel: float = 0.8, - max_angular_vel: float = 1.0, - lookahead_distance: float = 1.0, - goal_tolerance: float = 0.75, - angle_tolerance: float = 0.5, - robot_width: float = 0.5, - robot_length: float = 0.7, - visualization_size: int = 400, - control_frequency: float = 10.0, - safe_goal_distance: float = 1.5, - max_recovery_attempts: int = 4, - global_planner_plan: Optional[Callable[[VectorLike], Optional[Any]]] = None, - ): # Control frequency in Hz - # Store callables for robot interactions - self.get_costmap = get_costmap - self.get_robot_pose = get_robot_pose - self.move = move - - # Store parameters - self.safety_threshold = safety_threshold - self.max_linear_vel = max_linear_vel - self.max_angular_vel = max_angular_vel - self.lookahead_distance = lookahead_distance - self.goal_tolerance = goal_tolerance - self.angle_tolerance = angle_tolerance - self.robot_width = robot_width - self.robot_length = robot_length - self.visualization_size = visualization_size - self.control_frequency = control_frequency - self.control_period = 1.0 / control_frequency # Period in seconds - self.safe_goal_distance = safe_goal_distance # Distance to ignore obstacles at goal - self.ignore_obstacles = False # Flag for derived classes to check - self.max_recovery_attempts = max_recovery_attempts # Maximum recovery attempts - self.recovery_attempts = 0 # Current number of recovery attempts - self.global_planner_plan = global_planner_plan # Global planner function for replanning - - # Goal and Waypoint Tracking - self.goal_xy: Optional[Tuple[float, float]] = None # Current target for planning - self.goal_theta: Optional[float] = None # Goal orientation (radians) - self.waypoints: Optional[Path] = None # List of waypoints to follow - self.waypoints_in_absolute: Optional[Path] = None # Full path in absolute frame - self.waypoint_is_relative: bool = False # Whether waypoints are in relative frame - self.current_waypoint_index: int = 0 # Index of the next waypoint to reach - self.final_goal_reached: bool = False # Flag indicating if the final waypoint is reached - self.position_reached: bool = False # Flag indicating if position goal is reached - - # Stuck detection - self.stuck_detection_window_seconds = 4.0 # Time window for stuck detection (seconds) - self.position_history_size = int(self.stuck_detection_window_seconds * control_frequency) - self.position_history = deque( - maxlen=self.position_history_size - ) # History of recent positions - self.stuck_distance_threshold = 0.15 # Distance threshold for stuck detection (meters) - self.unstuck_distance_threshold = ( - 0.5 # Distance threshold for unstuck detection (meters) - increased hysteresis - ) - self.stuck_time_threshold = 3.0 # Time threshold for stuck detection (seconds) - increased - self.is_recovery_active = False # Whether recovery behavior is active - self.recovery_start_time = 0.0 # When recovery behavior started - self.recovery_duration = ( - 10.0 # How long to run recovery before giving up (seconds) - increased - ) - self.last_update_time = time.time() # Last time position was updated - self.navigation_failed = False # Flag indicating if navigation should be terminated - - # Recovery improvements - self.recovery_cooldown_time = ( - 3.0 # Seconds to wait after recovery before checking stuck again - ) - self.last_recovery_end_time = 0.0 # When the last recovery ended - self.pre_recovery_position = ( - None # Position when recovery started (for better stuck detection) - ) - self.backup_duration = 4.0 # How long to backup when stuck (seconds) - - # Cached data updated periodically for consistent plan() execution time - self._robot_pose = None - self._costmap = None - self._update_frequency = 10.0 # Hz - how often to update cached data - self._update_timer = None - self._start_periodic_updates() - - def _start_periodic_updates(self): - self._update_timer = threading.Thread(target=self._periodic_update, daemon=True) - self._update_timer.start() - - def _periodic_update(self): - while True: - self._robot_pose = self.get_robot_pose() - self._costmap = self.get_costmap() - time.sleep(1.0 / self._update_frequency) - - def reset(self): - """ - Reset all navigation and state tracking variables. - Should be called whenever a new goal is set. - """ - # Reset stuck detection state - self.position_history.clear() - self.is_recovery_active = False - self.recovery_start_time = 0.0 - self.last_update_time = time.time() - - # Reset navigation state flags - self.navigation_failed = False - self.position_reached = False - self.final_goal_reached = False - self.ignore_obstacles = False - - # Reset recovery improvements - self.last_recovery_end_time = 0.0 - self.pre_recovery_position = None - - # Reset recovery attempts - self.recovery_attempts = 0 - - # Clear waypoint following state - self.waypoints = None - self.current_waypoint_index = 0 - self.goal_xy = None # Clear previous goal - self.goal_theta = None # Clear previous goal orientation - - logger.info("Local planner state has been reset") - - def _get_robot_pose(self) -> Tuple[Tuple[float, float], float]: - """ - Get the current robot position and orientation. - - Returns: - Tuple containing: - - position as (x, y) tuple - - orientation (theta) in radians - """ - if self._robot_pose is None: - return ((0.0, 0.0), 0.0) # Fallback if not yet initialized - pos, rot = self._robot_pose.pos, self._robot_pose.rot - return (pos.x, pos.y), rot.z - - def _get_costmap(self): - """Get cached costmap data.""" - return self._costmap - - def clear_cache(self): - """Clear all cached data to force fresh retrieval on next access.""" - self._robot_pose = None - self._costmap = None - - def set_goal( - self, goal_xy: VectorLike, is_relative: bool = False, goal_theta: Optional[float] = None - ): - """Set a single goal position, converting to absolute frame if necessary. - This clears any existing waypoints being followed. - - Args: - goal_xy: The goal position to set. - is_relative: Whether the goal is in the robot's relative frame. - goal_theta: Optional goal orientation in radians - """ - # Reset all state variables - self.reset() - - target_goal_xy: Optional[Tuple[float, float]] = None - - # Transform goal to absolute frame if it's relative - if is_relative: - # Get current robot pose - odom = self._robot_pose - if odom is None: - logger.warning("Robot pose not yet available, cannot set relative goal") - return - robot_pos, robot_rot = odom.pos, odom.rot - - # Extract current position and orientation - robot_x, robot_y = robot_pos.x, robot_pos.y - robot_theta = robot_rot.z # Assuming rotation is euler angles - - # Transform the relative goal into absolute coordinates - goal_x, goal_y = to_tuple(goal_xy) - # Rotate - abs_x = goal_x * math.cos(robot_theta) - goal_y * math.sin(robot_theta) - abs_y = goal_x * math.sin(robot_theta) + goal_y * math.cos(robot_theta) - # Translate - target_goal_xy = (robot_x + abs_x, robot_y + abs_y) - - logger.info( - f"Goal set in relative frame, converted to absolute: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" - ) - else: - target_goal_xy = to_tuple(goal_xy) - logger.info( - f"Goal set directly in absolute frame: ({target_goal_xy[0]:.2f}, {target_goal_xy[1]:.2f})" - ) - - # Check if goal is valid (in bounds and not colliding) - if not self.is_goal_in_costmap_bounds(target_goal_xy) or self.check_goal_collision( - target_goal_xy - ): - logger.warning( - "Goal is in collision or out of bounds. Adjusting goal to valid position." - ) - self.goal_xy = self.adjust_goal_to_valid_position(target_goal_xy) - else: - self.goal_xy = target_goal_xy # Set the adjusted or original valid goal - - # Set goal orientation if provided - if goal_theta is not None: - if is_relative: - # Transform the orientation to absolute frame - odom = self._robot_pose - if odom is None: - logger.warning( - "Robot pose not yet available, cannot set relative goal orientation" - ) - return - robot_theta = odom.rot.z - self.goal_theta = normalize_angle(goal_theta + robot_theta) - else: - self.goal_theta = goal_theta - - def set_goal_waypoints(self, waypoints: Path, goal_theta: Optional[float] = None): - """Sets a path of waypoints for the robot to follow. - - Args: - waypoints: A list of waypoints to follow. Each waypoint is a tuple of (x, y) coordinates in absolute frame. - goal_theta: Optional final orientation in radians - """ - # Reset all state variables - self.reset() - - if not isinstance(waypoints, Path) or len(waypoints) == 0: - logger.warning("Invalid or empty path provided to set_goal_waypoints. Ignoring.") - self.waypoints = None - self.waypoint_is_relative = False - self.goal_xy = None - self.goal_theta = None - self.current_waypoint_index = 0 - return - - logger.info(f"Setting goal waypoints with {len(waypoints)} points.") - self.waypoints = waypoints - self.waypoint_is_relative = False - self.current_waypoint_index = 0 - - # Waypoints are always in absolute frame - self.waypoints_in_absolute = waypoints - - # Set the initial target to the first waypoint, adjusting if necessary - first_waypoint = self.waypoints_in_absolute[0] - if not self.is_goal_in_costmap_bounds(first_waypoint) or self.check_goal_collision( - first_waypoint - ): - logger.warning("First waypoint is invalid. Adjusting...") - self.goal_xy = self.adjust_goal_to_valid_position(first_waypoint) - else: - self.goal_xy = to_tuple(first_waypoint) # Initial target - - # Set goal orientation if provided - if goal_theta is not None: - self.goal_theta = goal_theta - - def _get_final_goal_position(self) -> Optional[Tuple[float, float]]: - """ - Get the final goal position (either last waypoint or direct goal). - - Returns: - Tuple (x, y) of the final goal, or None if no goal is set - """ - if self.waypoints_in_absolute is not None and len(self.waypoints_in_absolute) > 0: - return to_tuple(self.waypoints_in_absolute[-1]) - elif self.goal_xy is not None: - return self.goal_xy - return None - - def _distance_to_position(self, target_position: Tuple[float, float]) -> float: - """ - Calculate distance from the robot to a target position. - - Args: - target_position: Target (x, y) position - - Returns: - Distance in meters - """ - robot_pos, _ = self._get_robot_pose() - return np.linalg.norm( - [target_position[0] - robot_pos[0], target_position[1] - robot_pos[1]] - ) - - def plan(self) -> Dict[str, float]: - """ - Main planning method that computes velocity commands. - This includes common planning logic like waypoint following, - with algorithm-specific calculations delegated to subclasses. - - Returns: - Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys - """ - # If goal orientation is specified, rotate to match it - if ( - self.position_reached - and self.goal_theta is not None - and not self._is_goal_orientation_reached() - ): - return self._rotate_to_goal_orientation() - elif self.position_reached and self.goal_theta is None: - self.final_goal_reached = True - logger.info("Position goal reached. Stopping.") - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Check if the robot is stuck and handle accordingly - if self.check_if_stuck() and not self.position_reached: - # Check if we're stuck but close to our goal - final_goal_pos = self._get_final_goal_position() - - # If we have a goal position, check distance to it - if final_goal_pos is not None: - distance_to_goal = self._distance_to_position(final_goal_pos) - - # If we're stuck but within 2x safe_goal_distance of the goal, consider it a success - if distance_to_goal < 2.0 * self.safe_goal_distance: - logger.info( - f"Robot is stuck but within {distance_to_goal:.2f}m of goal (< {2.0 * self.safe_goal_distance:.2f}m). Considering navigation successful." - ) - self.position_reached = True - return {"x_vel": 0.0, "angular_vel": 0.0} - - if self.navigation_failed: - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Otherwise, execute normal recovery behavior - logger.warning("Robot is stuck - executing recovery behavior") - return self.execute_recovery_behavior() - - # Reset obstacle ignore flag - self.ignore_obstacles = False - - # --- Waypoint Following Mode --- - if self.waypoints is not None: - if self.final_goal_reached: - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Get current robot pose - robot_pos, robot_theta = self._get_robot_pose() - robot_pos_np = np.array(robot_pos) - - # Check if close to final waypoint - if self.waypoints_in_absolute is not None and len(self.waypoints_in_absolute) > 0: - final_waypoint = self.waypoints_in_absolute[-1] - dist_to_final = np.linalg.norm(robot_pos_np - final_waypoint) - - # If we're close to the final waypoint, adjust it and ignore obstacles - if dist_to_final < self.safe_goal_distance: - final_wp_tuple = to_tuple(final_waypoint) - adjusted_goal = self.adjust_goal_to_valid_position(final_wp_tuple) - # Create a new Path with the adjusted final waypoint - new_waypoints = self.waypoints_in_absolute[:-1] # Get all but the last waypoint - new_waypoints.append(adjusted_goal) # Append the adjusted goal - self.waypoints_in_absolute = new_waypoints - self.ignore_obstacles = True - - # Update the target goal based on waypoint progression - just_reached_final = self._update_waypoint_target(robot_pos_np) - - # If the helper indicates the final goal was just reached, stop immediately - if just_reached_final: - return {"x_vel": 0.0, "angular_vel": 0.0} - - # --- Single Goal or Current Waypoint Target Set --- - if self.goal_xy is None: - # If no goal is set (e.g., empty path or rejected goal), stop. - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Get necessary data for planning - costmap = self._get_costmap() - if costmap is None: - logger.warning("Local costmap is None. Cannot plan.") - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Check if close to single goal mode goal - if self.waypoints is None: - # Get distance to goal - goal_distance = self._distance_to_position(self.goal_xy) - - # If within safe distance of goal, adjust it and ignore obstacles - if goal_distance < self.safe_goal_distance: - self.goal_xy = self.adjust_goal_to_valid_position(self.goal_xy) - self.ignore_obstacles = True - - # First check position - if goal_distance < self.goal_tolerance or self.position_reached: - self.position_reached = True - - else: - self.position_reached = False - - # Call the algorithm-specific planning implementation - return self._compute_velocity_commands() - - @abstractmethod - def _compute_velocity_commands(self) -> Dict[str, float]: - """ - Algorithm-specific method to compute velocity commands. - Must be implemented by derived classes. - - Returns: - Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys - """ - pass - - def _rotate_to_goal_orientation(self) -> Dict[str, float]: - """Compute velocity commands to rotate to the goal orientation. - - Returns: - Dict[str, float]: Velocity commands with zero linear velocity - """ - # Get current robot orientation - _, robot_theta = self._get_robot_pose() - - # Calculate the angle difference - angle_diff = normalize_angle(self.goal_theta - robot_theta) - - # Determine rotation direction and speed - if abs(angle_diff) < self.angle_tolerance: - # Already at correct orientation - return {"x_vel": 0.0, "angular_vel": 0.0} - - # Calculate rotation speed - proportional to the angle difference - # but capped at max_angular_vel - direction = 1.0 if angle_diff > 0 else -1.0 - angular_vel = direction * min(abs(angle_diff), self.max_angular_vel) - - return {"x_vel": 0.0, "angular_vel": angular_vel} - - def _is_goal_orientation_reached(self) -> bool: - """Check if the current robot orientation matches the goal orientation. - - Returns: - bool: True if orientation is reached or no orientation goal is set - """ - if self.goal_theta is None: - return True # No orientation goal set - - # Get current robot orientation - _, robot_theta = self._get_robot_pose() - - # Calculate the angle difference and normalize - angle_diff = abs(normalize_angle(self.goal_theta - robot_theta)) - - return angle_diff <= self.angle_tolerance - - def _update_waypoint_target(self, robot_pos_np: np.ndarray) -> bool: - """Helper function to manage waypoint progression and update the target goal. - - Args: - robot_pos_np: Current robot position as a numpy array [x, y]. - - Returns: - bool: True if the final waypoint has just been reached, False otherwise. - """ - if self.waypoints is None or len(self.waypoints) == 0: - return False # Not in waypoint mode or empty path - - # Waypoints are always in absolute frame - self.waypoints_in_absolute = self.waypoints - - # Check if final goal is reached - final_waypoint = self.waypoints_in_absolute[-1] - dist_to_final = np.linalg.norm(robot_pos_np - final_waypoint) - - if dist_to_final <= self.goal_tolerance: - # Final waypoint position reached - if self.goal_theta is not None: - # Check orientation if specified - if self._is_goal_orientation_reached(): - self.final_goal_reached = True - return True - # Continue rotating - self.position_reached = True - return False - else: - # No orientation goal, mark as reached - self.final_goal_reached = True - return True - - # Always find the lookahead point - lookahead_point = None - for i in range(self.current_waypoint_index, len(self.waypoints_in_absolute)): - wp = self.waypoints_in_absolute[i] - dist_to_wp = np.linalg.norm(robot_pos_np - wp) - if dist_to_wp >= self.lookahead_distance: - lookahead_point = wp - # Update current waypoint index to this point - self.current_waypoint_index = i - break - - # If no point is far enough, target the final waypoint - if lookahead_point is None: - lookahead_point = self.waypoints_in_absolute[-1] - self.current_waypoint_index = len(self.waypoints_in_absolute) - 1 - - # Set the lookahead point as the immediate target, adjusting if needed - if not self.is_goal_in_costmap_bounds(lookahead_point) or self.check_goal_collision( - lookahead_point - ): - adjusted_lookahead = self.adjust_goal_to_valid_position(lookahead_point) - # Only update if adjustment didn't fail completely - if adjusted_lookahead is not None: - self.goal_xy = adjusted_lookahead - else: - self.goal_xy = to_tuple(lookahead_point) - - return False # Final goal not reached in this update cycle - - @abstractmethod - def update_visualization(self) -> np.ndarray: - """ - Generate visualization of the planning state. - Must be implemented by derived classes. - - Returns: - np.ndarray: Visualization image as numpy array - """ - pass - - def create_stream(self, frequency_hz: float = None) -> Observable: - """ - Create an Observable stream that emits the visualization image at a fixed frequency. - - Args: - frequency_hz: Optional frequency override (defaults to 1/4 of control_frequency if None) - - Returns: - Observable: Stream of visualization frames - """ - # Default to 1/4 of control frequency if not specified (to reduce CPU usage) - if frequency_hz is None: - frequency_hz = self.control_frequency / 4.0 - - subject = Subject() - sleep_time = 1.0 / frequency_hz - - def frame_emitter(): - while True: - try: - # Generate the frame using the updated method - frame = self.update_visualization() - subject.on_next(frame) - except Exception as e: - logger.error(f"Error in frame emitter thread: {e}") - # Optionally, emit an error frame or simply skip - # subject.on_error(e) # This would terminate the stream - time.sleep(sleep_time) - - emitter_thread = threading.Thread(target=frame_emitter, daemon=True) - emitter_thread.start() - logger.info(f"Started visualization frame emitter thread at {frequency_hz:.1f} Hz") - return subject - - @abstractmethod - def check_collision(self, direction: float) -> bool: - """ - Check if there's a collision in the given direction. - Must be implemented by derived classes. - - Args: - direction: Direction to check for collision in radians - - Returns: - bool: True if collision detected, False otherwise - """ - pass - - def is_goal_reached(self) -> bool: - """Check if the final goal (single or last waypoint) is reached, including orientation.""" - if self.waypoints is not None: - # Waypoint mode: check if the final waypoint and orientation have been reached - return self.final_goal_reached and self._is_goal_orientation_reached() - else: - # Single goal mode: check distance to the single goal and orientation - if self.goal_xy is None: - return False # No goal set - - if self.goal_theta is None: - return self.position_reached - - return self.position_reached and self._is_goal_orientation_reached() - - def check_goal_collision(self, goal_xy: VectorLike) -> bool: - """Check if the current goal is in collision with obstacles in the costmap. - - Returns: - bool: True if goal is in collision, False if goal is safe or cannot be checked - """ - - costmap = self._get_costmap() - if costmap is None: - logger.warning("Cannot check collision: No costmap available") - return False - - # Check if the position is occupied - collision_threshold = 80 # Consider values above 80 as obstacles - - # Use Costmap's is_occupied method - return costmap.is_occupied(goal_xy, threshold=collision_threshold) - - def is_goal_in_costmap_bounds(self, goal_xy: VectorLike) -> bool: - """Check if the goal position is within the bounds of the costmap. - - Args: - goal_xy: Goal position (x, y) in odom frame - - Returns: - bool: True if the goal is within the costmap bounds, False otherwise - """ - costmap = self._get_costmap() - if costmap is None: - logger.warning("Cannot check bounds: No costmap available") - return False - - # Get goal position in grid coordinates - goal_point = costmap.world_to_grid(goal_xy) - goal_cell_x, goal_cell_y = goal_point.x, goal_point.y - - # Check if goal is within the costmap bounds - is_in_bounds = 0 <= goal_cell_x < costmap.width and 0 <= goal_cell_y < costmap.height - - if not is_in_bounds: - logger.warning(f"Goal ({goal_xy[0]:.2f}, {goal_xy[1]:.2f}) is outside costmap bounds") - - return is_in_bounds - - def adjust_goal_to_valid_position( - self, goal_xy: VectorLike, clearance: float = 0.5 - ) -> Tuple[float, float]: - """Find a valid (non-colliding) goal position by moving it towards the robot. - - Args: - goal_xy: Original goal position (x, y) in odom frame - clearance: Additional distance to move back from obstacles for better clearance (meters) - - Returns: - Tuple[float, float]: A valid goal position, or the original goal if already valid - """ - [pos, rot] = self._get_robot_pose() - - robot_x, robot_y = pos[0], pos[1] - - # Original goal - goal_x, goal_y = to_tuple(goal_xy) - - if not self.check_goal_collision((goal_x, goal_y)): - return (goal_x, goal_y) - - # Calculate vector from goal to robot - dx = robot_x - goal_x - dy = robot_y - goal_y - distance = np.sqrt(dx * dx + dy * dy) - - if distance < 0.001: # Goal is at robot position - return to_tuple(goal_xy) - - # Normalize direction vector - dx /= distance - dy /= distance - - # Step size - step_size = 0.25 # meters - - # Move goal towards robot step by step - current_x, current_y = goal_x, goal_y - steps = 0 - max_steps = 50 # Safety limit - - # Variables to store the first valid position found - valid_found = False - valid_x, valid_y = None, None - - while steps < max_steps: - # Move towards robot - current_x += dx * step_size - current_y += dy * step_size - steps += 1 - - # Check if we've reached or passed the robot - new_distance = np.sqrt((current_x - robot_x) ** 2 + (current_y - robot_y) ** 2) - if new_distance < step_size: - # We've reached the robot without finding a valid point - # Move back one step from robot to avoid self-collision - current_x = robot_x - dx * step_size - current_y = robot_y - dy * step_size - break - - # Check if this position is valid - if not self.check_goal_collision( - (current_x, current_y) - ) and self.is_goal_in_costmap_bounds((current_x, current_y)): - # Store the first valid position - if not valid_found: - valid_found = True - valid_x, valid_y = current_x, current_y - - # If clearance is requested, continue searching for a better position - if clearance > 0: - continue - - # Calculate position with additional clearance - if clearance > 0: - # Calculate clearance position - clearance_x = current_x + dx * clearance - clearance_y = current_y + dy * clearance - - # Check if the clearance position is also valid - if not self.check_goal_collision( - (clearance_x, clearance_y) - ) and self.is_goal_in_costmap_bounds((clearance_x, clearance_y)): - return (clearance_x, clearance_y) - - # Return the valid position without clearance - return (current_x, current_y) - - # If we found a valid position earlier but couldn't add clearance - if valid_found: - return (valid_x, valid_y) - - logger.warning( - f"Could not find valid goal after {steps} steps, using closest point to robot" - ) - return (current_x, current_y) - - def check_if_stuck(self) -> bool: - """ - Check if the robot is stuck by analyzing movement history. - Includes improvements to prevent oscillation between stuck and recovered states. - - Returns: - bool: True if the robot is determined to be stuck, False otherwise - """ - # Get current position and time - current_time = time.time() - - # Get current robot position - [pos, _] = self._get_robot_pose() - current_position = (pos[0], pos[1], current_time) - - # If we're already in recovery, don't add movements to history (they're intentional) - # Instead, check if we should continue or end recovery - if self.is_recovery_active: - # Check if we've moved far enough from our pre-recovery position to consider unstuck - if self.pre_recovery_position is not None: - pre_recovery_x, pre_recovery_y = self.pre_recovery_position[:2] - displacement_from_start = np.sqrt( - (pos[0] - pre_recovery_x) ** 2 + (pos[1] - pre_recovery_y) ** 2 - ) - - # If we've moved far enough, we're unstuck - if displacement_from_start > self.unstuck_distance_threshold: - logger.info( - f"Robot has escaped from stuck state (moved {displacement_from_start:.3f}m from start)" - ) - self.is_recovery_active = False - self.last_recovery_end_time = current_time - # Do not reset recovery attempts here - only reset during replanning or goal reaching - # Clear position history to start fresh tracking - self.position_history.clear() - return False - - # Check if we've been trying to recover for too long - recovery_time = current_time - self.recovery_start_time - if recovery_time > self.recovery_duration: - logger.error( - f"Recovery behavior has been active for {self.recovery_duration}s without success" - ) - self.navigation_failed = True - return True - - # Continue recovery - return True - - # Check cooldown period - don't immediately check for stuck after recovery - if current_time - self.last_recovery_end_time < self.recovery_cooldown_time: - # Add position to history but don't check for stuck yet - self.position_history.append(current_position) - return False - - # Add current position to history (newest is appended at the end) - self.position_history.append(current_position) - - # Need enough history to make a determination - min_history_size = int( - self.stuck_detection_window_seconds * self.control_frequency * 0.6 - ) # 60% of window - if len(self.position_history) < min_history_size: - return False - - # Find positions within our detection window - window_start_time = current_time - self.stuck_detection_window_seconds - window_positions = [] - - # Collect positions within the window (newest entries will be at the end) - for pos_x, pos_y, timestamp in self.position_history: - if timestamp >= window_start_time: - window_positions.append((pos_x, pos_y, timestamp)) - - # Need at least a few positions in the window - if len(window_positions) < 3: - return False - - # Ensure correct order: oldest to newest - window_positions.sort(key=lambda p: p[2]) - - # Get the oldest and newest positions in the window - oldest_x, oldest_y, oldest_time = window_positions[0] - newest_x, newest_y, newest_time = window_positions[-1] - - # Calculate time range in the window - time_range = newest_time - oldest_time - - # Calculate displacement from oldest to newest position - displacement = np.sqrt((newest_x - oldest_x) ** 2 + (newest_y - oldest_y) ** 2) - - # Also check average displacement over multiple sub-windows to avoid false positives - sub_window_size = max(3, len(window_positions) // 3) - avg_displacement = 0.0 - displacement_count = 0 - - for i in range(0, len(window_positions) - sub_window_size, sub_window_size // 2): - start_pos = window_positions[i] - end_pos = window_positions[min(i + sub_window_size, len(window_positions) - 1)] - sub_displacement = np.sqrt( - (end_pos[0] - start_pos[0]) ** 2 + (end_pos[1] - start_pos[1]) ** 2 - ) - avg_displacement += sub_displacement - displacement_count += 1 - - if displacement_count > 0: - avg_displacement /= displacement_count - - # Check if we're stuck - moved less than threshold over minimum time - is_currently_stuck = ( - time_range >= self.stuck_time_threshold - and time_range <= self.stuck_detection_window_seconds - and displacement < self.stuck_distance_threshold - and avg_displacement < self.stuck_distance_threshold * 1.5 - ) - - if is_currently_stuck: - logger.warning( - f"Robot appears to be stuck! Total displacement: {displacement:.3f}m, " - f"avg displacement: {avg_displacement:.3f}m over {time_range:.1f}s" - ) - - # Start recovery behavior - self.is_recovery_active = True - self.recovery_start_time = current_time - self.pre_recovery_position = current_position - - # Clear position history to avoid contamination during recovery - self.position_history.clear() - - # Increment recovery attempts - self.recovery_attempts += 1 - logger.warning( - f"Starting recovery attempt {self.recovery_attempts}/{self.max_recovery_attempts}" - ) - - # Check if maximum recovery attempts have been exceeded - if self.recovery_attempts > self.max_recovery_attempts: - logger.error( - f"Maximum recovery attempts ({self.max_recovery_attempts}) exceeded. Navigation failed." - ) - self.navigation_failed = True - - return True - - return False - - def execute_recovery_behavior(self) -> Dict[str, float]: - """ - Execute enhanced recovery behavior when the robot is stuck. - - First attempt: Backup for a set duration - - Second+ attempts: Replan to the original goal using global planner - - Returns: - Dict[str, float]: Velocity commands for the recovery behavior - """ - current_time = time.time() - recovery_time = current_time - self.recovery_start_time - - # First recovery attempt: Simple backup behavior - if self.recovery_attempts % 2 == 0: - if recovery_time < self.backup_duration: - logger.warning(f"Recovery attempt 1: backup for {recovery_time:.1f}s") - return {"x_vel": -0.5, "angular_vel": 0.0} # Backup at moderate speed - else: - logger.info("Recovery attempt 1: backup completed") - self.recovery_attempts += 1 - return {"x_vel": 0.0, "angular_vel": 0.0} - - final_goal = self.waypoints_in_absolute[-1] - logger.info( - f"Recovery attempt {self.recovery_attempts}: replanning to final waypoint {final_goal}" - ) - - new_path = self.global_planner_plan(Vector([final_goal[0], final_goal[1]])) - - if new_path is not None: - logger.info("Replanning successful. Setting new waypoints.") - attempts = self.recovery_attempts - self.set_goal_waypoints(new_path, self.goal_theta) - self.recovery_attempts = attempts - self.is_recovery_active = False - self.last_recovery_end_time = current_time - else: - logger.error("Global planner could not find a path to the goal. Recovery failed.") - self.navigation_failed = True - - return {"x_vel": 0.0, "angular_vel": 0.0} - - -def navigate_to_goal_local( - robot, - goal_xy_robot: Tuple[float, float], - goal_theta: Optional[float] = None, - distance: float = 0.0, - timeout: float = 60.0, - stop_event: Optional[threading.Event] = None, -) -> bool: - """ - Navigates the robot to a goal specified in the robot's local frame - using the local planner. - - Args: - robot: Robot instance to control - goal_xy_robot: Tuple (x, y) representing the goal position relative - to the robot's current position and orientation. - distance: Desired distance to maintain from the goal in meters. - If non-zero, the robot will stop this far away from the goal. - timeout: Maximum time (in seconds) allowed to reach the goal. - stop_event: Optional threading.Event to signal when navigation should stop - - Returns: - bool: True if the goal was reached within the timeout, False otherwise. - """ - logger.info( - f"Starting navigation to local goal {goal_xy_robot} with distance {distance}m and timeout {timeout}s." - ) - - robot.local_planner.reset() - - goal_x, goal_y = goal_xy_robot - - # Calculate goal orientation to face the target - if goal_theta is None: - goal_theta = np.arctan2(goal_y, goal_x) - - # If distance is non-zero, adjust the goal to stop at the desired distance - if distance > 0: - # Calculate magnitude of the goal vector - goal_distance = np.sqrt(goal_x**2 + goal_y**2) - - # Only adjust if goal is further than the desired distance - if goal_distance > distance: - goal_x, goal_y = distance_angle_to_goal_xy(goal_distance - distance, goal_theta) - - # Set the goal in the robot's frame with orientation to face the original target - robot.local_planner.set_goal((goal_x, goal_y), is_relative=True, goal_theta=goal_theta) - - # Get control period from robot's local planner for consistent timing - control_period = 1.0 / robot.local_planner.control_frequency - - start_time = time.time() - goal_reached = False - - try: - while time.time() - start_time < timeout and not (stop_event and stop_event.is_set()): - # Check if goal has been reached - if robot.local_planner.is_goal_reached(): - logger.info("Goal reached successfully.") - goal_reached = True - break - - # Check if navigation failed flag is set - if robot.local_planner.navigation_failed: - logger.error("Navigation aborted due to repeated recovery failures.") - goal_reached = False - break - - # Get planned velocity towards the goal - vel_command = robot.local_planner.plan() - x_vel = vel_command.get("x_vel", 0.0) - angular_vel = vel_command.get("angular_vel", 0.0) - - # Send velocity command - robot.local_planner.move(Vector(x_vel, 0, angular_vel)) - - # Control loop frequency - use robot's control frequency - time.sleep(control_period) - - if not goal_reached: - logger.warning(f"Navigation timed out after {timeout} seconds before reaching goal.") - - except KeyboardInterrupt: - logger.info("Navigation to local goal interrupted by user.") - goal_reached = False # Consider interruption as failure - except Exception as e: - logger.error(f"Error during navigation to local goal: {e}") - goal_reached = False # Consider error as failure - finally: - logger.info("Stopping robot after navigation attempt.") - robot.local_planner.move(Vector(0, 0, 0)) # Stop the robot - - return goal_reached - - -def navigate_path_local( - robot, - path: Path, - timeout: float = 120.0, - goal_theta: Optional[float] = None, - stop_event: Optional[threading.Event] = None, -) -> bool: - """ - Navigates the robot along a path of waypoints using the waypoint following capability - of the local planner. - - Args: - robot: Robot instance to control - path: Path object containing waypoints in absolute frame - timeout: Maximum time (in seconds) allowed to follow the complete path - goal_theta: Optional final orientation in radians - stop_event: Optional threading.Event to signal when navigation should stop - - Returns: - bool: True if the entire path was successfully followed, False otherwise - """ - logger.info( - f"Starting navigation along path with {len(path)} waypoints and timeout {timeout}s." - ) - - robot.local_planner.reset() - - # Set the path in the local planner - robot.local_planner.set_goal_waypoints(path, goal_theta=goal_theta) - - # Get control period from robot's local planner for consistent timing - control_period = 1.0 / robot.local_planner.control_frequency - - start_time = time.time() - path_completed = False - - try: - while time.time() - start_time < timeout and not (stop_event and stop_event.is_set()): - # Check if the entire path has been traversed - if robot.local_planner.is_goal_reached(): - logger.info("Path traversed successfully.") - path_completed = True - break - - # Check if navigation failed flag is set - if robot.local_planner.navigation_failed: - logger.error("Navigation aborted due to repeated recovery failures.") - path_completed = False - break - - # Get planned velocity towards the current waypoint target - vel_command = robot.local_planner.plan() - x_vel = vel_command.get("x_vel", 0.0) - angular_vel = vel_command.get("angular_vel", 0.0) - - # Send velocity command - robot.local_planner.move(Vector(x_vel, 0, angular_vel)) - - # Control loop frequency - use robot's control frequency - time.sleep(control_period) - - if not path_completed: - logger.warning( - f"Path following timed out after {timeout} seconds before completing the path." - ) - - except KeyboardInterrupt: - logger.info("Path navigation interrupted by user.") - path_completed = False - except Exception as e: - logger.error(f"Error during path navigation: {e}") - path_completed = False - finally: - logger.info("Stopping robot after path navigation attempt.") - robot.local_planner.move(Vector(0, 0, 0)) # Stop the robot - - return path_completed - - -def visualize_local_planner_state( - occupancy_grid: np.ndarray, - grid_resolution: float, - grid_origin: Tuple[float, float], - robot_pose: Tuple[float, float, float], - visualization_size: int = 400, - robot_width: float = 0.5, - robot_length: float = 0.7, - map_size_meters: float = 10.0, - goal_xy: Optional[Tuple[float, float]] = None, - goal_theta: Optional[float] = None, - histogram: Optional[np.ndarray] = None, - selected_direction: Optional[float] = None, - waypoints: Optional["Path"] = None, - current_waypoint_index: Optional[int] = None, -) -> np.ndarray: - """Generate a bird's eye view visualization of the local costmap. - Optionally includes VFH histogram, selected direction, and waypoints path. - - Args: - occupancy_grid: 2D numpy array of the occupancy grid - grid_resolution: Resolution of the grid in meters/cell - grid_origin: Tuple (x, y) of the grid origin in the odom frame - robot_pose: Tuple (x, y, theta) of the robot pose in the odom frame - visualization_size: Size of the visualization image in pixels - robot_width: Width of the robot in meters - robot_length: Length of the robot in meters - map_size_meters: Size of the map to visualize in meters - goal_xy: Optional tuple (x, y) of the goal position in the odom frame - goal_theta: Optional goal orientation in radians (in odom frame) - histogram: Optional numpy array of the VFH histogram - selected_direction: Optional selected direction angle in radians - waypoints: Optional Path object containing waypoints to visualize - current_waypoint_index: Optional index of the current target waypoint - """ - - robot_x, robot_y, robot_theta = robot_pose - grid_origin_x, grid_origin_y = grid_origin - vis_size = visualization_size - scale = vis_size / map_size_meters - - vis_img = np.ones((vis_size, vis_size, 3), dtype=np.uint8) * 255 - center_x = vis_size // 2 - center_y = vis_size // 2 - - grid_height, grid_width = occupancy_grid.shape - - # Calculate robot position relative to grid origin - robot_rel_x = robot_x - grid_origin_x - robot_rel_y = robot_y - grid_origin_y - robot_cell_x = int(robot_rel_x / grid_resolution) - robot_cell_y = int(robot_rel_y / grid_resolution) - - half_size_cells = int(map_size_meters / grid_resolution / 2) - - # Draw grid cells (using standard occupancy coloring) - for y in range( - max(0, robot_cell_y - half_size_cells), min(grid_height, robot_cell_y + half_size_cells) - ): - for x in range( - max(0, robot_cell_x - half_size_cells), min(grid_width, robot_cell_x + half_size_cells) - ): - cell_rel_x_meters = (x - robot_cell_x) * grid_resolution - cell_rel_y_meters = (y - robot_cell_y) * grid_resolution - - img_x = int(center_x + cell_rel_x_meters * scale) - img_y = int(center_y - cell_rel_y_meters * scale) # Flip y-axis - - if 0 <= img_x < vis_size and 0 <= img_y < vis_size: - cell_value = occupancy_grid[y, x] - if cell_value == -1: - color = (200, 200, 200) # Unknown (Light gray) - elif cell_value == 0: - color = (255, 255, 255) # Free (White) - else: # Occupied - # Scale darkness based on occupancy value (0-100) - darkness = 255 - int(155 * (cell_value / 100)) - 100 - color = (darkness, darkness, darkness) # Shades of gray/black - - cell_size_px = max(1, int(grid_resolution * scale)) - cv2.rectangle( - vis_img, - (img_x - cell_size_px // 2, img_y - cell_size_px // 2), - (img_x + cell_size_px // 2, img_y + cell_size_px // 2), - color, - -1, - ) - - # Draw waypoints path if provided - if waypoints is not None and len(waypoints) > 0: - try: - path_points = [] - for i, waypoint in enumerate(waypoints): - # Convert waypoint from odom frame to visualization frame - wp_x, wp_y = waypoint[0], waypoint[1] - wp_rel_x = wp_x - robot_x - wp_rel_y = wp_y - robot_y - - wp_img_x = int(center_x + wp_rel_x * scale) - wp_img_y = int(center_y - wp_rel_y * scale) # Flip y-axis - - if 0 <= wp_img_x < vis_size and 0 <= wp_img_y < vis_size: - path_points.append((wp_img_x, wp_img_y)) - - # Draw each waypoint as a small circle - cv2.circle(vis_img, (wp_img_x, wp_img_y), 3, (0, 128, 0), -1) # Dark green dots - - # Highlight current target waypoint - if current_waypoint_index is not None and i == current_waypoint_index: - cv2.circle(vis_img, (wp_img_x, wp_img_y), 6, (0, 0, 255), 2) # Red circle - - # Connect waypoints with lines to show the path - if len(path_points) > 1: - for i in range(len(path_points) - 1): - cv2.line( - vis_img, path_points[i], path_points[i + 1], (0, 200, 0), 1 - ) # Green line - except Exception as e: - logger.error(f"Error drawing waypoints: {e}") - - # Draw histogram - if histogram is not None: - num_bins = len(histogram) - # Find absolute maximum value (ignoring any negative debug values) - abs_histogram = np.abs(histogram) - max_hist_value = np.max(abs_histogram) if np.max(abs_histogram) > 0 else 1.0 - hist_scale = (vis_size / 2) * 0.8 # Scale histogram lines to 80% of half the viz size - - for i in range(num_bins): - # Angle relative to robot's forward direction - angle_relative_to_robot = (i / num_bins) * 2 * math.pi - math.pi - # Angle in the visualization frame (relative to image +X axis) - vis_angle = angle_relative_to_robot + robot_theta - - # Get the value and check if it's a special debug value (negative) - hist_val = histogram[i] - is_debug_value = hist_val < 0 - - # Use absolute value for line length - normalized_val = min(1.0, abs(hist_val) / max_hist_value) - line_length = normalized_val * hist_scale - - # Calculate endpoint using the visualization angle - end_x = int(center_x + line_length * math.cos(vis_angle)) - end_y = int(center_y - line_length * math.sin(vis_angle)) # Flipped Y - - # Color based on value and whether it's a debug value - if is_debug_value: - # Use green for debug values (minimum cost bin) - color = (0, 255, 0) # Green - line_width = 2 # Thicker line for emphasis - else: - # Regular coloring for normal values (blue to red gradient based on obstacle density) - blue = max(0, 255 - int(normalized_val * 255)) - red = min(255, int(normalized_val * 255)) - color = (blue, 0, red) # BGR format: obstacles are redder, clear areas are bluer - line_width = 1 - - cv2.line(vis_img, (center_x, center_y), (end_x, end_y), color, line_width) - - # Draw robot - robot_length_px = int(robot_length * scale) - robot_width_px = int(robot_width * scale) - robot_pts = np.array( - [ - [-robot_length_px / 2, -robot_width_px / 2], - [robot_length_px / 2, -robot_width_px / 2], - [robot_length_px / 2, robot_width_px / 2], - [-robot_length_px / 2, robot_width_px / 2], - ], - dtype=np.float32, - ) - rotation_matrix = np.array( - [ - [math.cos(robot_theta), -math.sin(robot_theta)], - [math.sin(robot_theta), math.cos(robot_theta)], - ] - ) - robot_pts = np.dot(robot_pts, rotation_matrix.T) - robot_pts[:, 0] += center_x - robot_pts[:, 1] = center_y - robot_pts[:, 1] # Flip y-axis - cv2.fillPoly( - vis_img, [robot_pts.reshape((-1, 1, 2)).astype(np.int32)], (0, 0, 255) - ) # Red robot - - # Draw robot direction line - front_x = int(center_x + (robot_length_px / 2) * math.cos(robot_theta)) - front_y = int(center_y - (robot_length_px / 2) * math.sin(robot_theta)) - cv2.line(vis_img, (center_x, center_y), (front_x, front_y), (255, 0, 0), 2) # Blue line - - # Draw selected direction - if selected_direction is not None: - # selected_direction is relative to robot frame - # Angle in the visualization frame (relative to image +X axis) - vis_angle_selected = selected_direction + robot_theta - - # Make slightly longer than max histogram line - sel_dir_line_length = (vis_size / 2) * 0.9 - - sel_end_x = int(center_x + sel_dir_line_length * math.cos(vis_angle_selected)) - sel_end_y = int(center_y - sel_dir_line_length * math.sin(vis_angle_selected)) # Flipped Y - - cv2.line( - vis_img, (center_x, center_y), (sel_end_x, sel_end_y), (0, 165, 255), 2 - ) # BGR for Orange - - # Draw goal - if goal_xy is not None: - goal_x, goal_y = goal_xy - goal_rel_x_map = goal_x - robot_x - goal_rel_y_map = goal_y - robot_y - goal_img_x = int(center_x + goal_rel_x_map * scale) - goal_img_y = int(center_y - goal_rel_y_map * scale) # Flip y-axis - if 0 <= goal_img_x < vis_size and 0 <= goal_img_y < vis_size: - cv2.circle(vis_img, (goal_img_x, goal_img_y), 5, (0, 255, 0), -1) # Green circle - cv2.circle(vis_img, (goal_img_x, goal_img_y), 8, (0, 0, 0), 1) # Black outline - - # Draw goal orientation - if goal_theta is not None and goal_xy is not None: - # For waypoint mode, only draw orientation at the final waypoint - if waypoints is not None and len(waypoints) > 0: - # Use the final waypoint position - final_waypoint = waypoints[-1] - goal_x, goal_y = final_waypoint[0], final_waypoint[1] - else: - # Use the current goal position - goal_x, goal_y = goal_xy - - goal_rel_x_map = goal_x - robot_x - goal_rel_y_map = goal_y - robot_y - goal_img_x = int(center_x + goal_rel_x_map * scale) - goal_img_y = int(center_y - goal_rel_y_map * scale) # Flip y-axis - - # Calculate goal orientation vector direction in visualization frame - # goal_theta is already in odom frame, need to adjust for visualization orientation - goal_dir_length = 30 # Length of direction indicator in pixels - goal_dir_end_x = int(goal_img_x + goal_dir_length * math.cos(goal_theta)) - goal_dir_end_y = int(goal_img_y - goal_dir_length * math.sin(goal_theta)) # Flip y-axis - - # Draw goal orientation arrow - if 0 <= goal_img_x < vis_size and 0 <= goal_img_y < vis_size: - cv2.arrowedLine( - vis_img, - (goal_img_x, goal_img_y), - (goal_dir_end_x, goal_dir_end_y), - (255, 0, 255), - 4, - ) # Magenta arrow - - # Add scale bar - scale_bar_length_px = int(1.0 * scale) - scale_bar_x = vis_size - scale_bar_length_px - 10 - scale_bar_y = vis_size - 20 - cv2.line( - vis_img, - (scale_bar_x, scale_bar_y), - (scale_bar_x + scale_bar_length_px, scale_bar_y), - (0, 0, 0), - 2, - ) - cv2.putText( - vis_img, "1m", (scale_bar_x, scale_bar_y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 0), 1 - ) - - # Add status info - status_text = [] - if waypoints is not None: - if current_waypoint_index is not None: - status_text.append(f"WP: {current_waypoint_index}/{len(waypoints)}") - else: - status_text.append(f"WPs: {len(waypoints)}") - - y_pos = 20 - for text in status_text: - cv2.putText(vis_img, text, (10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) - y_pos += 20 - - return vis_img diff --git a/build/lib/dimos/robot/local_planner/simple.py b/build/lib/dimos/robot/local_planner/simple.py deleted file mode 100644 index 8eaf20ba6c..0000000000 --- a/build/lib/dimos/robot/local_planner/simple.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import time -from dataclasses import dataclass -from typing import Any, Callable, Optional - -import reactivex as rx -from plum import dispatch -from reactivex import operators as ops - -from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Vector3 -from dimos.robot.unitree_webrtc.type.odometry import Odometry - -# from dimos.robot.local_planner.local_planner import LocalPlanner -from dimos.types.costmap import Costmap -from dimos.types.path import Path -from dimos.types.pose import Pose -from dimos.types.vector import Vector, VectorLike, to_vector -from dimos.utils.logging_config import setup_logger -from dimos.utils.threadpool import get_scheduler - -logger = setup_logger("dimos.robot.unitree.global_planner") - - -def transform_to_robot_frame(global_vector: Vector, robot_position: Pose) -> Vector: - """Transform a global coordinate vector to robot-relative coordinates. - - Args: - global_vector: Vector in global coordinates - robot_position: Robot's position and orientation - - Returns: - Vector in robot coordinates where X is forward/backward, Y is left/right - """ - # Get the robot's yaw angle (rotation around Z-axis) - robot_yaw = robot_position.rot.z - - # Create rotation matrix to transform from global to robot frame - # We need to rotate the coordinate system by -robot_yaw to get robot-relative coordinates - cos_yaw = math.cos(-robot_yaw) - sin_yaw = math.sin(-robot_yaw) - - # Apply 2D rotation transformation - # This transforms a global direction vector into the robot's coordinate frame - # In robot frame: X=forward/backward, Y=left/right - # In global frame: X=east/west, Y=north/south - robot_x = global_vector.x * cos_yaw - global_vector.y * sin_yaw # Forward/backward - robot_y = global_vector.x * sin_yaw + global_vector.y * cos_yaw # Left/right - - return Vector(-robot_x, robot_y, 0) - - -class SimplePlanner(Module): - path: In[Path] = None - odom: In[PoseStamped] = None - movecmd: Out[Vector3] = None - - get_costmap: Callable[[], Costmap] - - latest_odom: PoseStamped = None - - goal: Optional[Vector] = None - speed: float = 0.3 - - def __init__( - self, - get_costmap: Callable[[], Costmap], - ): - Module.__init__(self) - self.get_costmap = get_costmap - - def get_move_stream(self, frequency: float = 40.0) -> rx.Observable: - return rx.interval(1.0 / frequency, scheduler=get_scheduler()).pipe( - # do we have a goal? - ops.filter(lambda _: self.goal is not None), - # For testing: make robot move left/right instead of rotating - ops.map(lambda _: self._test_translational_movement()), - self.frequency_spy("movement_test"), - ) - - @rpc - def start(self): - self.path.subscribe(self.set_goal) - - def setodom(odom: Odometry): - self.latest_odom = odom - - self.odom.subscribe(setodom) - self.get_move_stream(frequency=20.0).subscribe(self.movecmd.publish) - - @dispatch - def set_goal(self, goal: Path, stop_event=None, goal_theta=None) -> bool: - self.goal = goal.last().to_2d() - logger.info(f"Setting goal: {self.goal}") - return True - - @dispatch - def set_goal(self, goal: VectorLike, stop_event=None, goal_theta=None) -> bool: - self.goal = to_vector(goal).to_2d() - logger.info(f"Setting goal: {self.goal}") - return True - - def calc_move(self, direction: Vector) -> Vector: - """Calculate the movement vector based on the direction to the goal. - - Args: - direction: Direction vector towards the goal - - Returns: - Movement vector scaled by speed - """ - try: - # Normalize the direction vector and scale by speed - normalized_direction = direction.normalize() - move_vector = normalized_direction * self.speed - print("CALC MOVE", direction, normalized_direction, move_vector) - return move_vector - except Exception as e: - print("Error calculating move vector:", e) - - def spy(self, name: str): - def spyfun(x): - print(f"SPY {name}:", x) - return x - - return ops.map(spyfun) - - def frequency_spy(self, name: str, window_size: int = 10): - """Create a frequency spy that logs message rate over a sliding window. - - Args: - name: Name for the spy output - window_size: Number of messages to average frequency over - """ - timestamps = [] - - def freq_spy_fun(x): - current_time = time.time() - timestamps.append(current_time) - print(x) - # Keep only the last window_size timestamps - if len(timestamps) > window_size: - timestamps.pop(0) - - # Calculate frequency if we have enough samples - if len(timestamps) >= 2: - time_span = timestamps[-1] - timestamps[0] - if time_span > 0: - frequency = (len(timestamps) - 1) / time_span - print(f"FREQ SPY {name}: {frequency:.2f} Hz ({len(timestamps)} samples)") - else: - print(f"FREQ SPY {name}: calculating... ({len(timestamps)} samples)") - else: - print(f"FREQ SPY {name}: warming up... ({len(timestamps)} samples)") - - return x - - return ops.map(freq_spy_fun) - - def _test_translational_movement(self) -> Vector: - """Test translational movement by alternating left and right movement. - - Returns: - Vector with (x=0, y=left/right, z=0) for testing left-right movement - """ - # Use time to alternate between left and right movement every 3 seconds - current_time = time.time() - cycle_time = 6.0 # 6 second cycle (3 seconds each direction) - phase = (current_time % cycle_time) / cycle_time - - if phase < 0.5: - # First half: move LEFT (positive X according to our documentation) - movement = Vector3(0.2, 0, 0) # Move left at 0.2 m/s - direction = "LEFT (positive X)" - else: - # Second half: move RIGHT (negative X according to our documentation) - movement = Vector3(-0.2, 0, 0) # Move right at 0.2 m/s - direction = "RIGHT (negative X)" - - print("=== LEFT-RIGHT MOVEMENT TEST ===") - print(f"Phase: {phase:.2f}, Direction: {direction}") - print(f"Sending movement command: {movement}") - print(f"Expected: Robot should move {direction.split()[0]} relative to its body") - print("===================================") - return movement - - def _calculate_rotation_to_target(self, direction_to_goal: Vector) -> Vector: - """Calculate the rotation needed for the robot to face the target. - - Args: - direction_to_goal: Vector pointing from robot position to goal in global coordinates - - Returns: - Vector with (x=0, y=0, z=angular_velocity) for rotation only - """ - # Calculate the desired yaw angle to face the target - desired_yaw = math.atan2(direction_to_goal.y, direction_to_goal.x) - - # Get current robot yaw - current_yaw = self.latest_odom.orientation.z - - # Calculate the yaw error using a more robust method to avoid oscillation - yaw_error = math.atan2( - math.sin(desired_yaw - current_yaw), math.cos(desired_yaw - current_yaw) - ) - - print( - f"DEBUG: direction_to_goal={direction_to_goal}, desired_yaw={math.degrees(desired_yaw):.1f}°, current_yaw={math.degrees(current_yaw):.1f}°" - ) - print( - f"DEBUG: yaw_error={math.degrees(yaw_error):.1f}°, abs_error={abs(yaw_error):.3f}, tolerance=0.1" - ) - - # Calculate angular velocity (proportional control) - max_angular_speed = 0.15 # rad/s - raw_angular_velocity = yaw_error * 2.0 - angular_velocity = max(-max_angular_speed, min(max_angular_speed, raw_angular_velocity)) - - print( - f"DEBUG: raw_ang_vel={raw_angular_velocity:.3f}, clamped_ang_vel={angular_velocity:.3f}" - ) - - # Stop rotating if we're close enough to the target angle - if abs(yaw_error) < 0.1: # ~5.7 degrees tolerance - print("DEBUG: Within tolerance - stopping rotation") - angular_velocity = 0.0 - else: - print("DEBUG: Outside tolerance - continuing rotation") - - print( - f"Rotation control: current_yaw={math.degrees(current_yaw):.1f}°, desired_yaw={math.degrees(desired_yaw):.1f}°, error={math.degrees(yaw_error):.1f}°, ang_vel={angular_velocity:.3f}" - ) - - # Return movement command: no translation (x=0, y=0), only rotation (z=angular_velocity) - # Try flipping the sign in case the rotation convention is opposite - return Vector(0, 0, -angular_velocity) - - def _debug_direction(self, name: str, direction: Vector) -> Vector: - """Debug helper to log direction information""" - robot_pos = self.latest_odom - print( - f"DEBUG {name}: direction={direction}, robot_pos={robot_pos.position.to_2d()}, robot_yaw={math.degrees(robot_pos.rot.z):.1f}°, goal={self.goal}" - ) - return direction - - def _debug_robot_command(self, robot_cmd: Vector) -> Vector: - """Debug helper to log robot command information""" - print( - f"DEBUG robot_command: x={robot_cmd.x:.3f}, y={robot_cmd.y:.3f} (forward/backward, left/right)" - ) - return robot_cmd diff --git a/build/lib/dimos/robot/local_planner/vfh_local_planner.py b/build/lib/dimos/robot/local_planner/vfh_local_planner.py deleted file mode 100644 index f97701e5a5..0000000000 --- a/build/lib/dimos/robot/local_planner/vfh_local_planner.py +++ /dev/null @@ -1,435 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from typing import Dict, Tuple, Optional, Callable, Any -import cv2 -import logging - -from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import normalize_angle - -from dimos.robot.local_planner.local_planner import BaseLocalPlanner, visualize_local_planner_state -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector, VectorLike - -logger = setup_logger("dimos.robot.unitree.vfh_local_planner", level=logging.DEBUG) - - -class VFHPurePursuitPlanner(BaseLocalPlanner): - """ - A local planner that combines Vector Field Histogram (VFH) for obstacle avoidance - with Pure Pursuit for goal tracking. - """ - - def __init__( - self, - get_costmap: Callable[[], Optional[Costmap]], - get_robot_pose: Callable[[], Any], - move: Callable[[Vector], None], - safety_threshold: float = 0.8, - histogram_bins: int = 144, - max_linear_vel: float = 0.8, - max_angular_vel: float = 1.0, - lookahead_distance: float = 1.0, - goal_tolerance: float = 0.4, - angle_tolerance: float = 0.1, # ~5.7 degrees - robot_width: float = 0.5, - robot_length: float = 0.7, - visualization_size: int = 400, - control_frequency: float = 10.0, - safe_goal_distance: float = 1.0, - max_recovery_attempts: int = 3, - global_planner_plan: Optional[Callable[[VectorLike], Optional[Any]]] = None, - ): - """ - Initialize the VFH + Pure Pursuit planner. - - Args: - get_costmap: Function to get the latest local costmap - get_robot_pose: Function to get the latest robot pose (returning odom object) - move: Function to send velocity commands - safety_threshold: Distance to maintain from obstacles (meters) - histogram_bins: Number of directional bins in the polar histogram - max_linear_vel: Maximum linear velocity (m/s) - max_angular_vel: Maximum angular velocity (rad/s) - lookahead_distance: Lookahead distance for pure pursuit (meters) - goal_tolerance: Distance at which the goal is considered reached (meters) - angle_tolerance: Angle at which the goal orientation is considered reached (radians) - robot_width: Width of the robot for visualization (meters) - robot_length: Length of the robot for visualization (meters) - visualization_size: Size of the visualization image in pixels - control_frequency: Frequency at which the planner is called (Hz) - safe_goal_distance: Distance at which to adjust the goal and ignore obstacles (meters) - max_recovery_attempts: Maximum number of recovery attempts - global_planner_plan: Optional function to get the global plan - """ - # Initialize base class - super().__init__( - get_costmap=get_costmap, - get_robot_pose=get_robot_pose, - move=move, - safety_threshold=safety_threshold, - max_linear_vel=max_linear_vel, - max_angular_vel=max_angular_vel, - lookahead_distance=lookahead_distance, - goal_tolerance=goal_tolerance, - angle_tolerance=angle_tolerance, - robot_width=robot_width, - robot_length=robot_length, - visualization_size=visualization_size, - control_frequency=control_frequency, - safe_goal_distance=safe_goal_distance, - max_recovery_attempts=max_recovery_attempts, - global_planner_plan=global_planner_plan, - ) - - # VFH specific parameters - self.histogram_bins = histogram_bins - self.histogram = None - self.selected_direction = None - - # VFH tuning parameters - self.alpha = 0.25 # Histogram smoothing factor - self.obstacle_weight = 5.0 - self.goal_weight = 2.0 - self.prev_direction_weight = 1.0 - self.prev_selected_angle = 0.0 - self.prev_linear_vel = 0.0 - self.linear_vel_filter_factor = 0.4 - self.low_speed_nudge = 0.1 - - # Add after other initialization - self.angle_mapping = np.linspace(-np.pi, np.pi, self.histogram_bins, endpoint=False) - self.smoothing_kernel = np.array([self.alpha, (1 - 2 * self.alpha), self.alpha]) - - def _compute_velocity_commands(self) -> Dict[str, float]: - """ - VFH + Pure Pursuit specific implementation of velocity command computation. - - Returns: - Dict[str, float]: Velocity commands with 'x_vel' and 'angular_vel' keys - """ - # Get necessary data for planning - costmap = self._get_costmap() - if costmap is None: - logger.warning("No costmap available for planning") - return {"x_vel": 0.0, "angular_vel": 0.0} - - robot_pos, robot_theta = self._get_robot_pose() - robot_x, robot_y = robot_pos - robot_pose = (robot_x, robot_y, robot_theta) - - # Calculate goal-related parameters - goal_x, goal_y = self.goal_xy - dx = goal_x - robot_x - dy = goal_y - robot_y - goal_distance = np.linalg.norm([dx, dy]) - goal_direction = np.arctan2(dy, dx) - robot_theta - goal_direction = normalize_angle(goal_direction) - - self.histogram = self.build_polar_histogram(costmap, robot_pose) - - # If we're ignoring obstacles near the goal, zero out the histogram - if self.ignore_obstacles: - self.histogram = np.zeros_like(self.histogram) - - self.selected_direction = self.select_direction( - self.goal_weight, - self.obstacle_weight, - self.prev_direction_weight, - self.histogram, - goal_direction, - ) - - # Calculate Pure Pursuit Velocities - linear_vel, angular_vel = self.compute_pure_pursuit(goal_distance, self.selected_direction) - - # Slow down when turning sharply - if abs(self.selected_direction) > 0.25: # ~15 degrees - # Scale from 1.0 (small turn) to 0.5 (sharp turn at 90 degrees or more) - turn_factor = max(0.25, 1.0 - (abs(self.selected_direction) / (np.pi / 2))) - linear_vel *= turn_factor - - # Apply Collision Avoidance Stop - skip if ignoring obstacles - if not self.ignore_obstacles and self.check_collision( - self.selected_direction, safety_threshold=0.5 - ): - # Re-select direction prioritizing obstacle avoidance if colliding - self.selected_direction = self.select_direction( - self.goal_weight * 0.2, - self.obstacle_weight, - self.prev_direction_weight * 0.2, - self.histogram, - goal_direction, - ) - linear_vel, angular_vel = self.compute_pure_pursuit( - goal_distance, self.selected_direction - ) - - if self.check_collision(0.0, safety_threshold=self.safety_threshold): - linear_vel = 0.0 - - self.prev_linear_vel = linear_vel - filtered_linear_vel = self.prev_linear_vel * self.linear_vel_filter_factor + linear_vel * ( - 1 - self.linear_vel_filter_factor - ) - - return {"x_vel": filtered_linear_vel, "angular_vel": angular_vel} - - def _smooth_histogram(self, histogram: np.ndarray) -> np.ndarray: - """ - Apply advanced smoothing to the polar histogram to better identify valleys - and reduce noise. - - Args: - histogram: Raw histogram to smooth - - Returns: - np.ndarray: Smoothed histogram - """ - # Apply a windowed average with variable width based on obstacle density - smoothed = np.zeros_like(histogram) - bins = len(histogram) - - # First pass: basic smoothing with a 5-point kernel - # This uses a wider window than the original 3-point smoother - for i in range(bins): - # Compute indices with wrap-around - indices = [(i + j) % bins for j in range(-2, 3)] - # Apply weighted average (more weight to the center) - weights = [0.1, 0.2, 0.4, 0.2, 0.1] # Sum = 1.0 - smoothed[i] = sum(histogram[idx] * weight for idx, weight in zip(indices, weights)) - - # Second pass: peak and valley enhancement - enhanced = np.zeros_like(smoothed) - for i in range(bins): - # Check neighboring values - prev_idx = (i - 1) % bins - next_idx = (i + 1) % bins - - # Enhance valleys (low values) - if smoothed[i] < smoothed[prev_idx] and smoothed[i] < smoothed[next_idx]: - # It's a local minimum - make it even lower - enhanced[i] = smoothed[i] * 0.8 - # Enhance peaks (high values) - elif smoothed[i] > smoothed[prev_idx] and smoothed[i] > smoothed[next_idx]: - # It's a local maximum - make it even higher - enhanced[i] = min(1.0, smoothed[i] * 1.2) - else: - enhanced[i] = smoothed[i] - - return enhanced - - def build_polar_histogram(self, costmap: Costmap, robot_pose: Tuple[float, float, float]): - """ - Build a polar histogram of obstacle densities around the robot. - - Args: - costmap: Costmap object with grid and metadata - robot_pose: Tuple (x, y, theta) of the robot pose in the odom frame - - Returns: - np.ndarray: Polar histogram of obstacle densities - """ - - # Get grid and find all obstacle cells - occupancy_grid = costmap.grid - y_indices, x_indices = np.where(occupancy_grid > 0) - if len(y_indices) == 0: # No obstacles - return np.zeros(self.histogram_bins) - - # Get robot position in grid coordinates - robot_x, robot_y, robot_theta = robot_pose - robot_point = costmap.world_to_grid((robot_x, robot_y)) - robot_cell_x, robot_cell_y = robot_point.x, robot_point.y - - # Vectorized distance and angle calculation - dx_cells = x_indices - robot_cell_x - dy_cells = y_indices - robot_cell_y - distances = np.sqrt(dx_cells**2 + dy_cells**2) * costmap.resolution - angles_grid = np.arctan2(dy_cells, dx_cells) - angles_robot = normalize_angle(angles_grid - robot_theta) - - # Convert to bin indices - bin_indices = ((angles_robot + np.pi) / (2 * np.pi) * self.histogram_bins).astype( - int - ) % self.histogram_bins - - # Get obstacle values - obstacle_values = occupancy_grid[y_indices, x_indices] / 100.0 - - # Build histogram - histogram = np.zeros(self.histogram_bins) - mask = distances > 0 - # Weight obstacles by inverse square of distance and cell value - np.add.at(histogram, bin_indices[mask], obstacle_values[mask] / (distances[mask] ** 2)) - - # Apply the enhanced smoothing - return self._smooth_histogram(histogram) - - def select_direction( - self, goal_weight, obstacle_weight, prev_direction_weight, histogram, goal_direction - ): - """ - Select best direction based on a simple weighted cost function. - - Args: - goal_weight: Weight for the goal direction component - obstacle_weight: Weight for the obstacle avoidance component - prev_direction_weight: Weight for previous direction consistency - histogram: Polar histogram of obstacle density - goal_direction: Desired direction to goal - - Returns: - float: Selected direction in radians - """ - # Normalize histogram if needed - if np.max(histogram) > 0: - histogram = histogram / np.max(histogram) - - # Calculate costs for each possible direction - angle_diffs = np.abs(normalize_angle(self.angle_mapping - goal_direction)) - prev_diffs = np.abs(normalize_angle(self.angle_mapping - self.prev_selected_angle)) - - # Combine costs with weights - obstacle_costs = obstacle_weight * histogram - goal_costs = goal_weight * angle_diffs - prev_costs = prev_direction_weight * prev_diffs - - total_costs = obstacle_costs + goal_costs + prev_costs - - # Select direction with lowest cost - min_cost_idx = np.argmin(total_costs) - selected_angle = self.angle_mapping[min_cost_idx] - - # Update history for next iteration - self.prev_selected_angle = selected_angle - - return selected_angle - - def compute_pure_pursuit( - self, goal_distance: float, goal_direction: float - ) -> Tuple[float, float]: - """Compute pure pursuit velocities.""" - if goal_distance < self.goal_tolerance: - return 0.0, 0.0 - - lookahead = min(self.lookahead_distance, goal_distance) - linear_vel = min(self.max_linear_vel, goal_distance) - angular_vel = 2.0 * np.sin(goal_direction) / lookahead - angular_vel = max(-self.max_angular_vel, min(angular_vel, self.max_angular_vel)) - - return linear_vel, angular_vel - - def check_collision(self, selected_direction: float, safety_threshold: float = 1.0) -> bool: - """Check if there's an obstacle in the selected direction within safety threshold.""" - # Skip collision check if ignoring obstacles - if self.ignore_obstacles: - return False - - # Get the latest costmap and robot pose - costmap = self._get_costmap() - if costmap is None: - return False # No costmap available - - robot_pos, robot_theta = self._get_robot_pose() - robot_x, robot_y = robot_pos - - # Direction in world frame - direction_world = robot_theta + selected_direction - - # Safety distance in cells - safety_cells = int(safety_threshold / costmap.resolution) - - # Get robot position in grid coordinates - robot_point = costmap.world_to_grid((robot_x, robot_y)) - robot_cell_x, robot_cell_y = robot_point.x, robot_point.y - - # Check for obstacles along the selected direction - for dist in range(1, safety_cells + 1): - # Calculate cell position - cell_x = robot_cell_x + int(dist * np.cos(direction_world)) - cell_y = robot_cell_y + int(dist * np.sin(direction_world)) - - # Check if cell is within grid bounds - if not (0 <= cell_x < costmap.width and 0 <= cell_y < costmap.height): - continue - - # Check if cell contains an obstacle (threshold at 50) - if costmap.grid[int(cell_y), int(cell_x)] > 50: - return True - - return False # No collision detected - - def update_visualization(self) -> np.ndarray: - """Generate visualization of the planning state.""" - try: - costmap = self._get_costmap() - if costmap is None: - raise ValueError("Costmap is None") - - robot_pos, robot_theta = self._get_robot_pose() - robot_x, robot_y = robot_pos - robot_pose = (robot_x, robot_y, robot_theta) - - goal_xy = self.goal_xy # This could be a lookahead point or final goal - - # Get the latest histogram and selected direction, if available - histogram = getattr(self, "histogram", None) - selected_direction = getattr(self, "selected_direction", None) - - # Get waypoint data if in waypoint mode - waypoints_to_draw = self.waypoints_in_absolute - current_wp_index_to_draw = ( - self.current_waypoint_index if self.waypoints_in_absolute is not None else None - ) - # Ensure index is valid before passing - if waypoints_to_draw is not None and current_wp_index_to_draw is not None: - if not (0 <= current_wp_index_to_draw < len(waypoints_to_draw)): - current_wp_index_to_draw = None # Invalidate index if out of bounds - - return visualize_local_planner_state( - occupancy_grid=costmap.grid, - grid_resolution=costmap.resolution, - grid_origin=(costmap.origin.x, costmap.origin.y), - robot_pose=robot_pose, - goal_xy=goal_xy, # Current target (lookahead or final) - goal_theta=self.goal_theta, # Pass goal orientation if available - visualization_size=self.visualization_size, - robot_width=self.robot_width, - robot_length=self.robot_length, - histogram=histogram, - selected_direction=selected_direction, - waypoints=waypoints_to_draw, # Pass the full path - current_waypoint_index=current_wp_index_to_draw, # Pass the target index - ) - except Exception as e: - logger.error(f"Error during visualization update: {e}") - # Return a blank image with error text - blank = ( - np.ones((self.visualization_size, self.visualization_size, 3), dtype=np.uint8) * 255 - ) - cv2.putText( - blank, - "Viz Error", - (self.visualization_size // 4, self.visualization_size // 2), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 0, 0), - 2, - ) - return blank diff --git a/build/lib/dimos/robot/position_stream.py b/build/lib/dimos/robot/position_stream.py deleted file mode 100644 index 05d80b8bcf..0000000000 --- a/build/lib/dimos/robot/position_stream.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Position stream provider for ROS-based robots. - -This module creates a reactive stream of position updates from ROS odometry or pose topics. -""" - -import logging -from typing import Tuple, Optional -import time -from reactivex import Subject, Observable -from reactivex import operators as ops -from rclpy.node import Node -from geometry_msgs.msg import PoseStamped -from nav_msgs.msg import Odometry - -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.position_stream", level=logging.INFO) - - -class PositionStreamProvider: - """ - A provider for streaming position updates from ROS. - - This class creates an Observable stream of position updates by subscribing - to ROS odometry or pose topics. - """ - - def __init__( - self, - ros_node: Node, - odometry_topic: str = "/odom", - pose_topic: Optional[str] = None, - use_odometry: bool = True, - ): - """ - Initialize the position stream provider. - - Args: - ros_node: ROS node to use for subscriptions - odometry_topic: Name of the odometry topic (if use_odometry is True) - pose_topic: Name of the pose topic (if use_odometry is False) - use_odometry: Whether to use odometry (True) or pose (False) for position - """ - self.ros_node = ros_node - self.odometry_topic = odometry_topic - self.pose_topic = pose_topic - self.use_odometry = use_odometry - - self._subject = Subject() - - self.last_position = None - self.last_update_time = None - - self._create_subscription() - - logger.info( - f"PositionStreamProvider initialized with " - f"{'odometry topic' if use_odometry else 'pose topic'}: " - f"{odometry_topic if use_odometry else pose_topic}" - ) - - def _create_subscription(self): - """Create the appropriate ROS subscription based on configuration.""" - if self.use_odometry: - self.subscription = self.ros_node.create_subscription( - Odometry, self.odometry_topic, self._odometry_callback, 10 - ) - logger.info(f"Subscribed to odometry topic: {self.odometry_topic}") - else: - if not self.pose_topic: - raise ValueError("Pose topic must be specified when use_odometry is False") - - self.subscription = self.ros_node.create_subscription( - PoseStamped, self.pose_topic, self._pose_callback, 10 - ) - logger.info(f"Subscribed to pose topic: {self.pose_topic}") - - def _odometry_callback(self, msg: Odometry): - """ - Process odometry messages and extract position. - - Args: - msg: Odometry message from ROS - """ - x = msg.pose.pose.position.x - y = msg.pose.pose.position.y - - self._update_position(x, y) - - def _pose_callback(self, msg: PoseStamped): - """ - Process pose messages and extract position. - - Args: - msg: PoseStamped message from ROS - """ - x = msg.pose.position.x - y = msg.pose.position.y - - self._update_position(x, y) - - def _update_position(self, x: float, y: float): - """ - Update the current position and emit to subscribers. - - Args: - x: X coordinate - y: Y coordinate - """ - current_time = time.time() - position = (x, y) - - if self.last_update_time: - update_rate = 1.0 / (current_time - self.last_update_time) - logger.debug(f"Position update rate: {update_rate:.1f} Hz") - - self.last_position = position - self.last_update_time = current_time - - self._subject.on_next(position) - logger.debug(f"Position updated: ({x:.2f}, {y:.2f})") - - def get_position_stream(self) -> Observable: - """ - Get an Observable stream of position updates. - - Returns: - Observable that emits (x, y) tuples - """ - return self._subject.pipe( - ops.share() # Share the stream among multiple subscribers - ) - - def get_current_position(self) -> Optional[Tuple[float, float]]: - """ - Get the most recent position. - - Returns: - Tuple of (x, y) coordinates, or None if no position has been received - """ - return self.last_position - - def cleanup(self): - """Clean up resources.""" - if hasattr(self, "subscription") and self.subscription: - self.ros_node.destroy_subscription(self.subscription) - logger.info("Position subscription destroyed") diff --git a/build/lib/dimos/robot/recorder.py b/build/lib/dimos/robot/recorder.py deleted file mode 100644 index 56b6cea888..0000000000 --- a/build/lib/dimos/robot/recorder.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# UNDER DEVELOPMENT 🚧🚧🚧, NEEDS TESTING - -import threading -import time -from queue import Queue -from typing import Callable, Literal - -# from dimos.data.recording import Recorder - - -class RobotRecorder: - """A class for recording robot observation and actions. - - Recording at a specified frequency on the observation and action of a robot. It leverages a queue and a worker - thread to handle the recording asynchronously, ensuring that the main operations of the - robot are not blocked. - - Robot class must pass in the `get_state`, `get_observation`, `prepare_action` methods.` - get_state() gets the current state/pose of the robot. - get_observation() captures the observation/image of the robot. - prepare_action() calculates the action between the new and old states. - """ - - def __init__( - self, - get_state: Callable, - get_observation: Callable, - prepare_action: Callable, - frequency_hz: int = 5, - recorder_kwargs: dict = None, - on_static: Literal["record", "omit"] = "omit", - ) -> None: - """Initializes the RobotRecorder. - - This constructor sets up the recording mechanism on the given robot, including the recorder instance, - recording frequency, and the asynchronous processing queue and worker thread. It also - initializes attributes to track the last recorded pose and the current instruction. - - Args: - get_state: A function that returns the current state of the robot. - get_observation: A function that captures the observation/image of the robot. - prepare_action: A function that calculates the action between the new and old states. - frequency_hz: Frequency at which to record pose and image data (in Hz). - recorder_kwargs: Keyword arguments to pass to the Recorder constructor. - on_static: Whether to record on static poses or not. If "record", it will record when the robot is not moving. - """ - if recorder_kwargs is None: - recorder_kwargs = {} - self.recorder = Recorder(**recorder_kwargs) - self.task = None - - self.last_recorded_state = None - self.last_image = None - - self.recording = False - self.frequency_hz = frequency_hz - self.record_on_static = on_static == "record" - self.recording_queue = Queue() - - self.get_state = get_state - self.get_observation = get_observation - self.prepare_action = prepare_action - - self._worker_thread = threading.Thread(target=self._process_queue, daemon=True) - self._worker_thread.start() - - def __enter__(self): - """Enter the context manager, starting the recording.""" - self.start_recording(self.task) - - def __exit__(self, exc_type, exc_value, traceback) -> None: - """Exit the context manager, stopping the recording.""" - self.stop_recording() - - def record(self, task: str) -> "RobotRecorder": - """Set the task and return the context manager.""" - self.task = task - return self - - def reset_recorder(self) -> None: - """Reset the recorder.""" - while self.recording: - time.sleep(0.1) - self.recorder.reset() - - def record_from_robot(self) -> None: - """Records the current pose and captures an image at the specified frequency.""" - while self.recording: - start_time = time.perf_counter() - self.record_current_state() - elapsed_time = time.perf_counter() - start_time - # Sleep for the remaining time to maintain the desired frequency - sleep_time = max(0, (1.0 / self.frequency_hz) - elapsed_time) - time.sleep(sleep_time) - - def start_recording(self, task: str = "") -> None: - """Starts the recording of pose and image.""" - if not self.recording: - self.task = task - self.recording = True - self.recording_thread = threading.Thread(target=self.record_from_robot) - self.recording_thread.start() - - def stop_recording(self) -> None: - """Stops the recording of pose and image.""" - if self.recording: - self.recording = False - self.recording_thread.join() - - def _process_queue(self) -> None: - """Processes the recording queue asynchronously.""" - while True: - image, instruction, action, state = self.recording_queue.get() - self.recorder.record( - observation={"image": image, "instruction": instruction}, action=action, state=state - ) - self.recording_queue.task_done() - - def record_current_state(self) -> None: - """Records the current pose and image if the pose has changed.""" - state = self.get_state() - image = self.get_observation() - - # This is the beginning of the episode - if self.last_recorded_state is None: - self.last_recorded_state = state - self.last_image = image - return - - if state != self.last_recorded_state or self.record_on_static: - action = self.prepare_action(self.last_recorded_state, state) - self.recording_queue.put( - ( - self.last_image, - self.task, - action, - self.last_recorded_state, - ), - ) - self.last_image = image - self.last_recorded_state = state - - def record_last_state(self) -> None: - """Records the final pose and image after the movement completes.""" - self.record_current_state() diff --git a/build/lib/dimos/robot/robot.py b/build/lib/dimos/robot/robot.py deleted file mode 100644 index 58526b5f0c..0000000000 --- a/build/lib/dimos/robot/robot.py +++ /dev/null @@ -1,435 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base module for all DIMOS robots. - -This module provides the foundation for all DIMOS robots, including both physical -and simulated implementations, with common functionality for movement, control, -and video streaming. -""" - -from abc import ABC, abstractmethod -import os -from typing import Optional, List, Union, Dict, Any - -from dimos.hardware.interface import HardwareInterface -from dimos.perception.spatial_perception import SpatialMemory -from dimos.manipulation.manipulation_interface import ManipulationInterface -from dimos.types.robot_capabilities import RobotCapability -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger -from dimos.robot.connection_interface import ConnectionInterface - -from dimos.skills.skills import SkillLibrary -from reactivex import Observable, operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler - -from dimos.utils.threadpool import get_scheduler -from dimos.utils.reactive import backpressure -from dimos.stream.video_provider import VideoProvider - -logger = setup_logger("dimos.robot.robot") - - -class Robot(ABC): - """Base class for all DIMOS robots. - - This abstract base class defines the common interface and functionality for all - DIMOS robots, whether physical or simulated. It provides methods for movement, - rotation, video streaming, and hardware configuration management. - - Attributes: - agent_config: Configuration for the robot's agent. - hardware_interface: Interface to the robot's hardware components. - ros_control: ROS-based control system for the robot. - output_dir: Directory for storing output files. - disposables: Collection of disposable resources for cleanup. - pool_scheduler: Thread pool scheduler for managing concurrent operations. - """ - - def __init__( - self, - hardware_interface: HardwareInterface = None, - connection_interface: ConnectionInterface = None, - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - pool_scheduler: ThreadPoolScheduler = None, - skill_library: SkillLibrary = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = False, - capabilities: List[RobotCapability] = None, - video_stream: Optional[Observable] = None, - enable_perception: bool = True, - ): - """Initialize a Robot instance. - - Args: - hardware_interface: Interface to the robot's hardware. Defaults to None. - connection_interface: Connection interface for robot control and communication. - output_dir: Directory for storing output files. Defaults to "./assets/output". - pool_scheduler: Thread pool scheduler. If None, one will be created. - skill_library: Skill library instance. If None, one will be created. - spatial_memory_collection: Name of the collection in the ChromaDB database. - new_memory: If True, creates a new spatial memory from scratch. Defaults to False. - capabilities: List of robot capabilities. Defaults to None. - video_stream: Optional video stream. Defaults to None. - enable_perception: If True, enables perception streams and spatial memory. Defaults to True. - """ - self.hardware_interface = hardware_interface - self.connection_interface = connection_interface - self.output_dir = output_dir - self.disposables = CompositeDisposable() - self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() - self.skill_library = skill_library if skill_library else SkillLibrary() - self.enable_perception = enable_perception - - # Initialize robot capabilities - self.capabilities = capabilities or [] - - # Create output directory if it doesn't exist - os.makedirs(self.output_dir, exist_ok=True) - logger.info(f"Robot outputs will be saved to: {self.output_dir}") - - # Initialize memory properties - self.memory_dir = os.path.join(self.output_dir, "memory") - os.makedirs(self.memory_dir, exist_ok=True) - - # Initialize spatial memory properties - self.spatial_memory_dir = os.path.join(self.memory_dir, "spatial_memory") - self.spatial_memory_collection = spatial_memory_collection - self.db_path = os.path.join(self.spatial_memory_dir, "chromadb_data") - self.visual_memory_path = os.path.join(self.spatial_memory_dir, "visual_memory.pkl") - - # Create spatial memory directory - os.makedirs(self.spatial_memory_dir, exist_ok=True) - os.makedirs(self.db_path, exist_ok=True) - - # Initialize spatial memory properties - self._video_stream = video_stream - - # Only create video stream if connection interface is available - if self.connection_interface is not None: - # Get video stream - always create this, regardless of enable_perception - self._video_stream = self.get_video_stream(fps=10) # Lower FPS for processing - - # Create SpatialMemory instance only if perception is enabled - if self.enable_perception: - self._spatial_memory = SpatialMemory( - collection_name=self.spatial_memory_collection, - db_path=self.db_path, - visual_memory_path=self.visual_memory_path, - new_memory=new_memory, - output_dir=self.spatial_memory_dir, - video_stream=self._video_stream, - get_pose=self.get_pose, - ) - logger.info("Spatial memory initialized") - else: - self._spatial_memory = None - logger.info("Spatial memory disabled (enable_perception=False)") - - # Initialize manipulation interface if the robot has manipulation capability - self._manipulation_interface = None - if RobotCapability.MANIPULATION in self.capabilities: - # Initialize manipulation memory properties if the robot has manipulation capability - self.manipulation_memory_dir = os.path.join(self.memory_dir, "manipulation_memory") - - # Create manipulation memory directory - os.makedirs(self.manipulation_memory_dir, exist_ok=True) - - self._manipulation_interface = ManipulationInterface( - output_dir=self.output_dir, # Use the main output directory - new_memory=new_memory, - ) - logger.info("Manipulation interface initialized") - - def get_video_stream(self, fps: int = 30) -> Observable: - """Get the video stream with rate limiting and frame processing. - - Args: - fps: Frames per second for the video stream. Defaults to 30. - - Returns: - Observable: An observable stream of video frames. - - Raises: - RuntimeError: If no connection interface is available for video streaming. - """ - if self.connection_interface is None: - raise RuntimeError("No connection interface available for video streaming") - - stream = self.connection_interface.get_video_stream(fps) - if stream is None: - raise RuntimeError("No video stream available from connection interface") - - return stream.pipe( - ops.observe_on(self.pool_scheduler), - ) - - def move(self, velocity: Vector, duration: float = 0.0) -> bool: - """Move the robot using velocity commands. - - Args: - velocity: Velocity vector [x, y, yaw] where: - x: Linear velocity in x direction (m/s) - y: Linear velocity in y direction (m/s) - yaw: Angular velocity (rad/s) - duration: Duration to apply command (seconds). If 0, apply once. - - Returns: - bool: True if movement succeeded. - - Raises: - RuntimeError: If no connection interface is available. - """ - if self.connection_interface is None: - raise RuntimeError("No connection interface available for movement") - - return self.connection_interface.move(velocity, duration) - - def spin(self, degrees: float, speed: float = 45.0) -> bool: - """Rotate the robot by a specified angle. - - Args: - degrees: Angle to rotate in degrees (positive for counter-clockwise, - negative for clockwise). - speed: Angular speed in degrees/second. Defaults to 45.0. - - Returns: - bool: True if rotation succeeded. - - Raises: - RuntimeError: If no connection interface is available. - """ - if self.connection_interface is None: - raise RuntimeError("No connection interface available for rotation") - - # Convert degrees to radians - import math - - angular_velocity = math.radians(speed) - duration = abs(degrees) / speed if speed > 0 else 0 - - # Set direction based on sign of degrees - if degrees < 0: - angular_velocity = -angular_velocity - - velocity = Vector(0.0, 0.0, angular_velocity) - return self.connection_interface.move(velocity, duration) - - @abstractmethod - def get_pose(self) -> dict: - """ - Get the current pose (position and rotation) of the robot. - - Returns: - Dictionary containing: - - position: Tuple[float, float, float] (x, y, z) - - rotation: Tuple[float, float, float] (roll, pitch, yaw) in radians - """ - pass - - def webrtc_req( - self, - api_id: int, - topic: str = None, - parameter: str = "", - priority: int = 0, - request_id: str = None, - data=None, - timeout: float = 1000.0, - ): - """Send a WebRTC request command to the robot. - - Args: - api_id: The API ID for the command. - topic: The API topic to publish to. Defaults to ROSControl.webrtc_api_topic. - parameter: Additional parameter data. Defaults to "". - priority: Priority of the request. Defaults to 0. - request_id: Unique identifier for the request. If None, one will be generated. - data: Additional data to include with the request. Defaults to None. - timeout: Timeout for the request in milliseconds. Defaults to 1000.0. - - Returns: - The result of the WebRTC request. - - Raises: - RuntimeError: If no connection interface with WebRTC capability is available. - """ - if self.connection_interface is None: - raise RuntimeError("No connection interface available for WebRTC commands") - - # WebRTC requests are only available on ROS control interfaces - if hasattr(self.connection_interface, "queue_webrtc_req"): - return self.connection_interface.queue_webrtc_req( - api_id=api_id, - topic=topic, - parameter=parameter, - priority=priority, - request_id=request_id, - data=data, - timeout=timeout, - ) - else: - raise RuntimeError("WebRTC requests not supported by this connection interface") - - def pose_command(self, roll: float, pitch: float, yaw: float) -> bool: - """Send a pose command to the robot. - - Args: - roll: Roll angle in radians. - pitch: Pitch angle in radians. - yaw: Yaw angle in radians. - - Returns: - bool: True if command was sent successfully. - - Raises: - RuntimeError: If no connection interface with pose command capability is available. - """ - # Pose commands are only available on ROS control interfaces - if hasattr(self.connection_interface, "pose_command"): - return self.connection_interface.pose_command(roll, pitch, yaw) - else: - raise RuntimeError("Pose commands not supported by this connection interface") - - def update_hardware_interface(self, new_hardware_interface: HardwareInterface): - """Update the hardware interface with a new configuration. - - Args: - new_hardware_interface: New hardware interface to use for the robot. - """ - self.hardware_interface = new_hardware_interface - - def get_hardware_configuration(self): - """Retrieve the current hardware configuration. - - Returns: - The current hardware configuration from the hardware interface. - - Raises: - AttributeError: If hardware_interface is None. - """ - return self.hardware_interface.get_configuration() - - def set_hardware_configuration(self, configuration): - """Set a new hardware configuration. - - Args: - configuration: The new hardware configuration to set. - - Raises: - AttributeError: If hardware_interface is None. - """ - self.hardware_interface.set_configuration(configuration) - - @property - def spatial_memory(self) -> Optional[SpatialMemory]: - """Get the robot's spatial memory. - - Returns: - SpatialMemory: The robot's spatial memory system, or None if perception is disabled. - """ - return self._spatial_memory - - @property - def manipulation_interface(self) -> Optional[ManipulationInterface]: - """Get the robot's manipulation interface. - - Returns: - ManipulationInterface: The robot's manipulation interface or None if not available. - """ - return self._manipulation_interface - - def has_capability(self, capability: RobotCapability) -> bool: - """Check if the robot has a specific capability. - - Args: - capability: The capability to check for - - Returns: - bool: True if the robot has the capability, False otherwise - """ - return capability in self.capabilities - - def get_spatial_memory(self) -> Optional[SpatialMemory]: - """Simple getter for the spatial memory instance. - (For backwards compatibility) - - Returns: - The spatial memory instance or None if not set. - """ - return self._spatial_memory if self._spatial_memory else None - - @property - def video_stream(self) -> Optional[Observable]: - """Get the robot's video stream. - - Returns: - Observable: The robot's video stream or None if not available. - """ - return self._video_stream - - def get_skills(self): - """Get the robot's skill library. - - Returns: - The robot's skill library for adding/managing skills. - """ - return self.skill_library - - def cleanup(self): - """Clean up resources used by the robot. - - This method should be called when the robot is no longer needed to - ensure proper release of resources such as ROS connections and - subscriptions. - """ - # Dispose of resources - if self.disposables: - self.disposables.dispose() - - # Clean up connection interface - if self.connection_interface: - self.connection_interface.disconnect() - - self.disposables.dispose() - - -class MockRobot(Robot): - def __init__(self): - super().__init__() - self.ros_control = None - self.hardware_interface = None - self.skill_library = SkillLibrary() - - def my_print(self): - print("Hello, world!") - - -class MockManipulationRobot(Robot): - def __init__(self, skill_library: Optional[SkillLibrary] = None): - video_provider = VideoProvider("webcam", video_source=0) # Default camera - video_stream = backpressure( - video_provider.capture_video_as_observable(realtime=True, fps=30) - ) - - super().__init__( - capabilities=[RobotCapability.MANIPULATION], - video_stream=video_stream, - skill_library=skill_library, - ) - self.camera_intrinsics = [489.33, 367.0, 320.0, 240.0] - self.ros_control = None - self.hardware_interface = None diff --git a/build/lib/dimos/robot/ros_command_queue.py b/build/lib/dimos/robot/ros_command_queue.py deleted file mode 100644 index fc48ce5cde..0000000000 --- a/build/lib/dimos/robot/ros_command_queue.py +++ /dev/null @@ -1,471 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Queue-based command management system for robot commands. - -This module provides a unified approach to queueing and processing all robot commands, -including WebRTC requests and action client commands. -Commands are processed sequentially and only when the robot is in IDLE state. -""" - -import threading -import time -import uuid -from enum import Enum, auto -from queue import PriorityQueue, Empty -from typing import Callable, Optional, NamedTuple, Dict, Any -from dimos.utils.logging_config import setup_logger - -# Initialize logger for the ros command queue module -logger = setup_logger("dimos.robot.ros_command_queue") - - -class CommandType(Enum): - """Types of commands that can be queued""" - - WEBRTC = auto() # WebRTC API requests - ACTION = auto() # Any action client or function call - - -class WebRTCRequest(NamedTuple): - """Class to represent a WebRTC request in the queue""" - - id: str # Unique ID for tracking - api_id: int # API ID for the command - topic: str # Topic to publish to - parameter: str # Optional parameter string - priority: int # Priority level - timeout: float # How long to wait for this request to complete - - -class ROSCommand(NamedTuple): - """Class to represent a command in the queue""" - - id: str # Unique ID for tracking - cmd_type: CommandType # Type of command - execute_func: Callable # Function to execute the command - params: Dict[str, Any] # Parameters for the command (for debugging/logging) - priority: int # Priority level (lower is higher priority) - timeout: float # How long to wait for this command to complete - - -class ROSCommandQueue: - """ - Manages a queue of commands for the robot. - - Commands are executed sequentially, with only one command being processed at a time. - Commands are only executed when the robot is in the IDLE state. - """ - - def __init__( - self, - webrtc_func: Callable, - is_ready_func: Callable[[], bool] = None, - is_busy_func: Optional[Callable[[], bool]] = None, - debug: bool = True, - ): - """ - Initialize the ROSCommandQueue. - - Args: - webrtc_func: Function to send WebRTC requests - is_ready_func: Function to check if the robot is ready for a command - is_busy_func: Function to check if the robot is busy - debug: Whether to enable debug logging - """ - self._webrtc_func = webrtc_func - self._is_ready_func = is_ready_func or (lambda: True) - self._is_busy_func = is_busy_func - self._debug = debug - - # Queue of commands to process - self._queue = PriorityQueue() - self._current_command = None - self._last_command_time = 0 - - # Last known robot state - self._last_ready_state = None - self._last_busy_state = None - self._stuck_in_busy_since = None - - # Command execution status - self._should_stop = False - self._queue_thread = None - - # Stats - self._command_count = 0 - self._success_count = 0 - self._failure_count = 0 - self._command_history = [] - - self._max_queue_wait_time = ( - 30.0 # Maximum time to wait for robot to be ready before forcing - ) - - logger.info("ROSCommandQueue initialized") - - def start(self): - """Start the queue processing thread""" - if self._queue_thread is not None and self._queue_thread.is_alive(): - logger.warning("Queue processing thread already running") - return - - self._should_stop = False - self._queue_thread = threading.Thread(target=self._process_queue, daemon=True) - self._queue_thread.start() - logger.info("Queue processing thread started") - - def stop(self, timeout=2.0): - """ - Stop the queue processing thread - - Args: - timeout: Maximum time to wait for the thread to stop - """ - if self._queue_thread is None or not self._queue_thread.is_alive(): - logger.warning("Queue processing thread not running") - return - - self._should_stop = True - try: - self._queue_thread.join(timeout=timeout) - if self._queue_thread.is_alive(): - logger.warning(f"Queue processing thread did not stop within {timeout}s") - else: - logger.info("Queue processing thread stopped") - except Exception as e: - logger.error(f"Error stopping queue processing thread: {e}") - - def queue_webrtc_request( - self, - api_id: int, - topic: str = None, - parameter: str = "", - request_id: str = None, - data: Dict[str, Any] = None, - priority: int = 0, - timeout: float = 30.0, - ) -> str: - """ - Queue a WebRTC request - - Args: - api_id: API ID for the command - topic: Topic to publish to - parameter: Optional parameter string - request_id: Unique ID for the request (will be generated if not provided) - data: Data to include in the request - priority: Priority level (lower is higher priority) - timeout: Maximum time to wait for the command to complete - - Returns: - str: Unique ID for the request - """ - request_id = request_id or str(uuid.uuid4()) - - # Create a function that will execute this WebRTC request - def execute_webrtc(): - try: - logger.info(f"Executing WebRTC request: {api_id} (ID: {request_id})") - if self._debug: - logger.debug(f"[WebRTC Queue] SENDING request: API ID {api_id}") - - result = self._webrtc_func( - api_id=api_id, - topic=topic, - parameter=parameter, - request_id=request_id, - data=data, - ) - if not result: - logger.warning(f"WebRTC request failed: {api_id} (ID: {request_id})") - if self._debug: - logger.debug(f"[WebRTC Queue] Request API ID {api_id} FAILED to send") - return False - - if self._debug: - logger.debug(f"[WebRTC Queue] Request API ID {api_id} sent SUCCESSFULLY") - - # Allow time for the robot to process the command - start_time = time.time() - stabilization_delay = 0.5 # Half-second delay for stabilization - time.sleep(stabilization_delay) - - # Wait for the robot to complete the command (timeout check) - while self._is_busy_func() and (time.time() - start_time) < timeout: - if ( - self._debug and (time.time() - start_time) % 5 < 0.1 - ): # Print every ~5 seconds - logger.debug( - f"[WebRTC Queue] Still waiting on API ID {api_id} - elapsed: {time.time() - start_time:.1f}s" - ) - time.sleep(0.1) - - # Check if we timed out - if self._is_busy_func() and (time.time() - start_time) >= timeout: - logger.warning(f"WebRTC request timed out: {api_id} (ID: {request_id})") - return False - - wait_time = time.time() - start_time - if self._debug: - logger.debug( - f"[WebRTC Queue] Request API ID {api_id} completed after {wait_time:.1f}s" - ) - - logger.info(f"WebRTC request completed: {api_id} (ID: {request_id})") - return True - except Exception as e: - logger.error(f"Error executing WebRTC request: {e}") - if self._debug: - logger.debug(f"[WebRTC Queue] ERROR processing request: {e}") - return False - - # Create the command and queue it - command = ROSCommand( - id=request_id, - cmd_type=CommandType.WEBRTC, - execute_func=execute_webrtc, - params={"api_id": api_id, "topic": topic, "request_id": request_id}, - priority=priority, - timeout=timeout, - ) - - # Queue the command - self._queue.put((priority, self._command_count, command)) - self._command_count += 1 - if self._debug: - logger.debug( - f"[WebRTC Queue] Added request ID {request_id} for API ID {api_id} - Queue size now: {self.queue_size}" - ) - logger.info(f"Queued WebRTC request: {api_id} (ID: {request_id}, Priority: {priority})") - - return request_id - - def queue_action_client_request( - self, - action_name: str, - execute_func: Callable, - priority: int = 0, - timeout: float = 30.0, - **kwargs, - ) -> str: - """ - Queue any action client request or function - - Args: - action_name: Name of the action for logging/tracking - execute_func: Function to execute the command - priority: Priority level (lower is higher priority) - timeout: Maximum time to wait for the command to complete - **kwargs: Additional parameters to pass to the execute function - - Returns: - str: Unique ID for the request - """ - request_id = str(uuid.uuid4()) - - # Create the command - command = ROSCommand( - id=request_id, - cmd_type=CommandType.ACTION, - execute_func=execute_func, - params={"action_name": action_name, **kwargs}, - priority=priority, - timeout=timeout, - ) - - # Queue the command - self._queue.put((priority, self._command_count, command)) - self._command_count += 1 - - action_params = ", ".join([f"{k}={v}" for k, v in kwargs.items()]) - logger.info( - f"Queued action request: {action_name} (ID: {request_id}, Priority: {priority}, Params: {action_params})" - ) - - return request_id - - def _process_queue(self): - """Process commands in the queue""" - logger.info("Starting queue processing") - logger.info("[WebRTC Queue] Processing thread started") - - while not self._should_stop: - # Print queue status - self._print_queue_status() - - # Check if we're ready to process a command - if not self._queue.empty() and self._current_command is None: - current_time = time.time() - is_ready = self._is_ready_func() - is_busy = self._is_busy_func() if self._is_busy_func else False - - if self._debug: - logger.debug( - f"[WebRTC Queue] Status: {self.queue_size} requests waiting | Robot ready: {is_ready} | Robot busy: {is_busy}" - ) - - # Track robot state changes - if is_ready != self._last_ready_state: - logger.debug( - f"Robot ready state changed: {self._last_ready_state} -> {is_ready}" - ) - self._last_ready_state = is_ready - - if is_busy != self._last_busy_state: - logger.debug(f"Robot busy state changed: {self._last_busy_state} -> {is_busy}") - self._last_busy_state = is_busy - - # If the robot has transitioned to busy, record the time - if is_busy: - self._stuck_in_busy_since = current_time - else: - self._stuck_in_busy_since = None - - # Check if we've been waiting too long for the robot to be ready - force_processing = False - if ( - not is_ready - and is_busy - and self._stuck_in_busy_since is not None - and current_time - self._stuck_in_busy_since > self._max_queue_wait_time - ): - logger.warning( - f"Robot has been busy for {current_time - self._stuck_in_busy_since:.1f}s, " - f"forcing queue to continue" - ) - force_processing = True - - # Process the next command if ready or forcing - if is_ready or force_processing: - if self._debug and is_ready: - logger.debug("[WebRTC Queue] Robot is READY for next command") - - try: - # Get the next command - _, _, command = self._queue.get(block=False) - self._current_command = command - self._last_command_time = current_time - - # Log the command - cmd_info = f"ID: {command.id}, Type: {command.cmd_type.name}" - if command.cmd_type == CommandType.WEBRTC: - api_id = command.params.get("api_id") - cmd_info += f", API: {api_id}" - if self._debug: - logger.debug(f"[WebRTC Queue] DEQUEUED request: API ID {api_id}") - elif command.cmd_type == CommandType.ACTION: - action_name = command.params.get("action_name") - cmd_info += f", Action: {action_name}" - if self._debug: - logger.debug(f"[WebRTC Queue] DEQUEUED action: {action_name}") - - forcing_str = " (FORCED)" if force_processing else "" - logger.info(f"Processing command{forcing_str}: {cmd_info}") - - # Execute the command - try: - # Where command execution occurs - success = command.execute_func() - - if success: - self._success_count += 1 - logger.info(f"Command succeeded: {cmd_info}") - if self._debug: - logger.debug( - f"[WebRTC Queue] Command {command.id} marked as COMPLETED" - ) - else: - self._failure_count += 1 - logger.warning(f"Command failed: {cmd_info}") - if self._debug: - logger.debug(f"[WebRTC Queue] Command {command.id} FAILED") - - # Record command history - self._command_history.append( - { - "id": command.id, - "type": command.cmd_type.name, - "params": command.params, - "success": success, - "time": time.time() - self._last_command_time, - } - ) - - except Exception as e: - self._failure_count += 1 - logger.error(f"Error executing command: {e}") - if self._debug: - logger.debug(f"[WebRTC Queue] ERROR executing command: {e}") - - # Mark the command as complete - self._current_command = None - if self._debug: - logger.debug( - "[WebRTC Queue] Adding 0.5s stabilization delay before next command" - ) - time.sleep(0.5) - - except Empty: - pass - - # Sleep to avoid busy-waiting - time.sleep(0.1) - - logger.info("Queue processing stopped") - - def _print_queue_status(self): - """Print the current queue status""" - current_time = time.time() - - # Only print once per second to avoid spamming the log - if current_time - self._last_command_time < 1.0 and self._current_command is None: - return - - is_ready = self._is_ready_func() - is_busy = self._is_busy_func() if self._is_busy_func else False - queue_size = self.queue_size - - # Get information about the current command - current_command_info = "None" - if self._current_command is not None: - current_command_info = f"{self._current_command.cmd_type.name}" - if self._current_command.cmd_type == CommandType.WEBRTC: - api_id = self._current_command.params.get("api_id") - current_command_info += f" (API: {api_id})" - elif self._current_command.cmd_type == CommandType.ACTION: - action_name = self._current_command.params.get("action_name") - current_command_info += f" (Action: {action_name})" - - # Print the status - status = ( - f"Queue: {queue_size} items | " - f"Robot: {'READY' if is_ready else 'BUSY'} | " - f"Current: {current_command_info} | " - f"Stats: {self._success_count} OK, {self._failure_count} FAIL" - ) - - logger.debug(status) - self._last_command_time = current_time - - @property - def queue_size(self) -> int: - """Get the number of commands in the queue""" - return self._queue.qsize() - - @property - def current_command(self) -> Optional[ROSCommand]: - """Get the current command being processed""" - return self._current_command diff --git a/build/lib/dimos/robot/ros_control.py b/build/lib/dimos/robot/ros_control.py deleted file mode 100644 index 6aa51fc3a8..0000000000 --- a/build/lib/dimos/robot/ros_control.py +++ /dev/null @@ -1,867 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import rclpy -from rclpy.node import Node -from rclpy.executors import MultiThreadedExecutor -from rclpy.action import ActionClient -from geometry_msgs.msg import Twist -from nav2_msgs.action import Spin - -from sensor_msgs.msg import Image, CompressedImage -from cv_bridge import CvBridge -from enum import Enum, auto -import threading -import time -from typing import Optional, Dict, Any, Type -from abc import ABC, abstractmethod -from rclpy.qos import ( - QoSProfile, - QoSReliabilityPolicy, - QoSHistoryPolicy, - QoSDurabilityPolicy, -) -from dimos.stream.ros_video_provider import ROSVideoProvider -import math -from builtin_interfaces.msg import Duration -from geometry_msgs.msg import Point, Vector3 -from dimos.robot.ros_command_queue import ROSCommandQueue -from dimos.utils.logging_config import setup_logger - -from nav_msgs.msg import OccupancyGrid - -import tf2_ros -from dimos.robot.ros_transform import ROSTransformAbility -from dimos.robot.ros_observable_topic import ROSObservableTopicAbility -from dimos.robot.connection_interface import ConnectionInterface -from dimos.types.vector import Vector - -from nav_msgs.msg import Odometry - -logger = setup_logger("dimos.robot.ros_control") - -__all__ = ["ROSControl", "RobotMode"] - - -class RobotMode(Enum): - """Enum for robot modes""" - - UNKNOWN = auto() - INITIALIZING = auto() - IDLE = auto() - MOVING = auto() - ERROR = auto() - - -class ROSControl(ROSTransformAbility, ROSObservableTopicAbility, ConnectionInterface, ABC): - """Abstract base class for ROS-controlled robots""" - - def __init__( - self, - node_name: str, - camera_topics: Dict[str, str] = None, - max_linear_velocity: float = 1.0, - mock_connection: bool = False, - max_angular_velocity: float = 2.0, - state_topic: str = None, - imu_topic: str = None, - state_msg_type: Type = None, - imu_msg_type: Type = None, - webrtc_topic: str = None, - webrtc_api_topic: str = None, - webrtc_msg_type: Type = None, - move_vel_topic: str = None, - pose_topic: str = None, - odom_topic: str = "/odom", - global_costmap_topic: str = "map", - costmap_topic: str = "/local_costmap/costmap", - debug: bool = False, - ): - """ - Initialize base ROS control interface - Args: - node_name: Name for the ROS node - camera_topics: Dictionary of camera topics - max_linear_velocity: Maximum linear velocity (m/s) - max_angular_velocity: Maximum angular velocity (rad/s) - state_topic: Topic name for robot state (optional) - imu_topic: Topic name for IMU data (optional) - state_msg_type: The ROS message type for state data - imu_msg_type: The ROS message type for IMU data - webrtc_topic: Topic for WebRTC commands - webrtc_api_topic: Topic for WebRTC API commands - webrtc_msg_type: The ROS message type for webrtc data - move_vel_topic: Topic for direct movement commands - pose_topic: Topic for pose commands - odom_topic: Topic for odometry data - costmap_topic: Topic for local costmap data - """ - # Initialize rclpy and ROS node if not already running - if not rclpy.ok(): - rclpy.init() - - self._state_topic = state_topic - self._imu_topic = imu_topic - self._odom_topic = odom_topic - self._costmap_topic = costmap_topic - self._state_msg_type = state_msg_type - self._imu_msg_type = imu_msg_type - self._webrtc_msg_type = webrtc_msg_type - self._webrtc_topic = webrtc_topic - self._webrtc_api_topic = webrtc_api_topic - self._node = Node(node_name) - self._global_costmap_topic = global_costmap_topic - self._debug = debug - - # Prepare a multi-threaded executor - self._executor = MultiThreadedExecutor() - - # Movement constraints - self.MAX_LINEAR_VELOCITY = max_linear_velocity - self.MAX_ANGULAR_VELOCITY = max_angular_velocity - - self._subscriptions = [] - - # Track State variables - self._robot_state = None # Full state message - self._imu_state = None # Full IMU message - self._odom_data = None # Odometry data - self._costmap_data = None # Costmap data - self._mode = RobotMode.INITIALIZING - - # Create sensor data QoS profile - sensor_qos = QoSProfile( - reliability=QoSReliabilityPolicy.BEST_EFFORT, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.VOLATILE, - depth=1, - ) - - command_qos = QoSProfile( - reliability=QoSReliabilityPolicy.RELIABLE, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.VOLATILE, - depth=10, # Higher depth for commands to ensure delivery - ) - - if self._global_costmap_topic: - self._global_costmap_data = None - self._global_costmap_sub = self._node.create_subscription( - OccupancyGrid, - self._global_costmap_topic, - self._global_costmap_callback, - sensor_qos, - ) - self._subscriptions.append(self._global_costmap_sub) - else: - logger.warning("No costmap topic provided - costmap data tracking will be unavailable") - - # Initialize data handling - self._video_provider = None - self._bridge = None - if camera_topics: - self._bridge = CvBridge() - self._video_provider = ROSVideoProvider(dev_name=f"{node_name}_video") - - # Create subscribers for each topic with sensor QoS - for camera_config in camera_topics.values(): - topic = camera_config["topic"] - msg_type = camera_config["type"] - - logger.info( - f"Subscribing to {topic} with BEST_EFFORT QoS using message type {msg_type.__name__}" - ) - _camera_subscription = self._node.create_subscription( - msg_type, topic, self._image_callback, sensor_qos - ) - self._subscriptions.append(_camera_subscription) - - # Subscribe to state topic if provided - if self._state_topic and self._state_msg_type: - logger.info(f"Subscribing to {state_topic} with BEST_EFFORT QoS") - self._state_sub = self._node.create_subscription( - self._state_msg_type, - self._state_topic, - self._state_callback, - qos_profile=sensor_qos, - ) - self._subscriptions.append(self._state_sub) - else: - logger.warning( - "No state topic andor message type provided - robot state tracking will be unavailable" - ) - - if self._imu_topic and self._imu_msg_type: - self._imu_sub = self._node.create_subscription( - self._imu_msg_type, self._imu_topic, self._imu_callback, sensor_qos - ) - self._subscriptions.append(self._imu_sub) - else: - logger.warning( - "No IMU topic and/or message type provided - IMU data tracking will be unavailable" - ) - - if self._odom_topic: - self._odom_sub = self._node.create_subscription( - Odometry, self._odom_topic, self._odom_callback, sensor_qos - ) - self._subscriptions.append(self._odom_sub) - else: - logger.warning( - "No odometry topic provided - odometry data tracking will be unavailable" - ) - - if self._costmap_topic: - self._costmap_sub = self._node.create_subscription( - OccupancyGrid, self._costmap_topic, self._costmap_callback, sensor_qos - ) - self._subscriptions.append(self._costmap_sub) - else: - logger.warning("No costmap topic provided - costmap data tracking will be unavailable") - - # Nav2 Action Clients - self._spin_client = ActionClient(self._node, Spin, "spin") - - # Wait for action servers - if not mock_connection: - self._spin_client.wait_for_server() - - # Publishers - self._move_vel_pub = self._node.create_publisher(Twist, move_vel_topic, command_qos) - self._pose_pub = self._node.create_publisher(Vector3, pose_topic, command_qos) - - if webrtc_msg_type: - self._webrtc_pub = self._node.create_publisher( - webrtc_msg_type, webrtc_topic, qos_profile=command_qos - ) - - # Initialize command queue - self._command_queue = ROSCommandQueue( - webrtc_func=self.webrtc_req, - is_ready_func=lambda: self._mode == RobotMode.IDLE, - is_busy_func=lambda: self._mode == RobotMode.MOVING, - ) - # Start the queue processing thread - self._command_queue.start() - else: - logger.warning("No WebRTC message type provided - WebRTC commands will be unavailable") - - # Initialize TF Buffer and Listener for transform abilities - self._tf_buffer = tf2_ros.Buffer() - self._tf_listener = tf2_ros.TransformListener(self._tf_buffer, self._node) - logger.info(f"TF Buffer and Listener initialized for {node_name}") - - # Start ROS spin in a background thread via the executor - self._spin_thread = threading.Thread(target=self._ros_spin, daemon=True) - self._spin_thread.start() - - logger.info(f"{node_name} initialized with multi-threaded executor") - print(f"{node_name} initialized with multi-threaded executor") - - def get_global_costmap(self) -> Optional[OccupancyGrid]: - """ - Get current global_costmap data - - Returns: - Optional[OccupancyGrid]: Current global_costmap data or None if not available - """ - if not self._global_costmap_topic: - logger.warning( - "No global_costmap topic provided - global_costmap data tracking will be unavailable" - ) - return None - - if self._global_costmap_data: - return self._global_costmap_data - else: - return None - - def _global_costmap_callback(self, msg): - """Callback for costmap data""" - self._global_costmap_data = msg - - def _imu_callback(self, msg): - """Callback for IMU data""" - self._imu_state = msg - # Log IMU state (very verbose) - # logger.debug(f"IMU state updated: {self._imu_state}") - - def _odom_callback(self, msg): - """Callback for odometry data""" - self._odom_data = msg - - def _costmap_callback(self, msg): - """Callback for costmap data""" - self._costmap_data = msg - - def _state_callback(self, msg): - """Callback for state messages to track mode and progress""" - - # Call the abstract method to update RobotMode enum based on the received state - self._robot_state = msg - self._update_mode(msg) - # Log state changes (very verbose) - # logger.debug(f"Robot state updated: {self._robot_state}") - - @property - def robot_state(self) -> Optional[Any]: - """Get the full robot state message""" - return self._robot_state - - def _ros_spin(self): - """Background thread for spinning the multi-threaded executor.""" - self._executor.add_node(self._node) - try: - self._executor.spin() - finally: - self._executor.shutdown() - - def _clamp_velocity(self, velocity: float, max_velocity: float) -> float: - """Clamp velocity within safe limits""" - return max(min(velocity, max_velocity), -max_velocity) - - @abstractmethod - def _update_mode(self, *args, **kwargs): - """Update robot mode based on state - to be implemented by child classes""" - pass - - def get_state(self) -> Optional[Any]: - """ - Get current robot state - - Base implementation provides common state fields. Child classes should - extend this method to include their specific state information. - - Returns: - ROS msg containing the robot state information - """ - if not self._state_topic: - logger.warning("No state topic provided - robot state tracking will be unavailable") - return None - - return self._robot_state - - def get_imu_state(self) -> Optional[Any]: - """ - Get current IMU state - - Base implementation provides common state fields. Child classes should - extend this method to include their specific state information. - - Returns: - ROS msg containing the IMU state information - """ - if not self._imu_topic: - logger.warning("No IMU topic provided - IMU data tracking will be unavailable") - return None - return self._imu_state - - def get_odometry(self) -> Optional[Odometry]: - """ - Get current odometry data - - Returns: - Optional[Odometry]: Current odometry data or None if not available - """ - if not self._odom_topic: - logger.warning( - "No odometry topic provided - odometry data tracking will be unavailable" - ) - return None - return self._odom_data - - def get_costmap(self) -> Optional[OccupancyGrid]: - """ - Get current costmap data - - Returns: - Optional[OccupancyGrid]: Current costmap data or None if not available - """ - if not self._costmap_topic: - logger.warning("No costmap topic provided - costmap data tracking will be unavailable") - return None - return self._costmap_data - - def _image_callback(self, msg): - """Convert ROS image to numpy array and push to data stream""" - if self._video_provider and self._bridge: - try: - if isinstance(msg, CompressedImage): - frame = self._bridge.compressed_imgmsg_to_cv2(msg) - elif isinstance(msg, Image): - frame = self._bridge.imgmsg_to_cv2(msg, "bgr8") - else: - logger.error(f"Unsupported image message type: {type(msg)}") - return - self._video_provider.push_data(frame) - except Exception as e: - logger.error(f"Error converting image: {e}") - print(f"Full conversion error: {str(e)}") - - @property - def video_provider(self) -> Optional[ROSVideoProvider]: - """Data provider property for streaming data""" - return self._video_provider - - def get_video_stream(self, fps: int = 30) -> Optional[Observable]: - """Get the video stream from the robot's camera. - - Args: - fps: Frames per second for the video stream - - Returns: - Observable: An observable stream of video frames or None if not available - """ - if not self.video_provider: - return None - - return self.video_provider.get_stream(fps=fps) - - def _send_action_client_goal(self, client, goal_msg, description=None, time_allowance=20.0): - """ - Generic function to send any action client goal and wait for completion. - - Args: - client: The action client to use - goal_msg: The goal message to send - description: Optional description for logging - time_allowance: Maximum time to wait for completion - - Returns: - bool: True if action succeeded, False otherwise - """ - if description: - logger.info(description) - - print(f"[ROSControl] Sending action client goal: {description}") - print(f"[ROSControl] Goal message: {goal_msg}") - - # Reset action result tracking - self._action_success = None - - # Send the goal - send_goal_future = client.send_goal_async(goal_msg, feedback_callback=lambda feedback: None) - send_goal_future.add_done_callback(self._goal_response_callback) - - # Wait for completion - start_time = time.time() - while self._action_success is None and time.time() - start_time < time_allowance: - time.sleep(0.1) - - elapsed = time.time() - start_time - print( - f"[ROSControl] Action completed in {elapsed:.2f}s with result: {self._action_success}" - ) - - # Check result - if self._action_success is None: - logger.error(f"Action timed out after {time_allowance}s") - return False - elif self._action_success: - logger.info("Action succeeded") - return True - else: - logger.error("Action failed") - return False - - def move(self, velocity: Vector, duration: float = 0.0) -> bool: - """Send velocity commands to the robot. - - Args: - velocity: Velocity vector [x, y, yaw] where: - x: Linear velocity in x direction (m/s) - y: Linear velocity in y direction (m/s) - yaw: Angular velocity around z axis (rad/s) - duration: Duration to apply command (seconds). If 0, apply once. - - Returns: - bool: True if command was sent successfully - """ - x, y, yaw = velocity.x, velocity.y, velocity.z - - # Clamp velocities to safe limits - x = self._clamp_velocity(x, self.MAX_LINEAR_VELOCITY) - y = self._clamp_velocity(y, self.MAX_LINEAR_VELOCITY) - yaw = self._clamp_velocity(yaw, self.MAX_ANGULAR_VELOCITY) - - # Create and send command - cmd = Twist() - cmd.linear.x = float(x) - cmd.linear.y = float(y) - cmd.angular.z = float(yaw) - - try: - if duration > 0: - start_time = time.time() - while time.time() - start_time < duration: - self._move_vel_pub.publish(cmd) - time.sleep(0.1) # 10Hz update rate - # Stop after duration - self.stop() - else: - self._move_vel_pub.publish(cmd) - return True - - except Exception as e: - self._logger.error(f"Failed to send movement command: {e}") - return False - - def reverse(self, distance: float, speed: float = 0.5, time_allowance: float = 120) -> bool: - """ - Move the robot backward by a specified distance - - Args: - distance: Distance to move backward in meters (must be positive) - speed: Speed to move at in m/s (default 0.5) - time_allowance: Maximum time to wait for the request to complete - - Returns: - bool: True if movement succeeded - """ - try: - if distance <= 0: - logger.error("Distance must be positive") - return False - - speed = min(abs(speed), self.MAX_LINEAR_VELOCITY) - - # Define function to execute the reverse - def execute_reverse(): - # Create BackUp goal - goal = BackUp.Goal() - goal.target = Point() - goal.target.x = -distance # Negative for backward motion - goal.target.y = 0.0 - goal.target.z = 0.0 - goal.speed = speed # BackUp expects positive speed - goal.time_allowance = Duration(sec=time_allowance) - - print( - f"[ROSControl] execute_reverse: Creating BackUp goal with distance={distance}m, speed={speed}m/s" - ) - print( - f"[ROSControl] execute_reverse: Goal details: x={goal.target.x}, y={goal.target.y}, z={goal.target.z}, speed={goal.speed}" - ) - - logger.info(f"Moving backward: distance={distance}m, speed={speed}m/s") - - result = self._send_action_client_goal( - self._backup_client, - goal, - f"Moving backward {distance}m at {speed}m/s", - time_allowance, - ) - - print(f"[ROSControl] execute_reverse: BackUp action result: {result}") - return result - - # Queue the action - cmd_id = self._command_queue.queue_action_client_request( - action_name="reverse", - execute_func=execute_reverse, - priority=0, - timeout=time_allowance, - distance=distance, - speed=speed, - ) - logger.info( - f"Queued reverse command: {cmd_id} - Distance: {distance}m, Speed: {speed}m/s" - ) - return True - - except Exception as e: - logger.error(f"Backward movement failed: {e}") - import traceback - - logger.error(traceback.format_exc()) - return False - - def spin(self, degrees: float, speed: float = 45.0, time_allowance: float = 120) -> bool: - """ - Rotate the robot by a specified angle - - Args: - degrees: Angle to rotate in degrees (positive for counter-clockwise, negative for clockwise) - speed: Angular speed in degrees/second (default 45.0) - time_allowance: Maximum time to wait for the request to complete - - Returns: - bool: True if movement succeeded - """ - try: - # Convert degrees to radians - angle = math.radians(degrees) - angular_speed = math.radians(abs(speed)) - - # Clamp angular speed - angular_speed = min(angular_speed, self.MAX_ANGULAR_VELOCITY) - time_allowance = max( - int(abs(angle) / angular_speed * 2), 20 - ) # At least 20 seconds or double the expected time - - # Define function to execute the spin - def execute_spin(): - # Create Spin goal - goal = Spin.Goal() - goal.target_yaw = angle # Nav2 Spin action expects radians - goal.time_allowance = Duration(sec=time_allowance) - - logger.info(f"Spinning: angle={degrees}deg ({angle:.2f}rad)") - - return self._send_action_client_goal( - self._spin_client, - goal, - f"Spinning {degrees} degrees at {speed} deg/s", - time_allowance, - ) - - # Queue the action - cmd_id = self._command_queue.queue_action_client_request( - action_name="spin", - execute_func=execute_spin, - priority=0, - timeout=time_allowance, - degrees=degrees, - speed=speed, - ) - logger.info(f"Queued spin command: {cmd_id} - Degrees: {degrees}, Speed: {speed}deg/s") - return True - - except Exception as e: - logger.error(f"Spin movement failed: {e}") - import traceback - - logger.error(traceback.format_exc()) - return False - - def stop(self) -> bool: - """Stop all robot movement""" - try: - # self.navigator.cancelTask() - self._current_velocity = {"x": 0.0, "y": 0.0, "z": 0.0} - self._is_moving = False - return True - except Exception as e: - logger.error(f"Failed to stop movement: {e}") - return False - - def cleanup(self): - """Cleanup the executor, ROS node, and stop robot.""" - self.stop() - - # Stop the WebRTC queue manager - if self._command_queue: - logger.info("Stopping WebRTC queue manager...") - self._command_queue.stop() - - # Shut down the executor to stop spin loop cleanly - self._executor.shutdown() - - # Destroy node and shutdown rclpy - self._node.destroy_node() - rclpy.shutdown() - - def disconnect(self) -> None: - """Disconnect from the robot and clean up resources.""" - self.cleanup() - - def webrtc_req( - self, - api_id: int, - topic: str = None, - parameter: str = "", - priority: int = 0, - request_id: str = None, - data=None, - ) -> bool: - """ - Send a WebRTC request command to the robot - - Args: - api_id: The API ID for the command - topic: The API topic to publish to (defaults to self._webrtc_api_topic) - parameter: Optional parameter string - priority: Priority level (0 or 1) - request_id: Optional request ID for tracking (not used in ROS implementation) - data: Optional data dictionary (not used in ROS implementation) - params: Optional params dictionary (not used in ROS implementation) - - Returns: - bool: True if command was sent successfully - """ - try: - # Create and send command - cmd = self._webrtc_msg_type() - cmd.api_id = api_id - cmd.topic = topic if topic is not None else self._webrtc_api_topic - cmd.parameter = parameter - cmd.priority = priority - - self._webrtc_pub.publish(cmd) - logger.info(f"Sent WebRTC request: api_id={api_id}, topic={cmd.topic}") - return True - - except Exception as e: - logger.error(f"Failed to send WebRTC request: {e}") - return False - - def get_robot_mode(self) -> RobotMode: - """ - Get the current robot mode - - Returns: - RobotMode: The current robot mode enum value - """ - return self._mode - - def print_robot_mode(self): - """Print the current robot mode to the console""" - mode = self.get_robot_mode() - print(f"Current RobotMode: {mode.name}") - print(f"Mode enum: {mode}") - - def queue_webrtc_req( - self, - api_id: int, - topic: str = None, - parameter: str = "", - priority: int = 0, - timeout: float = 90.0, - request_id: str = None, - data=None, - ) -> str: - """ - Queue a WebRTC request to be sent when the robot is IDLE - - Args: - api_id: The API ID for the command - topic: The topic to publish to (defaults to self._webrtc_api_topic) - parameter: Optional parameter string - priority: Priority level (0 or 1) - timeout: Maximum time to wait for the request to complete - request_id: Optional request ID (if None, one will be generated) - data: Optional data dictionary (not used in ROS implementation) - - Returns: - str: Request ID that can be used to track the request - """ - return self._command_queue.queue_webrtc_request( - api_id=api_id, - topic=topic if topic is not None else self._webrtc_api_topic, - parameter=parameter, - priority=priority, - timeout=timeout, - request_id=request_id, - data=data, - ) - - def move_vel_control(self, x: float, y: float, yaw: float) -> bool: - """ - Send a single velocity command without duration handling. - - Args: - x: Forward/backward velocity (m/s) - y: Left/right velocity (m/s) - yaw: Rotational velocity (rad/s) - - Returns: - bool: True if command was sent successfully - """ - # Clamp velocities to safe limits - x = self._clamp_velocity(x, self.MAX_LINEAR_VELOCITY) - y = self._clamp_velocity(y, self.MAX_LINEAR_VELOCITY) - yaw = self._clamp_velocity(yaw, self.MAX_ANGULAR_VELOCITY) - - # Create and send command - cmd = Twist() - cmd.linear.x = float(x) - cmd.linear.y = float(y) - cmd.angular.z = float(yaw) - - try: - self._move_vel_pub.publish(cmd) - return True - except Exception as e: - logger.error(f"Failed to send velocity command: {e}") - return False - - def pose_command(self, roll: float, pitch: float, yaw: float) -> bool: - """ - Send a pose command to the robot to adjust its body orientation - - Args: - roll: Roll angle in radians - pitch: Pitch angle in radians - yaw: Yaw angle in radians - - Returns: - bool: True if command was sent successfully - """ - # Create the pose command message - cmd = Vector3() - cmd.x = float(roll) # Roll - cmd.y = float(pitch) # Pitch - cmd.z = float(yaw) # Yaw - - try: - self._pose_pub.publish(cmd) - logger.debug(f"Sent pose command: roll={roll}, pitch={pitch}, yaw={yaw}") - return True - except Exception as e: - logger.error(f"Failed to send pose command: {e}") - return False - - def get_position_stream(self): - """ - Get a stream of position updates from ROS. - - Returns: - Observable that emits (x, y) tuples representing the robot's position - """ - from dimos.robot.position_stream import PositionStreamProvider - - # Create a position stream provider - position_provider = PositionStreamProvider( - ros_node=self._node, - odometry_topic="/odom", # Default odometry topic - use_odometry=True, - ) - - return position_provider.get_position_stream() - - def _goal_response_callback(self, future): - """Handle the goal response.""" - goal_handle = future.result() - if not goal_handle.accepted: - logger.warn("Goal was rejected!") - print("[ROSControl] Goal was REJECTED by the action server") - self._action_success = False - return - - logger.info("Goal accepted") - print("[ROSControl] Goal was ACCEPTED by the action server") - result_future = goal_handle.get_result_async() - result_future.add_done_callback(self._goal_result_callback) - - def _goal_result_callback(self, future): - """Handle the goal result.""" - try: - result = future.result().result - logger.info("Goal completed") - print(f"[ROSControl] Goal COMPLETED with result: {result}") - self._action_success = True - except Exception as e: - logger.error(f"Goal failed with error: {e}") - print(f"[ROSControl] Goal FAILED with error: {e}") - self._action_success = False diff --git a/build/lib/dimos/robot/ros_observable_topic.py b/build/lib/dimos/robot/ros_observable_topic.py deleted file mode 100644 index 697ddff398..0000000000 --- a/build/lib/dimos/robot/ros_observable_topic.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import functools -import enum -import reactivex as rx -from reactivex import operators as ops -from reactivex.disposable import Disposable -from reactivex.scheduler import ThreadPoolScheduler -from rxpy_backpressure import BackPressure - -from nav_msgs import msg -from dimos.utils.logging_config import setup_logger -from dimos.utils.threadpool import get_scheduler -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector - -from typing import Union, Callable, Any - -from rclpy.qos import ( - QoSProfile, - QoSReliabilityPolicy, - QoSHistoryPolicy, - QoSDurabilityPolicy, -) - -__all__ = ["ROSObservableTopicAbility", "QOS"] - -ConversionType = Costmap -TopicType = Union[ConversionType, msg.OccupancyGrid, msg.Odometry] - - -class QOS(enum.Enum): - SENSOR = "sensor" - COMMAND = "command" - - def to_profile(self) -> QoSProfile: - if self == QOS.SENSOR: - return QoSProfile( - reliability=QoSReliabilityPolicy.BEST_EFFORT, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.VOLATILE, - depth=1, - ) - if self == QOS.COMMAND: - return QoSProfile( - reliability=QoSReliabilityPolicy.RELIABLE, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.VOLATILE, - depth=10, # Higher depth for commands to ensure delivery - ) - - raise ValueError(f"Unknown QoS enum value: {self}") - - -logger = setup_logger("dimos.robot.ros_control.observable_topic") - - -class ROSObservableTopicAbility: - # Ensures that we can return multiple observables which have multiple subscribers - # consuming the same topic at different (blocking) rates while: - # - # - immediately returning latest value received to new subscribers - # - allowing slow subscribers to consume the topic without blocking fast ones - # - dealing with backpressure from slow subscribers (auto dropping unprocessed messages) - # - # (for more details see corresponding test file) - # - # ROS thread ─► ReplaySubject─► observe_on(pool) ─► backpressure.latest ─► sub1 (fast) - # ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) - # └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) - # - def _maybe_conversion(self, msg_type: TopicType, callback) -> Callable[[TopicType], Any]: - if msg_type == Costmap: - return lambda msg: callback(Costmap.from_msg(msg)) - # just for test, not sure if this Vector auto-instantiation is used irl - if msg_type == Vector: - return lambda msg: callback(Vector.from_msg(msg)) - return callback - - def _sub_msg_type(self, msg_type): - if msg_type == Costmap: - return msg.OccupancyGrid - - if msg_type == Vector: - return msg.Odometry - - return msg_type - - @functools.lru_cache(maxsize=None) - def topic( - self, - topic_name: str, - msg_type: TopicType, - qos=QOS.SENSOR, - scheduler: ThreadPoolScheduler | None = None, - drop_unprocessed: bool = True, - ) -> rx.Observable: - if scheduler is None: - scheduler = get_scheduler() - - # Convert QOS to QoSProfile - qos_profile = qos.to_profile() - - # upstream ROS callback - def _on_subscribe(obs, _): - ros_sub = self._node.create_subscription( - self._sub_msg_type(msg_type), - topic_name, - self._maybe_conversion(msg_type, obs.on_next), - qos_profile, - ) - return Disposable(lambda: self._node.destroy_subscription(ros_sub)) - - upstream = rx.create(_on_subscribe) - - # hot, latest-cached core - core = upstream.pipe( - ops.replay(buffer_size=1), - ops.ref_count(), # still synchronous! - ) - - # per-subscriber factory - def per_sub(): - # hop off the ROS thread into the pool - base = core.pipe(ops.observe_on(scheduler)) - - # optional back-pressure handling - if not drop_unprocessed: - return base - - def _subscribe(observer, sch=None): - return base.subscribe(BackPressure.LATEST(observer), scheduler=sch) - - return rx.create(_subscribe) - - # each `.subscribe()` call gets its own async backpressure chain - return rx.defer(lambda *_: per_sub()) - - # If you are not interested in processing streams, just want to fetch the latest stream - # value use this function. It runs a subscription in the background. - # caches latest value for you, always ready to return. - # - # odom = robot.topic_latest("/odom", msg.Odometry) - # the initial call to odom() will block until the first message is received - # - # any time you'd like you can call: - # - # print(f"Latest odom: {odom()}") - # odom.dispose() # clean up the subscription - # - # see test_ros_observable_topic.py test_topic_latest for more details - def topic_latest( - self, topic_name: str, msg_type: TopicType, timeout: float | None = 100.0, qos=QOS.SENSOR - ): - """ - Blocks the current thread until the first message is received, then - returns `reader()` (sync) and keeps one ROS subscription alive - in the background. - - latest_scan = robot.ros_control.topic_latest_blocking("scan", LaserScan) - do_something(latest_scan()) # instant - latest_scan.dispose() # clean up - """ - # one shared observable with a 1-element replay buffer - core = self.topic(topic_name, msg_type, qos=qos).pipe(ops.replay(buffer_size=1)) - conn = core.connect() # starts the ROS subscription immediately - - try: - first_val = core.pipe( - ops.first(), *([ops.timeout(timeout)] if timeout is not None else []) - ).run() - except Exception: - conn.dispose() - msg = f"{topic_name} message not received after {timeout} seconds. Is robot connected?" - logger.error(msg) - raise Exception(msg) - - cache = {"val": first_val} - sub = core.subscribe(lambda v: cache.__setitem__("val", v)) - - def reader(): - return cache["val"] - - reader.dispose = lambda: (sub.dispose(), conn.dispose()) - return reader - - # If you are not interested in processing streams, just want to fetch the latest stream - # value use this function. It runs a subscription in the background. - # caches latest value for you, always ready to return - # - # odom = await robot.topic_latest_async("/odom", msg.Odometry) - # - # async nature of this function allows you to do other stuff while you wait - # for a first message to arrive - # - # any time you'd like you can call: - # - # print(f"Latest odom: {odom()}") - # odom.dispose() # clean up the subscription - # - # see test_ros_observable_topic.py test_topic_latest for more details - async def topic_latest_async( - self, topic_name: str, msg_type: TopicType, qos=QOS.SENSOR, timeout: float = 30.0 - ): - loop = asyncio.get_running_loop() - first = loop.create_future() - cache = {"val": None} - - core = self.topic(topic_name, msg_type, qos=qos) # single ROS callback - - def _on_next(v): - cache["val"] = v - if not first.done(): - loop.call_soon_threadsafe(first.set_result, v) - - subscription = core.subscribe(_on_next) - - try: - await asyncio.wait_for(first, timeout) - except Exception: - subscription.dispose() - raise - - def reader(): - return cache["val"] - - reader.dispose = subscription.dispose - return reader diff --git a/build/lib/dimos/robot/ros_transform.py b/build/lib/dimos/robot/ros_transform.py deleted file mode 100644 index b0c46fd275..0000000000 --- a/build/lib/dimos/robot/ros_transform.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import rclpy -from typing import Optional -from geometry_msgs.msg import TransformStamped -from tf2_ros import Buffer -import tf2_ros -from tf2_geometry_msgs import PointStamped -from dimos.utils.logging_config import setup_logger -from dimos.types.vector import Vector -from dimos.types.path import Path -from scipy.spatial.transform import Rotation as R - -logger = setup_logger("dimos.robot.ros_transform") - -__all__ = ["ROSTransformAbility"] - - -def to_euler_rot(msg: TransformStamped) -> [Vector, Vector]: - q = msg.transform.rotation - rotation = R.from_quat([q.x, q.y, q.z, q.w]) - return Vector(rotation.as_euler("xyz", degrees=False)) - - -def to_euler_pos(msg: TransformStamped) -> [Vector, Vector]: - return Vector(msg.transform.translation).to_2d() - - -def to_euler(msg: TransformStamped) -> [Vector, Vector]: - return [to_euler_pos(msg), to_euler_rot(msg)] - - -class ROSTransformAbility: - """Mixin class for handling ROS transforms between coordinate frames""" - - @property - def tf_buffer(self) -> Buffer: - if not hasattr(self, "_tf_buffer"): - self._tf_buffer = tf2_ros.Buffer() - self._tf_listener = tf2_ros.TransformListener(self._tf_buffer, self._node) - logger.info("Transform listener initialized") - - return self._tf_buffer - - def transform_euler_pos( - self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - return to_euler_pos(self.transform(source_frame, target_frame, timeout)) - - def transform_euler_rot( - self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - return to_euler_rot(self.transform(source_frame, target_frame, timeout)) - - def transform_euler(self, source_frame: str, target_frame: str = "map", timeout: float = 1.0): - res = self.transform(source_frame, target_frame, timeout) - return to_euler(res) - - def transform( - self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ) -> Optional[TransformStamped]: - try: - transform = self.tf_buffer.lookup_transform( - target_frame, - source_frame, - rclpy.time.Time(), - rclpy.duration.Duration(seconds=timeout), - ) - return transform - except ( - tf2_ros.LookupException, - tf2_ros.ConnectivityException, - tf2_ros.ExtrapolationException, - ) as e: - logger.error(f"Transform lookup failed: {e}") - return None - - def transform_point( - self, point: Vector, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - """Transform a point from source_frame to target_frame. - - Args: - point: The point to transform (x, y, z) - source_frame: The source frame of the point - target_frame: The target frame to transform to - timeout: Time to wait for the transform to become available (seconds) - - Returns: - The transformed point as a Vector, or None if the transform failed - """ - try: - # Wait for transform to become available - self.tf_buffer.can_transform( - target_frame, - source_frame, - rclpy.time.Time(), - rclpy.duration.Duration(seconds=timeout), - ) - - # Create a PointStamped message - ps = PointStamped() - ps.header.frame_id = source_frame - ps.header.stamp = rclpy.time.Time().to_msg() # Latest available transform - ps.point.x = point[0] - ps.point.y = point[1] - ps.point.z = point[2] if len(point) > 2 else 0.0 - - # Transform point - transformed_ps = self.tf_buffer.transform( - ps, target_frame, rclpy.duration.Duration(seconds=timeout) - ) - - # Return as Vector type - if len(point) > 2: - return Vector( - transformed_ps.point.x, transformed_ps.point.y, transformed_ps.point.z - ) - else: - return Vector(transformed_ps.point.x, transformed_ps.point.y) - except ( - tf2_ros.LookupException, - tf2_ros.ConnectivityException, - tf2_ros.ExtrapolationException, - ) as e: - logger.error(f"Transform from {source_frame} to {target_frame} failed: {e}") - return None - - def transform_path( - self, path: Path, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - """Transform a path from source_frame to target_frame. - - Args: - path: The path to transform - source_frame: The source frame of the path - target_frame: The target frame to transform to - timeout: Time to wait for the transform to become available (seconds) - - Returns: - The transformed path as a Path, or None if the transform failed - """ - transformed_path = Path() - for point in path: - transformed_point = self.transform_point(point, source_frame, target_frame, timeout) - if transformed_point is not None: - transformed_path.append(transformed_point) - return transformed_path - - def transform_rot( - self, rotation: Vector, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - """Transform a rotation from source_frame to target_frame. - - Args: - rotation: The rotation to transform as Euler angles (x, y, z) in radians - source_frame: The source frame of the rotation - target_frame: The target frame to transform to - timeout: Time to wait for the transform to become available (seconds) - - Returns: - The transformed rotation as a Vector of Euler angles (x, y, z), or None if the transform failed - """ - try: - # Wait for transform to become available - self.tf_buffer.can_transform( - target_frame, - source_frame, - rclpy.time.Time(), - rclpy.duration.Duration(seconds=timeout), - ) - - # Create a rotation matrix from the input Euler angles - input_rotation = R.from_euler("xyz", rotation, degrees=False) - - # Get the transform from source to target frame - transform = self.transform(source_frame, target_frame, timeout) - if transform is None: - return None - - # Extract the rotation from the transform - q = transform.transform.rotation - transform_rotation = R.from_quat([q.x, q.y, q.z, q.w]) - - # Compose the rotations - # The resulting rotation is the composition of the transform rotation and input rotation - result_rotation = transform_rotation * input_rotation - - # Convert back to Euler angles - euler_angles = result_rotation.as_euler("xyz", degrees=False) - - # Return as Vector type - return Vector(euler_angles) - - except ( - tf2_ros.LookupException, - tf2_ros.ConnectivityException, - tf2_ros.ExtrapolationException, - ) as e: - logger.error(f"Transform rotation from {source_frame} to {target_frame} failed: {e}") - return None - - def transform_pose( - self, - position: Vector, - rotation: Vector, - source_frame: str, - target_frame: str = "map", - timeout: float = 1.0, - ): - """Transform a pose from source_frame to target_frame. - - Args: - position: The position to transform - rotation: The rotation to transform - source_frame: The source frame of the pose - target_frame: The target frame to transform to - timeout: Time to wait for the transform to become available (seconds) - - Returns: - Tuple of (transformed_position, transformed_rotation) as Vectors, - or (None, None) if either transform failed - """ - # Transform position - transformed_position = self.transform_point(position, source_frame, target_frame, timeout) - - # Transform rotation - transformed_rotation = self.transform_rot(rotation, source_frame, target_frame, timeout) - - # Return results (both might be None if transforms failed) - return transformed_position, transformed_rotation diff --git a/build/lib/dimos/robot/test_ros_observable_topic.py b/build/lib/dimos/robot/test_ros_observable_topic.py deleted file mode 100644 index 71a1484de3..0000000000 --- a/build/lib/dimos/robot/test_ros_observable_topic.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -import time -import pytest -from dimos.utils.logging_config import setup_logger -from dimos.types.vector import Vector -import asyncio - - -class MockROSNode: - def __init__(self): - self.logger = setup_logger("ROS") - - self.sub_id_cnt = 0 - self.subs = {} - - def _get_sub_id(self): - sub_id = self.sub_id_cnt - self.sub_id_cnt += 1 - return sub_id - - def create_subscription(self, msg_type, topic_name, callback, qos): - # Mock implementation of ROS subscription - - sub_id = self._get_sub_id() - stop_event = threading.Event() - self.subs[sub_id] = stop_event - self.logger.info(f"Subscribed {topic_name} subid {sub_id}") - - # Create message simulation thread - def simulate_messages(): - message_count = 0 - while not stop_event.is_set(): - message_count += 1 - time.sleep(0.1) # 20Hz default publication rate - if topic_name == "/vector": - callback([message_count, message_count]) - else: - callback(message_count) - # cleanup - self.subs.pop(sub_id) - - thread = threading.Thread(target=simulate_messages, daemon=True) - thread.start() - return sub_id - - def destroy_subscription(self, subscription): - if subscription in self.subs: - self.subs[subscription].set() - self.logger.info(f"Destroyed subscription: {subscription}") - else: - self.logger.info(f"Unknown subscription: {subscription}") - - -# we are doing this in order to avoid importing ROS dependencies if ros tests aren't runnin -@pytest.fixture -def robot(): - from dimos.robot.ros_observable_topic import ROSObservableTopicAbility - - class MockRobot(ROSObservableTopicAbility): - def __init__(self): - self.logger = setup_logger("ROBOT") - # Initialize the mock ROS node - self._node = MockROSNode() - - return MockRobot() - - -# This test verifies a bunch of basics: -# -# 1. that the system creates a single ROS sub for multiple reactivex subs -# 2. that the system creates a single ROS sub for multiple observers -# 3. that the system unsubscribes from ROS when observers are disposed -# 4. that the system replays the last message to new observers, -# before the new ROS sub starts producing -@pytest.mark.ros -def test_parallel_and_cleanup(robot): - from nav_msgs import msg - - received_messages = [] - - obs1 = robot.topic("/odom", msg.Odometry) - - print(f"Created subscription: {obs1}") - - subscription1 = obs1.subscribe(lambda x: received_messages.append(x + 2)) - - subscription2 = obs1.subscribe(lambda x: received_messages.append(x + 3)) - - obs2 = robot.topic("/odom", msg.Odometry) - subscription3 = obs2.subscribe(lambda x: received_messages.append(x + 5)) - - time.sleep(0.25) - - # We have 2 messages and 3 subscribers - assert len(received_messages) == 6, "Should have received exactly 6 messages" - - # [1, 1, 1, 2, 2, 2] + - # [2, 3, 5, 2, 3, 5] - # = - for i in [3, 4, 6, 4, 5, 7]: - assert i in received_messages, f"Expected {i} in received messages, got {received_messages}" - - # ensure that ROS end has only a single subscription - assert len(robot._node.subs) == 1, ( - f"Expected 1 subscription, got {len(robot._node.subs)}: {robot._node.subs}" - ) - - subscription1.dispose() - subscription2.dispose() - subscription3.dispose() - - # Make sure that ros end was unsubscribed, thread terminated - time.sleep(0.1) - assert not robot._node.subs, f"Expected empty subs dict, got: {robot._node.subs}" - - # Ensure we replay the last message - second_received = [] - second_sub = obs1.subscribe(lambda x: second_received.append(x)) - - time.sleep(0.075) - # we immediately receive the stored topic message - assert len(second_received) == 1 - - # now that sub is hot, we wait for a second one - time.sleep(0.2) - - # we expect 2, 1 since first message was preserved from a previous ros topic sub - # second one is the first message of the second ros topic sub - assert second_received == [2, 1, 2] - - print(f"Second subscription immediately received {len(second_received)} message(s)") - - second_sub.dispose() - - time.sleep(0.1) - assert not robot._node.subs, f"Expected empty subs dict, got: {robot._node.subs}" - - print("Test completed successfully") - - -# here we test parallel subs and slow observers hogging our topic -# we expect slow observers to skip messages by default -# -# ROS thread ─► ReplaySubject─► observe_on(pool) ─► backpressure.latest ─► sub1 (fast) -# ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) -# └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) -@pytest.mark.ros -def test_parallel_and_hog(robot): - from nav_msgs import msg - - obs1 = robot.topic("/odom", msg.Odometry) - obs2 = robot.topic("/odom", msg.Odometry) - - subscriber1_messages = [] - subscriber2_messages = [] - subscriber3_messages = [] - - subscription1 = obs1.subscribe(lambda x: subscriber1_messages.append(x)) - subscription2 = obs1.subscribe(lambda x: time.sleep(0.15) or subscriber2_messages.append(x)) - subscription3 = obs2.subscribe(lambda x: time.sleep(0.25) or subscriber3_messages.append(x)) - - assert len(robot._node.subs) == 1 - - time.sleep(2) - - subscription1.dispose() - subscription2.dispose() - subscription3.dispose() - - print("Subscriber 1 messages:", len(subscriber1_messages), subscriber1_messages) - print("Subscriber 2 messages:", len(subscriber2_messages), subscriber2_messages) - print("Subscriber 3 messages:", len(subscriber3_messages), subscriber3_messages) - - assert len(subscriber1_messages) == 19 - assert len(subscriber2_messages) == 12 - assert len(subscriber3_messages) == 7 - - assert subscriber2_messages[1] != [2] - assert subscriber3_messages[1] != [2] - - time.sleep(0.1) - - assert robot._node.subs == {} - - -@pytest.mark.asyncio -@pytest.mark.ros -async def test_topic_latest_async(robot): - from nav_msgs import msg - - odom = await robot.topic_latest_async("/odom", msg.Odometry) - assert odom() == 1 - await asyncio.sleep(0.45) - assert odom() == 5 - odom.dispose() - await asyncio.sleep(0.1) - assert robot._node.subs == {} - - -@pytest.mark.ros -def test_topic_auto_conversion(robot): - odom = robot.topic("/vector", Vector).subscribe(lambda x: print(x)) - time.sleep(0.5) - odom.dispose() - - -@pytest.mark.ros -def test_topic_latest_sync(robot): - from nav_msgs import msg - - odom = robot.topic_latest("/odom", msg.Odometry) - assert odom() == 1 - time.sleep(0.45) - assert odom() == 5 - odom.dispose() - time.sleep(0.1) - assert robot._node.subs == {} - - -@pytest.mark.ros -def test_topic_latest_sync_benchmark(robot): - from nav_msgs import msg - - odom = robot.topic_latest("/odom", msg.Odometry) - - start_time = time.time() - for i in range(100): - odom() - end_time = time.time() - elapsed = end_time - start_time - avg_time = elapsed / 100 - - print("avg time", avg_time) - - assert odom() == 1 - time.sleep(0.45) - assert odom() >= 5 - odom.dispose() - time.sleep(0.1) - assert robot._node.subs == {} diff --git a/build/lib/dimos/robot/unitree/__init__.py b/build/lib/dimos/robot/unitree/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/robot/unitree/unitree_go2.py b/build/lib/dimos/robot/unitree/unitree_go2.py deleted file mode 100644 index ca878e7134..0000000000 --- a/build/lib/dimos/robot/unitree/unitree_go2.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -from typing import Optional, Union, List -import numpy as np -from dimos.robot.robot import Robot -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary -from reactivex.disposable import CompositeDisposable -import logging -import os -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from reactivex.scheduler import ThreadPoolScheduler -from dimos.utils.logging_config import setup_logger -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.robot.local_planner.local_planner import navigate_path_local -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.robot.global_planner.planner import AstarPlanner -from dimos.types.costmap import Costmap -from dimos.types.robot_capabilities import RobotCapability -from dimos.types.vector import Vector - -# Set up logging -logger = setup_logger("dimos.robot.unitree.unitree_go2", level=logging.DEBUG) - -# UnitreeGo2 Print Colors (Magenta) -UNITREE_GO2_PRINT_COLOR = "\033[35m" -UNITREE_GO2_RESET_COLOR = "\033[0m" - - -class UnitreeGo2(Robot): - """Unitree Go2 robot implementation using ROS2 control interface. - - This class extends the base Robot class to provide specific functionality - for the Unitree Go2 quadruped robot using ROS2 for communication and control. - """ - - def __init__( - self, - video_provider=None, - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - skill_library: SkillLibrary = None, - robot_capabilities: List[RobotCapability] = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = False, - disable_video_stream: bool = False, - mock_connection: bool = False, - enable_perception: bool = True, - ): - """Initialize UnitreeGo2 robot with ROS control interface. - - Args: - video_provider: Provider for video streams - output_dir: Directory for output files - skill_library: Library of robot skills - robot_capabilities: List of robot capabilities - spatial_memory_collection: Collection name for spatial memory - new_memory: Whether to create new memory collection - disable_video_stream: Whether to disable video streaming - mock_connection: Whether to use mock connection for testing - enable_perception: Whether to enable perception streams and spatial memory - """ - # Create ROS control interface - ros_control = UnitreeROSControl( - node_name="unitree_go2", - video_provider=video_provider, - disable_video_stream=disable_video_stream, - mock_connection=mock_connection, - ) - - # Initialize skill library if not provided - if skill_library is None: - skill_library = MyUnitreeSkills() - - # Initialize base robot with connection interface - super().__init__( - connection_interface=ros_control, - output_dir=output_dir, - skill_library=skill_library, - capabilities=robot_capabilities - or [ - RobotCapability.LOCOMOTION, - RobotCapability.VISION, - RobotCapability.AUDIO, - ], - spatial_memory_collection=spatial_memory_collection, - new_memory=new_memory, - enable_perception=enable_perception, - ) - - if self.skill_library is not None: - for skill in self.skill_library: - if isinstance(skill, AbstractRobotSkill): - self.skill_library.create_instance(skill.__name__, robot=self) - if isinstance(self.skill_library, MyUnitreeSkills): - self.skill_library._robot = self - self.skill_library.init() - self.skill_library.initialize_skills() - - # Camera stuff - self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] - self.camera_pitch = np.deg2rad(0) # negative for downward pitch - self.camera_height = 0.44 # meters - - # Initialize UnitreeGo2-specific attributes - self.disposables = CompositeDisposable() - self.main_stream_obs = None - - # Initialize thread pool scheduler - self.optimal_thread_count = multiprocessing.cpu_count() - self.thread_pool_scheduler = ThreadPoolScheduler(self.optimal_thread_count // 2) - - # Initialize visual servoing if enabled - if not disable_video_stream: - self.video_stream_ros = self.get_video_stream(fps=8) - if enable_perception: - self.person_tracker = PersonTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - self.object_tracker = ObjectTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - person_tracking_stream = self.person_tracker.create_stream(self.video_stream_ros) - object_tracking_stream = self.object_tracker.create_stream(self.video_stream_ros) - - self.person_tracking_stream = person_tracking_stream - self.object_tracking_stream = object_tracking_stream - else: - # Video stream is available but perception tracking is disabled - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None - else: - # Video stream is disabled - self.video_stream_ros = None - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None - - # Initialize the local planner and create BEV visualization stream - # Note: These features require ROS-specific methods that may not be available on all connection interfaces - if hasattr(self.connection_interface, "topic_latest") and hasattr( - self.connection_interface, "transform_euler" - ): - self.local_planner = VFHPurePursuitPlanner( - get_costmap=self.connection_interface.topic_latest( - "/local_costmap/costmap", Costmap - ), - transform=self.connection_interface, - move_vel_control=self.connection_interface.move_vel_control, - robot_width=0.36, # Unitree Go2 width in meters - robot_length=0.6, # Unitree Go2 length in meters - max_linear_vel=0.5, - lookahead_distance=2.0, - visualization_size=500, # 500x500 pixel visualization - ) - - self.global_planner = AstarPlanner( - conservativism=20, # how close to obstacles robot is allowed to path plan - set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( - self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event - ), - get_costmap=self.connection_interface.topic_latest("map", Costmap), - get_robot_pos=lambda: self.connection_interface.transform_euler_pos("base_link"), - ) - - # Create the visualization stream at 5Hz - self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) - else: - self.local_planner = None - self.global_planner = None - self.local_planner_viz_stream = None - - def get_skills(self) -> Optional[SkillLibrary]: - return self.skill_library - - def get_pose(self) -> dict: - """ - Get the current pose (position and rotation) of the robot in the map frame. - - Returns: - Dictionary containing: - - position: Vector (x, y, z) - - rotation: Vector (roll, pitch, yaw) in radians - """ - position_tuple, orientation_tuple = self.connection_interface.get_pose_odom_transform() - position = Vector(position_tuple[0], position_tuple[1], position_tuple[2]) - rotation = Vector(orientation_tuple[0], orientation_tuple[1], orientation_tuple[2]) - return {"position": position, "rotation": rotation} diff --git a/build/lib/dimos/robot/unitree/unitree_ros_control.py b/build/lib/dimos/robot/unitree/unitree_ros_control.py deleted file mode 100644 index 56e83cb30f..0000000000 --- a/build/lib/dimos/robot/unitree/unitree_ros_control.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from go2_interfaces.msg import Go2State, IMU -from unitree_go.msg import WebRtcReq -from typing import Type -from sensor_msgs.msg import Image, CompressedImage, CameraInfo -from dimos.robot.ros_control import ROSControl, RobotMode -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.unitree.unitree_ros_control") - - -class UnitreeROSControl(ROSControl): - """Hardware interface for Unitree Go2 robot using ROS2""" - - # ROS Camera Topics - CAMERA_TOPICS = { - "raw": {"topic": "camera/image_raw", "type": Image}, - "compressed": {"topic": "camera/compressed", "type": CompressedImage}, - "info": {"topic": "camera/camera_info", "type": CameraInfo}, - } - # Hard coded ROS Message types and Topic names for Unitree Go2 - DEFAULT_STATE_MSG_TYPE = Go2State - DEFAULT_IMU_MSG_TYPE = IMU - DEFAULT_WEBRTC_MSG_TYPE = WebRtcReq - DEFAULT_STATE_TOPIC = "go2_states" - DEFAULT_IMU_TOPIC = "imu" - DEFAULT_WEBRTC_TOPIC = "webrtc_req" - DEFAULT_CMD_VEL_TOPIC = "cmd_vel_out" - DEFAULT_POSE_TOPIC = "pose_cmd" - DEFAULT_ODOM_TOPIC = "odom" - DEFAULT_COSTMAP_TOPIC = "local_costmap/costmap" - DEFAULT_MAX_LINEAR_VELOCITY = 1.0 - DEFAULT_MAX_ANGULAR_VELOCITY = 2.0 - - # Hard coded WebRTC API parameters for Unitree Go2 - DEFAULT_WEBRTC_API_TOPIC = "rt/api/sport/request" - - def __init__( - self, - node_name: str = "unitree_hardware_interface", - state_topic: str = None, - imu_topic: str = None, - webrtc_topic: str = None, - webrtc_api_topic: str = None, - move_vel_topic: str = None, - pose_topic: str = None, - odom_topic: str = None, - costmap_topic: str = None, - state_msg_type: Type = None, - imu_msg_type: Type = None, - webrtc_msg_type: Type = None, - max_linear_velocity: float = None, - max_angular_velocity: float = None, - use_raw: bool = False, - debug: bool = False, - disable_video_stream: bool = False, - mock_connection: bool = False, - ): - """ - Initialize Unitree ROS control interface with default values for Unitree Go2 - - Args: - node_name: Name for the ROS node - state_topic: ROS Topic name for robot state (defaults to DEFAULT_STATE_TOPIC) - imu_topic: ROS Topic name for IMU data (defaults to DEFAULT_IMU_TOPIC) - webrtc_topic: ROS Topic for WebRTC commands (defaults to DEFAULT_WEBRTC_TOPIC) - cmd_vel_topic: ROS Topic for direct movement velocity commands (defaults to DEFAULT_CMD_VEL_TOPIC) - pose_topic: ROS Topic for pose commands (defaults to DEFAULT_POSE_TOPIC) - odom_topic: ROS Topic for odometry data (defaults to DEFAULT_ODOM_TOPIC) - costmap_topic: ROS Topic for local costmap data (defaults to DEFAULT_COSTMAP_TOPIC) - state_msg_type: ROS Message type for state data (defaults to DEFAULT_STATE_MSG_TYPE) - imu_msg_type: ROS message type for IMU data (defaults to DEFAULT_IMU_MSG_TYPE) - webrtc_msg_type: ROS message type for webrtc data (defaults to DEFAULT_WEBRTC_MSG_TYPE) - max_linear_velocity: Maximum linear velocity in m/s (defaults to DEFAULT_MAX_LINEAR_VELOCITY) - max_angular_velocity: Maximum angular velocity in rad/s (defaults to DEFAULT_MAX_ANGULAR_VELOCITY) - use_raw: Whether to use raw camera topics (defaults to False) - debug: Whether to enable debug logging - disable_video_stream: Whether to run without video stream for testing. - mock_connection: Whether to run without active ActionClient servers for testing. - """ - - logger.info("Initializing Unitree ROS control interface") - # Select which camera topics to use - active_camera_topics = None - if not disable_video_stream: - active_camera_topics = {"main": self.CAMERA_TOPICS["raw" if use_raw else "compressed"]} - - # Use default values if not provided - state_topic = state_topic or self.DEFAULT_STATE_TOPIC - imu_topic = imu_topic or self.DEFAULT_IMU_TOPIC - webrtc_topic = webrtc_topic or self.DEFAULT_WEBRTC_TOPIC - move_vel_topic = move_vel_topic or self.DEFAULT_CMD_VEL_TOPIC - pose_topic = pose_topic or self.DEFAULT_POSE_TOPIC - odom_topic = odom_topic or self.DEFAULT_ODOM_TOPIC - costmap_topic = costmap_topic or self.DEFAULT_COSTMAP_TOPIC - webrtc_api_topic = webrtc_api_topic or self.DEFAULT_WEBRTC_API_TOPIC - state_msg_type = state_msg_type or self.DEFAULT_STATE_MSG_TYPE - imu_msg_type = imu_msg_type or self.DEFAULT_IMU_MSG_TYPE - webrtc_msg_type = webrtc_msg_type or self.DEFAULT_WEBRTC_MSG_TYPE - max_linear_velocity = max_linear_velocity or self.DEFAULT_MAX_LINEAR_VELOCITY - max_angular_velocity = max_angular_velocity or self.DEFAULT_MAX_ANGULAR_VELOCITY - - super().__init__( - node_name=node_name, - camera_topics=active_camera_topics, - mock_connection=mock_connection, - state_topic=state_topic, - imu_topic=imu_topic, - state_msg_type=state_msg_type, - imu_msg_type=imu_msg_type, - webrtc_msg_type=webrtc_msg_type, - webrtc_topic=webrtc_topic, - webrtc_api_topic=webrtc_api_topic, - move_vel_topic=move_vel_topic, - pose_topic=pose_topic, - odom_topic=odom_topic, - costmap_topic=costmap_topic, - max_linear_velocity=max_linear_velocity, - max_angular_velocity=max_angular_velocity, - debug=debug, - ) - - # Unitree-specific RobotMode State update conditons - def _update_mode(self, msg: Go2State): - """ - Implementation of abstract method to update robot mode - - Logic: - - If progress is 0 and mode is 1, then state is IDLE - - If progress is 1 OR mode is NOT equal to 1, then state is MOVING - """ - # Direct access to protected instance variables from the parent class - mode = msg.mode - progress = msg.progress - - if progress == 0 and mode == 1: - self._mode = RobotMode.IDLE - logger.debug("Robot mode set to IDLE (progress=0, mode=1)") - elif progress == 1 or mode != 1: - self._mode = RobotMode.MOVING - logger.debug(f"Robot mode set to MOVING (progress={progress}, mode={mode})") - else: - self._mode = RobotMode.UNKNOWN - logger.debug(f"Robot mode set to UNKNOWN (progress={progress}, mode={mode})") diff --git a/build/lib/dimos/robot/unitree/unitree_skills.py b/build/lib/dimos/robot/unitree/unitree_skills.py deleted file mode 100644 index 5029123ed1..0000000000 --- a/build/lib/dimos/robot/unitree/unitree_skills.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING, List, Optional, Tuple, Union -import time -from pydantic import Field - -if TYPE_CHECKING: - from dimos.robot.robot import Robot, MockRobot -else: - Robot = "Robot" - MockRobot = "MockRobot" - -from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary -from dimos.types.constants import Colors -from dimos.types.vector import Vector - -# Module-level constant for Unitree ROS control definitions -UNITREE_ROS_CONTROLS: List[Tuple[str, int, str]] = [ - ("Damp", 1001, "Lowers the robot to the ground fully."), - ( - "BalanceStand", - 1002, - "Activates a mode that maintains the robot in a balanced standing position.", - ), - ( - "StandUp", - 1004, - "Commands the robot to transition from a sitting or prone position to a standing posture.", - ), - ( - "StandDown", - 1005, - "Instructs the robot to move from a standing position to a sitting or prone posture.", - ), - ( - "RecoveryStand", - 1006, - "Recovers the robot to a state from which it can take more commands. Useful to run after multiple dynamic commands like front flips.", - ), - # ( - # "Euler", - # 1007, - # "Adjusts the robot's orientation using Euler angles, providing precise control over its rotation.", - # ), - # ("Move", 1008, "Move the robot using velocity commands."), # Intentionally omitted - ("Sit", 1009, "Commands the robot to sit down from a standing or moving stance."), - # ( - # "RiseSit", - # 1010, - # "Commands the robot to rise back to a standing position from a sitting posture.", - # ), - # ( - # "SwitchGait", - # 1011, - # "Switches the robot's walking pattern or style dynamically, suitable for different terrains or speeds.", - # ), - # ("Trigger", 1012, "Triggers a specific action or custom routine programmed into the robot."), - # ( - # "BodyHeight", - # 1013, - # "Adjusts the height of the robot's body from the ground, useful for navigating various obstacles.", - # ), - # ( - # "FootRaiseHeight", - # 1014, - # "Controls how high the robot lifts its feet during movement, which can be adjusted for different surfaces.", - # ), - ( - "SpeedLevel", - 1015, - "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", - ), - ( - "ShakeHand", - 1016, - "Performs a greeting action, which could involve a wave or other friendly gesture.", - ), - ("Stretch", 1017, "Engages the robot in a stretching routine."), - # ( - # "TrajectoryFollow", - # 1018, - # "Directs the robot to follow a predefined trajectory, which could involve complex paths or maneuvers.", - # ), - # ( - # "ContinuousGait", - # 1019, - # "Enables a mode for continuous walking or running, ideal for long-distance travel.", - # ), - ("Content", 1020, "To display or trigger when the robot is happy."), - ("Wallow", 1021, "The robot falls onto its back and rolls around."), - ( - "Dance1", - 1022, - "Performs a predefined dance routine 1, programmed for entertainment or demonstration.", - ), - ("Dance2", 1023, "Performs another variant of a predefined dance routine 2."), - # ("GetBodyHeight", 1024, "Retrieves the current height of the robot's body from the ground."), - # ( - # "GetFootRaiseHeight", - # 1025, - # "Retrieves the current height at which the robot's feet are being raised during movement.", - # ), - # ("GetSpeedLevel", 1026, "Returns the current speed level at which the robot is operating."), - # ( - # "SwitchJoystick", - # 1027, - # "Toggles the control mode to joystick input, allowing for manual direction of the robot's movements.", - # ), - ( - "Pose", - 1028, - "Directs the robot to take a specific pose or stance, which could be used for tasks or performances.", - ), - ( - "Scrape", - 1029, - "Robot falls to its hind legs and makes scraping motions with its front legs.", - ), - ("FrontFlip", 1030, "Executes a front flip, a complex and dynamic maneuver."), - ("FrontJump", 1031, "Commands the robot to perform a forward jump."), - ( - "FrontPounce", - 1032, - "Initiates a pouncing movement forward, mimicking animal-like pouncing behavior.", - ), - # ("WiggleHips", 1033, "Causes the robot to wiggle its hips."), - # ( - # "GetState", - # 1034, - # "Retrieves the current operational state of the robot, including status reports or diagnostic information.", - # ), - # ( - # "EconomicGait", - # 1035, - # "Engages a more energy-efficient walking or running mode to conserve battery life.", - # ), - # ("FingerHeart", 1036, "Performs a finger heart gesture while on its hind legs."), - # ( - # "Handstand", - # 1301, - # "Commands the robot to perform a handstand, demonstrating balance and control.", - # ), - # ( - # "CrossStep", - # 1302, - # "Engages the robot in a cross-stepping routine, useful for complex locomotion or dance moves.", - # ), - # ( - # "OnesidedStep", - # 1303, - # "Commands the robot to perform a stepping motion that predominantly uses one side.", - # ), - # ( - # "Bound", - # 1304, - # "Initiates a bounding motion, similar to a light, repetitive hopping or leaping.", - # ), - # ( - # "LeadFollow", - # 1045, - # "Engages follow-the-leader behavior, where the robot follows a designated leader or follows a signal.", - # ), - # ("LeftFlip", 1042, "Executes a flip towards the left side."), - # ("RightFlip", 1043, "Performs a flip towards the right side."), - # ("Backflip", 1044, "Executes a backflip, a complex and dynamic maneuver."), -] - -# region MyUnitreeSkills - - -class MyUnitreeSkills(SkillLibrary): - """My Unitree Skills.""" - - _robot: Optional[Robot] = None - - @classmethod - def register_skills(cls, skill_classes: Union["AbstractSkill", list["AbstractSkill"]]): - """Add multiple skill classes as class attributes. - - Args: - skill_classes: List of skill classes to add - """ - if isinstance(skill_classes, list): - for skill_class in skill_classes: - setattr(cls, skill_class.__name__, skill_class) - else: - setattr(cls, skill_classes.__name__, skill_classes) - - def __init__(self, robot: Optional[Robot] = None): - super().__init__() - self._robot: Robot = None - - # Add dynamic skills to this class - self.register_skills(self.create_skills_live()) - - if robot is not None: - self._robot = robot - self.initialize_skills() - - def initialize_skills(self): - # Create the skills and add them to the list of skills - self.register_skills(self.create_skills_live()) - - # Provide the robot instance to each skill - for skill_class in self: - print( - f"{Colors.GREEN_PRINT_COLOR}Creating instance for skill: {skill_class}{Colors.RESET_COLOR}" - ) - self.create_instance(skill_class.__name__, robot=self._robot) - - # Refresh the class skills - self.refresh_class_skills() - - def create_skills_live(self) -> List[AbstractRobotSkill]: - # ================================================ - # Procedurally created skills - # ================================================ - class BaseUnitreeSkill(AbstractRobotSkill): - """Base skill for dynamic skill creation.""" - - def __call__(self): - string = f"{Colors.GREEN_PRINT_COLOR}This is a base skill, created for the specific skill: {self._app_id}{Colors.RESET_COLOR}" - print(string) - super().__call__() - if self._app_id is None: - raise RuntimeError( - f"{Colors.RED_PRINT_COLOR}" - f"No App ID provided to {self.__class__.__name__} Skill" - f"{Colors.RESET_COLOR}" - ) - else: - self._robot.webrtc_req(api_id=self._app_id) - string = f"{Colors.GREEN_PRINT_COLOR}{self.__class__.__name__} was successful: id={self._app_id}{Colors.RESET_COLOR}" - print(string) - return string - - skills_classes = [] - for name, app_id, description in UNITREE_ROS_CONTROLS: - skill_class = type( - name, # Name of the class - (BaseUnitreeSkill,), # Base classes - {"__doc__": description, "_app_id": app_id}, - ) - skills_classes.append(skill_class) - - return skills_classes - - # region Class-based Skills - - class Move(AbstractRobotSkill): - """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" - - x: float = Field(..., description="Forward velocity (m/s).") - y: float = Field(default=0.0, description="Left/right velocity (m/s)") - yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field(default=0.0, description="How long to move (seconds).") - - def __call__(self): - super().__call__() - return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) - - class Reverse(AbstractRobotSkill): - """Reverse the robot using direct velocity commands. Determine duration required based on user distance instructions.""" - - x: float = Field(..., description="Backward velocity (m/s). Positive values move backward.") - y: float = Field(default=0.0, description="Left/right velocity (m/s)") - yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field(default=0.0, description="How long to move (seconds).") - - def __call__(self): - super().__call__() - # Use move with negative x for backward movement - return self._robot.move(Vector(-self.x, self.y, self.yaw), duration=self.duration) - - class SpinLeft(AbstractRobotSkill): - """Spin the robot left using degree commands.""" - - degrees: float = Field(..., description="Distance to spin left in degrees") - - def __call__(self): - super().__call__() - return self._robot.spin(degrees=self.degrees) # Spinning left is positive degrees - - class SpinRight(AbstractRobotSkill): - """Spin the robot right using degree commands.""" - - degrees: float = Field(..., description="Distance to spin right in degrees") - - def __call__(self): - super().__call__() - return self._robot.spin(degrees=-self.degrees) # Spinning right is negative degrees - - class Wait(AbstractSkill): - """Wait for a specified amount of time.""" - - seconds: float = Field(..., description="Seconds to wait") - - def __call__(self): - time.sleep(self.seconds) - return f"Wait completed with length={self.seconds}s" diff --git a/build/lib/dimos/robot/unitree_webrtc/__init__.py b/build/lib/dimos/robot/unitree_webrtc/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/robot/unitree_webrtc/connection.py b/build/lib/dimos/robot/unitree_webrtc/connection.py deleted file mode 100644 index 86fe5f6a85..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/connection.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import functools -import threading -import time -from typing import Literal, TypeAlias - -import numpy as np -from aiortc import MediaStreamTrack -from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR -from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] - Go2WebRTCConnection, - WebRTCConnectionMethod, -) -from reactivex import operators as ops -from reactivex.observable import Observable -from reactivex.subject import Subject - -from dimos.core import In, Module, Out, rpc -from dimos.msgs.sensor_msgs import Image -from dimos.robot.connection_interface import ConnectionInterface -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.pose import Pose -from dimos.types.vector import Vector -from dimos.utils.reactive import backpressure, callback_to_observable - -VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] - - -class WebRTCRobot(ConnectionInterface): - def __init__(self, ip: str, mode: str = "ai"): - self.ip = ip - self.mode = mode - self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) - self.connect() - - def connect(self): - self.loop = asyncio.new_event_loop() - self.task = None - self.connected_event = asyncio.Event() - self.connection_ready = threading.Event() - - async def async_connect(): - await self.conn.connect() - await self.conn.datachannel.disableTrafficSaving(True) - - self.conn.datachannel.set_decoder(decoder_type="native") - - await self.conn.datachannel.pub_sub.publish_request_new( - RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} - ) - - self.connected_event.set() - self.connection_ready.set() - - while True: - await asyncio.sleep(1) - - def start_background_loop(): - asyncio.set_event_loop(self.loop) - self.task = self.loop.create_task(async_connect()) - self.loop.run_forever() - - self.loop = asyncio.new_event_loop() - self.thread = threading.Thread(target=start_background_loop, daemon=True) - self.thread.start() - self.connection_ready.wait() - - def move(self, velocity: Vector, duration: float = 0.0) -> bool: - """Send movement command to the robot using velocity commands. - - Args: - velocity: Velocity vector [x, y, yaw] where: - x: Forward/backward velocity (m/s) - y: Left/right velocity (m/s) - yaw: Rotational velocity (rad/s) - duration: How long to move (seconds). If 0, command is continuous - - Returns: - bool: True if command was sent successfully - """ - x, y, yaw = velocity.x, velocity.y, velocity.z - - # WebRTC coordinate mapping: - # x - Positive right, negative left - # y - positive forward, negative backwards - # yaw - Positive rotate right, negative rotate left - async def async_move(): - self.conn.datachannel.pub_sub.publish_without_callback( - RTC_TOPIC["WIRELESS_CONTROLLER"], - data={"lx": y, "ly": x, "rx": -yaw, "ry": 0}, - ) - - async def async_move_duration(): - """Send movement commands continuously for the specified duration.""" - start_time = time.time() - sleep_time = 0.01 - - while time.time() - start_time < duration: - await async_move() - await asyncio.sleep(sleep_time) - - try: - if duration > 0: - # Send continuous move commands for the duration - future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) - future.result() - # Stop after duration - self.stop() - else: - # Single command for continuous movement - future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) - future.result() - return True - except Exception as e: - print(f"Failed to send movement command: {e}") - return False - - # Generic conversion of unitree subscription to Subject (used for all subs) - def unitree_sub_stream(self, topic_name: str): - def subscribe_in_thread(cb): - # Run the subscription in the background thread that has the event loop - def run_subscription(): - self.conn.datachannel.pub_sub.subscribe(topic_name, cb) - - # Use call_soon_threadsafe to run in the background thread - self.loop.call_soon_threadsafe(run_subscription) - - def unsubscribe_in_thread(cb): - # Run the unsubscription in the background thread that has the event loop - def run_unsubscription(): - self.conn.datachannel.pub_sub.unsubscribe(topic_name) - - # Use call_soon_threadsafe to run in the background thread - self.loop.call_soon_threadsafe(run_unsubscription) - - return callback_to_observable( - start=subscribe_in_thread, - stop=unsubscribe_in_thread, - ) - - # Generic sync API call (we jump into the client thread) - def publish_request(self, topic: str, data: dict): - future = asyncio.run_coroutine_threadsafe( - self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop - ) - return future.result() - - @functools.cache - def raw_lidar_stream(self) -> Subject[LidarMessage]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) - - @functools.cache - def raw_odom_stream(self) -> Subject[Pose]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) - - @functools.cache - def lidar_stream(self) -> Subject[LidarMessage]: - return backpressure( - self.raw_lidar_stream().pipe( - ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame)) - ) - ) - - @functools.cache - def odom_stream(self) -> Subject[Pose]: - return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) - - @functools.cache - def lowstate_stream(self) -> Subject[LowStateMsg]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) - - def standup_ai(self): - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) - - def standup_normal(self): - self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) - time.sleep(0.5) - self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) - return True - - @rpc - def standup(self): - if self.mode == "ai": - return self.standup_ai() - else: - return self.standup_normal() - - @rpc - def liedown(self): - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) - - async def handstand(self): - return self.publish_request( - RTC_TOPIC["SPORT_MOD"], - {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, - ) - - @rpc - def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: - return self.publish_request( - RTC_TOPIC["VUI"], - { - "api_id": 1001, - "parameter": { - "color": color, - "time": colortime, - }, - }, - ) - - @functools.lru_cache(maxsize=None) - def video_stream(self) -> Observable[VideoMessage]: - subject: Subject[VideoMessage] = Subject() - stop_event = threading.Event() - - async def accept_track(track: MediaStreamTrack) -> VideoMessage: - while True: - if stop_event.is_set(): - return - frame = await track.recv() - subject.on_next(Image.from_numpy(frame.to_ndarray(format="bgr24"))) - - self.conn.video.add_track_callback(accept_track) - - # Run the video channel switching in the background thread - def switch_video_channel(): - self.conn.video.switchVideoChannel(True) - - self.loop.call_soon_threadsafe(switch_video_channel) - - def stop(cb): - stop_event.set() # Signal the loop to stop - self.conn.video.track_callbacks.remove(accept_track) - - # Run the video channel switching off in the background thread - def switch_video_channel_off(): - self.conn.video.switchVideoChannel(False) - - self.loop.call_soon_threadsafe(switch_video_channel_off) - - return subject.pipe(ops.finally_action(stop)) - - def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: - """Get the video stream from the robot's camera. - - Implements the AbstractRobot interface method. - - Args: - fps: Frames per second. This parameter is included for API compatibility, - but doesn't affect the actual frame rate which is determined by the camera. - - Returns: - Observable: An observable stream of video frames or None if video is not available. - """ - try: - print("Starting WebRTC video stream...") - stream = self.video_stream() - if stream is None: - print("Warning: Video stream is not available") - return stream - - except Exception as e: - print(f"Error getting video stream: {e}") - return None - - def stop(self) -> bool: - """Stop the robot's movement. - - Returns: - bool: True if stop command was sent successfully - """ - return self.move(Vector(0.0, 0.0, 0.0)) - - def disconnect(self) -> None: - """Disconnect from the robot and clean up resources.""" - if hasattr(self, "task") and self.task: - self.task.cancel() - if hasattr(self, "conn"): - - async def async_disconnect(): - try: - await self.conn.disconnect() - except: - pass - - if hasattr(self, "loop") and self.loop.is_running(): - asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) - - if hasattr(self, "loop") and self.loop.is_running(): - self.loop.call_soon_threadsafe(self.loop.stop) - - if hasattr(self, "thread") and self.thread.is_alive(): - self.thread.join(timeout=2.0) diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/__init__.py b/build/lib/dimos/robot/unitree_webrtc/testing/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/helpers.py b/build/lib/dimos/robot/unitree_webrtc/testing/helpers.py deleted file mode 100644 index 8d01cb76cc..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/testing/helpers.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import open3d as o3d -from typing import Callable, Union, Any, Protocol, Iterable -from reactivex.observable import Observable - -color1 = [1, 0.706, 0] -color2 = [0, 0.651, 0.929] -color3 = [0.8, 0.196, 0.6] -color4 = [0.235, 0.702, 0.443] -color = [color1, color2, color3, color4] - - -# benchmarking function can return int, which will be applied to the time. -# -# (in case there is some preparation within the fuction and this time needs to be subtracted -# from the benchmark target) -def benchmark(calls: int, targetf: Callable[[], Union[int, None]]) -> float: - start = time.time() - timemod = 0 - for _ in range(calls): - res = targetf() - if res is not None: - timemod += res - end = time.time() - return (end - start + timemod) * 1000 / calls - - -O3dDrawable = ( - o3d.geometry.Geometry - | o3d.geometry.LineSet - | o3d.geometry.TriangleMesh - | o3d.geometry.PointCloud -) - - -class ReturnsDrawable(Protocol): - def o3d_geometry(self) -> O3dDrawable: ... - - -Drawable = O3dDrawable | ReturnsDrawable - - -def show3d(*components: Iterable[Drawable], title: str = "open3d") -> o3d.visualization.Visualizer: - vis = o3d.visualization.Visualizer() - vis.create_window(window_name=title) - for component in components: - # our custom drawable components should return an open3d geometry - if hasattr(component, "o3d_geometry"): - vis.add_geometry(component.o3d_geometry) - else: - vis.add_geometry(component) - - opt = vis.get_render_option() - opt.background_color = [0, 0, 0] - opt.point_size = 10 - vis.poll_events() - vis.update_renderer() - return vis - - -def multivis(*vis: o3d.visualization.Visualizer) -> None: - while True: - for v in vis: - v.poll_events() - v.update_renderer() - - -def show3d_stream( - geometry_observable: Observable[Any], - clearframe: bool = False, - title: str = "open3d", -) -> o3d.visualization.Visualizer: - """ - Visualize a stream of geometries using Open3D. The first geometry initializes the visualizer. - Subsequent geometries update the visualizer. If no new geometry, just poll events. - geometry_observable: Observable of objects with .o3d_geometry or Open3D geometry - """ - import threading - import queue - import time - from typing import Any - - q: queue.Queue[Any] = queue.Queue() - stop_flag = threading.Event() - - def on_next(geometry: O3dDrawable) -> None: - q.put(geometry) - - def on_error(e: Exception) -> None: - print(f"Visualization error: {e}") - stop_flag.set() - - def on_completed() -> None: - print("Geometry stream completed") - stop_flag.set() - - subscription = geometry_observable.subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed, - ) - - def geom(geometry: Drawable) -> O3dDrawable: - """Extracts the Open3D geometry from the given object.""" - return geometry.o3d_geometry if hasattr(geometry, "o3d_geometry") else geometry - - # Wait for the first geometry - first_geometry = None - while first_geometry is None and not stop_flag.is_set(): - try: - first_geometry = q.get(timeout=100) - except queue.Empty: - print("No geometry received to visualize.") - return - - scene_geometries = [] - first_geom_obj = geom(first_geometry) - - scene_geometries.append(first_geom_obj) - - vis = show3d(first_geom_obj, title=title) - - try: - while not stop_flag.is_set(): - try: - geometry = q.get_nowait() - geom_obj = geom(geometry) - if clearframe: - scene_geometries = [] - vis.clear_geometries() - - vis.add_geometry(geom_obj) - scene_geometries.append(geom_obj) - else: - if geom_obj in scene_geometries: - print("updating existing geometry") - vis.update_geometry(geom_obj) - else: - print("new geometry") - vis.add_geometry(geom_obj) - scene_geometries.append(geom_obj) - except queue.Empty: - pass - vis.poll_events() - vis.update_renderer() - time.sleep(0.1) - - except KeyboardInterrupt: - print("closing visualizer...") - stop_flag.set() - vis.destroy_window() - subscription.dispose() - - return vis diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/mock.py b/build/lib/dimos/robot/unitree_webrtc/testing/mock.py deleted file mode 100644 index f929d33c5c..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/testing/mock.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import pickle -import glob -from typing import Union, Iterator, cast, overload -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage, RawLidarMsg - -from reactivex import operators as ops -from reactivex import interval, from_iterable -from reactivex.observable import Observable - - -class Mock: - def __init__(self, root="office", autocast: bool = True): - current_dir = os.path.dirname(os.path.abspath(__file__)) - self.root = os.path.join(current_dir, f"mockdata/{root}") - self.autocast = autocast - self.cnt = 0 - - @overload - def load(self, name: Union[int, str], /) -> LidarMessage: ... - @overload - def load(self, *names: Union[int, str]) -> list[LidarMessage]: ... - - def load(self, *names: Union[int, str]) -> Union[LidarMessage, list[LidarMessage]]: - if len(names) == 1: - return self.load_one(names[0]) - return list(map(lambda name: self.load_one(name), names)) - - def load_one(self, name: Union[int, str]) -> LidarMessage: - if isinstance(name, int): - file_name = f"/lidar_data_{name:03d}.pickle" - else: - file_name = f"/{name}.pickle" - - full_path = self.root + file_name - with open(full_path, "rb") as f: - return LidarMessage.from_msg(cast(RawLidarMsg, pickle.load(f))) - - def iterate(self) -> Iterator[LidarMessage]: - pattern = os.path.join(self.root, "lidar_data_*.pickle") - print("loading data", pattern) - for file_path in sorted(glob.glob(pattern)): - basename = os.path.basename(file_path) - filename = os.path.splitext(basename)[0] - yield self.load_one(filename) - - def stream(self, rate_hz=10.0): - sleep_time = 1.0 / rate_hz - - return from_iterable(self.iterate()).pipe( - ops.zip(interval(sleep_time)), - ops.map(lambda x: x[0] if isinstance(x, tuple) else x), - ) - - def save_stream(self, observable: Observable[LidarMessage]): - return observable.pipe(ops.map(lambda frame: self.save_one(frame))) - - def save(self, *frames): - [self.save_one(frame) for frame in frames] - return self.cnt - - def save_one(self, frame): - file_name = f"/lidar_data_{self.cnt:03d}.pickle" - full_path = self.root + file_name - - self.cnt += 1 - - if os.path.isfile(full_path): - raise Exception(f"file {full_path} exists") - - if frame.__class__ == LidarMessage: - frame = frame.raw_msg - - with open(full_path, "wb") as f: - pickle.dump(frame, f) - - return self.cnt diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/multimock.py b/build/lib/dimos/robot/unitree_webrtc/testing/multimock.py deleted file mode 100644 index cfc2688129..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/testing/multimock.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multimock – lightweight persistence & replay helper built on RxPy. - -A directory of pickle files acts as a tiny append-only log of (timestamp, data) -pairs. You can: - • save() / consume(): append new frames - • iterate(): read them back lazily - • interval_stream(): emit at a fixed cadence - • stream(): replay with original timing (optionally scaled) - -The implementation keeps memory usage constant by relying on reactive -operators instead of pre-materialising lists. Timing is reproduced via -`rx.timer`, and drift is avoided with `concat_map`. -""" - -from __future__ import annotations - -import glob -import os -import pickle -import time -from typing import Any, Generic, Iterator, List, Tuple, TypeVar, Union, Optional -from reactivex.scheduler import ThreadPoolScheduler - -from reactivex import from_iterable, interval, operators as ops -from reactivex.observable import Observable -from dimos.utils.threadpool import get_scheduler -from dimos.robot.unitree_webrtc.type.timeseries import TEvent, Timeseries - -T = TypeVar("T") - - -class Multimock(Generic[T], Timeseries[TEvent[T]]): - """Persist frames as pickle files and replay them with RxPy.""" - - def __init__(self, root: str = "office", file_prefix: str = "msg") -> None: - current_dir = os.path.dirname(os.path.abspath(__file__)) - self.root = os.path.join(current_dir, f"multimockdata/{root}") - self.file_prefix = file_prefix - - os.makedirs(self.root, exist_ok=True) - self.cnt: int = 0 - - def save(self, *frames: Any) -> int: - """Persist one or more frames; returns the new counter value.""" - for frame in frames: - self.save_one(frame) - return self.cnt - - def save_one(self, frame: Any) -> int: - """Persist a single frame and return the running count.""" - file_name = f"/{self.file_prefix}_{self.cnt:03d}.pickle" - full_path = os.path.join(self.root, file_name.lstrip("/")) - self.cnt += 1 - - if os.path.isfile(full_path): - raise FileExistsError(f"file {full_path} exists") - - # Optional convinience magic to extract raw messages from advanced types - # trying to deprecate for now - # if hasattr(frame, "raw_msg"): - # frame = frame.raw_msg # type: ignore[attr-defined] - - with open(full_path, "wb") as f: - pickle.dump([time.time(), frame], f) - - return self.cnt - - def load(self, *names: Union[int, str]) -> List[Tuple[float, T]]: - """Load multiple items by name or index.""" - return list(map(self.load_one, names)) - - def load_one(self, name: Union[int, str]) -> TEvent[T]: - """Load a single item by name or index.""" - if isinstance(name, int): - file_name = f"/{self.file_prefix}_{name:03d}.pickle" - else: - file_name = f"/{name}.pickle" - - full_path = os.path.join(self.root, file_name.lstrip("/")) - - with open(full_path, "rb") as f: - timestamp, data = pickle.load(f) - - return TEvent(timestamp, data) - - def iterate(self) -> Iterator[TEvent[T]]: - """Yield all persisted TEvent(timestamp, data) pairs lazily in order.""" - pattern = os.path.join(self.root, f"{self.file_prefix}_*.pickle") - for file_path in sorted(glob.glob(pattern)): - with open(file_path, "rb") as f: - timestamp, data = pickle.load(f) - yield TEvent(timestamp, data) - - def list(self) -> List[TEvent[T]]: - return list(self.iterate()) - - def interval_stream(self, rate_hz: float = 10.0) -> Observable[T]: - """Emit frames at a fixed rate, ignoring recorded timing.""" - sleep_time = 1.0 / rate_hz - return from_iterable(self.iterate()).pipe( - ops.zip(interval(sleep_time)), - ops.map(lambda pair: pair[1]), # keep only the frame - ) - - def stream( - self, - replay_speed: float = 1.0, - scheduler: Optional[ThreadPoolScheduler] = None, - ) -> Observable[T]: - def _generator(): - prev_ts: float | None = None - for event in self.iterate(): - if prev_ts is not None: - delay = (event.ts - prev_ts).total_seconds() / replay_speed - time.sleep(delay) - prev_ts = event.ts - yield event.data - - return from_iterable(_generator(), scheduler=scheduler or get_scheduler()) - - def consume(self, observable: Observable[Any]) -> Observable[int]: - """Side-effect: save every frame that passes through.""" - return observable.pipe(ops.map(self.save_one)) - - def __iter__(self) -> Iterator[TEvent[T]]: - """Allow iteration over the Multimock instance to yield TEvent(timestamp, data) pairs.""" - return self.iterate() diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/test_mock.py b/build/lib/dimos/robot/unitree_webrtc/testing/test_mock.py deleted file mode 100644 index 4852392943..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/testing/test_mock.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import pytest -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.testing.mock import Mock - - -@pytest.mark.needsdata -def test_mock_load_cast(): - mock = Mock("test") - - # Load a frame with type casting - frame = mock.load("a") - - # Verify it's a LidarMessage object - assert frame.__class__.__name__ == "LidarMessage" - assert hasattr(frame, "timestamp") - assert hasattr(frame, "origin") - assert hasattr(frame, "resolution") - assert hasattr(frame, "pointcloud") - - # Verify pointcloud has points - assert frame.pointcloud.has_points() - assert len(frame.pointcloud.points) > 0 - - -@pytest.mark.needsdata -def test_mock_iterate(): - """Test the iterate method of the Mock class.""" - mock = Mock("office") - - # Test iterate method - frames = list(mock.iterate()) - assert len(frames) > 0 - for frame in frames: - assert isinstance(frame, LidarMessage) - assert frame.pointcloud.has_points() - - -@pytest.mark.needsdata -def test_mock_stream(): - frames = [] - sub1 = Mock("office").stream(rate_hz=30.0).subscribe(on_next=frames.append) - time.sleep(0.1) - sub1.dispose() - - assert len(frames) >= 2 - assert isinstance(frames[0], LidarMessage) diff --git a/build/lib/dimos/robot/unitree_webrtc/testing/test_multimock.py b/build/lib/dimos/robot/unitree_webrtc/testing/test_multimock.py deleted file mode 100644 index 1d64cbd3a0..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/testing/test_multimock.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import pytest - -from reactivex import operators as ops - -from dimos.utils.reactive import backpressure -from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.robot.unitree_webrtc.type.timeseries import to_datetime -from dimos.robot.unitree_webrtc.testing.multimock import Multimock - - -@pytest.mark.needsdata -@pytest.mark.vis -def test_multimock_stream(): - backpressure(Multimock("athens_odom").stream().pipe(ops.map(Odometry.from_msg))).subscribe( - lambda x: print(x) - ) - map = Map() - - def lidarmsg(msg): - frame = LidarMessage.from_msg(msg) - map.add_frame(frame) - return [map, map.costmap.smudge()] - - mapstream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) - show3d_stream(mapstream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() - time.sleep(5) - - -@pytest.mark.needsdata -def test_clock_mismatch(): - for odometry_raw in Multimock("athens_odom").iterate(): - print( - odometry_raw.ts - to_datetime(odometry_raw.data["data"]["header"]["stamp"]), - odometry_raw.data["data"]["header"]["stamp"], - ) - - -@pytest.mark.needsdata -def test_odom_stream(): - for odometry_raw in Multimock("athens_odom").iterate(): - print(Odometry.from_msg(odometry_raw.data)) - - -@pytest.mark.needsdata -def test_lidar_stream(): - for lidar_raw in Multimock("athens_lidar").iterate(): - lidarmsg = LidarMessage.from_msg(lidar_raw.data) - print(lidarmsg) - print(lidar_raw) - - -@pytest.mark.needsdata -def test_multimock_timeseries(): - odom = Odometry.from_msg(Multimock("athens_odom").load_one(1).data) - lidar_raw = Multimock("athens_lidar").load_one(1).data - lidar = LidarMessage.from_msg(lidar_raw) - map = Map() - map.add_frame(lidar) - print(odom) - print(lidar) - print(lidar_raw) - print(map.costmap) - - -@pytest.mark.needsdata -def test_origin_changes(): - for lidar_raw in Multimock("athens_lidar").iterate(): - print(LidarMessage.from_msg(lidar_raw.data).origin) - - -@pytest.mark.needsdata -@pytest.mark.vis -def test_webui_multistream(): - websocket_vis = WebsocketVis() - websocket_vis.start() - - odom_stream = Multimock("athens_odom").stream().pipe(ops.map(Odometry.from_msg)) - lidar_stream = backpressure( - Multimock("athens_lidar").stream().pipe(ops.map(LidarMessage.from_msg)) - ) - - map = Map() - map_stream = map.consume(lidar_stream) - - costmap_stream = map_stream.pipe( - ops.map(lambda x: ["costmap", map.costmap.smudge(preserve_unknown=False)]) - ) - - websocket_vis.connect(costmap_stream) - websocket_vis.connect(odom_stream.pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) - - show3d_stream(lidar_stream, clearframe=True).run() diff --git a/build/lib/dimos/robot/unitree_webrtc/type/__init__.py b/build/lib/dimos/robot/unitree_webrtc/type/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/robot/unitree_webrtc/type/lidar.py b/build/lib/dimos/robot/unitree_webrtc/type/lidar.py deleted file mode 100644 index f45cb8dfe7..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/lidar.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from copy import copy -from typing import List, Optional, TypedDict - -import numpy as np -import open3d as o3d - -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.robot.unitree_webrtc.type.timeseries import to_human_readable -from dimos.types.costmap import Costmap, pointcloud_to_costmap -from dimos.types.vector import Vector - - -class RawLidarPoints(TypedDict): - points: np.ndarray # Shape (N, 3) array of 3D points [x, y, z] - - -class RawLidarData(TypedDict): - """Data portion of the LIDAR message""" - - frame_id: str - origin: List[float] - resolution: float - src_size: int - stamp: float - width: List[int] - data: RawLidarPoints - - -class RawLidarMsg(TypedDict): - """Static type definition for raw LIDAR message""" - - type: str - topic: str - data: RawLidarData - - -class LidarMessage(PointCloud2): - resolution: float # we lose resolution when encoding PointCloud2 - origin: Vector3 - raw_msg: Optional[RawLidarMsg] - _costmap: Optional[Costmap] = None - - def __init__(self, **kwargs): - super().__init__( - pointcloud=kwargs.get("pointcloud"), - ts=kwargs.get("ts"), - frame_id="lidar", - ) - - self.origin = kwargs.get("origin") - self.resolution = kwargs.get("resolution") - - @classmethod - def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg) -> "LidarMessage": - data = raw_message["data"] - points = data["data"]["points"] - pointcloud = o3d.geometry.PointCloud() - pointcloud.points = o3d.utility.Vector3dVector(points) - - origin = Vector3(data["origin"]) - # webrtc decoding via native decompression doesn't require us - # to shift the pointcloud by it's origin - # - # pointcloud.translate((origin / 2).to_tuple()) - - return cls( - origin=origin, - resolution=data["resolution"], - pointcloud=pointcloud, - ts=data["stamp"], - raw_msg=raw_message, - ) - - def to_pointcloud2(self) -> PointCloud2: - """Convert to PointCloud2 message format.""" - return PointCloud2( - pointcloud=self.pointcloud, - frame_id=self.frame_id, - ts=self.ts, - ) - - def __repr__(self): - return f"LidarMessage(ts={to_human_readable(self.ts)}, origin={self.origin}, resolution={self.resolution}, {self.pointcloud})" - - def __iadd__(self, other: "LidarMessage") -> "LidarMessage": - self.pointcloud += other.pointcloud - return self - - def __add__(self, other: "LidarMessage") -> "LidarMessage": - # Determine which message is more recent - if self.ts >= other.ts: - ts = self.ts - origin = self.origin - resolution = self.resolution - else: - ts = other.ts - origin = other.origin - resolution = other.resolution - - # Return a new LidarMessage with combined data - return LidarMessage( - ts=ts, - origin=origin, - resolution=resolution, - pointcloud=self.pointcloud + other.pointcloud, - ).estimate_normals() - - @property - def o3d_geometry(self): - return self.pointcloud - - def costmap(self, voxel_size: float = 0.2) -> Costmap: - if not self._costmap: - down_sampled_pointcloud = self.pointcloud.voxel_down_sample(voxel_size=voxel_size) - inflate_radius_m = 1.0 * voxel_size if voxel_size > self.resolution else 0.0 - grid, origin_xy = pointcloud_to_costmap( - down_sampled_pointcloud, - resolution=self.resolution, - inflate_radius_m=inflate_radius_m, - ) - self._costmap = Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.resolution) - - return self._costmap diff --git a/build/lib/dimos/robot/unitree_webrtc/type/lowstate.py b/build/lib/dimos/robot/unitree_webrtc/type/lowstate.py deleted file mode 100644 index 9c4d8edee5..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/lowstate.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TypedDict, List, Literal - -raw_odom_msg_sample = { - "type": "msg", - "topic": "rt/lf/lowstate", - "data": { - "imu_state": {"rpy": [0.008086, -0.007515, 2.981771]}, - "motor_state": [ - {"q": 0.098092, "temperature": 40, "lost": 0, "reserve": [0, 674]}, - {"q": 0.757921, "temperature": 32, "lost": 0, "reserve": [0, 674]}, - {"q": -1.490911, "temperature": 38, "lost": 6, "reserve": [0, 674]}, - {"q": -0.072477, "temperature": 42, "lost": 0, "reserve": [0, 674]}, - {"q": 1.020276, "temperature": 32, "lost": 5, "reserve": [0, 674]}, - {"q": -2.007172, "temperature": 38, "lost": 5, "reserve": [0, 674]}, - {"q": 0.071382, "temperature": 50, "lost": 5, "reserve": [0, 674]}, - {"q": 0.963379, "temperature": 36, "lost": 6, "reserve": [0, 674]}, - {"q": -1.978311, "temperature": 40, "lost": 5, "reserve": [0, 674]}, - {"q": -0.051066, "temperature": 48, "lost": 0, "reserve": [0, 674]}, - {"q": 0.73103, "temperature": 34, "lost": 10, "reserve": [0, 674]}, - {"q": -1.466473, "temperature": 38, "lost": 6, "reserve": [0, 674]}, - {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, - {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, - {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, - {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, - {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, - {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, - {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, - {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, - ], - "bms_state": { - "version_high": 1, - "version_low": 18, - "soc": 55, - "current": -2481, - "cycle": 56, - "bq_ntc": [30, 29], - "mcu_ntc": [33, 32], - }, - "foot_force": [97, 84, 81, 81], - "temperature_ntc1": 48, - "power_v": 28.331045, - }, -} - - -class MotorState(TypedDict): - q: float - temperature: int - lost: int - reserve: List[int] - - -class ImuState(TypedDict): - rpy: List[float] - - -class BmsState(TypedDict): - version_high: int - version_low: int - soc: int - current: int - cycle: int - bq_ntc: List[int] - mcu_ntc: List[int] - - -class LowStateData(TypedDict): - imu_state: ImuState - motor_state: List[MotorState] - bms_state: BmsState - foot_force: List[int] - temperature_ntc1: int - power_v: float - - -class LowStateMsg(TypedDict): - type: Literal["msg"] - topic: str - data: LowStateData diff --git a/build/lib/dimos/robot/unitree_webrtc/type/map.py b/build/lib/dimos/robot/unitree_webrtc/type/map.py deleted file mode 100644 index 898bd473b5..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/map.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional, Tuple -import time -import numpy as np -import open3d as o3d -import reactivex.operators as ops -from reactivex import interval -from reactivex.observable import Observable - -from dimos.core import In, Module, Out, rpc -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.types.costmap import Costmap, pointcloud_to_costmap - - -class Map(Module): - lidar: In[LidarMessage] = None - global_map: Out[LidarMessage] = None - pointcloud: o3d.geometry.PointCloud = o3d.geometry.PointCloud() - - def __init__( - self, - voxel_size: float = 0.05, - cost_resolution: float = 0.05, - global_publish_interval: Optional[float] = None, - **kwargs, - ): - self.voxel_size = voxel_size - self.cost_resolution = cost_resolution - self.global_publish_interval = global_publish_interval - super().__init__(**kwargs) - - @rpc - def start(self): - self.lidar.subscribe(self.add_frame) - - if self.global_publish_interval is not None: - interval(self.global_publish_interval).subscribe( - lambda _: self.global_map.publish(self.to_lidar_message()) - ) - - def to_lidar_message(self) -> LidarMessage: - return LidarMessage( - pointcloud=self.pointcloud, - origin=[0.0, 0.0, 0.0], - resolution=self.voxel_size, - ts=time.time(), - ) - - @rpc - def add_frame(self, frame: LidarMessage) -> "Map": - """Voxelise *frame* and splice it into the running map.""" - new_pct = frame.pointcloud.voxel_down_sample(voxel_size=self.voxel_size) - self.pointcloud = splice_cylinder(self.pointcloud, new_pct, shrink=0.5) - - def consume(self, observable: Observable[LidarMessage]) -> Observable["Map"]: - """Reactive operator that folds a stream of `LidarMessage` into the map.""" - return observable.pipe(ops.map(self.add_frame)) - - @property - def o3d_geometry(self) -> o3d.geometry.PointCloud: - return self.pointcloud - - @rpc - def costmap(self) -> Costmap: - """Return a fully inflated cost-map in a `Costmap` wrapper.""" - inflate_radius_m = 0.5 * self.voxel_size if self.voxel_size > self.cost_resolution else 0.0 - grid, origin_xy = pointcloud_to_costmap( - self.pointcloud, - resolution=self.cost_resolution, - inflate_radius_m=inflate_radius_m, - ) - - return Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.cost_resolution) - - -def splice_sphere( - map_pcd: o3d.geometry.PointCloud, - patch_pcd: o3d.geometry.PointCloud, - shrink: float = 0.95, -) -> o3d.geometry.PointCloud: - center = patch_pcd.get_center() - radius = np.linalg.norm(np.asarray(patch_pcd.points) - center, axis=1).max() * shrink - dists = np.linalg.norm(np.asarray(map_pcd.points) - center, axis=1) - victims = np.nonzero(dists < radius)[0] - survivors = map_pcd.select_by_index(victims, invert=True) - return survivors + patch_pcd - - -def splice_cylinder( - map_pcd: o3d.geometry.PointCloud, - patch_pcd: o3d.geometry.PointCloud, - axis: int = 2, - shrink: float = 0.95, -) -> o3d.geometry.PointCloud: - center = patch_pcd.get_center() - patch_pts = np.asarray(patch_pcd.points) - - # Axes perpendicular to cylinder - axes = [0, 1, 2] - axes.remove(axis) - - planar_dists = np.linalg.norm(patch_pts[:, axes] - center[axes], axis=1) - radius = planar_dists.max() * shrink - - axis_min = (patch_pts[:, axis].min() - center[axis]) * shrink + center[axis] - axis_max = (patch_pts[:, axis].max() - center[axis]) * shrink + center[axis] - - map_pts = np.asarray(map_pcd.points) - planar_dists_map = np.linalg.norm(map_pts[:, axes] - center[axes], axis=1) - - victims = np.nonzero( - (planar_dists_map < radius) - & (map_pts[:, axis] >= axis_min) - & (map_pts[:, axis] <= axis_max) - )[0] - - survivors = map_pcd.select_by_index(victims, invert=True) - return survivors + patch_pcd - - -def _inflate_lethal(costmap: np.ndarray, radius: int, lethal_val: int = 100) -> np.ndarray: - """Return *costmap* with lethal cells dilated by *radius* grid steps (circular).""" - if radius <= 0 or not np.any(costmap == lethal_val): - return costmap - - mask = costmap == lethal_val - dilated = mask.copy() - for dy in range(-radius, radius + 1): - for dx in range(-radius, radius + 1): - if dx * dx + dy * dy > radius * radius or (dx == 0 and dy == 0): - continue - dilated |= np.roll(mask, shift=(dy, dx), axis=(0, 1)) - - out = costmap.copy() - out[dilated] = lethal_val - return out diff --git a/build/lib/dimos/robot/unitree_webrtc/type/odometry.py b/build/lib/dimos/robot/unitree_webrtc/type/odometry.py deleted file mode 100644 index 76def232e4..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/odometry.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from datetime import datetime -from io import BytesIO -from typing import BinaryIO, Literal, TypeAlias, TypedDict - -from scipy.spatial.transform import Rotation as R - -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.robot.unitree_webrtc.type.timeseries import ( - EpochLike, - Timestamped, - to_datetime, - to_human_readable, -) -from dimos.types.timestamped import to_timestamp -from dimos.types.vector import Vector, VectorLike - -raw_odometry_msg_sample = { - "type": "msg", - "topic": "rt/utlidar/robot_pose", - "data": { - "header": {"stamp": {"sec": 1746565669, "nanosec": 448350564}, "frame_id": "odom"}, - "pose": { - "position": {"x": 5.961965, "y": -2.916958, "z": 0.319509}, - "orientation": {"x": 0.002787, "y": -0.000902, "z": -0.970244, "w": -0.242112}, - }, - }, -} - - -class TimeStamp(TypedDict): - sec: int - nanosec: int - - -class Header(TypedDict): - stamp: TimeStamp - frame_id: str - - -class RawPosition(TypedDict): - x: float - y: float - z: float - - -class Orientation(TypedDict): - x: float - y: float - z: float - w: float - - -class PoseData(TypedDict): - position: RawPosition - orientation: Orientation - - -class OdometryData(TypedDict): - header: Header - pose: PoseData - - -class RawOdometryMessage(TypedDict): - type: Literal["msg"] - topic: str - data: OdometryData - - -class Odometry(PoseStamped, Timestamped): - name = "geometry_msgs.PoseStamped" - - @classmethod - def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": - pose = msg["data"]["pose"] - - # Extract position - pos = Vector3( - pose["position"].get("x"), - pose["position"].get("y"), - pose["position"].get("z"), - ) - - rot = Quaternion( - pose["orientation"].get("x"), - pose["orientation"].get("y"), - pose["orientation"].get("z"), - pose["orientation"].get("w"), - ) - - ts = to_timestamp(msg["data"]["header"]["stamp"]) - return Odometry(position=pos, orientation=rot, ts=ts, frame_id="lidar") - - def __repr__(self) -> str: - return f"Odom pos({self.position}), rot({self.orientation})" diff --git a/build/lib/dimos/robot/unitree_webrtc/type/test_lidar.py b/build/lib/dimos/robot/unitree_webrtc/type/test_lidar.py deleted file mode 100644 index 912740a71a..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/test_lidar.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import time - -import pytest - -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.protocol.pubsub.lcmpubsub import LCM, Topic -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.utils.testing import SensorReplay - - -def test_init(): - lidar = SensorReplay("office_lidar") - - for raw_frame in itertools.islice(lidar.iterate(), 5): - assert isinstance(raw_frame, dict) - frame = LidarMessage.from_msg(raw_frame) - assert isinstance(frame, LidarMessage) - data = frame.to_pointcloud2().lcm_encode() - assert len(data) > 0 - assert isinstance(data, bytes) - - -@pytest.mark.tool -def test_publish(): - lcm = LCM() - lcm.start() - - topic = Topic(topic="/lidar", lcm_type=PointCloud2) - lidar = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) - - while True: - for frame in lidar.iterate(): - print(frame) - lcm.publish(topic, frame.to_pointcloud2()) - time.sleep(0.1) diff --git a/build/lib/dimos/robot/unitree_webrtc/type/test_map.py b/build/lib/dimos/robot/unitree_webrtc/type/test_map.py deleted file mode 100644 index d705bb965b..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/test_map.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from dimos.robot.unitree_webrtc.testing.helpers import show3d, show3d_stream -from dimos.robot.unitree_webrtc.testing.mock import Mock -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map, splice_sphere -from dimos.utils.reactive import backpressure -from dimos.utils.testing import SensorReplay - - -@pytest.mark.vis -def test_costmap_vis(): - map = Map() - for frame in Mock("office").iterate(): - print(frame) - map.add_frame(frame) - costmap = map.costmap - print(costmap) - show3d(costmap.smudge().pointcloud, title="Costmap").run() - - -@pytest.mark.vis -def test_reconstruction_with_realtime_vis(): - show3d_stream(Map().consume(Mock("office").stream(rate_hz=60.0)), clearframe=True).run() - - -@pytest.mark.vis -def test_splice_vis(): - mock = Mock("test") - target = mock.load("a") - insert = mock.load("b") - show3d(splice_sphere(target.pointcloud, insert.pointcloud, shrink=0.7)).run() - - -@pytest.mark.vis -def test_robot_vis(): - show3d_stream( - Map().consume(backpressure(Mock("office").stream())), - clearframe=True, - title="gloal dynamic map test", - ) - - -def test_robot_mapping(): - lidar_stream = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) - map = Map(voxel_size=0.5) - - # this will block until map has consumed the whole stream - map.consume(lidar_stream.stream()).run() - - # we investigate built map - costmap = map.costmap() - - assert costmap.grid.shape == (404, 276) - - assert 70 <= costmap.unknown_percent <= 80, ( - f"Unknown percent {costmap.unknown_percent} is not within the range 70-80" - ) - - assert 5 < costmap.free_percent < 10, ( - f"Free percent {costmap.free_percent} is not within the range 5-10" - ) - - assert 8 < costmap.occupied_percent < 15, ( - f"Occupied percent {costmap.occupied_percent} is not within the range 8-15" - ) diff --git a/build/lib/dimos/robot/unitree_webrtc/type/test_odometry.py b/build/lib/dimos/robot/unitree_webrtc/type/test_odometry.py deleted file mode 100644 index 0bd76f1900..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/test_odometry.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -import threading -from operator import add, sub -from typing import Optional - -import pytest -import reactivex.operators as ops -from dotenv import load_dotenv - -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.utils.testing import SensorReplay, SensorStorage - -_EXPECTED_TOTAL_RAD = -4.05212 - - -def test_dataset_size() -> None: - """Ensure the replay contains the expected number of messages.""" - assert sum(1 for _ in SensorReplay(name="raw_odometry_rotate_walk").iterate()) == 179 - - -def test_odometry_conversion_and_count() -> None: - """Each replay entry converts to :class:`Odometry` and count is correct.""" - for raw in SensorReplay(name="raw_odometry_rotate_walk").iterate(): - odom = Odometry.from_msg(raw) - assert isinstance(raw, dict) - assert isinstance(odom, Odometry) - - -def test_last_yaw_value() -> None: - """Verify yaw of the final message (regression guard).""" - last_msg = SensorReplay(name="raw_odometry_rotate_walk").stream().pipe(ops.last()).run() - - assert last_msg is not None, "Replay is empty" - assert last_msg["data"]["pose"]["orientation"] == { - "x": 0.01077, - "y": 0.008505, - "z": 0.499171, - "w": -0.866395, - } - - -def test_total_rotation_travel_iterate() -> None: - total_rad = 0.0 - prev_yaw: Optional[float] = None - - for odom in SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg).iterate(): - yaw = odom.orientation.radians.z - if prev_yaw is not None: - diff = yaw - prev_yaw - total_rad += diff - prev_yaw = yaw - - assert total_rad == pytest.approx(_EXPECTED_TOTAL_RAD, abs=0.001) - - -def test_total_rotation_travel_rxpy() -> None: - total_rad = ( - SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg) - .stream() - .pipe( - ops.map(lambda odom: odom.orientation.radians.z), - ops.pairwise(), # [1,2,3,4] -> [[1,2], [2,3], [3,4]] - ops.starmap(sub), # [sub(1,2), sub(2,3), sub(3,4)] - ops.reduce(add), - ) - .run() - ) - - assert total_rad == pytest.approx(4.05, abs=0.01) - - -# data collection tool -@pytest.mark.tool -def test_store_odometry_stream() -> None: - from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 - - load_dotenv() - - robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") - robot.standup() - - storage = SensorStorage("raw_odometry_rotate_walk") - storage.save_stream(robot.raw_odom_stream()) - - shutdown = threading.Event() - - try: - while not shutdown.wait(0.1): - pass - except KeyboardInterrupt: - shutdown.set() - finally: - robot.liedown() diff --git a/build/lib/dimos/robot/unitree_webrtc/type/test_timeseries.py b/build/lib/dimos/robot/unitree_webrtc/type/test_timeseries.py deleted file mode 100644 index fe96d75eaf..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/test_timeseries.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import timedelta, datetime -from dimos.robot.unitree_webrtc.type.timeseries import TEvent, TList - - -fixed_date = datetime(2025, 5, 13, 15, 2, 5).astimezone() -start_event = TEvent(fixed_date, 1) -end_event = TEvent(fixed_date + timedelta(seconds=10), 9) - -sample_list = TList([start_event, TEvent(fixed_date + timedelta(seconds=2), 5), end_event]) - - -def test_repr(): - assert ( - str(sample_list) - == "Timeseries(date=2025-05-13, start=15:02:05, end=15:02:15, duration=0:00:10, events=3, freq=0.30Hz)" - ) - - -def test_equals(): - assert start_event == TEvent(start_event.ts, 1) - assert start_event != TEvent(start_event.ts, 2) - assert start_event != TEvent(start_event.ts + timedelta(seconds=1), 1) - - -def test_range(): - assert sample_list.time_range() == (start_event.ts, end_event.ts) - - -def test_duration(): - assert sample_list.duration() == timedelta(seconds=10) diff --git a/build/lib/dimos/robot/unitree_webrtc/type/timeseries.py b/build/lib/dimos/robot/unitree_webrtc/type/timeseries.py deleted file mode 100644 index 48dfddcac5..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/timeseries.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from datetime import datetime, timedelta, timezone -from typing import Generic, Iterable, Tuple, TypedDict, TypeVar, Union - -PAYLOAD = TypeVar("PAYLOAD") - - -class RosStamp(TypedDict): - sec: int - nanosec: int - - -EpochLike = Union[int, float, datetime, RosStamp] - - -def from_ros_stamp(stamp: dict[str, int], tz: timezone = None) -> datetime: - """Convert ROS-style timestamp {'sec': int, 'nanosec': int} to datetime.""" - return datetime.fromtimestamp(stamp["sec"] + stamp["nanosec"] / 1e9, tz=tz) - - -def to_human_readable(ts: EpochLike) -> str: - dt = to_datetime(ts) - return dt.strftime("%Y-%m-%d %H:%M:%S") - - -def to_datetime(ts: EpochLike, tz: timezone = None) -> datetime: - if isinstance(ts, datetime): - # if ts.tzinfo is None: - # ts = ts.astimezone(tz) - return ts - if isinstance(ts, (int, float)): - return datetime.fromtimestamp(ts, tz=tz) - if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: - return datetime.fromtimestamp(ts["sec"] + ts["nanosec"] / 1e9, tz=tz) - raise TypeError("unsupported timestamp type") - - -class Timestamped(ABC): - """Abstract class for an event with a timestamp.""" - - ts: datetime - - def __init__(self, ts: EpochLike): - self.ts = to_datetime(ts) - - -class TEvent(Timestamped, Generic[PAYLOAD]): - """Concrete class for an event with a timestamp and data.""" - - def __init__(self, timestamp: EpochLike, data: PAYLOAD): - super().__init__(timestamp) - self.data = data - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TEvent): - return NotImplemented - return self.ts == other.ts and self.data == other.data - - def __repr__(self) -> str: - return f"TEvent(ts={self.ts}, data={self.data})" - - -EVENT = TypeVar("EVENT", bound=Timestamped) # any object that is a subclass of Timestamped - - -class Timeseries(ABC, Generic[EVENT]): - """Abstract class for an iterable of events with timestamps.""" - - @abstractmethod - def __iter__(self) -> Iterable[EVENT]: ... - - @property - def start_time(self) -> datetime: - """Return the timestamp of the earliest event, assuming the data is sorted.""" - return next(iter(self)).ts - - @property - def end_time(self) -> datetime: - """Return the timestamp of the latest event, assuming the data is sorted.""" - return next(reversed(list(self))).ts - - @property - def frequency(self) -> float: - """Calculate the frequency of events in Hz.""" - return len(list(self)) / (self.duration().total_seconds() or 1) - - def time_range(self) -> Tuple[datetime, datetime]: - """Return (earliest_ts, latest_ts). Empty input ⇒ ValueError.""" - return self.start_time, self.end_time - - def duration(self) -> timedelta: - """Total time spanned by the iterable (Δ = last - first).""" - return self.end_time - self.start_time - - def closest_to(self, timestamp: EpochLike) -> EVENT: - """Return the event closest to the given timestamp. Assumes timeseries is sorted.""" - print("closest to", timestamp) - target = to_datetime(timestamp) - print("converted to", target) - target_ts = target.timestamp() - - closest = None - min_dist = float("inf") - - for event in self: - dist = abs(event.ts - target_ts) - if dist > min_dist: - break - - min_dist = dist - closest = event - - print(f"closest: {closest}") - return closest - - def __repr__(self) -> str: - """Return a string representation of the Timeseries.""" - return f"Timeseries(date={self.start_time.strftime('%Y-%m-%d')}, start={self.start_time.strftime('%H:%M:%S')}, end={self.end_time.strftime('%H:%M:%S')}, duration={self.duration()}, events={len(list(self))}, freq={self.frequency:.2f}Hz)" - - def __str__(self) -> str: - """Return a string representation of the Timeseries.""" - return self.__repr__() - - -class TList(list[EVENT], Timeseries[EVENT]): - """A test class that inherits from both list and Timeseries.""" - - def __repr__(self) -> str: - """Return a string representation of the TList using Timeseries repr method.""" - return Timeseries.__repr__(self) diff --git a/build/lib/dimos/robot/unitree_webrtc/type/vector.py b/build/lib/dimos/robot/unitree_webrtc/type/vector.py deleted file mode 100644 index 22b00a753d..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/type/vector.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from typing import ( - Tuple, - List, - TypeVar, - Protocol, - runtime_checkable, - Any, - Iterable, - Union, -) -from numpy.typing import NDArray - -T = TypeVar("T", bound="Vector") - - -class Vector: - """A wrapper around numpy arrays for vector operations with intuitive syntax.""" - - def __init__(self, *args: Any) -> None: - """Initialize a vector from components or another iterable. - - Examples: - Vector(1, 2) # 2D vector - Vector(1, 2, 3) # 3D vector - Vector([1, 2, 3]) # From list - Vector(np.array([1, 2, 3])) # From numpy array - """ - if len(args) == 1 and hasattr(args[0], "__iter__"): - self._data = np.array(args[0], dtype=float) - elif len(args) == 1: - self._data = np.array([args[0].x, args[0].y, args[0].z], dtype=float) - - else: - self._data = np.array(args, dtype=float) - - @property - def yaw(self) -> float: - return self.x - - @property - def tuple(self) -> Tuple[float, ...]: - """Tuple representation of the vector.""" - return tuple(self._data) - - @property - def x(self) -> float: - """X component of the vector.""" - return self._data[0] if len(self._data) > 0 else 0.0 - - @property - def y(self) -> float: - """Y component of the vector.""" - return self._data[1] if len(self._data) > 1 else 0.0 - - @property - def z(self) -> float: - """Z component of the vector.""" - return self._data[2] if len(self._data) > 2 else 0.0 - - @property - def dim(self) -> int: - """Dimensionality of the vector.""" - return len(self._data) - - @property - def data(self) -> NDArray[np.float64]: - """Get the underlying numpy array.""" - return self._data - - def __len__(self) -> int: - return len(self._data) - - def __getitem__(self, idx: int) -> float: - return float(self._data[idx]) - - def __iter__(self) -> Iterable[float]: - return iter(self._data) - - def __repr__(self) -> str: - components = ",".join(f"{x:.6g}" for x in self._data) - return f"({components})" - - def __str__(self) -> str: - if self.dim < 2: - return self.__repr__() - - def getArrow() -> str: - repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] - - if self.y == 0 and self.x == 0: - return "·" - - # Calculate angle in radians and convert to directional index - angle = np.arctan2(self.y, self.x) - # Map angle to 0-7 index (8 directions) with proper orientation - dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) - # Get directional arrow symbol - return repr[dir_index] - - return f"{getArrow()} Vector {self.__repr__()}" - - def serialize(self) -> dict: - """Serialize the vector to a dictionary.""" - return {"type": "vector", "c": self._data.tolist()} - - def __eq__(self, other: Any) -> bool: - if isinstance(other, Vector): - return np.array_equal(self._data, other._data) - return np.array_equal(self._data, np.array(other, dtype=float)) - - def __add__(self: T, other: Union["Vector", Iterable[float]]) -> T: - if isinstance(other, Vector): - return self.__class__(self._data + other._data) - return self.__class__(self._data + np.array(other, dtype=float)) - - def __sub__(self: T, other: Union["Vector", Iterable[float]]) -> T: - if isinstance(other, Vector): - return self.__class__(self._data - other._data) - return self.__class__(self._data - np.array(other, dtype=float)) - - def __mul__(self: T, scalar: float) -> T: - return self.__class__(self._data * scalar) - - def __rmul__(self: T, scalar: float) -> T: - return self.__mul__(scalar) - - def __truediv__(self: T, scalar: float) -> T: - return self.__class__(self._data / scalar) - - def __neg__(self: T) -> T: - return self.__class__(-self._data) - - def dot(self, other: Union["Vector", Iterable[float]]) -> float: - """Compute dot product.""" - if isinstance(other, Vector): - return float(np.dot(self._data, other._data)) - return float(np.dot(self._data, np.array(other, dtype=float))) - - def cross(self: T, other: Union["Vector", Iterable[float]]) -> T: - """Compute cross product (3D vectors only).""" - if self.dim != 3: - raise ValueError("Cross product is only defined for 3D vectors") - - if isinstance(other, Vector): - other_data = other._data - else: - other_data = np.array(other, dtype=float) - - if len(other_data) != 3: - raise ValueError("Cross product requires two 3D vectors") - - return self.__class__(np.cross(self._data, other_data)) - - def length(self) -> float: - """Compute the Euclidean length (magnitude) of the vector.""" - return float(np.linalg.norm(self._data)) - - def length_squared(self) -> float: - """Compute the squared length of the vector (faster than length()).""" - return float(np.sum(self._data * self._data)) - - def normalize(self: T) -> T: - """Return a normalized unit vector in the same direction.""" - length = self.length() - if length < 1e-10: # Avoid division by near-zero - return self.__class__(np.zeros_like(self._data)) - return self.__class__(self._data / length) - - def to_2d(self: T) -> T: - """Convert a vector to a 2D vector by taking only the x and y components.""" - return self.__class__(self._data[:2]) - - def distance(self, other: Union["Vector", Iterable[float]]) -> float: - """Compute Euclidean distance to another vector.""" - if isinstance(other, Vector): - return float(np.linalg.norm(self._data - other._data)) - return float(np.linalg.norm(self._data - np.array(other, dtype=float))) - - def distance_squared(self, other: Union["Vector", Iterable[float]]) -> float: - """Compute squared Euclidean distance to another vector (faster than distance()).""" - if isinstance(other, Vector): - diff = self._data - other._data - else: - diff = self._data - np.array(other, dtype=float) - return float(np.sum(diff * diff)) - - def angle(self, other: Union["Vector", Iterable[float]]) -> float: - """Compute the angle (in radians) between this vector and another.""" - if self.length() < 1e-10 or (isinstance(other, Vector) and other.length() < 1e-10): - return 0.0 - - if isinstance(other, Vector): - other_data = other._data - else: - other_data = np.array(other, dtype=float) - - cos_angle = np.clip( - np.dot(self._data, other_data) - / (np.linalg.norm(self._data) * np.linalg.norm(other_data)), - -1.0, - 1.0, - ) - return float(np.arccos(cos_angle)) - - def project(self: T, onto: Union["Vector", Iterable[float]]) -> T: - """Project this vector onto another vector.""" - if isinstance(onto, Vector): - onto_data = onto._data - else: - onto_data = np.array(onto, dtype=float) - - onto_length_sq = np.sum(onto_data * onto_data) - if onto_length_sq < 1e-10: - return self.__class__(np.zeros_like(self._data)) - - scalar_projection = np.dot(self._data, onto_data) / onto_length_sq - return self.__class__(scalar_projection * onto_data) - - # this is here to test ros_observable_topic - # doesn't happen irl afaik that we want a vector from ros message - @classmethod - def from_msg(cls: type[T], msg: Any) -> T: - return cls(*msg) - - @classmethod - def zeros(cls: type[T], dim: int) -> T: - """Create a zero vector of given dimension.""" - return cls(np.zeros(dim)) - - @classmethod - def ones(cls: type[T], dim: int) -> T: - """Create a vector of ones with given dimension.""" - return cls(np.ones(dim)) - - @classmethod - def unit_x(cls: type[T], dim: int = 3) -> T: - """Create a unit vector in the x direction.""" - v = np.zeros(dim) - v[0] = 1.0 - return cls(v) - - @classmethod - def unit_y(cls: type[T], dim: int = 3) -> T: - """Create a unit vector in the y direction.""" - v = np.zeros(dim) - v[1] = 1.0 - return cls(v) - - @classmethod - def unit_z(cls: type[T], dim: int = 3) -> T: - """Create a unit vector in the z direction.""" - v = np.zeros(dim) - if dim > 2: - v[2] = 1.0 - return cls(v) - - def to_list(self) -> List[float]: - """Convert the vector to a list.""" - return [float(x) for x in self._data] - - def to_tuple(self) -> Tuple[float, ...]: - """Convert the vector to a tuple.""" - return tuple(self._data) - - def to_numpy(self) -> NDArray[np.float64]: - """Convert the vector to a numpy array.""" - return self._data - - -# Protocol approach for static type checking -@runtime_checkable -class VectorLike(Protocol): - """Protocol for types that can be treated as vectors.""" - - def __getitem__(self, key: int) -> float: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterable[float]: ... - - -def to_numpy(value: VectorLike) -> NDArray[np.float64]: - """Convert a vector-compatible value to a numpy array. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Numpy array representation - """ - if isinstance(value, Vector): - return value.data - elif isinstance(value, np.ndarray): - return value - else: - return np.array(value, dtype=float) - - -def to_vector(value: VectorLike) -> Vector: - """Convert a vector-compatible value to a Vector object. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Vector object - """ - if isinstance(value, Vector): - return value - else: - return Vector(value) - - -def to_tuple(value: VectorLike) -> Tuple[float, ...]: - """Convert a vector-compatible value to a tuple. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Tuple of floats - """ - if isinstance(value, Vector): - return tuple(float(x) for x in value.data) - elif isinstance(value, np.ndarray): - return tuple(float(x) for x in value) - elif isinstance(value, tuple): - return tuple(float(x) for x in value) - else: - # Convert to list first to ensure we have an indexable sequence - data = [value[i] for i in range(len(value))] - return tuple(float(x) for x in data) - - -def to_list(value: VectorLike) -> List[float]: - """Convert a vector-compatible value to a list. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - List of floats - """ - if isinstance(value, Vector): - return [float(x) for x in value.data] - elif isinstance(value, np.ndarray): - return [float(x) for x in value] - elif isinstance(value, list): - return [float(x) for x in value] - else: - # Convert to list using indexing - return [float(value[i]) for i in range(len(value))] - - -# Helper functions to check dimensionality -def is_2d(value: VectorLike) -> bool: - """Check if a vector-compatible value is 2D. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - True if the value is 2D - """ - if isinstance(value, Vector): - return len(value) == 2 - elif isinstance(value, np.ndarray): - return value.shape[-1] == 2 or value.size == 2 - else: - return len(value) == 2 - - -def is_3d(value: VectorLike) -> bool: - """Check if a vector-compatible value is 3D. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - True if the value is 3D - """ - if isinstance(value, Vector): - return len(value) == 3 - elif isinstance(value, np.ndarray): - return value.shape[-1] == 3 or value.size == 3 - else: - return len(value) == 3 - - -# Extraction functions for XYZ components -def x(value: VectorLike) -> float: - """Get the X component of a vector-compatible value. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - X component as a float - """ - if isinstance(value, Vector): - return value.x - else: - return float(to_numpy(value)[0]) - - -def y(value: VectorLike) -> float: - """Get the Y component of a vector-compatible value. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Y component as a float - """ - if isinstance(value, Vector): - return value.y - else: - arr = to_numpy(value) - return float(arr[1]) if len(arr) > 1 else 0.0 - - -def z(value: VectorLike) -> float: - """Get the Z component of a vector-compatible value. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Z component as a float - """ - if isinstance(value, Vector): - return value.z - else: - arr = to_numpy(value) - return float(arr[2]) if len(arr) > 2 else 0.0 diff --git a/build/lib/dimos/robot/unitree_webrtc/unitree_go2.py b/build/lib/dimos/robot/unitree_webrtc/unitree_go2.py deleted file mode 100644 index 94676bfffc..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/unitree_go2.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union, Optional, List -import time -import numpy as np -import os -from dimos.robot.robot import Robot -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.connection import WebRTCRobot -from dimos.robot.global_planner.planner import AstarPlanner -from dimos.utils.reactive import getter_streaming -from dimos.skills.skills import AbstractRobotSkill, SkillLibrary -from go2_webrtc_driver.constants import VUI_COLOR -from go2_webrtc_driver.webrtc_driver import WebRTCConnectionMethod -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.robot.local_planner.local_planner import navigate_path_local -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.types.robot_capabilities import RobotCapability -from dimos.types.vector import Vector -from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.robot.frontier_exploration.qwen_frontier_predictor import QwenFrontierPredictor -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( - WavefrontFrontierExplorer, -) -import threading - - -class Color(VUI_COLOR): ... - - -class UnitreeGo2(Robot): - def __init__( - self, - ip: str, - mode: str = "ai", - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - skill_library: SkillLibrary = None, - robot_capabilities: List[RobotCapability] = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = True, - enable_perception: bool = True, - ): - """Initialize Unitree Go2 robot with WebRTC control interface. - - Args: - ip: IP address of the robot - mode: Robot mode (ai, etc.) - output_dir: Directory for output files - skill_library: Skill library instance - robot_capabilities: List of robot capabilities - spatial_memory_collection: Collection name for spatial memory - new_memory: Whether to create new spatial memory - enable_perception: Whether to enable perception streams and spatial memory - """ - # Create WebRTC connection interface - self.webrtc_connection = WebRTCRobot( - ip=ip, - mode=mode, - ) - - print("standing up") - self.webrtc_connection.standup() - - # Initialize WebRTC-specific features - self.lidar_stream = self.webrtc_connection.lidar_stream() - self.odom = getter_streaming(self.webrtc_connection.odom_stream()) - self.map = Map(voxel_size=0.2) - self.map_stream = self.map.consume(self.lidar_stream) - self.lidar_message = getter_streaming(self.lidar_stream) - - if skill_library is None: - skill_library = MyUnitreeSkills() - - # Initialize base robot with connection interface - super().__init__( - connection_interface=self.webrtc_connection, - output_dir=output_dir, - skill_library=skill_library, - capabilities=robot_capabilities - or [ - RobotCapability.LOCOMOTION, - RobotCapability.VISION, - RobotCapability.AUDIO, - ], - spatial_memory_collection=spatial_memory_collection, - new_memory=new_memory, - enable_perception=enable_perception, - ) - - if self.skill_library is not None: - for skill in self.skill_library: - if isinstance(skill, AbstractRobotSkill): - self.skill_library.create_instance(skill.__name__, robot=self) - if isinstance(self.skill_library, MyUnitreeSkills): - self.skill_library._robot = self - self.skill_library.init() - self.skill_library.initialize_skills() - - # Camera configuration - self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] - self.camera_pitch = np.deg2rad(0) # negative for downward pitch - self.camera_height = 0.44 # meters - - # Initialize visual servoing using connection interface - video_stream = self.get_video_stream() - if video_stream is not None and enable_perception: - self.person_tracker = PersonTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - self.object_tracker = ObjectTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - person_tracking_stream = self.person_tracker.create_stream(video_stream) - object_tracking_stream = self.object_tracker.create_stream(video_stream) - - self.person_tracking_stream = person_tracking_stream - self.object_tracking_stream = object_tracking_stream - else: - # Video stream not available or perception disabled - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None - - self.global_planner = AstarPlanner( - set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( - self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event - ), - get_costmap=lambda: self.map.costmap, - get_robot_pos=lambda: self.odom().pos, - ) - - # Initialize the local planner using WebRTC-specific methods - self.local_planner = VFHPurePursuitPlanner( - get_costmap=lambda: self.lidar_message().costmap(), - get_robot_pose=lambda: self.odom(), - move=self.move, # Use the robot's move method directly - robot_width=0.36, # Unitree Go2 width in meters - robot_length=0.6, # Unitree Go2 length in meters - max_linear_vel=0.7, - max_angular_vel=0.65, - lookahead_distance=1.5, - visualization_size=500, # 500x500 pixel visualization - global_planner_plan=self.global_planner.plan, - ) - - # Initialize frontier exploration - self.frontier_explorer = WavefrontFrontierExplorer( - set_goal=self.global_planner.set_goal, - get_costmap=lambda: self.map.costmap, - get_robot_pos=lambda: self.odom().pos, - ) - - # Create the visualization stream at 5Hz - self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) - - def get_pose(self) -> dict: - """ - Get the current pose (position and rotation) of the robot in the map frame. - - Returns: - Dictionary containing: - - position: Vector (x, y, z) - - rotation: Vector (roll, pitch, yaw) in radians - """ - position = Vector(self.odom().pos.x, self.odom().pos.y, self.odom().pos.z) - orientation = Vector(self.odom().rot.x, self.odom().rot.y, self.odom().rot.z) - return {"position": position, "rotation": orientation} - - def explore(self, stop_event: Optional[threading.Event] = None) -> bool: - """ - Start autonomous frontier exploration. - - Args: - stop_event: Optional threading.Event to signal when exploration should stop - - Returns: - bool: True if exploration completed successfully, False if stopped or failed - """ - return self.frontier_explorer.explore(stop_event=stop_event) - - def odom_stream(self): - """Get the odometry stream from the robot. - - Returns: - Observable stream of robot odometry data containing position and orientation. - """ - return self.webrtc_connection.odom_stream() - - def standup(self): - """Make the robot stand up. - - Uses AI mode standup if robot is in AI mode, otherwise uses normal standup. - """ - return self.webrtc_connection.standup() - - def liedown(self): - """Make the robot lie down. - - Commands the robot to lie down on the ground. - """ - return self.webrtc_connection.liedown() - - @property - def costmap(self): - """Access to the costmap for navigation.""" - return self.map.costmap diff --git a/build/lib/dimos/robot/unitree_webrtc/unitree_skills.py b/build/lib/dimos/robot/unitree_webrtc/unitree_skills.py deleted file mode 100644 index f9dfc1eede..0000000000 --- a/build/lib/dimos/robot/unitree_webrtc/unitree_skills.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING, List, Optional, Tuple, Union -import time -from pydantic import Field - -if TYPE_CHECKING: - from dimos.robot.robot import Robot, MockRobot -else: - Robot = "Robot" - MockRobot = "MockRobot" - -from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary -from dimos.types.constants import Colors -from dimos.types.vector import Vector -from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD - -# Module-level constant for Unitree WebRTC control definitions -UNITREE_WEBRTC_CONTROLS: List[Tuple[str, int, str]] = [ - ("Damp", 1001, "Lowers the robot to the ground fully."), - ( - "BalanceStand", - 1002, - "Activates a mode that maintains the robot in a balanced standing position.", - ), - ( - "StandUp", - 1004, - "Commands the robot to transition from a sitting or prone position to a standing posture.", - ), - ( - "StandDown", - 1005, - "Instructs the robot to move from a standing position to a sitting or prone posture.", - ), - ( - "RecoveryStand", - 1006, - "Recovers the robot to a state from which it can take more commands. Useful to run after multiple dynamic commands like front flips, Must run after skills like sit and jump and standup.", - ), - ("Sit", 1009, "Commands the robot to sit down from a standing or moving stance."), - ( - "RiseSit", - 1010, - "Commands the robot to rise back to a standing position from a sitting posture.", - ), - ( - "SwitchGait", - 1011, - "Switches the robot's walking pattern or style dynamically, suitable for different terrains or speeds.", - ), - ("Trigger", 1012, "Triggers a specific action or custom routine programmed into the robot."), - ( - "BodyHeight", - 1013, - "Adjusts the height of the robot's body from the ground, useful for navigating various obstacles.", - ), - ( - "FootRaiseHeight", - 1014, - "Controls how high the robot lifts its feet during movement, which can be adjusted for different surfaces.", - ), - ( - "SpeedLevel", - 1015, - "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", - ), - ( - "Hello", - 1016, - "Performs a greeting action, which could involve a wave or other friendly gesture.", - ), - ("Stretch", 1017, "Engages the robot in a stretching routine."), - ( - "TrajectoryFollow", - 1018, - "Directs the robot to follow a predefined trajectory, which could involve complex paths or maneuvers.", - ), - ( - "ContinuousGait", - 1019, - "Enables a mode for continuous walking or running, ideal for long-distance travel.", - ), - ("Content", 1020, "To display or trigger when the robot is happy."), - ("Wallow", 1021, "The robot falls onto its back and rolls around."), - ( - "Dance1", - 1022, - "Performs a predefined dance routine 1, programmed for entertainment or demonstration.", - ), - ("Dance2", 1023, "Performs another variant of a predefined dance routine 2."), - ("GetBodyHeight", 1024, "Retrieves the current height of the robot's body from the ground."), - ( - "GetFootRaiseHeight", - 1025, - "Retrieves the current height at which the robot's feet are being raised during movement.", - ), - ( - "GetSpeedLevel", - 1026, - "Retrieves the current speed level setting of the robot.", - ), - ( - "SwitchJoystick", - 1027, - "Switches the robot's control mode to respond to joystick input for manual operation.", - ), - ( - "Pose", - 1028, - "Commands the robot to assume a specific pose or posture as predefined in its programming.", - ), - ("Scrape", 1029, "The robot performs a scraping motion."), - ( - "FrontFlip", - 1030, - "Commands the robot to perform a front flip, showcasing its agility and dynamic movement capabilities.", - ), - ( - "FrontJump", - 1031, - "Instructs the robot to jump forward, demonstrating its explosive movement capabilities.", - ), - ( - "FrontPounce", - 1032, - "Commands the robot to perform a pouncing motion forward.", - ), - ( - "WiggleHips", - 1033, - "The robot performs a hip wiggling motion, often used for entertainment or demonstration purposes.", - ), - ( - "GetState", - 1034, - "Retrieves the current operational state of the robot, including its mode, position, and status.", - ), - ( - "EconomicGait", - 1035, - "Engages a more energy-efficient walking or running mode to conserve battery life.", - ), - ("FingerHeart", 1036, "Performs a finger heart gesture while on its hind legs."), - ( - "Handstand", - 1301, - "Commands the robot to perform a handstand, demonstrating balance and control.", - ), - ( - "CrossStep", - 1302, - "Commands the robot to perform cross-step movements.", - ), - ( - "OnesidedStep", - 1303, - "Commands the robot to perform one-sided step movements.", - ), - ("Bound", 1304, "Commands the robot to perform bounding movements."), - ("MoonWalk", 1305, "Commands the robot to perform a moonwalk motion."), - ("LeftFlip", 1042, "Executes a flip towards the left side."), - ("RightFlip", 1043, "Performs a flip towards the right side."), - ("Backflip", 1044, "Executes a backflip, a complex and dynamic maneuver."), -] - -# region MyUnitreeSkills - - -class MyUnitreeSkills(SkillLibrary): - """My Unitree Skills for WebRTC interface.""" - - def __init__(self, robot: Optional[Robot] = None): - super().__init__() - self._robot: Robot = None - - # Add dynamic skills to this class - dynamic_skills = self.create_skills_live() - self.register_skills(dynamic_skills) - - @classmethod - def register_skills(cls, skill_classes: Union["AbstractSkill", list["AbstractSkill"]]): - """Add multiple skill classes as class attributes. - - Args: - skill_classes: List of skill classes to add - """ - if not isinstance(skill_classes, list): - skill_classes = [skill_classes] - - for skill_class in skill_classes: - # Add to the class as a skill - setattr(cls, skill_class.__name__, skill_class) - - def initialize_skills(self): - for skill_class in self.get_class_skills(): - self.create_instance(skill_class.__name__, robot=self._robot) - - # Refresh the class skills - self.refresh_class_skills() - - def create_skills_live(self) -> List[AbstractRobotSkill]: - # ================================================ - # Procedurally created skills - # ================================================ - class BaseUnitreeSkill(AbstractRobotSkill): - """Base skill for dynamic skill creation.""" - - def __call__(self): - string = f"{Colors.GREEN_PRINT_COLOR}This is a base skill, created for the specific skill: {self._app_id}{Colors.RESET_COLOR}" - print(string) - super().__call__() - if self._app_id is None: - raise RuntimeError( - f"{Colors.RED_PRINT_COLOR}" - f"No App ID provided to {self.__class__.__name__} Skill" - f"{Colors.RESET_COLOR}" - ) - else: - # Use WebRTC publish_request interface through the robot's webrtc_connection - result = self._robot.webrtc_connection.publish_request( - RTC_TOPIC["SPORT_MOD"], {"api_id": self._app_id} - ) - string = f"{Colors.GREEN_PRINT_COLOR}{self.__class__.__name__} was successful: id={self._app_id}{Colors.RESET_COLOR}" - print(string) - return string - - skills_classes = [] - for name, app_id, description in UNITREE_WEBRTC_CONTROLS: - if name not in ["Reverse", "Spin"]: # Exclude reverse and spin skills - skill_class = type( - name, # Name of the class - (BaseUnitreeSkill,), # Base classes - {"__doc__": description, "_app_id": app_id}, - ) - skills_classes.append(skill_class) - - return skills_classes - - # region Class-based Skills - - class Move(AbstractRobotSkill): - """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" - - x: float = Field(..., description="Forward velocity (m/s).") - y: float = Field(default=0.0, description="Left/right velocity (m/s)") - yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field(default=0.0, description="How long to move (seconds).") - - def __call__(self): - return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) - - class Wait(AbstractSkill): - """Wait for a specified amount of time.""" - - seconds: float = Field(..., description="Seconds to wait") - - def __call__(self): - time.sleep(self.seconds) - return f"Wait completed with length={self.seconds}s" - - # endregion - - -# endregion diff --git a/build/lib/dimos/simulation/__init__.py b/build/lib/dimos/simulation/__init__.py deleted file mode 100644 index 3d25363b30..0000000000 --- a/build/lib/dimos/simulation/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Try to import Isaac Sim components -try: - from .isaac import IsaacSimulator, IsaacStream -except ImportError: - IsaacSimulator = None # type: ignore - IsaacStream = None # type: ignore - -# Try to import Genesis components -try: - from .genesis import GenesisSimulator, GenesisStream -except ImportError: - GenesisSimulator = None # type: ignore - GenesisStream = None # type: ignore - -__all__ = ["IsaacSimulator", "IsaacStream", "GenesisSimulator", "GenesisStream"] diff --git a/build/lib/dimos/simulation/base/__init__.py b/build/lib/dimos/simulation/base/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/simulation/base/simulator_base.py b/build/lib/dimos/simulation/base/simulator_base.py deleted file mode 100644 index 91633bb53a..0000000000 --- a/build/lib/dimos/simulation/base/simulator_base.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Union, List, Dict -from abc import ABC, abstractmethod - - -class SimulatorBase(ABC): - """Base class for simulators.""" - - @abstractmethod - def __init__( - self, - headless: bool = True, - open_usd: Optional[str] = None, # Keep for Isaac compatibility - entities: Optional[List[Dict[str, Union[str, dict]]]] = None, # Add for Genesis - ): - """Initialize the simulator. - - Args: - headless: Whether to run without visualization - open_usd: Path to USD file (for Isaac) - entities: List of entity configurations (for Genesis) - """ - self.headless = headless - self.open_usd = open_usd - self.stage = None - - @abstractmethod - def get_stage(self): - """Get the current stage/scene.""" - pass - - @abstractmethod - def close(self): - """Close the simulation.""" - pass diff --git a/build/lib/dimos/simulation/base/stream_base.py b/build/lib/dimos/simulation/base/stream_base.py deleted file mode 100644 index d20af296e2..0000000000 --- a/build/lib/dimos/simulation/base/stream_base.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Literal, Optional, Union -from pathlib import Path -import subprocess - -AnnotatorType = Literal["rgb", "normals", "bounding_box_3d", "motion_vectors"] -TransportType = Literal["tcp", "udp"] - - -class StreamBase(ABC): - """Base class for simulation streaming.""" - - @abstractmethod - def __init__( - self, - simulator, - width: int = 1920, - height: int = 1080, - fps: int = 60, - camera_path: str = "/World/camera", - annotator_type: AnnotatorType = "rgb", - transport: TransportType = "tcp", - rtsp_url: str = "rtsp://mediamtx:8554/stream", - usd_path: Optional[Union[str, Path]] = None, - ): - """Initialize the stream. - - Args: - simulator: Simulator instance - width: Stream width in pixels - height: Stream height in pixels - fps: Frames per second - camera_path: Camera path in scene - annotator: Type of annotator to use - transport: Transport protocol - rtsp_url: RTSP stream URL - usd_path: Optional USD file path to load - """ - self.simulator = simulator - self.width = width - self.height = height - self.fps = fps - self.camera_path = camera_path - self.annotator_type = annotator_type - self.transport = transport - self.rtsp_url = rtsp_url - self.proc = None - - @abstractmethod - def _load_stage(self, usd_path: Union[str, Path]): - """Load stage from file.""" - pass - - @abstractmethod - def _setup_camera(self): - """Setup and validate camera.""" - pass - - def _setup_ffmpeg(self): - """Setup FFmpeg process for streaming.""" - command = [ - "ffmpeg", - "-y", - "-f", - "rawvideo", - "-vcodec", - "rawvideo", - "-pix_fmt", - "bgr24", - "-s", - f"{self.width}x{self.height}", - "-r", - str(self.fps), - "-i", - "-", - "-an", - "-c:v", - "h264_nvenc", - "-preset", - "fast", - "-f", - "rtsp", - "-rtsp_transport", - self.transport, - self.rtsp_url, - ] - self.proc = subprocess.Popen(command, stdin=subprocess.PIPE) - - @abstractmethod - def _setup_annotator(self): - """Setup annotator.""" - pass - - @abstractmethod - def stream(self): - """Start streaming.""" - pass - - @abstractmethod - def cleanup(self): - """Cleanup resources.""" - pass diff --git a/build/lib/dimos/simulation/genesis/__init__.py b/build/lib/dimos/simulation/genesis/__init__.py deleted file mode 100644 index 5657d9167b..0000000000 --- a/build/lib/dimos/simulation/genesis/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .simulator import GenesisSimulator -from .stream import GenesisStream - -__all__ = ["GenesisSimulator", "GenesisStream"] diff --git a/build/lib/dimos/simulation/genesis/simulator.py b/build/lib/dimos/simulation/genesis/simulator.py deleted file mode 100644 index e531e6b422..0000000000 --- a/build/lib/dimos/simulation/genesis/simulator.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Union, List, Dict -import genesis as gs # type: ignore -from ..base.simulator_base import SimulatorBase - - -class GenesisSimulator(SimulatorBase): - """Genesis simulator implementation.""" - - def __init__( - self, - headless: bool = True, - open_usd: Optional[str] = None, # Keep for compatibility - entities: Optional[List[Dict[str, Union[str, dict]]]] = None, - ): - """Initialize the Genesis simulation. - - Args: - headless: Whether to run without visualization - open_usd: Path to USD file (for Isaac) - entities: List of entity configurations to load. Each entity is a dict with: - - type: str ('mesh', 'urdf', 'mjcf', 'primitive') - - path: str (file path for mesh/urdf/mjcf) - - params: dict (parameters for primitives or loading options) - """ - super().__init__(headless, open_usd, entities) - - # Initialize Genesis - gs.init() - - # Create scene with viewer options - self.scene = gs.Scene( - show_viewer=not headless, - viewer_options=gs.options.ViewerOptions( - res=(1280, 960), - camera_pos=(3.5, 0.0, 2.5), - camera_lookat=(0.0, 0.0, 0.5), - camera_fov=40, - max_FPS=60, - ), - vis_options=gs.options.VisOptions( - show_world_frame=True, - world_frame_size=1.0, - show_link_frame=False, - show_cameras=False, - plane_reflection=True, - ambient_light=(0.1, 0.1, 0.1), - ), - renderer=gs.renderers.Rasterizer(), - ) - - # Handle USD parameter for compatibility - if open_usd: - print(f"[Warning] USD files not supported in Genesis. Ignoring: {open_usd}") - - # Load entities if provided - if entities: - self._load_entities(entities) - - # Don't build scene yet - let stream add camera first - self.is_built = False - - def _load_entities(self, entities: List[Dict[str, Union[str, dict]]]): - """Load multiple entities into the scene.""" - for entity in entities: - entity_type = entity.get("type", "").lower() - path = entity.get("path", "") - params = entity.get("params", {}) - - try: - if entity_type == "mesh": - mesh = gs.morphs.Mesh( - file=path, # Explicit file argument - **params, - ) - self.scene.add_entity(mesh) - print(f"[Genesis] Added mesh from {path}") - - elif entity_type == "urdf": - robot = gs.morphs.URDF( - file=path, # Explicit file argument - **params, - ) - self.scene.add_entity(robot) - print(f"[Genesis] Added URDF robot from {path}") - - elif entity_type == "mjcf": - mujoco = gs.morphs.MJCF( - file=path, # Explicit file argument - **params, - ) - self.scene.add_entity(mujoco) - print(f"[Genesis] Added MJCF model from {path}") - - elif entity_type == "primitive": - shape_type = params.pop("shape", "plane") - if shape_type == "plane": - morph = gs.morphs.Plane(**params) - elif shape_type == "box": - morph = gs.morphs.Box(**params) - elif shape_type == "sphere": - morph = gs.morphs.Sphere(**params) - else: - raise ValueError(f"Unsupported primitive shape: {shape_type}") - - # Add position if not specified - if "pos" not in params: - if shape_type == "plane": - morph.pos = [0, 0, 0] - else: - morph.pos = [0, 0, 1] # Lift objects above ground - - self.scene.add_entity(morph) - print(f"[Genesis] Added {shape_type} at position {morph.pos}") - - else: - raise ValueError(f"Unsupported entity type: {entity_type}") - - except Exception as e: - print(f"[Warning] Failed to load entity {entity}: {str(e)}") - - def add_entity(self, entity_type: str, path: str = "", **params): - """Add a single entity to the scene. - - Args: - entity_type: Type of entity ('mesh', 'urdf', 'mjcf', 'primitive') - path: File path for mesh/urdf/mjcf entities - **params: Additional parameters for entity creation - """ - self._load_entities([{"type": entity_type, "path": path, "params": params}]) - - def get_stage(self): - """Get the current stage/scene.""" - return self.scene - - def build(self): - """Build the scene if not already built.""" - if not self.is_built: - self.scene.build() - self.is_built = True - - def close(self): - """Close the simulation.""" - # Genesis handles cleanup automatically - pass diff --git a/build/lib/dimos/simulation/genesis/stream.py b/build/lib/dimos/simulation/genesis/stream.py deleted file mode 100644 index fbb70fea13..0000000000 --- a/build/lib/dimos/simulation/genesis/stream.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -import time -from typing import Optional, Union -from pathlib import Path -from ..base.stream_base import StreamBase, AnnotatorType, TransportType - - -class GenesisStream(StreamBase): - """Genesis stream implementation.""" - - def __init__( - self, - simulator, - width: int = 1920, - height: int = 1080, - fps: int = 60, - camera_path: str = "/camera", - annotator_type: AnnotatorType = "rgb", - transport: TransportType = "tcp", - rtsp_url: str = "rtsp://mediamtx:8554/stream", - usd_path: Optional[Union[str, Path]] = None, - ): - """Initialize the Genesis stream.""" - super().__init__( - simulator=simulator, - width=width, - height=height, - fps=fps, - camera_path=camera_path, - annotator_type=annotator_type, - transport=transport, - rtsp_url=rtsp_url, - usd_path=usd_path, - ) - - self.scene = simulator.get_stage() - - # Initialize components - if usd_path: - self._load_stage(usd_path) - self._setup_camera() - self._setup_ffmpeg() - self._setup_annotator() - - # Build scene after camera is set up - simulator.build() - - def _load_stage(self, usd_path: Union[str, Path]): - """Load stage from file.""" - # Genesis handles stage loading through simulator - pass - - def _setup_camera(self): - """Setup and validate camera.""" - self.camera = self.scene.add_camera( - res=(self.width, self.height), - pos=(3.5, 0.0, 2.5), - lookat=(0, 0, 0.5), - fov=30, - GUI=False, - ) - - def _setup_annotator(self): - """Setup the specified annotator.""" - # Genesis handles different render types through camera.render() - pass - - def stream(self): - """Start the streaming loop.""" - try: - print("[Stream] Starting Genesis camera stream...") - frame_count = 0 - start_time = time.time() - - while True: - frame_start = time.time() - - # Step simulation and get frame - step_start = time.time() - self.scene.step() - step_time = time.time() - step_start - print(f"[Stream] Simulation step took {step_time * 1000:.2f}ms") - - # Get frame based on annotator type - if self.annotator_type == "rgb": - frame, _, _, _ = self.camera.render(rgb=True) - elif self.annotator_type == "normals": - _, _, _, frame = self.camera.render(normal=True) - else: - frame, _, _, _ = self.camera.render(rgb=True) # Default to RGB - - # Convert frame format if needed - if isinstance(frame, np.ndarray): - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - - # Write to FFmpeg - self.proc.stdin.write(frame.tobytes()) - self.proc.stdin.flush() - - # Log metrics - frame_time = time.time() - frame_start - print(f"[Stream] Total frame processing took {frame_time * 1000:.2f}ms") - frame_count += 1 - - if frame_count % 100 == 0: - elapsed_time = time.time() - start_time - current_fps = frame_count / elapsed_time - print( - f"[Stream] Processed {frame_count} frames | Current FPS: {current_fps:.2f}" - ) - - except KeyboardInterrupt: - print("\n[Stream] Received keyboard interrupt, stopping stream...") - finally: - self.cleanup() - - def cleanup(self): - """Cleanup resources.""" - print("[Cleanup] Stopping FFmpeg process...") - if hasattr(self, "proc"): - self.proc.stdin.close() - self.proc.wait() - print("[Cleanup] Closing simulation...") - try: - self.simulator.close() - except AttributeError: - print("[Cleanup] Warning: Could not close simulator properly") - print("[Cleanup] Successfully cleaned up resources") diff --git a/build/lib/dimos/simulation/isaac/__init__.py b/build/lib/dimos/simulation/isaac/__init__.py deleted file mode 100644 index 2b9bdc082d..0000000000 --- a/build/lib/dimos/simulation/isaac/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .simulator import IsaacSimulator -from .stream import IsaacStream - -__all__ = ["IsaacSimulator", "IsaacStream"] diff --git a/build/lib/dimos/simulation/isaac/simulator.py b/build/lib/dimos/simulation/isaac/simulator.py deleted file mode 100644 index ba6fe319b4..0000000000 --- a/build/lib/dimos/simulation/isaac/simulator.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, List, Dict, Union -from isaacsim import SimulationApp -from ..base.simulator_base import SimulatorBase - - -class IsaacSimulator(SimulatorBase): - """Isaac Sim simulator implementation.""" - - def __init__( - self, - headless: bool = True, - open_usd: Optional[str] = None, - entities: Optional[List[Dict[str, Union[str, dict]]]] = None, # Add but ignore - ): - """Initialize the Isaac Sim simulation.""" - super().__init__(headless, open_usd) - self.app = SimulationApp({"headless": headless, "open_usd": open_usd}) - - def get_stage(self): - """Get the current USD stage.""" - import omni.usd - - self.stage = omni.usd.get_context().get_stage() - return self.stage - - def close(self): - """Close the simulation.""" - if hasattr(self, "app"): - self.app.close() diff --git a/build/lib/dimos/simulation/isaac/stream.py b/build/lib/dimos/simulation/isaac/stream.py deleted file mode 100644 index 44560783bd..0000000000 --- a/build/lib/dimos/simulation/isaac/stream.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import time -from typing import Optional, Union -from pathlib import Path -from ..base.stream_base import StreamBase, AnnotatorType, TransportType - - -class IsaacStream(StreamBase): - """Isaac Sim stream implementation.""" - - def __init__( - self, - simulator, - width: int = 1920, - height: int = 1080, - fps: int = 60, - camera_path: str = "/World/alfred_parent_prim/alfred_base_descr/chest_cam_rgb_camera_frame/chest_cam", - annotator_type: AnnotatorType = "rgb", - transport: TransportType = "tcp", - rtsp_url: str = "rtsp://mediamtx:8554/stream", - usd_path: Optional[Union[str, Path]] = None, - ): - """Initialize the Isaac Sim stream.""" - super().__init__( - simulator=simulator, - width=width, - height=height, - fps=fps, - camera_path=camera_path, - annotator_type=annotator_type, - transport=transport, - rtsp_url=rtsp_url, - usd_path=usd_path, - ) - - # Import omni.replicator after SimulationApp initialization - import omni.replicator.core as rep - - self.rep = rep - - # Initialize components - if usd_path: - self._load_stage(usd_path) - self._setup_camera() - self._setup_ffmpeg() - self._setup_annotator() - - def _load_stage(self, usd_path: Union[str, Path]): - """Load USD stage from file.""" - import omni.usd - - abs_path = str(Path(usd_path).resolve()) - omni.usd.get_context().open_stage(abs_path) - self.stage = self.simulator.get_stage() - if not self.stage: - raise RuntimeError(f"Failed to load stage: {abs_path}") - - def _setup_camera(self): - """Setup and validate camera.""" - self.stage = self.simulator.get_stage() - camera_prim = self.stage.GetPrimAtPath(self.camera_path) - if not camera_prim: - raise RuntimeError(f"Failed to find camera at path: {self.camera_path}") - - self.render_product = self.rep.create.render_product( - self.camera_path, resolution=(self.width, self.height) - ) - - def _setup_annotator(self): - """Setup the specified annotator.""" - self.annotator = self.rep.AnnotatorRegistry.get_annotator(self.annotator_type) - self.annotator.attach(self.render_product) - - def stream(self): - """Start the streaming loop.""" - try: - print("[Stream] Starting camera stream loop...") - frame_count = 0 - start_time = time.time() - - while True: - frame_start = time.time() - - # Step simulation and get frame - step_start = time.time() - self.rep.orchestrator.step() - step_time = time.time() - step_start - print(f"[Stream] Simulation step took {step_time * 1000:.2f}ms") - - frame = self.annotator.get_data() - frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR) - - # Write to FFmpeg - self.proc.stdin.write(frame.tobytes()) - self.proc.stdin.flush() - - # Log metrics - frame_time = time.time() - frame_start - print(f"[Stream] Total frame processing took {frame_time * 1000:.2f}ms") - frame_count += 1 - - if frame_count % 100 == 0: - elapsed_time = time.time() - start_time - current_fps = frame_count / elapsed_time - print( - f"[Stream] Processed {frame_count} frames | Current FPS: {current_fps:.2f}" - ) - - except KeyboardInterrupt: - print("\n[Stream] Received keyboard interrupt, stopping stream...") - finally: - self.cleanup() - - def cleanup(self): - """Cleanup resources.""" - print("[Cleanup] Stopping FFmpeg process...") - if hasattr(self, "proc"): - self.proc.stdin.close() - self.proc.wait() - print("[Cleanup] Closing simulation...") - self.simulator.close() - print("[Cleanup] Successfully cleaned up resources") diff --git a/build/lib/dimos/skills/__init__.py b/build/lib/dimos/skills/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/skills/kill_skill.py b/build/lib/dimos/skills/kill_skill.py deleted file mode 100644 index f7eb63e807..0000000000 --- a/build/lib/dimos/skills/kill_skill.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Kill skill for terminating running skills. - -This module provides a skill that can terminate other running skills, -particularly those running in separate threads like the monitor skill. -""" - -from typing import Optional -from pydantic import Field - -from dimos.skills.skills import AbstractSkill, SkillLibrary -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.skills.kill_skill") - - -class KillSkill(AbstractSkill): - """ - A skill that terminates other running skills. - - This skill can be used to stop long-running or background skills - like the monitor skill. It uses the centralized process management - in the SkillLibrary to track and terminate skills. - """ - - skill_name: str = Field(..., description="Name of the skill to terminate") - - def __init__(self, skill_library: Optional[SkillLibrary] = None, **data): - """ - Initialize the kill skill. - - Args: - skill_library: The skill library instance - **data: Additional data for configuration - """ - super().__init__(**data) - self._skill_library = skill_library - - def __call__(self): - """ - Terminate the specified skill. - - Returns: - A message indicating whether the skill was successfully terminated - """ - print("running skills", self._skill_library.get_running_skills()) - # Terminate the skill using the skill library - return self._skill_library.terminate_skill(self.skill_name) diff --git a/build/lib/dimos/skills/navigation.py b/build/lib/dimos/skills/navigation.py deleted file mode 100644 index 6d67ae04f2..0000000000 --- a/build/lib/dimos/skills/navigation.py +++ /dev/null @@ -1,699 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Semantic map skills for building and navigating spatial memory maps. - -This module provides two skills: -1. BuildSemanticMap - Builds a semantic map by recording video frames at different locations -2. Navigate - Queries an existing semantic map using natural language -""" - -import os -import time -import threading -from typing import Optional, Tuple -from dimos.utils.threadpool import get_scheduler - -from reactivex import operators as ops -from pydantic import Field - -from dimos.skills.skills import AbstractRobotSkill -from dimos.types.robot_location import RobotLocation -from dimos.utils.logging_config import setup_logger -from dimos.models.qwen.video_query import get_bbox_from_qwen_frame -from dimos.utils.transform_utils import distance_angle_to_goal_xy -from dimos.robot.local_planner.local_planner import navigate_to_goal_local - -logger = setup_logger("dimos.skills.semantic_map_skills") - - -def get_dimos_base_path(): - """ - Get the DiMOS base path from DIMOS_PATH environment variable or default to user's home directory. - - Returns: - Base path to use for DiMOS assets - """ - dimos_path = os.environ.get("DIMOS_PATH") - if dimos_path: - return dimos_path - # Get the current user's username - user = os.environ.get("USER", os.path.basename(os.path.expanduser("~"))) - return f"/home/{user}/dimos" - - -class NavigateWithText(AbstractRobotSkill): - """ - A skill that queries an existing semantic map using natural language or tries to navigate to an object in view. - - This skill first attempts to locate an object in the robot's camera view using vision. - If the object is found, it navigates to it. If not, it falls back to querying the - semantic map for a location matching the description. For example, "Find the Teddy Bear" - will first look for a Teddy Bear in view, then check the semantic map coordinates where - a Teddy Bear was previously observed. - - CALL THIS SKILL FOR ONE SUBJECT AT A TIME. For example: "Go to the person wearing a blue shirt in the living room", - you should call this skill twice, once for the person wearing a blue shirt and once for the living room. - - If skip_visual_search is True, this skill will skip the visual search for the object in view. - This is useful if you want to navigate to a general location such as a kitchen or office. - For example, "Go to the kitchen" will not look for a kitchen in view, but will check the semantic map coordinates where - a kitchen was previously observed. - """ - - query: str = Field("", description="Text query to search for in the semantic map") - - limit: int = Field(1, description="Maximum number of results to return") - distance: float = Field(1.0, description="Desired distance to maintain from object in meters") - skip_visual_search: bool = Field(False, description="Skip visual search for object in view") - timeout: float = Field(40.0, description="Maximum time to spend navigating in seconds") - - def __init__(self, robot=None, **data): - """ - Initialize the Navigate skill. - - Args: - robot: The robot instance - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - self._stop_event = threading.Event() - self._spatial_memory = None - self._scheduler = get_scheduler() # Use the shared DiMOS thread pool - self._navigation_disposable = None # Disposable returned by scheduler.schedule() - self._tracking_subscriber = None # For object tracking - self._similarity_threshold = 0.25 - - def _navigate_to_object(self): - """ - Helper method that attempts to navigate to an object visible in the camera view. - - Returns: - dict: Result dictionary with success status and details - """ - # Stop any existing operation - self._stop_event.clear() - - try: - logger.warning( - f"Attempting to navigate to visible object: {self.query} with desired distance {self.distance}m, timeout {self.timeout} seconds..." - ) - - # Try to get a bounding box from Qwen - only try once - bbox = None - try: - # Use the robot's existing video stream instead of creating a new one - frame = self._robot.get_video_stream().pipe(ops.take(1)).run() - # Use the frame-based function - bbox, object_size = get_bbox_from_qwen_frame(frame, object_name=self.query) - except Exception as e: - logger.error(f"Error querying Qwen: {e}") - return { - "success": False, - "failure_reason": "Perception", - "error": f"Could not detect {self.query} in view: {e}", - } - - if bbox is None or self._stop_event.is_set(): - logger.error(f"Failed to get bounding box for {self.query}") - return { - "success": False, - "failure_reason": "Perception", - "error": f"Could not find {self.query} in view", - } - - logger.info(f"Found {self.query} at {bbox} with size {object_size}") - - # Start the object tracker with the detected bbox - self._robot.object_tracker.track(bbox, frame=frame) - - # Get the first tracking data with valid distance and angle - start_time = time.time() - target_acquired = False - goal_x_robot = 0 - goal_y_robot = 0 - goal_angle = 0 - - while ( - time.time() - start_time < 10.0 - and not self._stop_event.is_set() - and not target_acquired - ): - # Get the latest tracking data - tracking_data = self._robot.object_tracking_stream.pipe(ops.take(1)).run() - - if tracking_data and tracking_data.get("targets") and tracking_data["targets"]: - target = tracking_data["targets"][0] - - if "distance" in target and "angle" in target: - # Convert target distance and angle to xy coordinates in robot frame - goal_distance = ( - target["distance"] - self.distance - ) # Subtract desired distance to stop short - goal_angle = -target["angle"] - logger.info(f"Target distance: {goal_distance}, Target angle: {goal_angle}") - - goal_x_robot, goal_y_robot = distance_angle_to_goal_xy( - goal_distance, goal_angle - ) - target_acquired = True - break - - else: - logger.warning("No valid target tracking data found.") - - else: - logger.warning("No valid target tracking data found.") - - time.sleep(0.1) - - if not target_acquired: - logger.error("Failed to acquire valid target tracking data") - return { - "success": False, - "failure_reason": "Perception", - "error": "Failed to track object", - } - - logger.info( - f"Navigating to target at local coordinates: ({goal_x_robot:.2f}, {goal_y_robot:.2f}), angle: {goal_angle:.2f}" - ) - - # Use navigate_to_goal_local instead of directly controlling the local planner - success = navigate_to_goal_local( - robot=self._robot, - goal_xy_robot=(goal_x_robot, goal_y_robot), - goal_theta=goal_angle, - distance=0.0, # We already accounted for desired distance - timeout=self.timeout, - stop_event=self._stop_event, - ) - - if success: - logger.info(f"Successfully navigated to {self.query}") - return { - "success": True, - "failure_reason": None, - "query": self.query, - "message": f"Successfully navigated to {self.query} in view", - } - else: - logger.warning( - f"Failed to reach {self.query} within timeout or operation was stopped" - ) - return { - "success": False, - "failure_reason": "Navigation", - "error": f"Failed to reach {self.query} within timeout", - } - - except Exception as e: - logger.error(f"Error in navigate to object: {e}") - return {"success": False, "failure_reason": "Code Error", "error": f"Error: {e}"} - finally: - # Clean up - self._robot.object_tracker.cleanup() - - def _navigate_using_semantic_map(self): - """ - Helper method that attempts to navigate using the semantic map query. - - Returns: - dict: Result dictionary with success status and details - """ - logger.info(f"Querying semantic map for: '{self.query}'") - - try: - self._spatial_memory = self._robot.get_spatial_memory() - - # Run the query - results = self._spatial_memory.query_by_text(self.query, limit=self.limit) - - if not results: - logger.warning(f"No results found for query: '{self.query}'") - return { - "success": False, - "query": self.query, - "error": "No matching location found in semantic map", - } - - # Get the best match - best_match = results[0] - metadata = best_match.get("metadata", {}) - - if isinstance(metadata, list) and metadata: - metadata = metadata[0] - - # Extract coordinates from metadata - if ( - isinstance(metadata, dict) - and "pos_x" in metadata - and "pos_y" in metadata - and "rot_z" in metadata - ): - pos_x = metadata.get("pos_x", 0) - pos_y = metadata.get("pos_y", 0) - theta = metadata.get("rot_z", 0) - - # Calculate similarity score (distance is inverse of similarity) - similarity = 1.0 - ( - best_match.get("distance", 0) if best_match.get("distance") is not None else 0 - ) - - logger.info( - f"Found match for '{self.query}' at ({pos_x:.2f}, {pos_y:.2f}, rotation {theta:.2f}) with similarity: {similarity:.4f}" - ) - - # Check if similarity is below the threshold - if similarity < self._similarity_threshold: - logger.warning( - f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})" - ) - return { - "success": False, - "query": self.query, - "position": (pos_x, pos_y), - "rotation": theta, - "similarity": similarity, - "error": f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})", - } - - # Reset the stop event before starting navigation - self._stop_event.clear() - - # The scheduler approach isn't working, switch to direct threading - # Define a navigation function that will run on a separate thread - def run_navigation(): - skill_library = self._robot.get_skills() - self.register_as_running("Navigate", skill_library) - - try: - logger.info( - f"Starting navigation to ({pos_x:.2f}, {pos_y:.2f}) with rotation {theta:.2f}" - ) - # Pass our stop_event to allow cancellation - result = False - try: - result = self._robot.global_planner.set_goal( - (pos_x, pos_y), goal_theta=theta, stop_event=self._stop_event - ) - except Exception as e: - logger.error(f"Error calling global_planner.set_goal: {e}") - - if result: - logger.info("Navigation completed successfully") - else: - logger.error("Navigation did not complete successfully") - return result - except Exception as e: - logger.error(f"Unexpected error in navigation thread: {e}") - return False - finally: - self.stop() - - # Cancel any existing navigation before starting a new one - # Signal stop to any running navigation - self._stop_event.set() - # Clear stop event for new navigation - self._stop_event.clear() - - # Run the navigation in the main thread - run_navigation() - - return { - "success": True, - "query": self.query, - "position": (pos_x, pos_y), - "rotation": theta, - "similarity": similarity, - "metadata": metadata, - } - else: - logger.warning(f"No valid position data found for query: '{self.query}'") - return { - "success": False, - "query": self.query, - "error": "No valid position data found in semantic map", - } - - except Exception as e: - logger.error(f"Error in semantic map navigation: {e}") - return {"success": False, "error": f"Semantic map error: {e}"} - - def __call__(self): - """ - First attempts to navigate to an object in view, then falls back to querying the semantic map. - - Returns: - A dictionary with the result of the navigation attempt - """ - super().__call__() - - if not self.query: - error_msg = "No query provided to Navigate skill" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - # First, try to find and navigate to the object in camera view - logger.info(f"First attempting to find and navigate to visible object: '{self.query}'") - - if not self.skip_visual_search: - object_result = self._navigate_to_object() - - if object_result and object_result["success"]: - logger.info(f"Successfully navigated to {self.query} in view") - return object_result - - elif object_result and object_result["failure_reason"] == "Navigation": - logger.info( - f"Failed to navigate to {self.query} in view: {object_result.get('error', 'Unknown error')}" - ) - return object_result - - # If object navigation failed, fall back to semantic map - logger.info( - f"Object not found in view. Falling back to semantic map query for: '{self.query}'" - ) - - return self._navigate_using_semantic_map() - - def stop(self): - """ - Stop the navigation skill and clean up resources. - - Returns: - A message indicating whether the navigation was stopped successfully - """ - logger.info("Stopping Navigate skill") - - # Signal any running processes to stop via the shared event - self._stop_event.set() - - skill_library = self._robot.get_skills() - self.unregister_as_running("Navigate", skill_library) - - # Dispose of any existing navigation task - if hasattr(self, "_navigation_disposable") and self._navigation_disposable: - logger.info("Disposing navigation task") - try: - self._navigation_disposable.dispose() - except Exception as e: - logger.error(f"Error disposing navigation task: {e}") - self._navigation_disposable = None - - return "Navigate skill stopped successfully." - - -class GetPose(AbstractRobotSkill): - """ - A skill that returns the current position and orientation of the robot. - - This skill is useful for getting the current pose of the robot in the map frame. You call this skill - if you want to remember a location, for example, "remember this is where my favorite chair is" and then - call this skill to get the position and rotation of approximately where the chair is. You can then use - the position to navigate to the chair. - - When location_name is provided, this skill will also remember the current location with that name, - allowing you to navigate back to it later using the Navigate skill. - """ - - location_name: str = Field( - "", description="Optional name to assign to this location (e.g., 'kitchen', 'office')" - ) - - def __init__(self, robot=None, **data): - """ - Initialize the GetPose skill. - - Args: - robot: The robot instance - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - - def __call__(self): - """ - Get the current pose of the robot. - - Returns: - A dictionary containing the position and rotation of the robot - """ - super().__call__() - - if self._robot is None: - error_msg = "No robot instance provided to GetPose skill" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - try: - # Get the current pose using the robot's get_pose method - pose_data = self._robot.get_pose() - - # Extract position and rotation from the new dictionary format - position = pose_data["position"] - rotation = pose_data["rotation"] - - # Format the response - result = { - "success": True, - "position": { - "x": position.x, - "y": position.y, - "z": position.z, - }, - "rotation": {"roll": rotation.x, "pitch": rotation.y, "yaw": rotation.z}, - } - - # If location_name is provided, remember this location - if self.location_name: - # Get the spatial memory instance - spatial_memory = self._robot.get_spatial_memory() - - # Create a RobotLocation object - location = RobotLocation( - name=self.location_name, - position=(position.x, position.y, position.z), - rotation=(rotation.x, rotation.y, rotation.z), - ) - - # Add to spatial memory - if spatial_memory.add_robot_location(location): - result["location_saved"] = True - result["location_name"] = self.location_name - logger.info(f"Location '{self.location_name}' saved at {position}") - else: - result["location_saved"] = False - logger.error(f"Failed to save location '{self.location_name}'") - - return result - except Exception as e: - error_msg = f"Error getting robot pose: {e}" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - -class NavigateToGoal(AbstractRobotSkill): - """ - A skill that navigates the robot to a specified position and orientation. - - This skill uses the global planner to generate a path to the target position - and then uses navigate_path_local to follow that path, achieving the desired - orientation at the goal position. - """ - - position: Tuple[float, float] = Field( - (0.0, 0.0), description="Target position (x, y) in map frame" - ) - rotation: Optional[float] = Field(None, description="Target orientation (yaw) in radians") - frame: str = Field("map", description="Reference frame for the position and rotation") - timeout: float = Field(120.0, description="Maximum time (in seconds) allowed for navigation") - - def __init__(self, robot=None, **data): - """ - Initialize the NavigateToGoal skill. - - Args: - robot: The robot instance - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - self._stop_event = threading.Event() - - def __call__(self): - """ - Navigate to the specified goal position and orientation. - - Returns: - A dictionary containing the result of the navigation attempt - """ - super().__call__() - - if self._robot is None: - error_msg = "No robot instance provided to NavigateToGoal skill" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - # Reset stop event to make sure we don't immediately abort - self._stop_event.clear() - - skill_library = self._robot.get_skills() - self.register_as_running("NavigateToGoal", skill_library) - - logger.info( - f"Starting navigation to position=({self.position[0]:.2f}, {self.position[1]:.2f}) " - f"with rotation={self.rotation if self.rotation is not None else 'None'} " - f"in frame={self.frame}" - ) - - try: - # Use the global planner to set the goal and generate a path - result = self._robot.global_planner.set_goal( - self.position, goal_theta=self.rotation, stop_event=self._stop_event - ) - - if result: - logger.info("Navigation completed successfully") - return { - "success": True, - "position": self.position, - "rotation": self.rotation, - "message": "Goal reached successfully", - } - else: - logger.warning("Navigation did not complete successfully") - return { - "success": False, - "position": self.position, - "rotation": self.rotation, - "message": "Goal could not be reached", - } - - except Exception as e: - error_msg = f"Error during navigation: {e}" - logger.error(error_msg) - return { - "success": False, - "position": self.position, - "rotation": self.rotation, - "error": error_msg, - } - finally: - self.stop() - - def stop(self): - """ - Stop the navigation. - - Returns: - A message indicating that the navigation was stopped - """ - logger.info("Stopping NavigateToGoal") - skill_library = self._robot.get_skills() - self.unregister_as_running("NavigateToGoal", skill_library) - self._stop_event.set() - return "Navigation stopped" - - -class Explore(AbstractRobotSkill): - """ - A skill that performs autonomous frontier exploration. - - This skill continuously finds and navigates to unknown frontiers in the environment - until no more frontiers are found or the exploration is stopped. - - Don't save GetPose locations when frontier exploring. Don't call any other skills except stop skill when needed. - """ - - timeout: float = Field(60.0, description="Maximum time (in seconds) allowed for exploration") - - def __init__(self, robot=None, **data): - """ - Initialize the Explore skill. - - Args: - robot: The robot instance - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - self._stop_event = threading.Event() - - def __call__(self): - """ - Start autonomous frontier exploration. - - Returns: - A dictionary containing the result of the exploration - """ - super().__call__() - - if self._robot is None: - error_msg = "No robot instance provided to Explore skill" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - # Reset stop event to make sure we don't immediately abort - self._stop_event.clear() - - skill_library = self._robot.get_skills() - self.register_as_running("Explore", skill_library) - - logger.info("Starting autonomous frontier exploration") - - try: - # Start exploration using the robot's explore method - result = self._robot.explore(stop_event=self._stop_event) - - if result: - logger.info("Exploration completed successfully - no more frontiers found") - return { - "success": True, - "message": "Exploration completed - all accessible areas explored", - } - else: - if self._stop_event.is_set(): - logger.info("Exploration stopped by user") - return { - "success": False, - "message": "Exploration stopped by user", - } - else: - logger.warning("Exploration did not complete successfully") - return { - "success": False, - "message": "Exploration failed or was interrupted", - } - - except Exception as e: - error_msg = f"Error during exploration: {e}" - logger.error(error_msg) - return { - "success": False, - "error": error_msg, - } - finally: - self.stop() - - def stop(self): - """ - Stop the exploration. - - Returns: - A message indicating that the exploration was stopped - """ - logger.info("Stopping Explore") - skill_library = self._robot.get_skills() - self.unregister_as_running("Explore", skill_library) - self._stop_event.set() - return "Exploration stopped" diff --git a/build/lib/dimos/skills/observe.py b/build/lib/dimos/skills/observe.py deleted file mode 100644 index 067307353a..0000000000 --- a/build/lib/dimos/skills/observe.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Observer skill for an agent. - -This module provides a skill that sends a single image from any -Robot Data Stream to the Qwen VLM for inference and adds the response -to the agent's conversation history. -""" - -import time -from typing import Optional -import base64 -import cv2 -import numpy as np -import reactivex as rx -from reactivex import operators as ops -from pydantic import Field - -from dimos.skills.skills import AbstractRobotSkill -from dimos.agents.agent import LLMAgent -from dimos.models.qwen.video_query import query_single_frame -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.skills.observe") - - -class Observe(AbstractRobotSkill): - """ - A skill that captures a single frame from a Robot Video Stream, sends it to a VLM, - and adds the response to the agent's conversation history. - - This skill is used for visual reasoning, spatial understanding, or any queries involving visual information that require critical thinking. - """ - - query_text: str = Field( - "What do you see in this image? Describe the environment in detail.", - description="Query text to send to the VLM model with the image", - ) - - def __init__(self, robot=None, agent: Optional[LLMAgent] = None, **data): - """ - Initialize the Observe skill. - - Args: - robot: The robot instance - agent: The agent to store results in - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - self._agent = agent - self._model_name = "qwen2.5-vl-72b-instruct" - - # Get the video stream from the robot - self._video_stream = self._robot.video_stream - if self._video_stream is None: - logger.error("Failed to get video stream from robot") - - def __call__(self): - """ - Capture a single frame, process it with Qwen, and add the result to conversation history. - - Returns: - A message indicating the observation result - """ - super().__call__() - - if self._agent is None: - error_msg = "No agent provided to Observe skill" - logger.error(error_msg) - return error_msg - - if self._robot is None: - error_msg = "No robot instance provided to Observe skill" - logger.error(error_msg) - return error_msg - - if self._video_stream is None: - error_msg = "No video stream available" - logger.error(error_msg) - return error_msg - - try: - logger.info("Capturing frame for Qwen observation") - - # Get a single frame from the video stream - frame = self._get_frame_from_stream() - - if frame is None: - error_msg = "Failed to capture frame from video stream" - logger.error(error_msg) - return error_msg - - # Process the frame with Qwen - response = self._process_frame_with_qwen(frame) - - logger.info(f"Added Qwen observation to conversation history") - return f"Observation complete: {response}" - - except Exception as e: - error_msg = f"Error in Observe skill: {e}" - logger.error(error_msg) - return error_msg - - def _get_frame_from_stream(self): - """ - Get a single frame from the video stream. - - Returns: - A single frame from the video stream, or None if no frame is available - """ - if self._video_stream is None: - logger.error("Video stream is None") - return None - - frame = None - frame_subject = rx.subject.Subject() - - subscription = self._video_stream.pipe( - ops.take(1) # Take just one frame - ).subscribe( - on_next=lambda x: frame_subject.on_next(x), - on_error=lambda e: logger.error(f"Error getting frame: {e}"), - ) - - # Wait up to 5 seconds for a frame - timeout = 5.0 - start_time = time.time() - - def on_frame(f): - nonlocal frame - frame = f - - frame_subject.subscribe(on_frame) - - while frame is None and time.time() - start_time < timeout: - time.sleep(0.1) - - subscription.dispose() - return frame - - def _process_frame_with_qwen(self, frame): - """ - Process a frame with the Qwen model using query_single_frame. - - Args: - frame: The video frame to process (numpy array) - - Returns: - The response from Qwen - """ - logger.info(f"Processing frame with Qwen model: {self._model_name}") - - try: - # Convert numpy array to PIL Image if needed - from PIL import Image - - if isinstance(frame, np.ndarray): - # OpenCV uses BGR, PIL uses RGB - if frame.shape[-1] == 3: # Check if it has color channels - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - pil_image = Image.fromarray(frame_rgb) - else: - pil_image = Image.fromarray(frame) - else: - pil_image = frame - - # Query Qwen with the frame (direct function call) - response = query_single_frame( - pil_image, - self.query_text, - model_name=self._model_name, - ) - - logger.info(f"Qwen response received: {response[:100]}...") - return response - - except Exception as e: - logger.error(f"Error processing frame with Qwen: {e}") - raise diff --git a/build/lib/dimos/skills/observe_stream.py b/build/lib/dimos/skills/observe_stream.py deleted file mode 100644 index 7b4e08874e..0000000000 --- a/build/lib/dimos/skills/observe_stream.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Observer skill for an agent. - -This module provides a skill that periodically sends images from any -Robot Data Stream to an agent for inference. -""" - -import time -import threading -from typing import Optional -import base64 -import cv2 -import numpy as np -import reactivex as rx -from reactivex import operators as ops -from pydantic import Field -from PIL import Image - -from dimos.skills.skills import AbstractRobotSkill -from dimos.agents.agent import LLMAgent -from dimos.models.qwen.video_query import query_single_frame -from dimos.utils.threadpool import get_scheduler -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.skills.observe_stream") - - -class ObserveStream(AbstractRobotSkill): - """ - A skill that periodically Observes a Robot Video Stream and sends images to current instance of an agent for context. - - This skill runs in a non-halting manner, allowing other skills to run concurrently. - It can be used for continuous perception and passive monitoring, such as waiting for a person to enter a room - or to monitor changes in the environment. - """ - - timestep: float = Field( - 60.0, description="Time interval in seconds between observation queries" - ) - query_text: str = Field( - "What do you see in this image? Alert me if you see any people or important changes.", - description="Query text to send to agent with each image", - ) - max_duration: float = Field( - 0.0, description="Maximum duration to run the observer in seconds (0 for indefinite)" - ) - - def __init__(self, robot=None, agent: Optional[LLMAgent] = None, video_stream=None, **data): - """ - Initialize the ObserveStream skill. - - Args: - robot: The robot instance - agent: The agent to send queries to - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - self._agent = agent - self._stop_event = threading.Event() - self._monitor_thread = None - self._scheduler = get_scheduler() - self._subscription = None - - # Get the video stream - # TODO: Use the video stream provided in the constructor for dynamic video_stream selection by the agent - self._video_stream = self._robot.video_stream - if self._video_stream is None: - logger.error("Failed to get video stream from robot") - return - - def __call__(self): - """ - Start the observing process in a separate thread using the threadpool. - - Returns: - A message indicating the observer has started - """ - super().__call__() - - if self._agent is None: - error_msg = "No agent provided to ObserveStream" - logger.error(error_msg) - return error_msg - - if self._robot is None: - error_msg = "No robot instance provided to ObserveStream" - logger.error(error_msg) - return error_msg - - self.stop() - - self._stop_event.clear() - - # Initialize start time for duration tracking - self._start_time = time.time() - - interval_observable = rx.interval(self.timestep, scheduler=self._scheduler).pipe( - ops.take_while(lambda _: not self._stop_event.is_set()) - ) - - # Subscribe to the interval observable - self._subscription = interval_observable.subscribe( - on_next=self._monitor_iteration, - on_error=lambda e: logger.error(f"Error in monitor observable: {e}"), - on_completed=lambda: logger.info("Monitor observable completed"), - ) - - skill_library = self._robot.get_skills() - self.register_as_running("ObserveStream", skill_library, self._subscription) - - logger.info(f"Observer started with timestep={self.timestep}s, query='{self.query_text}'") - return f"Observer started with timestep={self.timestep}s, query='{self.query_text}'" - - def _monitor_iteration(self, iteration): - """ - Execute a single observer iteration. - - Args: - iteration: The iteration number (provided by rx.interval) - """ - try: - if self.max_duration > 0: - elapsed_time = time.time() - self._start_time - if elapsed_time > self.max_duration: - logger.info(f"Observer reached maximum duration of {self.max_duration}s") - self.stop() - return - - logger.info(f"Observer iteration {iteration} executing") - - # Get a frame from the video stream - frame = self._get_frame_from_stream() - - if frame is not None: - self._process_frame(frame) - else: - logger.warning("Failed to get frame from video stream") - - except Exception as e: - logger.error(f"Error in monitor iteration {iteration}: {e}") - - def _get_frame_from_stream(self): - """ - Get a single frame from the video stream. - - Args: - video_stream: The ROS video stream observable - - Returns: - A single frame from the video stream, or None if no frame is available - """ - frame = None - - frame_subject = rx.subject.Subject() - - subscription = self._video_stream.pipe( - ops.take(1) # Take just one frame - ).subscribe( - on_next=lambda x: frame_subject.on_next(x), - on_error=lambda e: logger.error(f"Error getting frame: {e}"), - ) - - timeout = 5.0 # 5 seconds timeout - start_time = time.time() - - def on_frame(f): - nonlocal frame - frame = f - - frame_subject.subscribe(on_frame) - - while frame is None and time.time() - start_time < timeout: - time.sleep(0.1) - - subscription.dispose() - - return frame - - def _process_frame(self, frame): - """ - Process a frame with the Qwen VLM and add the response to conversation history. - - Args: - frame: The video frame to process - """ - logger.info("Processing frame with Qwen VLM") - - try: - # Convert frame to PIL Image format - if isinstance(frame, np.ndarray): - # OpenCV uses BGR, PIL uses RGB - if frame.shape[-1] == 3: # Check if it has color channels - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - pil_image = Image.fromarray(frame_rgb) - else: - pil_image = Image.fromarray(frame) - else: - pil_image = frame - - # Use Qwen to process the frame - model_name = "qwen2.5-vl-72b-instruct" # Using the most capable model - response = query_single_frame(pil_image, self.query_text, model_name=model_name) - - logger.info(f"Qwen response received: {response[:100]}...") - - # Add the response to the conversation history - # self._agent.append_to_history( - # f"Observation: {response}", - # ) - response = self._agent.run_observable_query(f"Observation: {response}") - - logger.info("Added Qwen observation to conversation history") - - except Exception as e: - logger.error(f"Error processing frame with Qwen VLM: {e}") - - def stop(self): - """ - Stop the ObserveStream monitoring process. - - Returns: - A message indicating the observer has stopped - """ - if self._subscription is not None and not self._subscription.is_disposed: - logger.info("Stopping ObserveStream") - self._stop_event.set() - self._subscription.dispose() - self._subscription = None - - return "Observer stopped" - return "Observer was not running" diff --git a/build/lib/dimos/skills/rest/__init__.py b/build/lib/dimos/skills/rest/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/skills/rest/rest.py b/build/lib/dimos/skills/rest/rest.py deleted file mode 100644 index 3e7c7426cc..0000000000 --- a/build/lib/dimos/skills/rest/rest.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import requests -from dimos.skills.skills import AbstractSkill -from pydantic import Field -import logging - -logger = logging.getLogger(__name__) - - -class GenericRestSkill(AbstractSkill): - """Performs a configurable REST API call. - - This skill executes an HTTP request based on the provided parameters. It - supports various HTTP methods and allows specifying URL, timeout. - - Attributes: - url: The target URL for the API call. - method: The HTTP method (e.g., 'GET', 'POST'). Case-insensitive. - timeout: Request timeout in seconds. - """ - - # TODO: Add query parameters, request body data (form-encoded or JSON), and headers. - # , query - # parameters, request body data (form-encoded or JSON), and headers. - # params: Optional dictionary of URL query parameters. - # data: Optional dictionary for form-encoded request body data. - # json_payload: Optional dictionary for JSON request body data. Use the - # alias 'json' when initializing. - # headers: Optional dictionary of HTTP headers. - url: str = Field(..., description="The target URL for the API call.") - method: str = Field(..., description="HTTP method (e.g., 'GET', 'POST').") - timeout: int = Field(..., description="Request timeout in seconds.") - # params: Optional[Dict[str, Any]] = Field(default=None, description="URL query parameters.") - # data: Optional[Dict[str, Any]] = Field(default=None, description="Form-encoded request body.") - # json_payload: Optional[Dict[str, Any]] = Field(default=None, alias="json", description="JSON request body.") - # headers: Optional[Dict[str, str]] = Field(default=None, description="HTTP headers.") - - def __call__(self) -> str: - """Executes the configured REST API call. - - Returns: - The text content of the response on success (HTTP 2xx). - - Raises: - requests.exceptions.RequestException: If a connection error, timeout, - or other request-related issue occurs. - requests.exceptions.HTTPError: If the server returns an HTTP 4xx or - 5xx status code. - Exception: For any other unexpected errors during execution. - - Returns: - A string representing the success or failure outcome. If successful, - returns the response body text. If an error occurs, returns a - descriptive error message. - """ - try: - logger.debug( - f"Executing {self.method.upper()} request to {self.url} " - f"with timeout={self.timeout}" # , params={self.params}, " - # f"data={self.data}, json={self.json_payload}, headers={self.headers}" - ) - response = requests.request( - method=self.method.upper(), # Normalize method to uppercase - url=self.url, - # params=self.params, - # data=self.data, - # json=self.json_payload, # Use the attribute name defined in Pydantic - # headers=self.headers, - timeout=self.timeout, - ) - response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx) - logger.debug( - f"Request successful. Status: {response.status_code}, Response: {response.text[:100]}..." - ) - return response.text # Return text content directly - except requests.exceptions.HTTPError as http_err: - logger.error( - f"HTTP error occurred: {http_err} - Status Code: {http_err.response.status_code}" - ) - return f"HTTP error making {self.method.upper()} request to {self.url}: {http_err.response.status_code} {http_err.response.reason}" - except requests.exceptions.RequestException as req_err: - logger.error(f"Request exception occurred: {req_err}") - return f"Error making {self.method.upper()} request to {self.url}: {req_err}" - except Exception as e: - logger.exception(f"An unexpected error occurred: {e}") # Log the full traceback - return f"An unexpected error occurred: {type(e).__name__}: {e}" diff --git a/build/lib/dimos/skills/skills.py b/build/lib/dimos/skills/skills.py deleted file mode 100644 index f6c7456d24..0000000000 --- a/build/lib/dimos/skills/skills.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Any, Optional -from pydantic import BaseModel -from openai import pydantic_function_tool - -from dimos.types.constants import Colors - -# Configure logging for the module -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -# region SkillLibrary - - -class SkillLibrary: - # ==== Flat Skill Library ==== - - def __init__(self): - self.registered_skills: list["AbstractSkill"] = [] - self.class_skills: list["AbstractSkill"] = [] - self._running_skills = {} # {skill_name: (instance, subscription)} - - self.init() - - def init(self): - # Collect all skills from the parent class and update self.skills - self.refresh_class_skills() - - # Temporary - self.registered_skills = self.class_skills.copy() - - def get_class_skills(self) -> list["AbstractSkill"]: - """Extract all AbstractSkill subclasses from a class. - - Returns: - List of skill classes found within the class - """ - skills = [] - - # Loop through all attributes of the class - for attr_name in dir(self.__class__): - # Skip special/dunder attributes - if attr_name.startswith("__"): - continue - - try: - attr = getattr(self.__class__, attr_name) - - # Check if it's a class and inherits from AbstractSkill - if ( - isinstance(attr, type) - and issubclass(attr, AbstractSkill) - and attr is not AbstractSkill - ): - skills.append(attr) - except (AttributeError, TypeError): - # Skip attributes that can't be accessed or aren't classes - continue - - return skills - - def refresh_class_skills(self): - self.class_skills = self.get_class_skills() - - def add(self, skill: "AbstractSkill") -> None: - if skill not in self.registered_skills: - self.registered_skills.append(skill) - - def get(self) -> list["AbstractSkill"]: - return self.registered_skills.copy() - - def remove(self, skill: "AbstractSkill") -> None: - try: - self.registered_skills.remove(skill) - except ValueError: - logger.warning(f"Attempted to remove non-existent skill: {skill}") - - def clear(self) -> None: - self.registered_skills.clear() - - def __iter__(self): - return iter(self.registered_skills) - - def __len__(self) -> int: - return len(self.registered_skills) - - def __contains__(self, skill: "AbstractSkill") -> bool: - return skill in self.registered_skills - - def __getitem__(self, index): - return self.registered_skills[index] - - # ==== Calling a Function ==== - - _instances: dict[str, dict] = {} - - def create_instance(self, name, **kwargs): - # Key based only on the name - key = name - - print(f"Preparing to create instance with name: {name} and args: {kwargs}") - - if key not in self._instances: - # Instead of creating an instance, store the args for later use - self._instances[key] = kwargs - print(f"Stored args for later instance creation: {name} with args: {kwargs}") - - def call(self, name, **args): - try: - # Get the stored args if available; otherwise, use an empty dict - stored_args = self._instances.get(name, {}) - - # Merge the arguments with priority given to stored arguments - complete_args = {**args, **stored_args} - - # Dynamically get the class from the module or current script - skill_class = getattr(self, name, None) - if skill_class is None: - for skill in self.get(): - if name == skill.__name__: - skill_class = skill - break - if skill_class is None: - error_msg = f"Skill '{name}' is not available. Please check if it's properly registered." - logger.error(f"Skill class not found: {name}") - return error_msg - - # Initialize the instance with the merged arguments - instance = skill_class(**complete_args) - print(f"Instance created and function called for: {name} with args: {complete_args}") - - # Call the instance directly - return instance() - except Exception as e: - error_msg = f"Error executing skill '{name}': {str(e)}" - logger.error(error_msg) - return error_msg - - # ==== Tools ==== - - def get_tools(self) -> Any: - tools_json = self.get_list_of_skills_as_json(list_of_skills=self.registered_skills) - # print(f"{Colors.YELLOW_PRINT_COLOR}Tools JSON: {tools_json}{Colors.RESET_COLOR}") - return tools_json - - def get_list_of_skills_as_json(self, list_of_skills: list["AbstractSkill"]) -> list[str]: - return list(map(pydantic_function_tool, list_of_skills)) - - def register_running_skill(self, name: str, instance: Any, subscription=None): - """ - Register a running skill with its subscription. - - Args: - name: Name of the skill (will be converted to lowercase) - instance: Instance of the running skill - subscription: Optional subscription associated with the skill - """ - name = name.lower() - self._running_skills[name] = (instance, subscription) - logger.info(f"Registered running skill: {name}") - - def unregister_running_skill(self, name: str): - """ - Unregister a running skill. - - Args: - name: Name of the skill to remove (will be converted to lowercase) - - Returns: - True if the skill was found and removed, False otherwise - """ - name = name.lower() - if name in self._running_skills: - del self._running_skills[name] - logger.info(f"Unregistered running skill: {name}") - return True - return False - - def get_running_skills(self): - """ - Get all running skills. - - Returns: - A dictionary of running skill names and their (instance, subscription) tuples - """ - return self._running_skills.copy() - - def terminate_skill(self, name: str): - """ - Terminate a running skill. - - Args: - name: Name of the skill to terminate (will be converted to lowercase) - - Returns: - A message indicating whether the skill was successfully terminated - """ - name = name.lower() - if name in self._running_skills: - instance, subscription = self._running_skills[name] - - try: - # Call the stop method if it exists - if hasattr(instance, "stop") and callable(instance.stop): - result = instance.stop() - logger.info(f"Stopped skill: {name}") - else: - logger.warning(f"Skill {name} does not have a stop method") - - # Also dispose the subscription if it exists - if ( - subscription is not None - and hasattr(subscription, "dispose") - and callable(subscription.dispose) - ): - subscription.dispose() - logger.info(f"Disposed subscription for skill: {name}") - elif subscription is not None: - logger.warning(f"Skill {name} has a subscription but it's not disposable") - - # unregister the skill - self.unregister_running_skill(name) - return f"Successfully terminated skill: {name}" - - except Exception as e: - error_msg = f"Error terminating skill {name}: {e}" - logger.error(error_msg) - # Even on error, try to unregister the skill - self.unregister_running_skill(name) - return error_msg - else: - return f"No running skill found with name: {name}" - - -# endregion SkillLibrary - -# region AbstractSkill - - -class AbstractSkill(BaseModel): - def __init__(self, *args, **kwargs): - print("Initializing AbstractSkill Class") - super().__init__(*args, **kwargs) - self._instances = {} - self._list_of_skills = [] # Initialize the list of skills - print(f"Instances: {self._instances}") - - def clone(self) -> "AbstractSkill": - return AbstractSkill() - - def register_as_running(self, name: str, skill_library: SkillLibrary, subscription=None): - """ - Register this skill as running in the skill library. - - Args: - name: Name of the skill (will be converted to lowercase) - skill_library: The skill library to register with - subscription: Optional subscription associated with the skill - """ - skill_library.register_running_skill(name, self, subscription) - - def unregister_as_running(self, name: str, skill_library: SkillLibrary): - """ - Unregister this skill from the skill library. - - Args: - name: Name of the skill to remove (will be converted to lowercase) - skill_library: The skill library to unregister from - """ - skill_library.unregister_running_skill(name) - - # ==== Tools ==== - def get_tools(self) -> Any: - tools_json = self.get_list_of_skills_as_json(list_of_skills=self._list_of_skills) - # print(f"Tools JSON: {tools_json}") - return tools_json - - def get_list_of_skills_as_json(self, list_of_skills: list["AbstractSkill"]) -> list[str]: - return list(map(pydantic_function_tool, list_of_skills)) - - -# endregion AbstractSkill - -# region Abstract Robot Skill - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from dimos.robot.robot import Robot -else: - Robot = "Robot" - - -class AbstractRobotSkill(AbstractSkill): - _robot: Robot = None - - def __init__(self, *args, robot: Optional[Robot] = None, **kwargs): - super().__init__(*args, **kwargs) - self._robot = robot - print( - f"{Colors.BLUE_PRINT_COLOR}Robot Skill Initialized with Robot: {robot}{Colors.RESET_COLOR}" - ) - - def set_robot(self, robot: Robot) -> None: - """Set the robot reference for this skills instance. - - Args: - robot: The robot instance to associate with these skills. - """ - self._robot = robot - - def __call__(self): - if self._robot is None: - raise RuntimeError( - f"{Colors.RED_PRINT_COLOR}" - f"No Robot instance provided to Robot Skill: {self.__class__.__name__}" - f"{Colors.RESET_COLOR}" - ) - else: - print( - f"{Colors.BLUE_PRINT_COLOR}Robot Instance provided to Robot Skill: {self.__class__.__name__}{Colors.RESET_COLOR}" - ) - - -# endregion Abstract Robot Skill diff --git a/build/lib/dimos/skills/speak.py b/build/lib/dimos/skills/speak.py deleted file mode 100644 index e73b9e792a..0000000000 --- a/build/lib/dimos/skills/speak.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.skills.skills import AbstractSkill -from pydantic import Field -from reactivex import Subject -from typing import Optional, Any, List -import time -import threading -import queue -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.skills.speak") - -# Global lock to prevent multiple simultaneous audio playbacks -_audio_device_lock = threading.RLock() - -# Global queue for sequential audio processing -_audio_queue = queue.Queue() -_queue_processor_thread = None -_queue_running = False - - -def _process_audio_queue(): - """Background thread to process audio requests sequentially""" - global _queue_running - - while _queue_running: - try: - # Get the next queued audio task with a timeout - task = _audio_queue.get(timeout=1.0) - if task is None: # Sentinel value to stop the thread - break - - # Execute the task (which is a function to be called) - task() - _audio_queue.task_done() - - except queue.Empty: - # No tasks in queue, just continue waiting - continue - except Exception as e: - logger.error(f"Error in audio queue processor: {e}") - # Continue processing other tasks - - -def start_audio_queue_processor(): - """Start the background thread for processing audio requests""" - global _queue_processor_thread, _queue_running - - if _queue_processor_thread is None or not _queue_processor_thread.is_alive(): - _queue_running = True - _queue_processor_thread = threading.Thread( - target=_process_audio_queue, daemon=True, name="AudioQueueProcessor" - ) - _queue_processor_thread.start() - logger.info("Started audio queue processor thread") - - -# Start the queue processor when module is imported -start_audio_queue_processor() - - -class Speak(AbstractSkill): - """Speak text out loud to humans nearby or to other robots.""" - - text: str = Field(..., description="Text to speak") - - def __init__(self, tts_node: Optional[Any] = None, **data): - super().__init__(**data) - self._tts_node = tts_node - self._audio_complete = threading.Event() - self._subscription = None - self._subscriptions: List = [] # Track all subscriptions - - def __call__(self): - if not self._tts_node: - logger.error("No TTS node provided to Speak skill") - return "Error: No TTS node available" - - # Create a result queue to get the result back from the audio thread - result_queue = queue.Queue(1) - - # Define the speech task to run in the audio queue - def speak_task(): - try: - # Using a lock to ensure exclusive access to audio device - with _audio_device_lock: - text_subject = Subject() - self._audio_complete.clear() - self._subscriptions = [] - - # This function will be called when audio processing is complete - def on_complete(): - logger.info(f"TTS audio playback completed for: {self.text}") - self._audio_complete.set() - - # This function will be called if there's an error - def on_error(error): - logger.error(f"Error in TTS processing: {error}") - self._audio_complete.set() - - # Connect the Subject to the TTS node and keep the subscription - self._tts_node.consume_text(text_subject) - - # Subscribe to the audio output to know when it's done - self._subscription = self._tts_node.emit_text().subscribe( - on_next=lambda text: logger.debug(f"TTS processing: {text}"), - on_completed=on_complete, - on_error=on_error, - ) - self._subscriptions.append(self._subscription) - - # Emit the text to the Subject - text_subject.on_next(self.text) - text_subject.on_completed() # Signal that we're done sending text - - # Wait for audio playback to complete with a timeout - # Using a dynamic timeout based on text length - timeout = max(5, len(self.text) * 0.1) - logger.debug(f"Waiting for TTS completion with timeout {timeout:.1f}s") - - if not self._audio_complete.wait(timeout=timeout): - logger.warning(f"TTS timeout reached for: {self.text}") - else: - # Add a small delay after audio completes to ensure buffers are fully flushed - time.sleep(0.3) - - # Clean up all subscriptions - for sub in self._subscriptions: - if sub: - sub.dispose() - self._subscriptions = [] - - # Successfully completed - result_queue.put(f"Spoke: {self.text} successfully") - except Exception as e: - logger.error(f"Error in speak task: {e}") - result_queue.put(f"Error speaking text: {str(e)}") - - # Add our speech task to the global queue for sequential processing - display_text = self.text[:50] + "..." if len(self.text) > 50 else self.text - logger.info(f"Queueing speech task: '{display_text}'") - _audio_queue.put(speak_task) - - # Wait for the result with a timeout - try: - # Use a longer timeout than the audio playback itself - text_len_timeout = len(self.text) * 0.15 # 150ms per character - max_timeout = max(10, text_len_timeout) # At least 10 seconds - - return result_queue.get(timeout=max_timeout) - except queue.Empty: - logger.error("Timed out waiting for speech task to complete") - return f"Error: Timed out while speaking: {self.text}" diff --git a/build/lib/dimos/skills/unitree/__init__.py b/build/lib/dimos/skills/unitree/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/build/lib/dimos/skills/unitree/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/build/lib/dimos/skills/unitree/unitree_speak.py b/build/lib/dimos/skills/unitree/unitree_speak.py deleted file mode 100644 index 05004398f9..0000000000 --- a/build/lib/dimos/skills/unitree/unitree_speak.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.skills.skills import AbstractRobotSkill -from pydantic import Field -import time -import tempfile -import os -import json -import base64 -import hashlib -import soundfile as sf -import numpy as np -from openai import OpenAI -from dimos.utils.logging_config import setup_logger -from go2_webrtc_driver.constants import RTC_TOPIC - -logger = setup_logger("dimos.skills.unitree.unitree_speak") - -# Audio API constants (from go2_webrtc_driver) -AUDIO_API = { - "GET_AUDIO_LIST": 1001, - "SELECT_START_PLAY": 1002, - "PAUSE": 1003, - "UNSUSPEND": 1004, - "SET_PLAY_MODE": 1007, - "UPLOAD_AUDIO_FILE": 2001, - "ENTER_MEGAPHONE": 4001, - "EXIT_MEGAPHONE": 4002, - "UPLOAD_MEGAPHONE": 4003, -} - -PLAY_MODES = {"NO_CYCLE": "no_cycle", "SINGLE_CYCLE": "single_cycle", "LIST_LOOP": "list_loop"} - - -class UnitreeSpeak(AbstractRobotSkill): - """Speak text out loud through the robot's speakers using WebRTC audio upload.""" - - text: str = Field(..., description="Text to speak") - voice: str = Field( - default="echo", description="Voice to use (alloy, echo, fable, onyx, nova, shimmer)" - ) - speed: float = Field(default=1.2, description="Speech speed (0.25 to 4.0)") - use_megaphone: bool = Field( - default=False, description="Use megaphone mode for lower latency (experimental)" - ) - - def __init__(self, **data): - super().__init__(**data) - self._openai_client = None - - def _get_openai_client(self): - if self._openai_client is None: - self._openai_client = OpenAI() - return self._openai_client - - def _generate_audio(self, text: str) -> bytes: - try: - client = self._get_openai_client() - response = client.audio.speech.create( - model="tts-1", voice=self.voice, input=text, speed=self.speed, response_format="mp3" - ) - return response.content - except Exception as e: - logger.error(f"Error generating audio: {e}") - raise - - def _webrtc_request(self, api_id: int, parameter: dict = None): - if parameter is None: - parameter = {} - - request_data = {"api_id": api_id, "parameter": json.dumps(parameter) if parameter else "{}"} - - return self._robot.webrtc_connection.publish_request( - RTC_TOPIC["AUDIO_HUB_REQ"], request_data - ) - - def _upload_audio_to_robot(self, audio_data: bytes, filename: str) -> str: - try: - file_md5 = hashlib.md5(audio_data).hexdigest() - b64_data = base64.b64encode(audio_data).decode("utf-8") - - chunk_size = 61440 - chunks = [b64_data[i : i + chunk_size] for i in range(0, len(b64_data), chunk_size)] - total_chunks = len(chunks) - - logger.info(f"Uploading audio '{filename}' in {total_chunks} chunks (optimized)") - - for i, chunk in enumerate(chunks, 1): - parameter = { - "file_name": filename, - "file_type": "wav", - "file_size": len(audio_data), - "current_block_index": i, - "total_block_number": total_chunks, - "block_content": chunk, - "current_block_size": len(chunk), - "file_md5": file_md5, - "create_time": int(time.time() * 1000), - } - - logger.debug(f"Sending chunk {i}/{total_chunks}") - response = self._webrtc_request(AUDIO_API["UPLOAD_AUDIO_FILE"], parameter) - - logger.info(f"Audio upload completed for '{filename}'") - - list_response = self._webrtc_request(AUDIO_API["GET_AUDIO_LIST"], {}) - - if list_response and "data" in list_response: - data_str = list_response.get("data", {}).get("data", "{}") - audio_list = json.loads(data_str).get("audio_list", []) - - for audio in audio_list: - if audio.get("CUSTOM_NAME") == filename: - return audio.get("UNIQUE_ID") - - logger.warning( - f"Could not find uploaded audio '{filename}' in list, using filename as UUID" - ) - return filename - - except Exception as e: - logger.error(f"Error uploading audio to robot: {e}") - raise - - def _play_audio_on_robot(self, uuid: str): - try: - self._webrtc_request(AUDIO_API["SET_PLAY_MODE"], {"play_mode": PLAY_MODES["NO_CYCLE"]}) - time.sleep(0.1) - - parameter = {"unique_id": uuid} - - logger.info(f"Playing audio with UUID: {uuid}") - self._webrtc_request(AUDIO_API["SELECT_START_PLAY"], parameter) - - except Exception as e: - logger.error(f"Error playing audio on robot: {e}") - raise - - def _stop_audio_playback(self): - try: - logger.debug("Stopping audio playback") - self._webrtc_request(AUDIO_API["PAUSE"], {}) - except Exception as e: - logger.warning(f"Error stopping audio playback: {e}") - - def _upload_and_play_megaphone(self, audio_data: bytes, duration: float): - try: - logger.debug("Entering megaphone mode") - self._webrtc_request(AUDIO_API["ENTER_MEGAPHONE"], {}) - - time.sleep(0.2) - - b64_data = base64.b64encode(audio_data).decode("utf-8") - - chunk_size = 4096 - chunks = [b64_data[i : i + chunk_size] for i in range(0, len(b64_data), chunk_size)] - total_chunks = len(chunks) - - logger.info(f"Uploading megaphone audio in {total_chunks} chunks") - - for i, chunk in enumerate(chunks, 1): - parameter = { - "current_block_size": len(chunk), - "block_content": chunk, - "current_block_index": i, - "total_block_number": total_chunks, - } - - logger.debug(f"Sending megaphone chunk {i}/{total_chunks}") - self._webrtc_request(AUDIO_API["UPLOAD_MEGAPHONE"], parameter) - - if i < total_chunks: - time.sleep(0.05) - - logger.info("Megaphone audio upload completed, waiting for playback") - - time.sleep(duration + 1.0) - - except Exception as e: - logger.error(f"Error in megaphone mode: {e}") - try: - self._webrtc_request(AUDIO_API["EXIT_MEGAPHONE"], {}) - except: - pass - raise - finally: - try: - logger.debug("Exiting megaphone mode") - self._webrtc_request(AUDIO_API["EXIT_MEGAPHONE"], {}) - time.sleep(0.1) - except Exception as e: - logger.warning(f"Error exiting megaphone mode: {e}") - - def __call__(self): - super().__call__() - - if not self._robot: - logger.error("No robot instance provided to UnitreeSpeak skill") - return "Error: No robot instance available" - - try: - display_text = self.text[:50] + "..." if len(self.text) > 50 else self.text - logger.info(f"Speaking: '{display_text}'") - - logger.debug("Generating audio with OpenAI TTS") - audio_data = self._generate_audio(self.text) - - with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_mp3: - tmp_mp3.write(audio_data) - tmp_mp3_path = tmp_mp3.name - - try: - audio_array, sample_rate = sf.read(tmp_mp3_path) - - if audio_array.ndim > 1: - audio_array = np.mean(audio_array, axis=1) - - target_sample_rate = 22050 - if sample_rate != target_sample_rate: - logger.debug(f"Resampling from {sample_rate}Hz to {target_sample_rate}Hz") - old_length = len(audio_array) - new_length = int(old_length * target_sample_rate / sample_rate) - old_indices = np.arange(old_length) - new_indices = np.linspace(0, old_length - 1, new_length) - audio_array = np.interp(new_indices, old_indices, audio_array) - sample_rate = target_sample_rate - - audio_array = audio_array / np.max(np.abs(audio_array)) - - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav: - sf.write(tmp_wav.name, audio_array, sample_rate, format="WAV", subtype="PCM_16") - tmp_wav.seek(0) - wav_data = open(tmp_wav.name, "rb").read() - os.unlink(tmp_wav.name) - - logger.info( - f"Audio size: {len(wav_data) / 1024:.1f}KB, duration: {len(audio_array) / sample_rate:.1f}s" - ) - - finally: - os.unlink(tmp_mp3_path) - - if self.use_megaphone: - logger.debug("Using megaphone mode for lower latency") - duration = len(audio_array) / sample_rate - self._upload_and_play_megaphone(wav_data, duration) - - return f"Spoke: '{display_text}' on robot successfully (megaphone mode)" - else: - filename = f"speak_{int(time.time() * 1000)}" - - logger.debug("Uploading audio to robot") - uuid = self._upload_audio_to_robot(wav_data, filename) - - logger.debug("Playing audio on robot") - self._play_audio_on_robot(uuid) - - duration = len(audio_array) / sample_rate - logger.debug(f"Waiting {duration:.1f}s for playback to complete") - # time.sleep(duration + 0.2) - - # self._stop_audio_playback() - - return f"Spoke: '{display_text}' on robot successfully" - - except Exception as e: - logger.error(f"Error in speak skill: {e}") - return f"Error speaking text: {str(e)}" diff --git a/build/lib/dimos/skills/visual_navigation_skills.py b/build/lib/dimos/skills/visual_navigation_skills.py deleted file mode 100644 index 96e21eb92d..0000000000 --- a/build/lib/dimos/skills/visual_navigation_skills.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Visual navigation skills for robot interaction. - -This module provides skills for visual navigation, including following humans -and navigating to specific objects using computer vision. -""" - -import time -import logging -import threading -from typing import Optional, Tuple - -from dimos.skills.skills import AbstractRobotSkill -from dimos.utils.logging_config import setup_logger -from dimos.perception.visual_servoing import VisualServoing -from pydantic import Field -from dimos.types.vector import Vector - -logger = setup_logger("dimos.skills.visual_navigation", level=logging.DEBUG) - - -class FollowHuman(AbstractRobotSkill): - """ - A skill that makes the robot follow a human using visual servoing continuously. - - This skill uses the robot's person tracking stream to follow a human - while maintaining a specified distance. It will keep following the human - until the timeout is reached or the skill is stopped. Don't use this skill - if you want to navigate to a specific person, use NavigateTo instead. - """ - - distance: float = Field( - 1.5, description="Desired distance to maintain from the person in meters" - ) - timeout: float = Field(20.0, description="Maximum time to follow the person in seconds") - point: Optional[Tuple[int, int]] = Field( - None, description="Optional point to start tracking (x,y pixel coordinates)" - ) - - def __init__(self, robot=None, **data): - super().__init__(robot=robot, **data) - self._stop_event = threading.Event() - self._visual_servoing = None - - def __call__(self): - """ - Start following a human using visual servoing. - - Returns: - bool: True if successful, False otherwise - """ - super().__call__() - - if ( - not hasattr(self._robot, "person_tracking_stream") - or self._robot.person_tracking_stream is None - ): - logger.error("Robot does not have a person tracking stream") - return False - - # Stop any existing operation - self.stop() - self._stop_event.clear() - - success = False - - try: - # Initialize visual servoing - self._visual_servoing = VisualServoing( - tracking_stream=self._robot.person_tracking_stream - ) - - logger.warning(f"Following human for {self.timeout} seconds...") - start_time = time.time() - - # Start tracking - track_success = self._visual_servoing.start_tracking( - point=self.point, desired_distance=self.distance - ) - - if not track_success: - logger.error("Failed to start tracking") - return False - - # Main follow loop - while ( - self._visual_servoing.running - and time.time() - start_time < self.timeout - and not self._stop_event.is_set() - ): - output = self._visual_servoing.updateTracking() - x_vel = output.get("linear_vel") - z_vel = output.get("angular_vel") - logger.debug(f"Following human: x_vel: {x_vel}, z_vel: {z_vel}") - self._robot.move(Vector(x_vel, 0, z_vel)) - time.sleep(0.05) - - # If we completed the full timeout duration, consider it success - if time.time() - start_time >= self.timeout: - success = True - logger.info("Human following completed successfully") - elif self._stop_event.is_set(): - logger.info("Human following stopped externally") - else: - logger.info("Human following stopped due to tracking loss") - - return success - - except Exception as e: - logger.error(f"Error in follow human: {e}") - return False - finally: - # Clean up - if self._visual_servoing: - self._visual_servoing.stop_tracking() - self._visual_servoing = None - - def stop(self): - """ - Stop the human following process. - - Returns: - bool: True if stopped, False if it wasn't running - """ - if self._visual_servoing is not None: - logger.info("Stopping FollowHuman skill") - self._stop_event.set() - - # Clean up visual servoing if it exists - self._visual_servoing.stop_tracking() - self._visual_servoing = None - - return True - return False diff --git a/build/lib/dimos/stream/__init__.py b/build/lib/dimos/stream/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/stream/audio/__init__.py b/build/lib/dimos/stream/audio/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/stream/audio/base.py b/build/lib/dimos/stream/audio/base.py deleted file mode 100644 index a22e6606d6..0000000000 --- a/build/lib/dimos/stream/audio/base.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from reactivex import Observable -import numpy as np - - -class AbstractAudioEmitter(ABC): - """Base class for components that emit audio.""" - - @abstractmethod - def emit_audio(self) -> Observable: - """Create an observable that emits audio frames. - - Returns: - Observable emitting audio frames - """ - pass - - -class AbstractAudioConsumer(ABC): - """Base class for components that consume audio.""" - - @abstractmethod - def consume_audio(self, audio_observable: Observable) -> "AbstractAudioConsumer": - """Set the audio observable to consume. - - Args: - audio_observable: Observable emitting audio frames - - Returns: - Self for method chaining - """ - pass - - -class AbstractAudioTransform(AbstractAudioConsumer, AbstractAudioEmitter): - """Base class for components that both consume and emit audio. - - This represents a transform in an audio processing pipeline. - """ - - pass - - -class AudioEvent: - """Class to represent an audio frame event with metadata.""" - - def __init__(self, data: np.ndarray, sample_rate: int, timestamp: float, channels: int = 1): - """ - Initialize an AudioEvent. - - Args: - data: Audio data as numpy array - sample_rate: Audio sample rate in Hz - timestamp: Unix timestamp when the audio was captured - channels: Number of audio channels - """ - self.data = data - self.sample_rate = sample_rate - self.timestamp = timestamp - self.channels = channels - self.dtype = data.dtype - self.shape = data.shape - - def to_float32(self) -> "AudioEvent": - """Convert audio data to float32 format normalized to [-1.0, 1.0].""" - if self.data.dtype == np.float32: - return self - - new_data = self.data.astype(np.float32) - if self.data.dtype == np.int16: - new_data /= 32768.0 - - return AudioEvent( - data=new_data, - sample_rate=self.sample_rate, - timestamp=self.timestamp, - channels=self.channels, - ) - - def to_int16(self) -> "AudioEvent": - """Convert audio data to int16 format.""" - if self.data.dtype == np.int16: - return self - - new_data = self.data - if self.data.dtype == np.float32: - new_data = (new_data * 32767).astype(np.int16) - - return AudioEvent( - data=new_data, - sample_rate=self.sample_rate, - timestamp=self.timestamp, - channels=self.channels, - ) - - def __repr__(self) -> str: - return ( - f"AudioEvent(shape={self.shape}, dtype={self.dtype}, " - f"sample_rate={self.sample_rate}, channels={self.channels})" - ) diff --git a/build/lib/dimos/stream/audio/node_key_recorder.py b/build/lib/dimos/stream/audio/node_key_recorder.py deleted file mode 100644 index 6494dcbef9..0000000000 --- a/build/lib/dimos/stream/audio/node_key_recorder.py +++ /dev/null @@ -1,336 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List -import numpy as np -import time -import threading -import sys -import select -from reactivex import Observable -from reactivex.subject import Subject, ReplaySubject - -from dimos.stream.audio.base import AbstractAudioTransform, AudioEvent - -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.audio.key_recorder") - - -class KeyRecorder(AbstractAudioTransform): - """ - Audio recorder that captures audio events and combines them. - Press a key to toggle recording on/off. - """ - - def __init__( - self, - max_recording_time: float = 120.0, - always_subscribe: bool = False, - ): - """ - Initialize KeyRecorder. - - Args: - max_recording_time: Maximum recording time in seconds - always_subscribe: If True, subscribe to audio source continuously, - If False, only subscribe when recording (more efficient - but some audio devices may need time to initialize) - """ - self.max_recording_time = max_recording_time - self.always_subscribe = always_subscribe - - self._audio_buffer = [] - self._is_recording = False - self._recording_start_time = 0 - self._sample_rate = None # Will be updated from incoming audio - self._channels = None # Will be set from first event - - self._audio_observable = None - self._subscription = None - self._output_subject = Subject() # For record-time passthrough - self._recording_subject = ReplaySubject(1) # For full completed recordings - - # Start a thread to monitor for input - self._running = True - self._input_thread = threading.Thread(target=self._input_monitor, daemon=True) - self._input_thread.start() - - logger.info("Started audio recorder (press any key to start/stop recording)") - - def consume_audio(self, audio_observable: Observable) -> "KeyRecorder": - """ - Set the audio observable to use when recording. - If always_subscribe is True, subscribes immediately. - Otherwise, subscribes only when recording starts. - - Args: - audio_observable: Observable emitting AudioEvent objects - - Returns: - Self for method chaining - """ - self._audio_observable = audio_observable - - # If configured to always subscribe, do it now - if self.always_subscribe and not self._subscription: - self._subscription = audio_observable.subscribe( - on_next=self._process_audio_event, - on_error=self._handle_error, - on_completed=self._handle_completion, - ) - logger.debug("Subscribed to audio source (always_subscribe=True)") - - return self - - def emit_audio(self) -> Observable: - """ - Create an observable that emits audio events in real-time (pass-through). - - Returns: - Observable emitting AudioEvent objects in real-time - """ - return self._output_subject - - def emit_recording(self) -> Observable: - """ - Create an observable that emits combined audio recordings when recording stops. - - Returns: - Observable emitting AudioEvent objects with complete recordings - """ - return self._recording_subject - - def stop(self): - """Stop recording and clean up resources.""" - logger.info("Stopping audio recorder") - - # If recording is in progress, stop it first - if self._is_recording: - self._stop_recording() - - # Always clean up subscription on full stop - if self._subscription: - self._subscription.dispose() - self._subscription = None - - # Stop input monitoring thread - self._running = False - if self._input_thread.is_alive(): - self._input_thread.join(1.0) - - def _input_monitor(self): - """Monitor for key presses to toggle recording.""" - logger.info("Press Enter to start/stop recording...") - - while self._running: - # Check if there's input available - if select.select([sys.stdin], [], [], 0.1)[0]: - sys.stdin.readline() - - if self._is_recording: - self._stop_recording() - else: - self._start_recording() - - # Sleep a bit to reduce CPU usage - time.sleep(0.1) - - def _start_recording(self): - """Start recording audio and subscribe to the audio source if not always subscribed.""" - if not self._audio_observable: - logger.error("Cannot start recording: No audio source has been set") - return - - # Subscribe to the observable if not using always_subscribe - if not self._subscription: - self._subscription = self._audio_observable.subscribe( - on_next=self._process_audio_event, - on_error=self._handle_error, - on_completed=self._handle_completion, - ) - logger.debug("Subscribed to audio source for recording") - - self._is_recording = True - self._recording_start_time = time.time() - self._audio_buffer = [] - logger.info("Recording... (press Enter to stop)") - - def _stop_recording(self): - """Stop recording, unsubscribe from audio source if not always subscribed, and emit the combined audio event.""" - self._is_recording = False - recording_duration = time.time() - self._recording_start_time - - # Unsubscribe from the audio source if not using always_subscribe - if not self.always_subscribe and self._subscription: - self._subscription.dispose() - self._subscription = None - logger.debug("Unsubscribed from audio source after recording") - - logger.info(f"Recording stopped after {recording_duration:.2f} seconds") - - # Combine all audio events into one - if len(self._audio_buffer) > 0: - combined_audio = self._combine_audio_events(self._audio_buffer) - self._recording_subject.on_next(combined_audio) - else: - logger.warning("No audio was recorded") - - def _process_audio_event(self, audio_event): - """Process incoming audio events.""" - - # Only buffer if recording - if not self._is_recording: - return - - # Pass through audio events in real-time - self._output_subject.on_next(audio_event) - - # First audio event - determine channel count/sample rate - if self._channels is None: - self._channels = audio_event.channels - self._sample_rate = audio_event.sample_rate - logger.info(f"Setting channel count to {self._channels}") - - # Add to buffer - self._audio_buffer.append(audio_event) - - # Check if we've exceeded max recording time - if time.time() - self._recording_start_time > self.max_recording_time: - logger.warning(f"Max recording time ({self.max_recording_time}s) reached") - self._stop_recording() - - def _combine_audio_events(self, audio_events: List[AudioEvent]) -> AudioEvent: - """Combine multiple audio events into a single event.""" - if not audio_events: - logger.warning("Attempted to combine empty audio events list") - return None - - # Filter out any empty events that might cause broadcasting errors - valid_events = [ - event - for event in audio_events - if event is not None - and (hasattr(event, "data") and event.data is not None and event.data.size > 0) - ] - - if not valid_events: - logger.warning("No valid audio events to combine") - return None - - first_event = valid_events[0] - channels = first_event.channels - dtype = first_event.data.dtype - - # Calculate total samples only from valid events - total_samples = sum(event.data.shape[0] for event in valid_events) - - # Safety check - if somehow we got no samples - if total_samples <= 0: - logger.warning(f"Combined audio would have {total_samples} samples - aborting") - return None - - # For multichannel audio, data shape could be (samples,) or (samples, channels) - if len(first_event.data.shape) == 1: - # 1D audio data (mono) - combined_data = np.zeros(total_samples, dtype=dtype) - - # Copy data - offset = 0 - for event in valid_events: - samples = event.data.shape[0] - if samples > 0: # Extra safety check - combined_data[offset : offset + samples] = event.data - offset += samples - else: - # Multichannel audio data (stereo or more) - combined_data = np.zeros((total_samples, channels), dtype=dtype) - - # Copy data - offset = 0 - for event in valid_events: - samples = event.data.shape[0] - if samples > 0 and offset + samples <= total_samples: # Safety check - try: - combined_data[offset : offset + samples] = event.data - offset += samples - except ValueError as e: - logger.error( - f"Error combining audio events: {e}. " - f"Event shape: {event.data.shape}, " - f"Combined shape: {combined_data.shape}, " - f"Offset: {offset}, Samples: {samples}" - ) - # Continue with next event instead of failing completely - - # Create new audio event with the combined data - if combined_data.size > 0: - return AudioEvent( - data=combined_data, - sample_rate=self._sample_rate, - timestamp=valid_events[0].timestamp, - channels=channels, - ) - else: - logger.warning("Failed to create valid combined audio event") - return None - - def _handle_error(self, error): - """Handle errors from the observable.""" - logger.error(f"Error in audio observable: {error}") - - def _handle_completion(self): - """Handle completion of the observable.""" - logger.info("Audio observable completed") - self.stop() - - -if __name__ == "__main__": - from dimos.stream.audio.node_microphone import ( - SounddeviceAudioSource, - ) - from dimos.stream.audio.node_output import SounddeviceAudioOutput - from dimos.stream.audio.node_volume_monitor import monitor - from dimos.stream.audio.node_normalizer import AudioNormalizer - from dimos.stream.audio.utils import keepalive - - # Create microphone source, recorder, and audio output - mic = SounddeviceAudioSource() - - # my audio device needs time to init, so for smoother ux we constantly listen - recorder = KeyRecorder(always_subscribe=True) - - normalizer = AudioNormalizer() - speaker = SounddeviceAudioOutput() - - # Connect the components - normalizer.consume_audio(mic.emit_audio()) - recorder.consume_audio(normalizer.emit_audio()) - # recorder.consume_audio(mic.emit_audio()) - - # Monitor microphone input levels (real-time pass-through) - monitor(recorder.emit_audio()) - - # Connect the recorder output to the speakers to hear recordings when completed - playback_speaker = SounddeviceAudioOutput() - playback_speaker.consume_audio(recorder.emit_recording()) - - # TODO: we should be able to run normalizer post hoc on the recording as well, - # it's not working, this needs a review - # - # normalizer.consume_audio(recorder.emit_recording()) - # playback_speaker.consume_audio(normalizer.emit_audio()) - - keepalive() diff --git a/build/lib/dimos/stream/audio/node_microphone.py b/build/lib/dimos/stream/audio/node_microphone.py deleted file mode 100644 index bdb9b32180..0000000000 --- a/build/lib/dimos/stream/audio/node_microphone.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.audio.base import ( - AbstractAudioEmitter, - AudioEvent, -) - -import numpy as np -from typing import Optional, List, Dict, Any -from reactivex import Observable, create, disposable -import time -import sounddevice as sd - -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.audio.node_microphone") - - -class SounddeviceAudioSource(AbstractAudioEmitter): - """Audio source implementation using the sounddevice library.""" - - def __init__( - self, - device_index: Optional[int] = None, - sample_rate: int = 16000, - channels: int = 1, - block_size: int = 1024, - dtype: np.dtype = np.float32, - ): - """ - Initialize SounddeviceAudioSource. - - Args: - device_index: Audio device index (None for default) - sample_rate: Audio sample rate in Hz - channels: Number of audio channels (1=mono, 2=stereo) - block_size: Number of samples per audio frame - dtype: Data type for audio samples (np.float32 or np.int16) - """ - self.device_index = device_index - self.sample_rate = sample_rate - self.channels = channels - self.block_size = block_size - self.dtype = dtype - - self._stream = None - self._running = False - - def emit_audio(self) -> Observable: - """ - Create an observable that emits audio frames. - - Returns: - Observable emitting AudioEvent objects - """ - - def on_subscribe(observer, scheduler): - # Callback function to process audio data - def audio_callback(indata, frames, time_info, status): - if status: - logger.warning(f"Audio callback status: {status}") - - # Create audio event - audio_event = AudioEvent( - data=indata.copy(), - sample_rate=self.sample_rate, - timestamp=time.time(), - channels=self.channels, - ) - - observer.on_next(audio_event) - - # Start the audio stream - try: - self._stream = sd.InputStream( - device=self.device_index, - samplerate=self.sample_rate, - channels=self.channels, - blocksize=self.block_size, - dtype=self.dtype, - callback=audio_callback, - ) - self._stream.start() - self._running = True - - logger.info( - f"Started audio capture: {self.sample_rate}Hz, " - f"{self.channels} channels, {self.block_size} samples per frame" - ) - - except Exception as e: - logger.error(f"Error starting audio stream: {e}") - observer.on_error(e) - - # Return a disposable to clean up resources - def dispose(): - logger.info("Stopping audio capture") - self._running = False - if self._stream: - self._stream.stop() - self._stream.close() - self._stream = None - - return disposable.Disposable(dispose) - - return create(on_subscribe) - - def get_available_devices(self) -> List[Dict[str, Any]]: - """Get a list of available audio input devices.""" - return sd.query_devices() - - -if __name__ == "__main__": - from dimos.stream.audio.node_volume_monitor import monitor - from dimos.stream.audio.utils import keepalive - - monitor(SounddeviceAudioSource().emit_audio()) - keepalive() diff --git a/build/lib/dimos/stream/audio/node_normalizer.py b/build/lib/dimos/stream/audio/node_normalizer.py deleted file mode 100644 index db9557a5b1..0000000000 --- a/build/lib/dimos/stream/audio/node_normalizer.py +++ /dev/null @@ -1,220 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable - -import numpy as np -from reactivex import Observable, create, disposable - -from dimos.utils.logging_config import setup_logger -from dimos.stream.audio.volume import ( - calculate_rms_volume, - calculate_peak_volume, -) -from dimos.stream.audio.base import ( - AbstractAudioTransform, - AudioEvent, -) - - -logger = setup_logger("dimos.stream.audio.node_normalizer") - - -class AudioNormalizer(AbstractAudioTransform): - """ - Audio normalizer that remembers max volume and rescales audio to normalize it. - - This class applies dynamic normalization to audio frames. It keeps track of - the max volume encountered and uses that to normalize the audio to a target level. - """ - - def __init__( - self, - target_level: float = 1.0, - min_volume_threshold: float = 0.01, - max_gain: float = 10.0, - decay_factor: float = 0.999, - adapt_speed: float = 0.05, - volume_func: Callable[[np.ndarray], float] = calculate_peak_volume, - ): - """ - Initialize AudioNormalizer. - - Args: - target_level: Target normalization level (0.0 to 1.0) - min_volume_threshold: Minimum volume to apply normalization - max_gain: Maximum allowed gain to prevent excessive amplification - decay_factor: Decay factor for max volume (0.0-1.0, higher = slower decay) - adapt_speed: How quickly to adapt to new volume levels (0.0-1.0) - volume_func: Function to calculate volume (default: peak volume) - """ - self.target_level = target_level - self.min_volume_threshold = min_volume_threshold - self.max_gain = max_gain - self.decay_factor = decay_factor - self.adapt_speed = adapt_speed - self.volume_func = volume_func - - # Internal state - self.max_volume = 0.0 - self.current_gain = 1.0 - self.audio_observable = None - - def _normalize_audio(self, audio_event: AudioEvent) -> AudioEvent: - """ - Normalize audio data based on tracked max volume. - - Args: - audio_event: Input audio event - - Returns: - Normalized audio event - """ - # Convert to float32 for processing if needed - if audio_event.data.dtype != np.float32: - audio_event = audio_event.to_float32() - - # Calculate current volume using provided function - current_volume = self.volume_func(audio_event.data) - - # Update max volume with decay - self.max_volume = max(current_volume, self.max_volume * self.decay_factor) - - # Calculate ideal gain - if self.max_volume > self.min_volume_threshold: - ideal_gain = self.target_level / self.max_volume - else: - ideal_gain = 1.0 # No normalization needed for very quiet audio - - # Limit gain to max_gain - ideal_gain = min(ideal_gain, self.max_gain) - - # Smoothly adapt current gain towards ideal gain - self.current_gain = ( - 1 - self.adapt_speed - ) * self.current_gain + self.adapt_speed * ideal_gain - - # Apply gain to audio data - normalized_data = audio_event.data * self.current_gain - - # Clip to prevent distortion (values should stay within -1.0 to 1.0) - normalized_data = np.clip(normalized_data, -1.0, 1.0) - - # Create new audio event with normalized data - return AudioEvent( - data=normalized_data, - sample_rate=audio_event.sample_rate, - timestamp=audio_event.timestamp, - channels=audio_event.channels, - ) - - def consume_audio(self, audio_observable: Observable) -> "AudioNormalizer": - """ - Set the audio source observable to consume. - - Args: - audio_observable: Observable emitting AudioEvent objects - - Returns: - Self for method chaining - """ - self.audio_observable = audio_observable - return self - - def emit_audio(self) -> Observable: - """ - Create an observable that emits normalized audio frames. - - Returns: - Observable emitting normalized AudioEvent objects - """ - if self.audio_observable is None: - raise ValueError("No audio source provided. Call consume_audio() first.") - - def on_subscribe(observer, scheduler): - # Subscribe to the audio observable - audio_subscription = self.audio_observable.subscribe( - on_next=lambda event: observer.on_next(self._normalize_audio(event)), - on_error=lambda error: observer.on_error(error), - on_completed=lambda: observer.on_completed(), - ) - - logger.info( - f"Started audio normalizer with target level: {self.target_level}, max gain: {self.max_gain}" - ) - - # Return a disposable to clean up resources - def dispose(): - logger.info("Stopping audio normalizer") - audio_subscription.dispose() - - return disposable.Disposable(dispose) - - return create(on_subscribe) - - -if __name__ == "__main__": - import sys - from dimos.stream.audio.node_microphone import ( - SounddeviceAudioSource, - ) - from dimos.stream.audio.node_simulated import SimulatedAudioSource - from dimos.stream.audio.node_volume_monitor import monitor - from dimos.stream.audio.node_output import SounddeviceAudioOutput - from dimos.stream.audio.utils import keepalive - - # Parse command line arguments - volume_method = "peak" # Default to peak - use_mic = False # Default to microphone input - target_level = 1 # Default target level - - # Process arguments - for arg in sys.argv[1:]: - if arg == "rms": - volume_method = "rms" - elif arg == "peak": - volume_method = "peak" - elif arg == "mic": - use_mic = True - elif arg.startswith("level="): - try: - target_level = float(arg.split("=")[1]) - except ValueError: - print(f"Invalid target level: {arg}") - sys.exit(1) - - # Create appropriate audio source - if use_mic: - audio_source = SounddeviceAudioSource() - print("Using microphone input") - else: - audio_source = SimulatedAudioSource(volume_oscillation=True) - print("Using simulated audio source") - - # Select volume function - volume_func = calculate_rms_volume if volume_method == "rms" else calculate_peak_volume - - # Create normalizer - normalizer = AudioNormalizer(target_level=target_level, volume_func=volume_func) - - # Connect the audio source to the normalizer - normalizer.consume_audio(audio_source.emit_audio()) - - print(f"Using {volume_method} volume method with target level {target_level}") - SounddeviceAudioOutput().consume_audio(normalizer.emit_audio()) - - # Monitor the normalized audio - monitor(normalizer.emit_audio()) - keepalive() diff --git a/build/lib/dimos/stream/audio/node_output.py b/build/lib/dimos/stream/audio/node_output.py deleted file mode 100644 index ee2e2c5ec2..0000000000 --- a/build/lib/dimos/stream/audio/node_output.py +++ /dev/null @@ -1,187 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, List, Dict, Any -import numpy as np -import sounddevice as sd -from reactivex import Observable - -from dimos.utils.logging_config import setup_logger -from dimos.stream.audio.base import ( - AbstractAudioTransform, -) - -logger = setup_logger("dimos.stream.audio.node_output") - - -class SounddeviceAudioOutput(AbstractAudioTransform): - """ - Audio output implementation using the sounddevice library. - - This class implements AbstractAudioTransform so it can both play audio and - optionally pass audio events through to other components (for example, to - record audio while playing it, or to visualize the waveform while playing). - """ - - def __init__( - self, - device_index: Optional[int] = None, - sample_rate: int = 16000, - channels: int = 1, - block_size: int = 1024, - dtype: np.dtype = np.float32, - ): - """ - Initialize SounddeviceAudioOutput. - - Args: - device_index: Audio device index (None for default) - sample_rate: Audio sample rate in Hz - channels: Number of audio channels (1=mono, 2=stereo) - block_size: Number of samples per audio frame - dtype: Data type for audio samples (np.float32 or np.int16) - """ - self.device_index = device_index - self.sample_rate = sample_rate - self.channels = channels - self.block_size = block_size - self.dtype = dtype - - self._stream = None - self._running = False - self._subscription = None - self.audio_observable = None - - def consume_audio(self, audio_observable: Observable) -> "SounddeviceAudioOutput": - """ - Subscribe to an audio observable and play the audio through the speakers. - - Args: - audio_observable: Observable emitting AudioEvent objects - - Returns: - Self for method chaining - """ - self.audio_observable = audio_observable - - # Create and start the output stream - try: - self._stream = sd.OutputStream( - device=self.device_index, - samplerate=self.sample_rate, - channels=self.channels, - blocksize=self.block_size, - dtype=self.dtype, - ) - self._stream.start() - self._running = True - - logger.info( - f"Started audio output: {self.sample_rate}Hz, " - f"{self.channels} channels, {self.block_size} samples per frame" - ) - - except Exception as e: - logger.error(f"Error starting audio output stream: {e}") - raise e - - # Subscribe to the observable - self._subscription = audio_observable.subscribe( - on_next=self._play_audio_event, - on_error=self._handle_error, - on_completed=self._handle_completion, - ) - - return self - - def emit_audio(self) -> Observable: - """ - Pass through the audio observable to allow chaining with other components. - - Returns: - The same Observable that was provided to consume_audio - """ - if self.audio_observable is None: - raise ValueError("No audio source provided. Call consume_audio() first.") - - return self.audio_observable - - def stop(self): - """Stop audio output and clean up resources.""" - logger.info("Stopping audio output") - self._running = False - - if self._subscription: - self._subscription.dispose() - self._subscription = None - - if self._stream: - self._stream.stop() - self._stream.close() - self._stream = None - - def _play_audio_event(self, audio_event): - """Play audio from an AudioEvent.""" - if not self._running or not self._stream: - return - - try: - # Ensure data type matches our stream - if audio_event.dtype != self.dtype: - if self.dtype == np.float32: - audio_event = audio_event.to_float32() - elif self.dtype == np.int16: - audio_event = audio_event.to_int16() - - # Write audio data to the stream - self._stream.write(audio_event.data) - except Exception as e: - logger.error(f"Error playing audio: {e}") - - def _handle_error(self, error): - """Handle errors from the observable.""" - logger.error(f"Error in audio observable: {error}") - - def _handle_completion(self): - """Handle completion of the observable.""" - logger.info("Audio observable completed") - self._running = False - if self._stream: - self._stream.stop() - self._stream.close() - self._stream = None - - def get_available_devices(self) -> List[Dict[str, Any]]: - """Get a list of available audio output devices.""" - return sd.query_devices() - - -if __name__ == "__main__": - from dimos.stream.audio.node_microphone import ( - SounddeviceAudioSource, - ) - from dimos.stream.audio.node_normalizer import AudioNormalizer - from dimos.stream.audio.utils import keepalive - - # Create microphone source, normalizer and audio output - mic = SounddeviceAudioSource() - normalizer = AudioNormalizer() - speaker = SounddeviceAudioOutput() - - # Connect the components in a pipeline - normalizer.consume_audio(mic.emit_audio()) - speaker.consume_audio(normalizer.emit_audio()) - - keepalive() diff --git a/build/lib/dimos/stream/audio/node_simulated.py b/build/lib/dimos/stream/audio/node_simulated.py deleted file mode 100644 index c9aff9a32d..0000000000 --- a/build/lib/dimos/stream/audio/node_simulated.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.audio.abstract import ( - AbstractAudioEmitter, - AudioEvent, -) -import numpy as np -from reactivex import Observable, create, disposable -import threading -import time - -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.stream.audio.node_simulated") - - -class SimulatedAudioSource(AbstractAudioEmitter): - """Audio source that generates simulated audio for testing.""" - - def __init__( - self, - sample_rate: int = 16000, - frame_length: int = 1024, - channels: int = 1, - dtype: np.dtype = np.float32, - frequency: float = 440.0, # A4 note - waveform: str = "sine", # Type of waveform - modulation_rate: float = 0.5, # Modulation rate in Hz - volume_oscillation: bool = True, # Enable sinusoidal volume changes - volume_oscillation_rate: float = 0.2, # Volume oscillation rate in Hz - ): - """ - Initialize SimulatedAudioSource. - - Args: - sample_rate: Audio sample rate in Hz - frame_length: Number of samples per frame - channels: Number of audio channels - dtype: Data type for audio samples - frequency: Frequency of the sine wave in Hz - waveform: Type of waveform ("sine", "square", "triangle", "sawtooth") - modulation_rate: Frequency modulation rate in Hz - volume_oscillation: Whether to oscillate volume sinusoidally - volume_oscillation_rate: Rate of volume oscillation in Hz - """ - self.sample_rate = sample_rate - self.frame_length = frame_length - self.channels = channels - self.dtype = dtype - self.frequency = frequency - self.waveform = waveform.lower() - self.modulation_rate = modulation_rate - self.volume_oscillation = volume_oscillation - self.volume_oscillation_rate = volume_oscillation_rate - self.phase = 0.0 - self.volume_phase = 0.0 - - self._running = False - self._thread = None - - def _generate_sine_wave(self, time_points: np.ndarray) -> np.ndarray: - """Generate a waveform based on selected type.""" - # Generate base time points with phase - t = time_points + self.phase - - # Add frequency modulation for more interesting sounds - if self.modulation_rate > 0: - # Modulate frequency between 0.5x and 1.5x the base frequency - freq_mod = self.frequency * (1.0 + 0.5 * np.sin(2 * np.pi * self.modulation_rate * t)) - else: - freq_mod = np.ones_like(t) * self.frequency - - # Create phase argument for oscillators - phase_arg = 2 * np.pi * np.cumsum(freq_mod / self.sample_rate) - - # Generate waveform based on selection - if self.waveform == "sine": - wave = np.sin(phase_arg) - elif self.waveform == "square": - wave = np.sign(np.sin(phase_arg)) - elif self.waveform == "triangle": - wave = ( - 2 * np.abs(2 * (phase_arg / (2 * np.pi) - np.floor(phase_arg / (2 * np.pi) + 0.5))) - - 1 - ) - elif self.waveform == "sawtooth": - wave = 2 * (phase_arg / (2 * np.pi) - np.floor(0.5 + phase_arg / (2 * np.pi))) - else: - # Default to sine wave - wave = np.sin(phase_arg) - - # Apply sinusoidal volume oscillation if enabled - if self.volume_oscillation: - # Current time points for volume calculation - vol_t = t + self.volume_phase - - # Volume oscillates between 0.0 and 1.0 using a sine wave (complete silence to full volume) - volume_factor = 0.5 + 0.5 * np.sin(2 * np.pi * self.volume_oscillation_rate * vol_t) - - # Apply the volume factor - wave *= volume_factor * 0.7 - - # Update volume phase for next frame - self.volume_phase += ( - time_points[-1] - time_points[0] + (time_points[1] - time_points[0]) - ) - - # Update phase for next frame - self.phase += time_points[-1] - time_points[0] + (time_points[1] - time_points[0]) - - # Add a second channel if needed - if self.channels == 2: - wave = np.column_stack((wave, wave)) - elif self.channels > 2: - wave = np.tile(wave.reshape(-1, 1), (1, self.channels)) - - # Convert to int16 if needed - if self.dtype == np.int16: - wave = (wave * 32767).astype(np.int16) - - return wave - - def _audio_thread(self, observer, interval: float): - """Thread function for simulated audio generation.""" - try: - sample_index = 0 - self._running = True - - while self._running: - # Calculate time points for this frame - time_points = ( - np.arange(sample_index, sample_index + self.frame_length) / self.sample_rate - ) - - # Generate audio data - audio_data = self._generate_sine_wave(time_points) - - # Create audio event - audio_event = AudioEvent( - data=audio_data, - sample_rate=self.sample_rate, - timestamp=time.time(), - channels=self.channels, - ) - - observer.on_next(audio_event) - - # Update sample index for next frame - sample_index += self.frame_length - - # Sleep to simulate real-time audio - time.sleep(interval) - - except Exception as e: - logger.error(f"Error in simulated audio thread: {e}") - observer.on_error(e) - finally: - self._running = False - observer.on_completed() - - def emit_audio(self, fps: int = 30) -> Observable: - """ - Create an observable that emits simulated audio frames. - - Args: - fps: Frames per second to emit - - Returns: - Observable emitting AudioEvent objects - """ - - def on_subscribe(observer, scheduler): - # Calculate interval based on fps - interval = 1.0 / fps - - # Start the audio generation thread - self._thread = threading.Thread( - target=self._audio_thread, args=(observer, interval), daemon=True - ) - self._thread.start() - - logger.info( - f"Started simulated audio source: {self.sample_rate}Hz, " - f"{self.channels} channels, {self.frame_length} samples per frame" - ) - - # Return a disposable to clean up - def dispose(): - logger.info("Stopping simulated audio") - self._running = False - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=1.0) - - return disposable.Disposable(dispose) - - return create(on_subscribe) - - -if __name__ == "__main__": - from dimos.stream.audio.utils import keepalive - from dimos.stream.audio.node_volume_monitor import monitor - from dimos.stream.audio.node_output import SounddeviceAudioOutput - - source = SimulatedAudioSource() - speaker = SounddeviceAudioOutput() - speaker.consume_audio(source.emit_audio()) - monitor(speaker.emit_audio()) - - keepalive() diff --git a/build/lib/dimos/stream/audio/node_volume_monitor.py b/build/lib/dimos/stream/audio/node_volume_monitor.py deleted file mode 100644 index 6510667307..0000000000 --- a/build/lib/dimos/stream/audio/node_volume_monitor.py +++ /dev/null @@ -1,176 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable -from reactivex import Observable, create, disposable - -from dimos.stream.audio.base import AudioEvent, AbstractAudioConsumer -from dimos.stream.audio.text.base import AbstractTextEmitter -from dimos.stream.audio.text.node_stdout import TextPrinterNode -from dimos.stream.audio.volume import calculate_peak_volume -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.stream.audio.node_volume_monitor") - - -class VolumeMonitorNode(AbstractAudioConsumer, AbstractTextEmitter): - """ - A node that monitors audio volume and emits text descriptions. - """ - - def __init__( - self, - threshold: float = 0.01, - bar_length: int = 50, - volume_func: Callable = calculate_peak_volume, - ): - """ - Initialize VolumeMonitorNode. - - Args: - threshold: Threshold for considering audio as active - bar_length: Length of the volume bar in characters - volume_func: Function to calculate volume (defaults to peak volume) - """ - self.threshold = threshold - self.bar_length = bar_length - self.volume_func = volume_func - self.func_name = volume_func.__name__.replace("calculate_", "") - self.audio_observable = None - - def create_volume_text(self, volume: float) -> str: - """ - Create a text representation of the volume level. - - Args: - volume: Volume level between 0.0 and 1.0 - - Returns: - String representation of the volume - """ - # Calculate number of filled segments - filled = int(volume * self.bar_length) - - # Create the bar - bar = "█" * filled + "░" * (self.bar_length - filled) - - # Determine if we're above threshold - active = volume >= self.threshold - - # Format the text - percentage = int(volume * 100) - activity = "active" if active else "silent" - return f"{bar} {percentage:3d}% {activity}" - - def consume_audio(self, audio_observable: Observable) -> "VolumeMonitorNode": - """ - Set the audio source observable to consume. - - Args: - audio_observable: Observable emitting AudioEvent objects - - Returns: - Self for method chaining - """ - self.audio_observable = audio_observable - return self - - def emit_text(self) -> Observable: - """ - Create an observable that emits volume text descriptions. - - Returns: - Observable emitting text descriptions of audio volume - """ - if self.audio_observable is None: - raise ValueError("No audio source provided. Call consume_audio() first.") - - def on_subscribe(observer, scheduler): - logger.info(f"Starting volume monitor (method: {self.func_name})") - - # Subscribe to the audio source - def on_audio_event(event: AudioEvent): - try: - # Calculate volume - volume = self.volume_func(event.data) - - # Create text representation - text = self.create_volume_text(volume) - - # Emit the text - observer.on_next(text) - except Exception as e: - logger.error(f"Error processing audio event: {e}") - observer.on_error(e) - - # Set up subscription to audio source - subscription = self.audio_observable.subscribe( - on_next=on_audio_event, - on_error=lambda e: observer.on_error(e), - on_completed=lambda: observer.on_completed(), - ) - - # Return a disposable to clean up resources - def dispose(): - logger.info("Stopping volume monitor") - subscription.dispose() - - return disposable.Disposable(dispose) - - return create(on_subscribe) - - -def monitor( - audio_source: Observable, - threshold: float = 0.01, - bar_length: int = 50, - volume_func: Callable = calculate_peak_volume, -) -> VolumeMonitorNode: - """ - Create a volume monitor node connected to a text output node. - - Args: - audio_source: The audio source to monitor - threshold: Threshold for considering audio as active - bar_length: Length of the volume bar in characters - volume_func: Function to calculate volume - - Returns: - The configured volume monitor node - """ - # Create the volume monitor node with specified parameters - volume_monitor = VolumeMonitorNode( - threshold=threshold, bar_length=bar_length, volume_func=volume_func - ) - - # Connect the volume monitor to the audio source - volume_monitor.consume_audio(audio_source) - - # Create and connect the text printer node - text_printer = TextPrinterNode() - text_printer.consume_text(volume_monitor.emit_text()) - - # Return the volume monitor node - return volume_monitor - - -if __name__ == "__main__": - from utils import keepalive - from audio.node_simulated import SimulatedAudioSource - - # Use the monitor function to create and connect the nodes - volume_monitor = monitor(SimulatedAudioSource().emit_audio()) - - keepalive() diff --git a/build/lib/dimos/stream/audio/pipelines.py b/build/lib/dimos/stream/audio/pipelines.py deleted file mode 100644 index ee2ae43316..0000000000 --- a/build/lib/dimos/stream/audio/pipelines.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.audio.node_microphone import SounddeviceAudioSource -from dimos.stream.audio.node_normalizer import AudioNormalizer -from dimos.stream.audio.node_volume_monitor import monitor -from dimos.stream.audio.node_key_recorder import KeyRecorder -from dimos.stream.audio.node_output import SounddeviceAudioOutput -from dimos.stream.audio.stt.node_whisper import WhisperNode -from dimos.stream.audio.tts.node_openai import OpenAITTSNode, Voice -from dimos.stream.audio.text.node_stdout import TextPrinterNode - - -def stt(): - # Create microphone source, recorder, and audio output - mic = SounddeviceAudioSource() - normalizer = AudioNormalizer() - recorder = KeyRecorder(always_subscribe=True) - whisper_node = WhisperNode() # Assign to global variable - - # Connect audio processing pipeline - normalizer.consume_audio(mic.emit_audio()) - recorder.consume_audio(normalizer.emit_audio()) - monitor(recorder.emit_audio()) - whisper_node.consume_audio(recorder.emit_recording()) - - user_text_printer = TextPrinterNode(prefix="USER: ") - user_text_printer.consume_text(whisper_node.emit_text()) - - return whisper_node - - -def tts(): - tts_node = OpenAITTSNode(speed=1.2, voice=Voice.ONYX) - agent_text_printer = TextPrinterNode(prefix="AGENT: ") - agent_text_printer.consume_text(tts_node.emit_text()) - - response_output = SounddeviceAudioOutput(sample_rate=24000) - response_output.consume_audio(tts_node.emit_audio()) - - return tts_node diff --git a/build/lib/dimos/stream/audio/utils.py b/build/lib/dimos/stream/audio/utils.py deleted file mode 100644 index 712086ffd6..0000000000 --- a/build/lib/dimos/stream/audio/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - - -def keepalive(): - try: - # Keep the program running - print("Press Ctrl+C to exit") - print("-" * 60) - while True: - time.sleep(0.1) - except KeyboardInterrupt: - print("\nStopping pipeline") diff --git a/build/lib/dimos/stream/audio/volume.py b/build/lib/dimos/stream/audio/volume.py deleted file mode 100644 index f2e50ab72c..0000000000 --- a/build/lib/dimos/stream/audio/volume.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - - -def calculate_rms_volume(audio_data: np.ndarray) -> float: - """ - Calculate RMS (Root Mean Square) volume of audio data. - - Args: - audio_data: Audio data as numpy array - - Returns: - RMS volume as a float between 0.0 and 1.0 - """ - # For multi-channel audio, calculate RMS across all channels - if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: - # Flatten all channels - audio_data = audio_data.flatten() - - # Calculate RMS - rms = np.sqrt(np.mean(np.square(audio_data))) - - # For int16 data, normalize to [0, 1] - if audio_data.dtype == np.int16: - rms = rms / 32768.0 - - return rms - - -def calculate_peak_volume(audio_data: np.ndarray) -> float: - """ - Calculate peak volume of audio data. - - Args: - audio_data: Audio data as numpy array - - Returns: - Peak volume as a float between 0.0 and 1.0 - """ - # For multi-channel audio, find max across all channels - if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: - # Flatten all channels - audio_data = audio_data.flatten() - - # Find absolute peak value - peak = np.max(np.abs(audio_data)) - - # For int16 data, normalize to [0, 1] - if audio_data.dtype == np.int16: - peak = peak / 32768.0 - - return peak - - -if __name__ == "__main__": - # Example usage - import time - from .node_simulated import SimulatedAudioSource - - # Create a simulated audio source - audio_source = SimulatedAudioSource() - - # Create observable and subscribe to get a single frame - audio_observable = audio_source.capture_audio_as_observable() - - def process_frame(frame): - # Calculate and print both RMS and peak volumes - rms_vol = calculate_rms_volume(frame.data) - peak_vol = calculate_peak_volume(frame.data) - - print(f"RMS Volume: {rms_vol:.4f}") - print(f"Peak Volume: {peak_vol:.4f}") - print(f"Ratio (Peak/RMS): {peak_vol / rms_vol:.2f}") - - # Set a flag to track when processing is complete - processed = {"done": False} - - def process_frame_wrapper(frame): - # Process the frame - process_frame(frame) - # Mark as processed - processed["done"] = True - - # Subscribe to get a single frame and process it - subscription = audio_observable.subscribe( - on_next=process_frame_wrapper, on_completed=lambda: print("Completed") - ) - - # Wait for frame processing to complete - while not processed["done"]: - time.sleep(0.01) - - # Now dispose the subscription from the main thread, not from within the callback - subscription.dispose() diff --git a/build/lib/dimos/stream/data_provider.py b/build/lib/dimos/stream/data_provider.py deleted file mode 100644 index 73e1ba0f20..0000000000 --- a/build/lib/dimos/stream/data_provider.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC -from reactivex import Subject, Observable -from reactivex.subject import Subject -from reactivex.scheduler import ThreadPoolScheduler -import multiprocessing -import logging - -import reactivex as rx -from reactivex import operators as ops - -logging.basicConfig(level=logging.INFO) - -# Create a thread pool scheduler for concurrent processing -pool_scheduler = ThreadPoolScheduler(multiprocessing.cpu_count()) - - -class AbstractDataProvider(ABC): - """Abstract base class for data providers using ReactiveX.""" - - def __init__(self, dev_name: str = "NA"): - self.dev_name = dev_name - self._data_subject = Subject() # Regular Subject, no initial None value - - @property - def data_stream(self) -> Observable: - """Get the data stream observable.""" - return self._data_subject - - def push_data(self, data): - """Push new data to the stream.""" - self._data_subject.on_next(data) - - def dispose(self): - """Cleanup resources.""" - self._data_subject.dispose() - - -class ROSDataProvider(AbstractDataProvider): - """ReactiveX data provider for ROS topics.""" - - def __init__(self, dev_name: str = "ros_provider"): - super().__init__(dev_name) - self.logger = logging.getLogger(dev_name) - - def push_data(self, data): - """Push new data to the stream.""" - print(f"ROSDataProvider pushing data of type: {type(data)}") - super().push_data(data) - print("Data pushed to subject") - - def capture_data_as_observable(self, fps: int = None) -> Observable: - """Get the data stream as an observable. - - Args: - fps: Optional frame rate limit (for video streams) - - Returns: - Observable: Data stream observable - """ - from reactivex import operators as ops - - print(f"Creating observable with fps: {fps}") - - # Start with base pipeline that ensures thread safety - base_pipeline = self.data_stream.pipe( - # Ensure emissions are handled on thread pool - ops.observe_on(pool_scheduler), - # Add debug logging to track data flow - ops.do_action( - on_next=lambda x: print(f"Got frame in pipeline: {type(x)}"), - on_error=lambda e: print(f"Pipeline error: {e}"), - on_completed=lambda: print("Pipeline completed"), - ), - ) - - # If fps is specified, add rate limiting - if fps and fps > 0: - print(f"Adding rate limiting at {fps} FPS") - return base_pipeline.pipe( - # Use scheduler for time-based operations - ops.sample(1.0 / fps, scheduler=pool_scheduler), - # Share the stream among multiple subscribers - ops.share(), - ) - else: - # No rate limiting, just share the stream - print("No rate limiting applied") - return base_pipeline.pipe(ops.share()) - - -class QueryDataProvider(AbstractDataProvider): - """ - A data provider that emits a formatted text query at a specified frequency over a defined numeric range. - - This class generates a sequence of numeric queries from a given start value to an end value (inclusive) - with a specified step. Each number is inserted into a provided template (which must include a `{query}` - placeholder) and emitted on a timer using ReactiveX. - - Attributes: - dev_name (str): The name of the data provider. - logger (logging.Logger): Logger instance for logging messages. - """ - - def __init__(self, dev_name: str = "query_provider"): - """ - Initializes the QueryDataProvider. - - Args: - dev_name (str): The name of the data provider. Defaults to "query_provider". - """ - super().__init__(dev_name) - self.logger = logging.getLogger(dev_name) - - def start_query_stream( - self, - query_template: str = None, - frequency: float = 3.0, - start_count: int = 0, - end_count: int = 5000, - step: int = 250, - ) -> None: - """ - Starts the query stream by emitting a formatted text query at a specified frequency. - - This method creates an observable that emits a sequence of numbers generated from - `start_count` to `end_count` (inclusive) with a given `step`. Each number is then formatted - using the `query_template`. The formatted query is pushed to the internal data stream. - - Args: - query_template (str): The template string for formatting queries. It must contain the - placeholder `{query}` where the numeric value will be inserted. If None, a default - template is used. - frequency (float): The frequency (in seconds) at which queries are emitted. Defaults to 3.0. - start_count (int): The starting number for query generation. Defaults to 0. - end_count (int): The ending number for query generation (inclusive). Defaults to 5000. - step (int): The increment between consecutive query numbers. Defaults to 250. - """ - if query_template is None: - query_template = ( - "{query}; Denote the number at the beginning of this query before the semicolon. " - "Only provide the number, without any other text in your response. " - "If the number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. " - "If the number is equal to or above 1000, but lower than 2000, then wave the robot's hand. " - "If the number is equal to or above 2000, then clear debris. " - "IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!" - ) - - # Generate the sequence of numeric queries. - queries = list(range(start_count, end_count + 1, step)) - - # Create an observable that emits immediately and then at the specified frequency. - timer = rx.timer(0, frequency) - query_source = rx.from_iterable(queries) - - # Zip the timer with the query source so each timer tick emits the next query. - query_stream = timer.pipe( - ops.zip(query_source), - ops.map(lambda pair: query_template.format(query=pair[1])), - ops.observe_on(pool_scheduler), - # ops.do_action( - # on_next=lambda q: self.logger.info(f"Emitting query: {q}"), - # on_error=lambda e: self.logger.error(f"Query stream error: {e}"), - # on_completed=lambda: self.logger.info("Query stream completed") - # ), - ops.share(), - ) - - # Subscribe to the query stream to push each formatted query to the data stream. - query_stream.subscribe(lambda q: self.push_data(q)) diff --git a/build/lib/dimos/stream/frame_processor.py b/build/lib/dimos/stream/frame_processor.py deleted file mode 100644 index b07a09118b..0000000000 --- a/build/lib/dimos/stream/frame_processor.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -import os -from reactivex import Observable -from reactivex import operators as ops -from typing import Tuple, Optional - - -# TODO: Reorganize, filenaming - Consider merger with VideoOperators class -class FrameProcessor: - def __init__(self, output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=False): - """Initializes the FrameProcessor. - - Sets up the output directory for frame storage and optionally cleans up - existing JPG files. - - Args: - output_dir: Directory path for storing processed frames. - Defaults to '{os.getcwd()}/assets/output/frames'. - delete_on_init: If True, deletes all existing JPG files in output_dir. - Defaults to False. - - Raises: - OSError: If directory creation fails or if file deletion fails. - PermissionError: If lacking permissions for directory/file operations. - """ - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - - if delete_on_init: - try: - jpg_files = [f for f in os.listdir(self.output_dir) if f.lower().endswith(".jpg")] - for file in jpg_files: - file_path = os.path.join(self.output_dir, file) - os.remove(file_path) - print(f"Cleaned up {len(jpg_files)} existing JPG files from {self.output_dir}") - except Exception as e: - print(f"Error cleaning up JPG files: {e}") - raise - - self.image_count = 1 - # TODO: Add randomness to jpg folder storage naming. - # Will overwrite between sessions. - - def to_grayscale(self, frame): - if frame is None: - print("Received None frame for grayscale conversion.") - return None - return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - - def edge_detection(self, frame): - return cv2.Canny(frame, 100, 200) - - def resize(self, frame, scale=0.5): - return cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) - - def export_to_jpeg(self, frame, save_limit=100, loop=False, suffix=""): - if frame is None: - print("Error: Attempted to save a None image.") - return None - - # Check if the image has an acceptable number of channels - if len(frame.shape) == 3 and frame.shape[2] not in [1, 3, 4]: - print(f"Error: Frame with shape {frame.shape} has unsupported number of channels.") - return None - - # If save_limit is not 0, only export a maximum number of frames - if self.image_count > save_limit and save_limit != 0: - if loop: - self.image_count = 1 - else: - return frame - - filepath = os.path.join(self.output_dir, f"{self.image_count}_{suffix}.jpg") - cv2.imwrite(filepath, frame) - self.image_count += 1 - return frame - - def compute_optical_flow( - self, - acc: Tuple[np.ndarray, np.ndarray, Optional[float]], - current_frame: np.ndarray, - compute_relevancy: bool = True, - ) -> Tuple[np.ndarray, np.ndarray, Optional[float]]: - """Computes optical flow between consecutive frames. - - Uses the Farneback algorithm to compute dense optical flow between the - previous and current frame. Optionally calculates a relevancy score - based on the mean magnitude of motion vectors. - - Args: - acc: Accumulator tuple containing: - prev_frame: Previous video frame (np.ndarray) - prev_flow: Previous optical flow (np.ndarray) - prev_relevancy: Previous relevancy score (float or None) - current_frame: Current video frame as BGR image (np.ndarray) - compute_relevancy: If True, calculates mean magnitude of flow vectors. - Defaults to True. - - Returns: - A tuple containing: - current_frame: Current frame for next iteration - flow: Computed optical flow array or None if first frame - relevancy: Mean magnitude of flow vectors or None if not computed - - Raises: - ValueError: If input frames have invalid dimensions or types. - TypeError: If acc is not a tuple of correct types. - """ - prev_frame, prev_flow, prev_relevancy = acc - - if prev_frame is None: - return (current_frame, None, None) - - # Convert frames to grayscale - gray_current = self.to_grayscale(current_frame) - gray_prev = self.to_grayscale(prev_frame) - - # Compute optical flow - flow = cv2.calcOpticalFlowFarneback(gray_prev, gray_current, None, 0.5, 3, 15, 3, 5, 1.2, 0) - - # Relevancy calulation (average magnitude of flow vectors) - relevancy = None - if compute_relevancy: - mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) - relevancy = np.mean(mag) - - # Return the current frame as the new previous frame and the processed optical flow, with relevancy score - return (current_frame, flow, relevancy) - - def visualize_flow(self, flow): - if flow is None: - return None - hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) - hsv[..., 1] = 255 - mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) - hsv[..., 0] = ang * 180 / np.pi / 2 - hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) - rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) - return rgb - - # ============================== - - def process_stream_edge_detection(self, frame_stream): - return frame_stream.pipe( - ops.map(self.edge_detection), - ) - - def process_stream_resize(self, frame_stream): - return frame_stream.pipe( - ops.map(self.resize), - ) - - def process_stream_to_greyscale(self, frame_stream): - return frame_stream.pipe( - ops.map(self.to_grayscale), - ) - - def process_stream_optical_flow(self, frame_stream: Observable) -> Observable: - """Processes video stream to compute and visualize optical flow. - - Computes optical flow between consecutive frames and generates a color-coded - visualization where hue represents flow direction and intensity represents - flow magnitude. This method optimizes performance by disabling relevancy - computation. - - Args: - frame_stream: An Observable emitting video frames as numpy arrays. - Each frame should be in BGR format with shape (height, width, 3). - - Returns: - An Observable emitting visualized optical flow frames as BGR images - (np.ndarray). Hue indicates flow direction, intensity shows magnitude. - - Raises: - TypeError: If frame_stream is not an Observable. - ValueError: If frames have invalid dimensions or format. - - Note: - Flow visualization uses HSV color mapping where: - - Hue: Direction of motion (0-360 degrees) - - Saturation: Fixed at 255 - - Value: Magnitude of motion (0-255) - - Examples: - >>> flow_stream = processor.process_stream_optical_flow(frame_stream) - >>> flow_stream.subscribe(lambda flow: cv2.imshow('Flow', flow)) - """ - return frame_stream.pipe( - ops.scan( - lambda acc, frame: self.compute_optical_flow(acc, frame, compute_relevancy=False), - (None, None, None), - ), - ops.map(lambda result: result[1]), # Extract flow component - ops.filter(lambda flow: flow is not None), - ops.map(self.visualize_flow), - ) - - def process_stream_optical_flow_with_relevancy(self, frame_stream: Observable) -> Observable: - """Processes video stream to compute optical flow with movement relevancy. - - Applies optical flow computation to each frame and returns both the - visualized flow and a relevancy score indicating the amount of movement. - The relevancy score is calculated as the mean magnitude of flow vectors. - This method includes relevancy computation for motion detection. - - Args: - frame_stream: An Observable emitting video frames as numpy arrays. - Each frame should be in BGR format with shape (height, width, 3). - - Returns: - An Observable emitting tuples of (visualized_flow, relevancy_score): - visualized_flow: np.ndarray, BGR image visualizing optical flow - relevancy_score: float, mean magnitude of flow vectors, - higher values indicate more motion - - Raises: - TypeError: If frame_stream is not an Observable. - ValueError: If frames have invalid dimensions or format. - - Examples: - >>> flow_stream = processor.process_stream_optical_flow_with_relevancy( - ... frame_stream - ... ) - >>> flow_stream.subscribe( - ... lambda result: print(f"Motion score: {result[1]}") - ... ) - - Note: - Relevancy scores are computed using mean magnitude of flow vectors. - Higher scores indicate more movement in the frame. - """ - return frame_stream.pipe( - ops.scan( - lambda acc, frame: self.compute_optical_flow(acc, frame, compute_relevancy=True), - (None, None, None), - ), - # Result is (current_frame, flow, relevancy) - ops.filter(lambda result: result[1] is not None), # Filter out None flows - ops.map( - lambda result: ( - self.visualize_flow(result[1]), # Visualized flow - result[2], # Relevancy score - ) - ), - ops.filter(lambda result: result[0] is not None), # Ensure valid visualization - ) - - def process_stream_with_jpeg_export( - self, frame_stream: Observable, suffix: str = "", loop: bool = False - ) -> Observable: - """Processes stream by saving frames as JPEGs while passing them through. - - Saves each frame from the stream as a JPEG file and passes the frame - downstream unmodified. Files are saved sequentially with optional suffix - in the configured output directory (self.output_dir). If loop is True, - it will cycle back and overwrite images starting from the first one - after reaching the save_limit. - - Args: - frame_stream: An Observable emitting video frames as numpy arrays. - Each frame should be in BGR format with shape (height, width, 3). - suffix: Optional string to append to filename before index. - Defaults to empty string. Example: "optical" -> "optical_1.jpg" - loop: If True, reset the image counter to 1 after reaching - save_limit, effectively looping the saves. Defaults to False. - - Returns: - An Observable emitting the same frames that were saved. Returns None - for frames that could not be saved due to format issues or save_limit - (unless loop is True). - - Raises: - TypeError: If frame_stream is not an Observable. - ValueError: If frames have invalid format or output directory - is not writable. - OSError: If there are file system permission issues. - - Note: - Frames are saved as '{suffix}_{index}.jpg' where index - increments for each saved frame. Saving stops after reaching - the configured save_limit (default: 100) unless loop is True. - """ - return frame_stream.pipe( - ops.map(lambda frame: self.export_to_jpeg(frame, suffix=suffix, loop=loop)), - ) diff --git a/build/lib/dimos/stream/ros_video_provider.py b/build/lib/dimos/stream/ros_video_provider.py deleted file mode 100644 index 7ca6fa4aa7..0000000000 --- a/build/lib/dimos/stream/ros_video_provider.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ROS-based video provider module. - -This module provides a video frame provider that receives frames from ROS (Robot Operating System) -and makes them available as an Observable stream. -""" - -from reactivex import Subject, Observable -from reactivex import operators as ops -from reactivex.scheduler import ThreadPoolScheduler -import logging -import time -from typing import Optional -import numpy as np - -from dimos.stream.video_provider import AbstractVideoProvider - -logging.basicConfig(level=logging.INFO) - - -class ROSVideoProvider(AbstractVideoProvider): - """Video provider that uses a Subject to broadcast frames pushed by ROS. - - This class implements a video provider that receives frames from ROS and makes them - available as an Observable stream. It uses ReactiveX's Subject to broadcast frames. - - Attributes: - logger: Logger instance for this provider. - _subject: ReactiveX Subject that broadcasts frames. - _last_frame_time: Timestamp of the last received frame. - """ - - def __init__( - self, dev_name: str = "ros_video", pool_scheduler: Optional[ThreadPoolScheduler] = None - ): - """Initialize the ROS video provider. - - Args: - dev_name: A string identifying this provider. - pool_scheduler: Optional ThreadPoolScheduler for multithreading. - """ - super().__init__(dev_name, pool_scheduler) - self.logger = logging.getLogger(dev_name) - self._subject = Subject() - self._last_frame_time = None - self.logger.info("ROSVideoProvider initialized") - - def push_data(self, frame: np.ndarray) -> None: - """Push a new frame into the provider. - - Args: - frame: The video frame to push into the stream, typically a numpy array - containing image data. - - Raises: - Exception: If there's an error pushing the frame. - """ - try: - current_time = time.time() - if self._last_frame_time: - frame_interval = current_time - self._last_frame_time - self.logger.debug( - f"Frame interval: {frame_interval:.3f}s ({1 / frame_interval:.1f} FPS)" - ) - self._last_frame_time = current_time - - self.logger.debug(f"Pushing frame type: {type(frame)}") - self._subject.on_next(frame) - self.logger.debug("Frame pushed") - except Exception as e: - self.logger.error(f"Push error: {e}") - raise - - def capture_video_as_observable(self, fps: int = 30) -> Observable: - """Return an observable of video frames. - - Args: - fps: Frames per second rate limit (default: 30; ignored for now). - - Returns: - Observable: An observable stream of video frames (numpy.ndarray objects), - with each emission containing a single video frame. The frames are - multicast to all subscribers. - - Note: - The fps parameter is currently not enforced. See implementation note below. - """ - self.logger.info(f"Creating observable with {fps} FPS rate limiting") - # TODO: Implement rate limiting using ops.throttle_with_timeout() or - # ops.sample() to restrict emissions to one frame per (1/fps) seconds. - # Example: ops.sample(1.0/fps) - return self._subject.pipe( - # Ensure subscription work happens on the thread pool - ops.subscribe_on(self.pool_scheduler), - # Ensure observer callbacks execute on the thread pool - ops.observe_on(self.pool_scheduler), - # Make the stream hot/multicast so multiple subscribers get the same frames - ops.share(), - ) diff --git a/build/lib/dimos/stream/rtsp_video_provider.py b/build/lib/dimos/stream/rtsp_video_provider.py deleted file mode 100644 index 5926c4f676..0000000000 --- a/build/lib/dimos/stream/rtsp_video_provider.py +++ /dev/null @@ -1,380 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""RTSP video provider using ffmpeg for robust stream handling.""" - -import subprocess -import threading -import time -from typing import Optional - -import ffmpeg # ffmpeg-python wrapper -import numpy as np -import reactivex as rx -from reactivex import operators as ops -from reactivex.disposable import Disposable -from reactivex.observable import Observable -from reactivex.scheduler import ThreadPoolScheduler - -from dimos.utils.logging_config import setup_logger - -# Assuming AbstractVideoProvider and exceptions are in the sibling file -from .video_provider import AbstractVideoProvider, VideoFrameError, VideoSourceError - -logger = setup_logger("dimos.stream.rtsp_video_provider") - - -class RtspVideoProvider(AbstractVideoProvider): - """Video provider implementation for capturing RTSP streams using ffmpeg. - - This provider uses the ffmpeg-python library to interact with ffmpeg, - providing more robust handling of various RTSP streams compared to OpenCV's - built-in VideoCapture for RTSP. - """ - - def __init__( - self, dev_name: str, rtsp_url: str, pool_scheduler: Optional[ThreadPoolScheduler] = None - ) -> None: - """Initializes the RTSP video provider. - - Args: - dev_name: The name of the device or stream (for identification). - rtsp_url: The URL of the RTSP stream (e.g., "rtsp://user:pass@ip:port/path"). - pool_scheduler: The scheduler for thread pool operations. Defaults to global scheduler. - """ - super().__init__(dev_name, pool_scheduler) - self.rtsp_url = rtsp_url - # Holds the currently active ffmpeg process Popen object - self._ffmpeg_process: Optional[subprocess.Popen] = None - # Lock to protect access to the ffmpeg process object - self._lock = threading.Lock() - - def _get_stream_info(self) -> dict: - """Probes the RTSP stream to get video dimensions and FPS using ffprobe.""" - logger.info(f"({self.dev_name}) Probing RTSP stream.") - try: - # Probe the stream without the problematic timeout argument - probe = ffmpeg.probe(self.rtsp_url) - except ffmpeg.Error as e: - stderr = e.stderr.decode("utf8", errors="ignore") if e.stderr else "No stderr" - msg = f"({self.dev_name}) Failed to probe RTSP stream {self.rtsp_url}: {stderr}" - logger.error(msg) - raise VideoSourceError(msg) from e - except Exception as e: - msg = f"({self.dev_name}) Unexpected error during probing {self.rtsp_url}: {e}" - logger.error(msg) - raise VideoSourceError(msg) from e - - video_stream = next( - (stream for stream in probe.get("streams", []) if stream.get("codec_type") == "video"), - None, - ) - - if video_stream is None: - msg = f"({self.dev_name}) No video stream found in {self.rtsp_url}" - logger.error(msg) - raise VideoSourceError(msg) - - width = video_stream.get("width") - height = video_stream.get("height") - fps_str = video_stream.get("avg_frame_rate", "0/1") - - if not width or not height: - msg = f"({self.dev_name}) Could not determine resolution for {self.rtsp_url}. Stream info: {video_stream}" - logger.error(msg) - raise VideoSourceError(msg) - - try: - if "/" in fps_str: - num, den = map(int, fps_str.split("/")) - fps = float(num) / den if den != 0 else 30.0 - else: - fps = float(fps_str) - if fps <= 0: - logger.warning( - f"({self.dev_name}) Invalid avg_frame_rate '{fps_str}', defaulting FPS to 30." - ) - fps = 30.0 - except ValueError: - logger.warning( - f"({self.dev_name}) Could not parse FPS '{fps_str}', defaulting FPS to 30." - ) - fps = 30.0 - - logger.info(f"({self.dev_name}) Stream info: {width}x{height} @ {fps:.2f} FPS") - return {"width": width, "height": height, "fps": fps} - - def _start_ffmpeg_process(self, width: int, height: int) -> subprocess.Popen: - """Starts the ffmpeg process to capture and decode the stream.""" - logger.info(f"({self.dev_name}) Starting ffmpeg process for rtsp stream.") - try: - # Configure ffmpeg input: prefer TCP, set timeout, reduce buffering/delay - input_options = { - "rtsp_transport": "tcp", - "stimeout": "5000000", # 5 seconds timeout for RTSP server responses - "fflags": "nobuffer", # Reduce input buffering - "flags": "low_delay", # Reduce decoding delay - # 'timeout': '10000000' # Removed: This was misinterpreted as listen timeout - } - process = ( - ffmpeg.input(self.rtsp_url, **input_options) - .output("pipe:", format="rawvideo", pix_fmt="bgr24") # Output raw BGR frames - .global_args("-loglevel", "warning") # Reduce ffmpeg log spam, use 'error' for less - .run_async(pipe_stdout=True, pipe_stderr=True) # Capture stdout and stderr - ) - logger.info(f"({self.dev_name}) ffmpeg process started (PID: {process.pid})") - return process - except ffmpeg.Error as e: - stderr = e.stderr.decode("utf8", errors="ignore") if e.stderr else "No stderr" - msg = f"({self.dev_name}) Failed to start ffmpeg for {self.rtsp_url}: {stderr}" - logger.error(msg) - raise VideoSourceError(msg) from e - except Exception as e: # Catch other errors like ffmpeg executable not found - msg = f"({self.dev_name}) An unexpected error occurred starting ffmpeg: {e}" - logger.error(msg) - raise VideoSourceError(msg) from e - - def capture_video_as_observable(self, fps: int = 0) -> Observable: - """Creates an observable from the RTSP stream using ffmpeg. - - The observable attempts to reconnect if the stream drops. - - Args: - fps: This argument is currently ignored. The provider attempts - to use the stream's native frame rate. Future versions might - allow specifying an output FPS via ffmpeg filters. - - Returns: - Observable: An observable emitting video frames as NumPy arrays (BGR format). - - Raises: - VideoSourceError: If the stream cannot be initially probed or the - ffmpeg process fails to start. - VideoFrameError: If there's an error reading or processing frames. - """ - if fps != 0: - logger.warning( - f"({self.dev_name}) The 'fps' argument ({fps}) is currently ignored. Using stream native FPS." - ) - - def emit_frames(observer, scheduler): - """Function executed by rx.create to emit frames.""" - process: Optional[subprocess.Popen] = None - # Event to signal the processing loop should stop (e.g., on dispose) - should_stop = threading.Event() - - def cleanup_process(): - """Safely terminate the ffmpeg process if it's running.""" - nonlocal process - logger.debug(f"({self.dev_name}) Cleanup requested.") - # Use lock to ensure thread safety when accessing/modifying process - with self._lock: - # Check if the process exists and is still running - if process and process.poll() is None: - logger.info( - f"({self.dev_name}) Terminating ffmpeg process (PID: {process.pid})." - ) - try: - process.terminate() # Ask ffmpeg to exit gracefully - process.wait(timeout=1.0) # Wait up to 1 second - except subprocess.TimeoutExpired: - logger.warning( - f"({self.dev_name}) ffmpeg (PID: {process.pid}) did not terminate gracefully, killing." - ) - process.kill() # Force kill if it didn't exit - except Exception as e: - logger.error(f"({self.dev_name}) Error during ffmpeg termination: {e}") - finally: - # Ensure we clear the process variable even if wait/kill fails - process = None - # Also clear the shared class attribute if this was the active process - if self._ffmpeg_process and self._ffmpeg_process.pid == process.pid: - self._ffmpeg_process = None - elif process and process.poll() is not None: - # Process exists but already terminated - logger.debug( - f"({self.dev_name}) ffmpeg process (PID: {process.pid}) already terminated (exit code: {process.poll()})." - ) - process = None # Clear the variable - # Clear shared attribute if it matches - if self._ffmpeg_process and self._ffmpeg_process.pid == process.pid: - self._ffmpeg_process = None - else: - # Process variable is already None or doesn't match _ffmpeg_process - logger.debug( - f"({self.dev_name}) No active ffmpeg process found needing termination in cleanup." - ) - - try: - # 1. Probe the stream to get essential info (width, height) - stream_info = self._get_stream_info() - width = stream_info["width"] - height = stream_info["height"] - # Calculate expected bytes per frame (BGR format = 3 bytes per pixel) - frame_size = width * height * 3 - - # 2. Main loop: Start ffmpeg and read frames. Retry on failure. - while not should_stop.is_set(): - try: - # Start or reuse the ffmpeg process - with self._lock: - # Check if another thread/subscription already started the process - if self._ffmpeg_process and self._ffmpeg_process.poll() is None: - logger.warning( - f"({self.dev_name}) ffmpeg process (PID: {self._ffmpeg_process.pid}) seems to be already running. Reusing." - ) - process = self._ffmpeg_process - else: - # Start a new ffmpeg process - process = self._start_ffmpeg_process(width, height) - # Store the new process handle in the shared class attribute - self._ffmpeg_process = process - - # 3. Frame reading loop - while not should_stop.is_set(): - # Read exactly one frame's worth of bytes - in_bytes = process.stdout.read(frame_size) - - if len(in_bytes) == 0: - # End of stream or process terminated unexpectedly - logger.warning( - f"({self.dev_name}) ffmpeg stdout returned 0 bytes. EOF or process terminated." - ) - process.wait(timeout=0.5) # Allow stderr to flush - stderr_data = process.stderr.read().decode("utf8", errors="ignore") - exit_code = process.poll() - logger.warning( - f"({self.dev_name}) ffmpeg process (PID: {process.pid}) exited with code {exit_code}. Stderr: {stderr_data}" - ) - # Break inner loop to trigger cleanup and potential restart - with self._lock: - # Clear the shared process handle if it matches the one that just exited - if ( - self._ffmpeg_process - and self._ffmpeg_process.pid == process.pid - ): - self._ffmpeg_process = None - process = None # Clear local process variable - break # Exit frame reading loop - - elif len(in_bytes) != frame_size: - # Received incomplete frame data - indicates a problem - msg = f"({self.dev_name}) Incomplete frame read. Expected {frame_size}, got {len(in_bytes)}. Stopping." - logger.error(msg) - observer.on_error(VideoFrameError(msg)) - should_stop.set() # Signal outer loop to stop - break # Exit frame reading loop - - # Convert the raw bytes to a NumPy array (height, width, channels) - frame = np.frombuffer(in_bytes, np.uint8).reshape((height, width, 3)) - # Emit the frame to subscribers - observer.on_next(frame) - - # 4. Handle ffmpeg process exit/crash (if not stopping deliberately) - if not should_stop.is_set() and process is None: - logger.info( - f"({self.dev_name}) ffmpeg process ended. Attempting reconnection in 5 seconds..." - ) - # Wait for a few seconds before trying to restart - time.sleep(5) - # Continue to the next iteration of the outer loop to restart - - except (VideoSourceError, ffmpeg.Error) as e: - # Errors during ffmpeg process start or severe runtime errors - logger.error( - f"({self.dev_name}) Unrecoverable ffmpeg error: {e}. Stopping emission." - ) - observer.on_error(e) - should_stop.set() # Stop retrying - except Exception as e: - # Catch other unexpected errors during frame reading/processing - logger.error( - f"({self.dev_name}) Unexpected error processing stream: {e}", - exc_info=True, - ) - observer.on_error(VideoFrameError(f"Frame processing failed: {e}")) - should_stop.set() # Stop retrying - - # 5. Loop finished (likely due to should_stop being set) - logger.info(f"({self.dev_name}) Frame emission loop stopped.") - observer.on_completed() - - except VideoSourceError as e: - # Handle errors during the initial probing phase - logger.error(f"({self.dev_name}) Failed initial setup: {e}") - observer.on_error(e) - except Exception as e: - # Catch-all for unexpected errors before the main loop starts - logger.error(f"({self.dev_name}) Unexpected setup error: {e}", exc_info=True) - observer.on_error(VideoSourceError(f"Setup failed: {e}")) - finally: - # Crucial: Ensure the ffmpeg process is terminated when the observable - # is completed, errored, or disposed. - logger.debug(f"({self.dev_name}) Entering finally block in emit_frames.") - cleanup_process() - - # Return a Disposable that, when called (by unsubscribe/dispose), - # signals the loop to stop. The finally block handles the actual cleanup. - return Disposable(should_stop.set) - - # Create the observable using rx.create, applying scheduling and sharing - return rx.create(emit_frames).pipe( - ops.subscribe_on(self.pool_scheduler), # Run the emit_frames logic on the pool - # ops.observe_on(self.pool_scheduler), # Optional: Switch thread for downstream operators - ops.share(), # Ensure multiple subscribers share the same ffmpeg process - ) - - def dispose_all(self) -> None: - """Disposes of all managed resources, including terminating the ffmpeg process.""" - logger.info(f"({self.dev_name}) dispose_all called.") - # Terminate the ffmpeg process using the same locked logic as cleanup - with self._lock: - process = self._ffmpeg_process # Get the current process handle - if process and process.poll() is None: - logger.info( - f"({self.dev_name}) Terminating ffmpeg process (PID: {process.pid}) via dispose_all." - ) - try: - process.terminate() - process.wait(timeout=1.0) - except subprocess.TimeoutExpired: - logger.warning( - f"({self.dev_name}) ffmpeg process (PID: {process.pid}) kill required in dispose_all." - ) - process.kill() - except Exception as e: - logger.error( - f"({self.dev_name}) Error during ffmpeg termination in dispose_all: {e}" - ) - finally: - self._ffmpeg_process = None # Clear handle after attempting termination - elif process: # Process exists but already terminated - logger.debug( - f"({self.dev_name}) ffmpeg process (PID: {process.pid}) already terminated in dispose_all." - ) - self._ffmpeg_process = None - else: - logger.debug( - f"({self.dev_name}) No active ffmpeg process found during dispose_all." - ) - - # Call the parent class's dispose_all to handle Rx Disposables - super().dispose_all() - - def __del__(self) -> None: - """Destructor attempts to clean up resources if not explicitly disposed.""" - # Logging in __del__ is generally discouraged due to interpreter state issues, - # but can be helpful for debugging resource leaks. Use print for robustness here if needed. - # print(f"DEBUG: ({self.dev_name}) __del__ called.") - self.dispose_all() diff --git a/build/lib/dimos/stream/stream_merger.py b/build/lib/dimos/stream/stream_merger.py deleted file mode 100644 index 6f854b2d80..0000000000 --- a/build/lib/dimos/stream/stream_merger.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, TypeVar, Tuple -from reactivex import Observable -from reactivex import operators as ops - -T = TypeVar("T") -Q = TypeVar("Q") - - -def create_stream_merger( - data_input_stream: Observable[T], text_query_stream: Observable[Q] -) -> Observable[Tuple[Q, List[T]]]: - """ - Creates a merged stream that combines the latest value from data_input_stream - with each value from text_query_stream. - - Args: - data_input_stream: Observable stream of data values - text_query_stream: Observable stream of query values - - Returns: - Observable that emits tuples of (query, latest_data) - """ - # Encompass any data items as a list for safe evaluation - safe_data_stream = data_input_stream.pipe( - # We don't modify the data, just pass it through in a list - # This avoids any boolean evaluation of arrays - ops.map(lambda x: [x]) - ) - - # Use safe_data_stream instead of raw data_input_stream - return text_query_stream.pipe(ops.with_latest_from(safe_data_stream)) diff --git a/build/lib/dimos/stream/video_operators.py b/build/lib/dimos/stream/video_operators.py deleted file mode 100644 index 78ba7518a1..0000000000 --- a/build/lib/dimos/stream/video_operators.py +++ /dev/null @@ -1,622 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime, timedelta -import cv2 -import numpy as np -from reactivex import Observable, Observer, create -from reactivex import operators as ops -from typing import Any, Callable, Tuple, Optional - -import zmq -import base64 -from enum import Enum - -from dimos.stream.frame_processor import FrameProcessor - - -class VideoOperators: - """Collection of video processing operators for reactive video streams.""" - - @staticmethod - def with_fps_sampling( - fps: int = 25, *, sample_interval: Optional[timedelta] = None, use_latest: bool = True - ) -> Callable[[Observable], Observable]: - """Creates an operator that samples frames at a specified rate. - - Creates a transformation operator that samples frames either by taking - the latest frame or the first frame in each interval. Provides frame - rate control through time-based selection. - - Args: - fps: Desired frames per second, defaults to 25 FPS. - Ignored if sample_interval is provided. - sample_interval: Optional explicit interval between samples. - If provided, overrides the fps parameter. - use_latest: If True, uses the latest frame in interval. - If False, uses the first frame. Defaults to True. - - Returns: - A function that transforms an Observable[np.ndarray] stream to a sampled - Observable[np.ndarray] stream with controlled frame rate. - - Raises: - ValueError: If fps is not positive or sample_interval is negative. - TypeError: If sample_interval is provided but not a timedelta object. - - Examples: - Sample latest frame at 30 FPS (good for real-time): - >>> video_stream.pipe( - ... VideoOperators.with_fps_sampling(fps=30) - ... ) - - Sample first frame with custom interval (good for consistent timing): - >>> video_stream.pipe( - ... VideoOperators.with_fps_sampling( - ... sample_interval=timedelta(milliseconds=40), - ... use_latest=False - ... ) - ... ) - - Note: - This operator helps manage high-speed video streams through time-based - frame selection. It reduces the frame rate by selecting frames at - specified intervals. - - When use_latest=True: - - Uses sampling to select the most recent frame at fixed intervals - - Discards intermediate frames, keeping only the latest - - Best for real-time video where latest frame is most relevant - - Uses ops.sample internally - - When use_latest=False: - - Uses throttling to select the first frame in each interval - - Ignores subsequent frames until next interval - - Best for scenarios where you want consistent frame timing - - Uses ops.throttle_first internally - - This is an approropriate solution for managing video frame rates and - memory usage in many scenarios. - """ - if sample_interval is None: - if fps <= 0: - raise ValueError("FPS must be positive") - sample_interval = timedelta(microseconds=int(1_000_000 / fps)) - - def _operator(source: Observable) -> Observable: - return source.pipe( - ops.sample(sample_interval) if use_latest else ops.throttle_first(sample_interval) - ) - - return _operator - - @staticmethod - def with_jpeg_export( - frame_processor: "FrameProcessor", - save_limit: int = 100, - suffix: str = "", - loop: bool = False, - ) -> Callable[[Observable], Observable]: - """Creates an operator that saves video frames as JPEG files. - - Creates a transformation operator that saves each frame from the video - stream as a JPEG file while passing the frame through unchanged. - - Args: - frame_processor: FrameProcessor instance that handles the JPEG export - operations and maintains file count. - save_limit: Maximum number of frames to save before stopping. - Defaults to 100. Set to 0 for unlimited saves. - suffix: Optional string to append to filename before index. - Example: "raw" creates "1_raw.jpg". - Defaults to empty string. - loop: If True, when save_limit is reached, the files saved are - loopbacked and overwritten with the most recent frame. - Defaults to False. - Returns: - A function that transforms an Observable of frames into another - Observable of the same frames, with side effect of saving JPEGs. - - Raises: - ValueError: If save_limit is negative. - TypeError: If frame_processor is not a FrameProcessor instance. - - Example: - >>> video_stream.pipe( - ... VideoOperators.with_jpeg_export(processor, suffix="raw") - ... ) - """ - - def _operator(source: Observable) -> Observable: - return source.pipe( - ops.map( - lambda frame: frame_processor.export_to_jpeg(frame, save_limit, loop, suffix) - ) - ) - - return _operator - - @staticmethod - def with_optical_flow_filtering(threshold: float = 1.0) -> Callable[[Observable], Observable]: - """Creates an operator that filters optical flow frames by relevancy score. - - Filters a stream of optical flow results (frame, relevancy_score) tuples, - passing through only frames that meet the relevancy threshold. - - Args: - threshold: Minimum relevancy score required for frames to pass through. - Defaults to 1.0. Higher values mean more motion required. - - Returns: - A function that transforms an Observable of (frame, score) tuples - into an Observable of frames that meet the threshold. - - Raises: - ValueError: If threshold is negative. - TypeError: If input stream items are not (frame, float) tuples. - - Examples: - Basic filtering: - >>> optical_flow_stream.pipe( - ... VideoOperators.with_optical_flow_filtering(threshold=1.0) - ... ) - - With custom threshold: - >>> optical_flow_stream.pipe( - ... VideoOperators.with_optical_flow_filtering(threshold=2.5) - ... ) - - Note: - Input stream should contain tuples of (frame, relevancy_score) where - frame is a numpy array and relevancy_score is a float or None. - None scores are filtered out. - """ - return lambda source: source.pipe( - ops.filter(lambda result: result[1] is not None), - ops.filter(lambda result: result[1] > threshold), - ops.map(lambda result: result[0]), - ) - - @staticmethod - def with_edge_detection( - frame_processor: "FrameProcessor", - ) -> Callable[[Observable], Observable]: - return lambda source: source.pipe( - ops.map(lambda frame: frame_processor.edge_detection(frame)) - ) - - @staticmethod - def with_optical_flow( - frame_processor: "FrameProcessor", - ) -> Callable[[Observable], Observable]: - return lambda source: source.pipe( - ops.scan( - lambda acc, frame: frame_processor.compute_optical_flow( - acc, frame, compute_relevancy=False - ), - (None, None, None), - ), - ops.map(lambda result: result[1]), # Extract flow component - ops.filter(lambda flow: flow is not None), - ops.map(frame_processor.visualize_flow), - ) - - @staticmethod - def with_zmq_socket( - socket: zmq.Socket, scheduler: Optional[Any] = None - ) -> Callable[[Observable], Observable]: - def send_frame(frame, socket): - _, img_encoded = cv2.imencode(".jpg", frame) - socket.send(img_encoded.tobytes()) - # print(f"Frame received: {frame.shape}") - - # Use a default scheduler if none is provided - if scheduler is None: - from reactivex.scheduler import ThreadPoolScheduler - - scheduler = ThreadPoolScheduler(1) # Single-threaded pool for isolation - - return lambda source: source.pipe( - ops.observe_on(scheduler), # Ensure this part runs on its own thread - ops.do_action(lambda frame: send_frame(frame, socket)), - ) - - @staticmethod - def encode_image() -> Callable[[Observable], Observable]: - """ - Operator to encode an image to JPEG format and convert it to a Base64 string. - - Returns: - A function that transforms an Observable of images into an Observable - of tuples containing the Base64 string of the encoded image and its dimensions. - """ - - def _operator(source: Observable) -> Observable: - def _encode_image(image: np.ndarray) -> Tuple[str, Tuple[int, int]]: - try: - width, height = image.shape[:2] - _, buffer = cv2.imencode(".jpg", image) - if buffer is None: - raise ValueError("Failed to encode image") - base64_image = base64.b64encode(buffer).decode("utf-8") - return base64_image, (width, height) - except Exception as e: - raise e - - return source.pipe(ops.map(_encode_image)) - - return _operator - - -from reactivex.disposable import Disposable -from reactivex import Observable -from threading import Lock - - -class Operators: - @staticmethod - def exhaust_lock(process_item): - """ - For each incoming item, call `process_item(item)` to get an Observable. - - If we're busy processing the previous one, skip new items. - - Use a lock to ensure concurrency safety across threads. - """ - - def _exhaust_lock(source: Observable) -> Observable: - def _subscribe(observer, scheduler=None): - in_flight = False - lock = Lock() - upstream_done = False - - upstream_disp = None - active_inner_disp = None - - def dispose_all(): - if upstream_disp: - upstream_disp.dispose() - if active_inner_disp: - active_inner_disp.dispose() - - def on_next(value): - nonlocal in_flight, active_inner_disp - lock.acquire() - try: - if not in_flight: - in_flight = True - print("Processing new item.") - else: - print("Skipping item, already processing.") - return - finally: - lock.release() - - # We only get here if we grabbed the in_flight slot - try: - inner_source = process_item(value) - except Exception as ex: - observer.on_error(ex) - return - - def inner_on_next(ivalue): - observer.on_next(ivalue) - - def inner_on_error(err): - nonlocal in_flight - with lock: - in_flight = False - observer.on_error(err) - - def inner_on_completed(): - nonlocal in_flight - with lock: - in_flight = False - if upstream_done: - observer.on_completed() - - # Subscribe to the inner observable - nonlocal active_inner_disp - active_inner_disp = inner_source.subscribe( - on_next=inner_on_next, - on_error=inner_on_error, - on_completed=inner_on_completed, - scheduler=scheduler, - ) - - def on_error(err): - dispose_all() - observer.on_error(err) - - def on_completed(): - nonlocal upstream_done - with lock: - upstream_done = True - # If we're not busy, we can end now - if not in_flight: - observer.on_completed() - - upstream_disp = source.subscribe( - on_next, on_error, on_completed, scheduler=scheduler - ) - return dispose_all - - return create(_subscribe) - - return _exhaust_lock - - @staticmethod - def exhaust_lock_per_instance(process_item, lock: Lock): - """ - - For each item from upstream, call process_item(item) -> Observable. - - If a frame arrives while one is "in flight", discard it. - - 'lock' ensures we safely check/modify the 'in_flight' state in a multithreaded environment. - """ - - def _exhaust_lock(source: Observable) -> Observable: - def _subscribe(observer, scheduler=None): - in_flight = False - upstream_done = False - - upstream_disp = None - active_inner_disp = None - - def dispose_all(): - if upstream_disp: - upstream_disp.dispose() - if active_inner_disp: - active_inner_disp.dispose() - - def on_next(value): - nonlocal in_flight, active_inner_disp - with lock: - # If not busy, claim the slot - if not in_flight: - in_flight = True - print("\033[34mProcessing new item.\033[0m") - else: - # Already processing => drop - print("\033[34mSkipping item, already processing.\033[0m") - return - - # We only get here if we acquired the slot - try: - inner_source = process_item(value) - except Exception as ex: - observer.on_error(ex) - return - - def inner_on_next(ivalue): - observer.on_next(ivalue) - - def inner_on_error(err): - nonlocal in_flight - with lock: - in_flight = False - print("\033[34mError in inner on error.\033[0m") - observer.on_error(err) - - def inner_on_completed(): - nonlocal in_flight - with lock: - in_flight = False - print("\033[34mInner on completed.\033[0m") - if upstream_done: - observer.on_completed() - - # Subscribe to the inner Observable - nonlocal active_inner_disp - active_inner_disp = inner_source.subscribe( - on_next=inner_on_next, - on_error=inner_on_error, - on_completed=inner_on_completed, - scheduler=scheduler, - ) - - def on_error(e): - dispose_all() - observer.on_error(e) - - def on_completed(): - nonlocal upstream_done - with lock: - upstream_done = True - print("\033[34mOn completed.\033[0m") - if not in_flight: - observer.on_completed() - - upstream_disp = source.subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed, - scheduler=scheduler, - ) - - return Disposable(dispose_all) - - return create(_subscribe) - - return _exhaust_lock - - @staticmethod - def exhaust_map(project): - def _exhaust_map(source: Observable): - def subscribe(observer, scheduler=None): - is_processing = False - - def on_next(item): - nonlocal is_processing - if not is_processing: - is_processing = True - print("\033[35mProcessing item.\033[0m") - try: - inner_observable = project(item) # Create the inner observable - inner_observable.subscribe( - on_next=observer.on_next, - on_error=observer.on_error, - on_completed=lambda: set_not_processing(), - scheduler=scheduler, - ) - except Exception as e: - observer.on_error(e) - else: - print("\033[35mSkipping item, already processing.\033[0m") - - def set_not_processing(): - nonlocal is_processing - is_processing = False - print("\033[35mItem processed.\033[0m") - - return source.subscribe( - on_next=on_next, - on_error=observer.on_error, - on_completed=observer.on_completed, - scheduler=scheduler, - ) - - return create(subscribe) - - return _exhaust_map - - @staticmethod - def with_lock(lock: Lock): - def operator(source: Observable): - def subscribe(observer, scheduler=None): - def on_next(item): - if not lock.locked(): # Check if the lock is free - if lock.acquire(blocking=False): # Non-blocking acquire - try: - print("\033[32mAcquired lock, processing item.\033[0m") - observer.on_next(item) - finally: # Ensure lock release even if observer.on_next throws - lock.release() - else: - print("\033[34mLock busy, skipping item.\033[0m") - else: - print("\033[34mLock busy, skipping item.\033[0m") - - def on_error(error): - observer.on_error(error) - - def on_completed(): - observer.on_completed() - - return source.subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed, - scheduler=scheduler, - ) - - return Observable(subscribe) - - return operator - - @staticmethod - def with_lock_check(lock: Lock): # Renamed for clarity - def operator(source: Observable): - def subscribe(observer, scheduler=None): - def on_next(item): - if not lock.locked(): # Check if the lock is held WITHOUT acquiring - print(f"\033[32mLock is free, processing item: {item}\033[0m") - observer.on_next(item) - else: - print(f"\033[34mLock is busy, skipping item: {item}\033[0m") - # observer.on_completed() - - def on_error(error): - observer.on_error(error) - - def on_completed(): - observer.on_completed() - - return source.subscribe( - on_next=on_next, - on_error=on_error, - on_completed=on_completed, - scheduler=scheduler, - ) - - return Observable(subscribe) - - return operator - - # PrintColor enum for standardized color formatting - class PrintColor(Enum): - RED = "\033[31m" - GREEN = "\033[32m" - BLUE = "\033[34m" - YELLOW = "\033[33m" - MAGENTA = "\033[35m" - CYAN = "\033[36m" - WHITE = "\033[37m" - RESET = "\033[0m" - - @staticmethod - def print_emission( - id: str, - dev_name: str = "NA", - counts: dict = None, - color: "Operators.PrintColor" = None, - enabled: bool = True, - ): - """ - Creates an operator that prints the emission with optional counts for debugging. - - Args: - id: Identifier for the emission point (e.g., 'A', 'B') - dev_name: Device or component name for context - counts: External dictionary to track emission count across operators. If None, will not print emission count. - color: Color for the printed output from PrintColor enum (default is RED) - enabled: Whether to print the emission count (default is True) - Returns: - An operator that counts and prints emissions without modifying the stream - """ - # If enabled is false, return the source unchanged - if not enabled: - return lambda source: source - - # Use RED as default if no color provided - if color is None: - color = Operators.PrintColor.RED - - def _operator(source: Observable) -> Observable: - def _subscribe(observer: Observer, scheduler=None): - def on_next(value): - if counts is not None: - # Initialize count if necessary - if id not in counts: - counts[id] = 0 - - # Increment and print - counts[id] += 1 - print( - f"{color.value}({dev_name} - {id}) Emission Count - {counts[id]} {datetime.now()}{Operators.PrintColor.RESET.value}" - ) - else: - print( - f"{color.value}({dev_name} - {id}) Emitted - {datetime.now()}{Operators.PrintColor.RESET.value}" - ) - - # Pass value through unchanged - observer.on_next(value) - - return source.subscribe( - on_next=on_next, - on_error=observer.on_error, - on_completed=observer.on_completed, - scheduler=scheduler, - ) - - return create(_subscribe) - - return _operator diff --git a/build/lib/dimos/stream/video_provider.py b/build/lib/dimos/stream/video_provider.py deleted file mode 100644 index 050905a024..0000000000 --- a/build/lib/dimos/stream/video_provider.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Video provider module for capturing and streaming video frames. - -This module provides classes for capturing video from various sources and -exposing them as reactive observables. It handles resource management, -frame rate control, and thread safety. -""" - -# Standard library imports -import logging -import os -import time -from abc import ABC, abstractmethod -from threading import Lock -from typing import Optional - -# Third-party imports -import cv2 -import reactivex as rx -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.observable import Observable -from reactivex.scheduler import ThreadPoolScheduler - -# Local imports -from dimos.utils.threadpool import get_scheduler - -# Note: Logging configuration should ideally be in the application initialization, -# not in a module. Keeping it for now but with a more restricted scope. -logger = logging.getLogger(__name__) - - -# Specific exception classes -class VideoSourceError(Exception): - """Raised when there's an issue with the video source.""" - - pass - - -class VideoFrameError(Exception): - """Raised when there's an issue with frame acquisition.""" - - pass - - -class AbstractVideoProvider(ABC): - """Abstract base class for video providers managing video capture resources.""" - - def __init__( - self, dev_name: str = "NA", pool_scheduler: Optional[ThreadPoolScheduler] = None - ) -> None: - """Initializes the video provider with a device name. - - Args: - dev_name: The name of the device. Defaults to "NA". - pool_scheduler: The scheduler to use for thread pool operations. - If None, the global scheduler from get_scheduler() will be used. - """ - self.dev_name = dev_name - self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() - self.disposables = CompositeDisposable() - - @abstractmethod - def capture_video_as_observable(self, fps: int = 30) -> Observable: - """Create an observable from video capture. - - Args: - fps: Frames per second to emit. Defaults to 30fps. - - Returns: - Observable: An observable emitting frames at the specified rate. - - Raises: - VideoSourceError: If the video source cannot be opened. - VideoFrameError: If frames cannot be read properly. - """ - pass - - def dispose_all(self) -> None: - """Disposes of all active subscriptions managed by this provider.""" - if self.disposables: - self.disposables.dispose() - else: - logger.info("No disposables to dispose.") - - def __del__(self) -> None: - """Destructor to ensure resources are cleaned up if not explicitly disposed.""" - self.dispose_all() - - -class VideoProvider(AbstractVideoProvider): - """Video provider implementation for capturing video as an observable.""" - - def __init__( - self, - dev_name: str, - video_source: str = f"{os.getcwd()}/assets/video-f30-480p.mp4", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - ) -> None: - """Initializes the video provider with a device name and video source. - - Args: - dev_name: The name of the device. - video_source: The path to the video source. Defaults to a sample video. - pool_scheduler: The scheduler to use for thread pool operations. - If None, the global scheduler from get_scheduler() will be used. - """ - super().__init__(dev_name, pool_scheduler) - self.video_source = video_source - self.cap = None - self.lock = Lock() - - def _initialize_capture(self) -> None: - """Initializes the video capture object if not already initialized. - - Raises: - VideoSourceError: If the video source cannot be opened. - """ - if self.cap is None or not self.cap.isOpened(): - # Release previous capture if it exists but is closed - if self.cap: - self.cap.release() - logger.info("Released previous capture") - - # Attempt to open new capture - self.cap = cv2.VideoCapture(self.video_source) - if self.cap is None or not self.cap.isOpened(): - error_msg = f"Failed to open video source: {self.video_source}" - logger.error(error_msg) - raise VideoSourceError(error_msg) - - logger.info(f"Opened new capture: {self.video_source}") - - def capture_video_as_observable(self, realtime: bool = True, fps: int = 30) -> Observable: - """Creates an observable from video capture. - - Creates an observable that emits frames at specified FPS or the video's - native FPS, with proper resource management and error handling. - - Args: - realtime: If True, use the video's native FPS. Defaults to True. - fps: Frames per second to emit. Defaults to 30fps. Only used if - realtime is False or the video's native FPS is not available. - - Returns: - Observable: An observable emitting frames at the configured rate. - - Raises: - VideoSourceError: If the video source cannot be opened. - VideoFrameError: If frames cannot be read properly. - """ - - def emit_frames(observer, scheduler): - try: - self._initialize_capture() - - # Determine the FPS to use based on configuration and availability - local_fps: float = fps - if realtime: - native_fps: float = self.cap.get(cv2.CAP_PROP_FPS) - if native_fps > 0: - local_fps = native_fps - else: - logger.warning("Native FPS not available, defaulting to specified FPS") - - frame_interval: float = 1.0 / local_fps - frame_time: float = time.monotonic() - - while self.cap.isOpened(): - # Thread-safe access to video capture - with self.lock: - ret, frame = self.cap.read() - - if not ret: - # Loop video when we reach the end - logger.warning("End of video reached, restarting playback") - with self.lock: - self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) - continue - - # Control frame rate to match target FPS - now: float = time.monotonic() - next_frame_time: float = frame_time + frame_interval - sleep_time: float = next_frame_time - now - - if sleep_time > 0: - time.sleep(sleep_time) - - observer.on_next(frame) - frame_time = next_frame_time - - except VideoSourceError as e: - logger.error(f"Video source error: {e}") - observer.on_error(e) - except Exception as e: - logger.error(f"Unexpected error during frame emission: {e}") - observer.on_error(VideoFrameError(f"Frame acquisition failed: {e}")) - finally: - # Clean up resources regardless of success or failure - with self.lock: - if self.cap and self.cap.isOpened(): - self.cap.release() - logger.info("Capture released") - observer.on_completed() - - return rx.create(emit_frames).pipe( - ops.subscribe_on(self.pool_scheduler), - ops.observe_on(self.pool_scheduler), - ops.share(), # Share the stream among multiple subscribers - ) - - def dispose_all(self) -> None: - """Disposes of all resources including video capture.""" - with self.lock: - if self.cap and self.cap.isOpened(): - self.cap.release() - logger.info("Capture released in dispose_all") - super().dispose_all() - - def __del__(self) -> None: - """Destructor to ensure resources are cleaned up if not explicitly disposed.""" - self.dispose_all() diff --git a/build/lib/dimos/stream/video_providers/__init__.py b/build/lib/dimos/stream/video_providers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/stream/video_providers/unitree.py b/build/lib/dimos/stream/video_providers/unitree.py deleted file mode 100644 index e1a7587146..0000000000 --- a/build/lib/dimos/stream/video_providers/unitree.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.video_provider import AbstractVideoProvider - -from queue import Queue -from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod -from aiortc import MediaStreamTrack -import asyncio -from reactivex import Observable, create, operators as ops -import logging -import threading -import time - - -class UnitreeVideoProvider(AbstractVideoProvider): - def __init__( - self, - dev_name: str = "UnitreeGo2", - connection_method: WebRTCConnectionMethod = WebRTCConnectionMethod.LocalSTA, - serial_number: str = None, - ip: str = None, - ): - """Initialize the Unitree video stream with WebRTC connection. - - Args: - dev_name: Name of the device - connection_method: WebRTC connection method (LocalSTA, LocalAP, Remote) - serial_number: Serial number of the robot (required for LocalSTA with serial) - ip: IP address of the robot (required for LocalSTA with IP) - """ - super().__init__(dev_name) - self.frame_queue = Queue() - self.loop = None - self.asyncio_thread = None - - # Initialize WebRTC connection based on method - if connection_method == WebRTCConnectionMethod.LocalSTA: - if serial_number: - self.conn = Go2WebRTCConnection(connection_method, serialNumber=serial_number) - elif ip: - self.conn = Go2WebRTCConnection(connection_method, ip=ip) - else: - raise ValueError( - "Either serial_number or ip must be provided for LocalSTA connection" - ) - elif connection_method == WebRTCConnectionMethod.LocalAP: - self.conn = Go2WebRTCConnection(connection_method) - else: - raise ValueError("Unsupported connection method") - - async def _recv_camera_stream(self, track: MediaStreamTrack): - """Receive video frames from WebRTC and put them in the queue.""" - while True: - frame = await track.recv() - # Convert the frame to a NumPy array in BGR format - img = frame.to_ndarray(format="bgr24") - self.frame_queue.put(img) - - def _run_asyncio_loop(self, loop): - """Run the asyncio event loop in a separate thread.""" - asyncio.set_event_loop(loop) - - async def setup(): - try: - await self.conn.connect() - self.conn.video.switchVideoChannel(True) - self.conn.video.add_track_callback(self._recv_camera_stream) - - await self.conn.datachannel.switchToNormalMode() - # await self.conn.datachannel.sendDamp() - - # await asyncio.sleep(5) - - # await self.conn.datachannel.sendDamp() - # await asyncio.sleep(5) - # await self.conn.datachannel.sendStandUp() - # await asyncio.sleep(5) - - # Wiggle the robot - # await self.conn.datachannel.switchToNormalMode() - # await self.conn.datachannel.sendWiggle() - # await asyncio.sleep(3) - - # Stretch the robot - # await self.conn.datachannel.sendStretch() - # await asyncio.sleep(3) - - except Exception as e: - logging.error(f"Error in WebRTC connection: {e}") - raise - - loop.run_until_complete(setup()) - loop.run_forever() - - def capture_video_as_observable(self, fps: int = 30) -> Observable: - """Create an observable that emits video frames at the specified FPS. - - Args: - fps: Frames per second to emit (default: 30) - - Returns: - Observable emitting video frames - """ - frame_interval = 1.0 / fps - - def emit_frames(observer, scheduler): - try: - # Start asyncio loop if not already running - if not self.loop: - self.loop = asyncio.new_event_loop() - self.asyncio_thread = threading.Thread( - target=self._run_asyncio_loop, args=(self.loop,) - ) - self.asyncio_thread.start() - - frame_time = time.monotonic() - - while True: - if not self.frame_queue.empty(): - frame = self.frame_queue.get() - - # Control frame rate - now = time.monotonic() - next_frame_time = frame_time + frame_interval - sleep_time = next_frame_time - now - - if sleep_time > 0: - time.sleep(sleep_time) - - observer.on_next(frame) - frame_time = next_frame_time - else: - time.sleep(0.001) # Small sleep to prevent CPU overuse - - except Exception as e: - logging.error(f"Error during frame emission: {e}") - observer.on_error(e) - finally: - if self.loop: - self.loop.call_soon_threadsafe(self.loop.stop) - if self.asyncio_thread: - self.asyncio_thread.join() - observer.on_completed() - - return create(emit_frames).pipe( - ops.share() # Share the stream among multiple subscribers - ) - - def dispose_all(self): - """Clean up resources.""" - if self.loop: - self.loop.call_soon_threadsafe(self.loop.stop) - if self.asyncio_thread: - self.asyncio_thread.join() - super().dispose_all() diff --git a/build/lib/dimos/stream/videostream.py b/build/lib/dimos/stream/videostream.py deleted file mode 100644 index ee63261ae6..0000000000 --- a/build/lib/dimos/stream/videostream.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 - - -class VideoStream: - def __init__(self, source=0): - """ - Initialize the video stream from a camera source. - - Args: - source (int or str): Camera index or video file path. - """ - self.capture = cv2.VideoCapture(source) - if not self.capture.isOpened(): - raise ValueError(f"Unable to open video source {source}") - - def __iter__(self): - return self - - def __next__(self): - ret, frame = self.capture.read() - if not ret: - self.capture.release() - raise StopIteration - return frame - - def release(self): - self.capture.release() diff --git a/build/lib/dimos/types/__init__.py b/build/lib/dimos/types/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/types/constants.py b/build/lib/dimos/types/constants.py deleted file mode 100644 index 91841e8bef..0000000000 --- a/build/lib/dimos/types/constants.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class Colors: - GREEN_PRINT_COLOR: str = "\033[32m" - YELLOW_PRINT_COLOR: str = "\033[33m" - RED_PRINT_COLOR: str = "\033[31m" - BLUE_PRINT_COLOR: str = "\033[34m" - MAGENTA_PRINT_COLOR: str = "\033[35m" - CYAN_PRINT_COLOR: str = "\033[36m" - WHITE_PRINT_COLOR: str = "\033[37m" - RESET_COLOR: str = "\033[0m" diff --git a/build/lib/dimos/types/costmap.py b/build/lib/dimos/types/costmap.py deleted file mode 100644 index 2d9b1c433e..0000000000 --- a/build/lib/dimos/types/costmap.py +++ /dev/null @@ -1,534 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import base64 -import pickle -import math -import numpy as np -from typing import Optional -from scipy import ndimage -from dimos.types.ros_polyfill import OccupancyGrid -from scipy.ndimage import binary_dilation -from dimos.types.vector import Vector, VectorLike, x, y, to_vector -import open3d as o3d -from matplotlib.path import Path -from PIL import Image -import cv2 - -DTYPE2STR = { - np.float32: "f32", - np.float64: "f64", - np.int32: "i32", - np.int8: "i8", -} - -STR2DTYPE = {v: k for k, v in DTYPE2STR.items()} - - -class CostValues: - """Standard cost values for occupancy grid cells.""" - - FREE = 0 # Free space - UNKNOWN = -1 # Unknown space - OCCUPIED = 100 # Occupied/lethal space - - -def encode_ndarray(arr: np.ndarray, compress: bool = False): - arr_c = np.ascontiguousarray(arr) - payload = arr_c.tobytes() - b64 = base64.b64encode(payload).decode("ascii") - - return { - "type": "grid", - "shape": arr_c.shape, - "dtype": DTYPE2STR[arr_c.dtype.type], - "data": b64, - } - - -class Costmap: - """Class to hold ROS OccupancyGrid data.""" - - def __init__( - self, - grid: np.ndarray, - origin: VectorLike, - origin_theta: float = 0, - resolution: float = 0.05, - ): - """Initialize Costmap with its core attributes.""" - self.grid = grid - self.resolution = resolution - self.origin = to_vector(origin).to_2d() - self.origin_theta = origin_theta - self.width = self.grid.shape[1] - self.height = self.grid.shape[0] - - def serialize(self) -> tuple: - """Serialize the Costmap instance to a tuple.""" - return { - "type": "costmap", - "grid": encode_ndarray(self.grid), - "origin": self.origin.serialize(), - "resolution": self.resolution, - "origin_theta": self.origin_theta, - } - - @classmethod - def from_msg(cls, costmap_msg: OccupancyGrid) -> "Costmap": - """Create a Costmap instance from a ROS OccupancyGrid message.""" - if costmap_msg is None: - raise Exception("need costmap msg") - - # Extract info from the message - width = costmap_msg.info.width - height = costmap_msg.info.height - resolution = costmap_msg.info.resolution - - # Get origin position as a vector-like object - origin = ( - costmap_msg.info.origin.position.x, - costmap_msg.info.origin.position.y, - ) - - # Calculate orientation from quaternion - qx = costmap_msg.info.origin.orientation.x - qy = costmap_msg.info.origin.orientation.y - qz = costmap_msg.info.origin.orientation.z - qw = costmap_msg.info.origin.orientation.w - origin_theta = math.atan2(2.0 * (qw * qz + qx * qy), 1.0 - 2.0 * (qy * qy + qz * qz)) - - # Convert to numpy array - data = np.array(costmap_msg.data, dtype=np.int8) - grid = data.reshape((height, width)) - - return cls( - grid=grid, - resolution=resolution, - origin=origin, - origin_theta=origin_theta, - ) - - def save_pickle(self, pickle_path: str): - """Save costmap to a pickle file. - - Args: - pickle_path: Path to save the pickle file - """ - with open(pickle_path, "wb") as f: - pickle.dump(self, f) - - @classmethod - def from_pickle(cls, pickle_path: str) -> "Costmap": - """Load costmap from a pickle file containing either a Costmap object or constructor arguments.""" - with open(pickle_path, "rb") as f: - data = pickle.load(f) - - # Check if data is already a Costmap object - if isinstance(data, cls): - return data - else: - # Assume it's constructor arguments - costmap = cls(*data) - return costmap - - @classmethod - def create_empty( - cls, width: int = 100, height: int = 100, resolution: float = 0.1 - ) -> "Costmap": - """Create an empty costmap with specified dimensions.""" - return cls( - grid=np.zeros((height, width), dtype=np.int8), - resolution=resolution, - origin=(0.0, 0.0), - origin_theta=0.0, - ) - - def world_to_grid(self, point: VectorLike) -> Vector: - """Convert world coordinates to grid coordinates. - - Args: - point: A vector-like object containing X,Y coordinates - - Returns: - Tuple of (grid_x, grid_y) as integers - """ - return (to_vector(point) - self.origin) / self.resolution - - def grid_to_world(self, grid_point: VectorLike) -> Vector: - return to_vector(grid_point) * self.resolution + self.origin - - def is_occupied(self, point: VectorLike, threshold: int = 50) -> bool: - """Check if a position in world coordinates is occupied. - - Args: - point: Vector-like object containing X,Y coordinates - threshold: Cost threshold above which a cell is considered occupied (0-100) - - Returns: - True if position is occupied or out of bounds, False otherwise - """ - grid_point = self.world_to_grid(point) - grid_x, grid_y = int(grid_point.x), int(grid_point.y) - if 0 <= grid_x < self.width and 0 <= grid_y < self.height: - # Consider unknown (-1) as unoccupied for navigation purposes - value = self.grid[grid_y, grid_x] - return value >= threshold - return True # Consider out-of-bounds as occupied - - def get_value(self, point: VectorLike) -> Optional[int]: - point = self.world_to_grid(point) - - if 0 <= point.x < self.width and 0 <= point.y < self.height: - return int(self.grid[int(point.y), int(point.x)]) - return None - - def set_value(self, point: VectorLike, value: int = 0) -> bool: - point = self.world_to_grid(point) - - if 0 <= point.x < self.width and 0 <= point.y < self.height: - self.grid[int(point.y), int(point.x)] = value - return value - return False - - def smudge( - self, - kernel_size: int = 7, - iterations: int = 25, - decay_factor: float = 0.9, - threshold: int = 90, - preserve_unknown: bool = False, - ) -> "Costmap": - """ - Creates a new costmap with expanded obstacles (smudged). - - Args: - kernel_size: Size of the convolution kernel for dilation (must be odd) - iterations: Number of dilation iterations - decay_factor: Factor to reduce cost as distance increases (0.0-1.0) - threshold: Minimum cost value to consider as an obstacle for expansion - preserve_unknown: Whether to keep unknown (-1) cells as unknown - - Returns: - A new Costmap instance with expanded obstacles - """ - # Make sure kernel size is odd - if kernel_size % 2 == 0: - kernel_size += 1 - - # Create a copy of the grid for processing - grid_copy = self.grid.copy() - - # Create a mask of unknown cells if needed - unknown_mask = None - if preserve_unknown: - unknown_mask = grid_copy == -1 - # Temporarily replace unknown cells with 0 for processing - # This allows smudging to go over unknown areas - grid_copy[unknown_mask] = 0 - - # Create a mask of cells that are above the threshold - obstacle_mask = grid_copy >= threshold - - # Create a binary map of obstacles - binary_map = obstacle_mask.astype(np.uint8) * 100 - - # Create a circular kernel for dilation (instead of square) - y, x = np.ogrid[ - -kernel_size // 2 : kernel_size // 2 + 1, - -kernel_size // 2 : kernel_size // 2 + 1, - ] - kernel = (x * x + y * y <= (kernel_size // 2) * (kernel_size // 2)).astype(np.uint8) - - # Create distance map using dilation - # Each iteration adds one 'ring' of cells around obstacles - dilated_map = binary_map.copy() - - # Store each layer of dilation with decreasing values - layers = [] - - # First layer is the original obstacle cells - layers.append(binary_map.copy()) - - for i in range(iterations): - # Dilate the binary map - dilated = ndimage.binary_dilation( - dilated_map > 0, structure=kernel, iterations=1 - ).astype(np.uint8) - - # Calculate the new layer (cells that were just added in this iteration) - new_layer = (dilated - (dilated_map > 0).astype(np.uint8)) * 100 - - # Apply decay factor based on distance from obstacle - new_layer = new_layer * (decay_factor ** (i + 1)) - - # Add to layers list - layers.append(new_layer) - - # Update dilated map for next iteration - dilated_map = dilated * 100 - - # Combine all layers to create a distance-based cost map - smudged_map = np.zeros_like(grid_copy) - for layer in layers: - # For each cell, keep the maximum value across all layers - smudged_map = np.maximum(smudged_map, layer) - - # Preserve original obstacles - smudged_map[obstacle_mask] = grid_copy[obstacle_mask] - - # When preserve_unknown is true, restore all original unknown cells - # This overlays unknown cells on top of the smudged map - if preserve_unknown and unknown_mask is not None: - smudged_map[unknown_mask] = -1 - - # Ensure cost values are in valid range (0-100) except for unknown (-1) - if preserve_unknown: - valid_cells = ~unknown_mask - smudged_map[valid_cells] = np.clip(smudged_map[valid_cells], 0, 100) - else: - smudged_map = np.clip(smudged_map, 0, 100) - - # Create a new costmap with the smudged grid - return Costmap( - grid=smudged_map.astype(np.int8), - resolution=self.resolution, - origin=self.origin, - origin_theta=self.origin_theta, - ) - - def subsample(self, subsample_factor: int = 2) -> "Costmap": - """ - Create a subsampled (lower resolution) version of the costmap. - - Args: - subsample_factor: Factor by which to reduce resolution (e.g., 2 = half resolution, 4 = quarter resolution) - - Returns: - New Costmap instance with reduced resolution - """ - if subsample_factor <= 1: - return self # No subsampling needed - - # Calculate new grid dimensions - new_height = self.height // subsample_factor - new_width = self.width // subsample_factor - - # Create new grid by subsampling - subsampled_grid = np.zeros((new_height, new_width), dtype=self.grid.dtype) - - # Sample every subsample_factor-th point - for i in range(new_height): - for j in range(new_width): - orig_i = i * subsample_factor - orig_j = j * subsample_factor - - # Take a small neighborhood and use the most conservative value - # (prioritize occupied > unknown > free for safety) - neighborhood = self.grid[ - orig_i : min(orig_i + subsample_factor, self.height), - orig_j : min(orig_j + subsample_factor, self.width), - ] - - # Priority: Occupied (100) > Unknown (-1) > Free (0) - if np.any(neighborhood == CostValues.OCCUPIED): - subsampled_grid[i, j] = CostValues.OCCUPIED - elif np.any(neighborhood == CostValues.UNKNOWN): - subsampled_grid[i, j] = CostValues.UNKNOWN - else: - subsampled_grid[i, j] = CostValues.FREE - - # Create new costmap with adjusted resolution and origin - new_resolution = self.resolution * subsample_factor - - return Costmap( - grid=subsampled_grid, - resolution=new_resolution, - origin=self.origin, # Origin stays the same - ) - - @property - def total_cells(self) -> int: - return self.width * self.height - - @property - def occupied_cells(self) -> int: - return np.sum(self.grid >= 0.1) - - @property - def unknown_cells(self) -> int: - return np.sum(self.grid == -1) - - @property - def free_cells(self) -> int: - return self.total_cells - self.occupied_cells - self.unknown_cells - - @property - def free_percent(self) -> float: - return (self.free_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 - - @property - def occupied_percent(self) -> float: - return (self.occupied_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 - - @property - def unknown_percent(self) -> float: - return (self.unknown_cells / self.total_cells) * 100 if self.total_cells > 0 else 0.0 - - def __str__(self) -> str: - """ - Create a string representation of the Costmap. - - Returns: - A formatted string with key costmap information - """ - - cell_info = [ - "▦ Costmap", - f"{self.width}x{self.height}", - f"({self.width * self.resolution:.1f}x{self.height * self.resolution:.1f}m @", - f"{1 / self.resolution:.0f}cm res)", - f"Origin: ({x(self.origin):.2f}, {y(self.origin):.2f})", - f"▣ {self.occupied_percent:.1f}%", - f"□ {self.free_percent:.1f}%", - f"◌ {self.unknown_percent:.1f}%", - ] - - return " ".join(cell_info) - - def costmap_to_image(self, image_path: str) -> None: - """ - Convert costmap to JPEG image with ROS-style coloring. - Free space: light grey, Obstacles: black, Unknown: dark gray - - Args: - image_path: Path to save the JPEG image - """ - # Create image array (height, width, 3 for RGB) - img_array = np.zeros((self.height, self.width, 3), dtype=np.uint8) - - # Apply ROS-style coloring based on costmap values - for i in range(self.height): - for j in range(self.width): - value = self.grid[i, j] - if value == CostValues.FREE: # Free space = light grey (205, 205, 205) - img_array[i, j] = [205, 205, 205] - elif value == CostValues.UNKNOWN: # Unknown = dark gray (128, 128, 128) - img_array[i, j] = [128, 128, 128] - elif value >= CostValues.OCCUPIED: # Occupied/obstacles = black (0, 0, 0) - img_array[i, j] = [0, 0, 0] - else: # Any other values (low cost) = light grey - img_array[i, j] = [205, 205, 205] - - # Flip vertically to match ROS convention (origin at bottom-left) - img_array = np.flipud(img_array) - - # Create PIL image and save as JPEG - img = Image.fromarray(img_array, "RGB") - img.save(image_path, "JPEG", quality=95) - print(f"Costmap image saved to: {image_path}") - - -def _inflate_lethal(costmap: np.ndarray, radius: int, lethal_val: int = 100) -> np.ndarray: - """Return *costmap* with lethal cells dilated by *radius* grid steps (circular).""" - if radius <= 0 or not np.any(costmap == lethal_val): - return costmap - - mask = costmap == lethal_val - dilated = mask.copy() - for dy in range(-radius, radius + 1): - for dx in range(-radius, radius + 1): - if dx * dx + dy * dy > radius * radius or (dx == 0 and dy == 0): - continue - dilated |= np.roll(mask, shift=(dy, dx), axis=(0, 1)) - - out = costmap.copy() - out[dilated] = lethal_val - return out - - -def pointcloud_to_costmap( - pcd: o3d.geometry.PointCloud, - *, - resolution: float = 0.05, - ground_z: float = 0.0, - obs_min_height: float = 0.15, - max_height: Optional[float] = 0.5, - inflate_radius_m: Optional[float] = None, - default_unknown: int = -1, - cost_free: int = 0, - cost_lethal: int = 100, -) -> tuple[np.ndarray, np.ndarray]: - """Rasterise *pcd* into a 2-D int8 cost-map with optional obstacle inflation. - - Grid origin is **aligned** to the `resolution` lattice so that when - `resolution == voxel_size` every voxel centroid lands squarely inside a cell - (no alternating blank lines). - """ - - pts = np.asarray(pcd.points, dtype=np.float32) - if pts.size == 0: - return np.full((1, 1), default_unknown, np.int8), np.zeros(2, np.float32) - - # 0. Ceiling filter -------------------------------------------------------- - if max_height is not None: - pts = pts[pts[:, 2] <= max_height] - if pts.size == 0: - return np.full((1, 1), default_unknown, np.int8), np.zeros(2, np.float32) - - # 1. Bounding box & aligned origin --------------------------------------- - xy_min = pts[:, :2].min(axis=0) - xy_max = pts[:, :2].max(axis=0) - - # Align origin to the resolution grid (anchor = 0,0) - origin = np.floor(xy_min / resolution) * resolution - - # Grid dimensions (inclusive) ------------------------------------------- - Nx, Ny = (np.ceil((xy_max - origin) / resolution).astype(int) + 1).tolist() - - # 2. Bin points ------------------------------------------------------------ - idx_xy = np.floor((pts[:, :2] - origin) / resolution).astype(np.int32) - np.clip(idx_xy[:, 0], 0, Nx - 1, out=idx_xy[:, 0]) - np.clip(idx_xy[:, 1], 0, Ny - 1, out=idx_xy[:, 1]) - - lin = idx_xy[:, 1] * Nx + idx_xy[:, 0] - z_max = np.full(Nx * Ny, -np.inf, np.float32) - np.maximum.at(z_max, lin, pts[:, 2]) - z_max = z_max.reshape(Ny, Nx) - - # 3. Cost rules ----------------------------------------------------------- - costmap = np.full_like(z_max, default_unknown, np.int8) - known = z_max != -np.inf - costmap[known] = cost_free - - lethal = z_max >= (ground_z + obs_min_height) - costmap[lethal] = cost_lethal - - # 4. Optional inflation ---------------------------------------------------- - if inflate_radius_m and inflate_radius_m > 0: - cells = int(np.ceil(inflate_radius_m / resolution)) - costmap = _inflate_lethal(costmap, cells, lethal_val=cost_lethal) - - return costmap, origin.astype(np.float32) - - -if __name__ == "__main__": - costmap = Costmap.from_pickle("costmapMsg.pickle") - print(costmap) - - # Create a smudged version of the costmap for better planning - smudged_costmap = costmap.smudge( - kernel_size=10, iterations=10, threshold=80, preserve_unknown=False - ) - - print(costmap) diff --git a/build/lib/dimos/types/label.py b/build/lib/dimos/types/label.py deleted file mode 100644 index ce037aed7a..0000000000 --- a/build/lib/dimos/types/label.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Any - - -class LabelType: - def __init__(self, labels: Dict[str, Any], metadata: Any = None): - """ - Initializes a standardized label type. - - Args: - labels (Dict[str, Any]): A dictionary of labels with descriptions. - metadata (Any, optional): Additional metadata related to the labels. - """ - self.labels = labels - self.metadata = metadata - - def get_label_descriptions(self): - """Return a list of label descriptions.""" - return [desc["description"] for desc in self.labels.values()] - - def save_to_json(self, filepath: str): - """Save the labels to a JSON file.""" - import json - - with open(filepath, "w") as f: - json.dump(self.labels, f, indent=4) diff --git a/build/lib/dimos/types/manipulation.py b/build/lib/dimos/types/manipulation.py deleted file mode 100644 index d61d73a7ed..0000000000 --- a/build/lib/dimos/types/manipulation.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum -from typing import Dict, List, Optional, Any, Union, TypedDict, Tuple, Literal -from dataclasses import dataclass, field, fields -from abc import ABC, abstractmethod -import uuid -import numpy as np -import time -from dimos.types.vector import Vector - - -class ConstraintType(Enum): - """Types of manipulation constraints.""" - - TRANSLATION = "translation" - ROTATION = "rotation" - FORCE = "force" - - -@dataclass -class AbstractConstraint(ABC): - """Base class for all manipulation constraints.""" - - description: str = "" - id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) - - -@dataclass -class TranslationConstraint(AbstractConstraint): - """Constraint parameters for translational movement along a single axis.""" - - translation_axis: Literal["x", "y", "z"] = None # Axis to translate along - reference_point: Optional[Vector] = None - bounds_min: Optional[Vector] = None # For bounded translation - bounds_max: Optional[Vector] = None # For bounded translation - target_point: Optional[Vector] = None # For relative positioning - - -@dataclass -class RotationConstraint(AbstractConstraint): - """Constraint parameters for rotational movement around a single axis.""" - - rotation_axis: Literal["roll", "pitch", "yaw"] = None # Axis to rotate around - start_angle: Optional[Vector] = None # Angle values applied to the specified rotation axis - end_angle: Optional[Vector] = None # Angle values applied to the specified rotation axis - pivot_point: Optional[Vector] = None # Point of rotation - secondary_pivot_point: Optional[Vector] = None # For double point rotations - - -@dataclass -class ForceConstraint(AbstractConstraint): - """Constraint parameters for force application.""" - - max_force: float = 0.0 # Maximum force in newtons - min_force: float = 0.0 # Minimum force in newtons - force_direction: Optional[Vector] = None # Direction of force application - - -class ObjectData(TypedDict, total=False): - """Data about an object in the manipulation scene.""" - - object_id: int # Unique ID for the object - bbox: List[float] # Bounding box [x1, y1, x2, y2] - depth: float # Depth in meters from Metric3d - confidence: float # Detection confidence - class_id: int # Class ID from the detector - label: str # Semantic label (e.g., 'cup', 'table') - movement_tolerance: float # (0.0 = immovable, 1.0 = freely movable) - segmentation_mask: np.ndarray # Binary mask of the object's pixels - position: Dict[str, float] # 3D position {x, y, z} - rotation: Dict[str, float] # 3D rotation {roll, pitch, yaw} - size: Dict[str, float] # Object dimensions {width, height} - - -class ManipulationMetadata(TypedDict, total=False): - """Typed metadata for manipulation constraints.""" - - timestamp: float - objects: Dict[str, ObjectData] - - -@dataclass -class ManipulationTaskConstraint: - """Set of constraints for a specific manipulation action.""" - - constraints: List[AbstractConstraint] = field(default_factory=list) - - def add_constraint(self, constraint: AbstractConstraint): - """Add a constraint to this set.""" - if constraint not in self.constraints: - self.constraints.append(constraint) - - def get_constraints(self) -> List[AbstractConstraint]: - """Get all constraints in this set.""" - return self.constraints - - -@dataclass -class ManipulationTask: - """Complete definition of a manipulation task.""" - - description: str - target_object: str # Semantic label of target object - target_point: Optional[Tuple[float, float]] = ( - None # (X,Y) point in pixel-space of the point to manipulate on target object - ) - metadata: ManipulationMetadata = field(default_factory=dict) - timestamp: float = field(default_factory=time.time) - task_id: str = "" - result: Optional[Dict[str, Any]] = None # Any result data from the task execution - constraints: Union[List[AbstractConstraint], ManipulationTaskConstraint, AbstractConstraint] = ( - field(default_factory=list) - ) - - def add_constraint(self, constraint: AbstractConstraint): - """Add a constraint to this manipulation task.""" - # If constraints is a ManipulationTaskConstraint object - if isinstance(self.constraints, ManipulationTaskConstraint): - self.constraints.add_constraint(constraint) - return - - # If constraints is a single AbstractConstraint, convert to list - if isinstance(self.constraints, AbstractConstraint): - self.constraints = [self.constraints, constraint] - return - - # If constraints is a list, append to it - # This will also handle empty lists (the default case) - self.constraints.append(constraint) - - def get_constraints(self) -> List[AbstractConstraint]: - """Get all constraints in this manipulation task.""" - # If constraints is a ManipulationTaskConstraint object - if isinstance(self.constraints, ManipulationTaskConstraint): - return self.constraints.get_constraints() - - # If constraints is a single AbstractConstraint, return as list - if isinstance(self.constraints, AbstractConstraint): - return [self.constraints] - - # If constraints is a list (including empty list), return it - return self.constraints diff --git a/build/lib/dimos/types/path.py b/build/lib/dimos/types/path.py deleted file mode 100644 index c87658182f..0000000000 --- a/build/lib/dimos/types/path.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from typing import List, Union, Tuple, Iterator, TypeVar -from dimos.types.vector import Vector - -T = TypeVar("T", bound="Path") - - -class Path: - """A class representing a path as a sequence of points.""" - - def __init__( - self, - points: Union[List[Vector], List[np.ndarray], List[Tuple], np.ndarray, None] = None, - ): - """Initialize a path from a list of points. - - Args: - points: List of Vector objects, numpy arrays, tuples, or a 2D numpy array where each row is a point. - If None, creates an empty path. - - Examples: - Path([Vector(1, 2), Vector(3, 4)]) # from Vector objects - Path([(1, 2), (3, 4)]) # from tuples - Path(np.array([[1, 2], [3, 4]])) # from 2D numpy array - """ - if points is None: - self._points = np.zeros((0, 0), dtype=float) - return - - if isinstance(points, np.ndarray) and points.ndim == 2: - # If already a 2D numpy array, use it directly - self._points = points.astype(float) - else: - # Convert various input types to numpy array - converted = [] - for p in points: - if isinstance(p, Vector): - converted.append(p.data) - else: - converted.append(p) - self._points = np.array(converted, dtype=float) - - def serialize(self) -> Tuple: - """Serialize the vector to a tuple.""" - return { - "type": "path", - "points": self._points.tolist(), - } - - @property - def points(self) -> np.ndarray: - """Get the path points as a numpy array.""" - return self._points - - def as_vectors(self) -> List[Vector]: - """Get the path points as Vector objects.""" - return [Vector(p) for p in self._points] - - def append(self, point: Union[Vector, np.ndarray, Tuple]) -> None: - """Append a point to the path. - - Args: - point: A Vector, numpy array, or tuple representing a point - """ - if isinstance(point, Vector): - point_data = point.data - else: - point_data = np.array(point, dtype=float) - - if len(self._points) == 0: - # If empty, create with correct dimensionality - self._points = np.array([point_data]) - else: - self._points = np.vstack((self._points, point_data)) - - def extend(self, points: Union[List[Vector], List[np.ndarray], List[Tuple], "Path"]) -> None: - """Extend the path with more points. - - Args: - points: List of points or another Path object - """ - if isinstance(points, Path): - if len(self._points) == 0: - self._points = points.points.copy() - else: - self._points = np.vstack((self._points, points.points)) - else: - for point in points: - self.append(point) - - def insert(self, index: int, point: Union[Vector, np.ndarray, Tuple]) -> None: - """Insert a point at a specific index. - - Args: - index: The index at which to insert the point - point: A Vector, numpy array, or tuple representing a point - """ - if isinstance(point, Vector): - point_data = point.data - else: - point_data = np.array(point, dtype=float) - - if len(self._points) == 0: - self._points = np.array([point_data]) - else: - self._points = np.insert(self._points, index, point_data, axis=0) - - def remove(self, index: int) -> np.ndarray: - """Remove and return the point at the given index. - - Args: - index: The index of the point to remove - - Returns: - The removed point as a numpy array - """ - point = self._points[index].copy() - self._points = np.delete(self._points, index, axis=0) - return point - - def clear(self) -> None: - """Remove all points from the path.""" - self._points = np.zeros( - (0, self._points.shape[1] if len(self._points) > 0 else 0), dtype=float - ) - - def length(self) -> float: - """Calculate the total length of the path. - - Returns: - The sum of the distances between consecutive points - """ - if len(self._points) < 2: - return 0.0 - - # Efficient vector calculation of consecutive point distances - diff = self._points[1:] - self._points[:-1] - segment_lengths = np.sqrt(np.sum(diff * diff, axis=1)) - return float(np.sum(segment_lengths)) - - def resample(self: T, point_spacing: float) -> T: - """Resample the path with approximately uniform spacing between points. - - Args: - point_spacing: The desired distance between consecutive points - - Returns: - A new Path object with resampled points - """ - if len(self._points) < 2 or point_spacing <= 0: - return self.__class__(self._points.copy()) - - resampled_points = [self._points[0].copy()] - accumulated_distance = 0.0 - - for i in range(1, len(self._points)): - current_point = self._points[i] - prev_point = self._points[i - 1] - segment_vector = current_point - prev_point - segment_length = np.linalg.norm(segment_vector) - - if segment_length < 1e-10: - continue - - direction = segment_vector / segment_length - - # Add points along this segment until we reach the end - while accumulated_distance + segment_length >= point_spacing: - # How far along this segment the next point should be - dist_along_segment = point_spacing - accumulated_distance - if dist_along_segment < 0: - break - - # Create the new point - new_point = prev_point + direction * dist_along_segment - resampled_points.append(new_point) - - # Update for next iteration - accumulated_distance = 0 - segment_length -= dist_along_segment - prev_point = new_point - - # Update the accumulated distance for the next segment - accumulated_distance += segment_length - - # Add the last point if it's not already there - if len(self._points) > 1: - last_point = self._points[-1] - if not np.array_equal(resampled_points[-1], last_point): - resampled_points.append(last_point.copy()) - - return self.__class__(np.array(resampled_points)) - - def simplify(self: T, tolerance: float) -> T: - """Simplify the path using the Ramer-Douglas-Peucker algorithm. - - Args: - tolerance: The maximum distance a point can deviate from the simplified path - - Returns: - A new simplified Path object - """ - if len(self._points) <= 2: - return self.__class__(self._points.copy()) - - # Implementation of Ramer-Douglas-Peucker algorithm - def rdp(points, epsilon, start, end): - if end <= start + 1: - return [start] - - # Find point with max distance from line - line_vec = points[end] - points[start] - line_length = np.linalg.norm(line_vec) - - if line_length < 1e-10: # If start and end points are the same - # Distance from next point to start point - max_dist = np.linalg.norm(points[start + 1] - points[start]) - max_idx = start + 1 - else: - max_dist = 0 - max_idx = start - - for i in range(start + 1, end): - # Distance from point to line - p_vec = points[i] - points[start] - - # Project p_vec onto line_vec - proj_scalar = np.dot(p_vec, line_vec) / (line_length * line_length) - proj = points[start] + proj_scalar * line_vec - - # Calculate perpendicular distance - dist = np.linalg.norm(points[i] - proj) - - if dist > max_dist: - max_dist = dist - max_idx = i - - # Recursive call - result = [] - if max_dist > epsilon: - result_left = rdp(points, epsilon, start, max_idx) - result_right = rdp(points, epsilon, max_idx, end) - result = result_left + result_right[1:] - else: - result = [start, end] - - return result - - indices = rdp(self._points, tolerance, 0, len(self._points) - 1) - indices.append(len(self._points) - 1) # Make sure the last point is included - indices = sorted(set(indices)) # Remove duplicates and sort - - return self.__class__(self._points[indices]) - - def smooth(self: T, weight: float = 0.5, iterations: int = 1) -> T: - """Smooth the path using a moving average filter. - - Args: - weight: How much to weight the neighboring points (0-1) - iterations: Number of smoothing passes to apply - - Returns: - A new smoothed Path object - """ - if len(self._points) <= 2 or weight <= 0 or iterations <= 0: - return self.__class__(self._points.copy()) - - smoothed_points = self._points.copy() - - for _ in range(iterations): - new_points = np.zeros_like(smoothed_points) - new_points[0] = smoothed_points[0] # Keep first point unchanged - - # Apply weighted average to middle points - for i in range(1, len(smoothed_points) - 1): - neighbor_avg = 0.5 * (smoothed_points[i - 1] + smoothed_points[i + 1]) - new_points[i] = (1 - weight) * smoothed_points[i] + weight * neighbor_avg - - new_points[-1] = smoothed_points[-1] # Keep last point unchanged - smoothed_points = new_points - - return self.__class__(smoothed_points) - - def nearest_point_index(self, point: Union[Vector, np.ndarray, Tuple]) -> int: - """Find the index of the closest point on the path to the given point. - - Args: - point: The reference point - - Returns: - Index of the closest point on the path - """ - if len(self._points) == 0: - raise ValueError("Cannot find nearest point in an empty path") - - if isinstance(point, Vector): - point_data = point.data - else: - point_data = np.array(point, dtype=float) - - # Calculate squared distances to all points - diff = self._points - point_data - sq_distances = np.sum(diff * diff, axis=1) - - # Return index of minimum distance - return int(np.argmin(sq_distances)) - - def reverse(self: T) -> T: - """Reverse the path direction. - - Returns: - A new Path object with points in reverse order - """ - return self.__class__(self._points[::-1].copy()) - - def __len__(self) -> int: - """Return the number of points in the path.""" - return len(self._points) - - def __getitem__(self, idx) -> Union[np.ndarray, "Path"]: - """Get a point or slice of points from the path.""" - if isinstance(idx, slice): - return self.__class__(self._points[idx]) - return self._points[idx].copy() - - def get_vector(self, idx: int) -> Vector: - """Get a point at the given index as a Vector object.""" - return Vector(self._points[idx]) - - def last(self) -> Vector: - """Get the first point in the path as a Vector object.""" - if len(self._points) == 0: - return None - return Vector(self._points[-1]) - - def head(self) -> Vector: - """Get the first point in the path as a Vector object.""" - if len(self._points) == 0: - return None - return Vector(self._points[0]) - - def tail(self) -> "Path": - """Get all points except the first point as a new Path object.""" - if len(self._points) <= 1: - return None - return self.__class__(self._points[1:].copy()) - - def __iter__(self) -> Iterator[np.ndarray]: - """Iterate over the points in the path.""" - for point in self._points: - yield point.copy() - - def __repr__(self) -> str: - """String representation of the path.""" - return f"↶ Path ({len(self._points)} Points)" - - def ipush(self, point: Union[Vector, np.ndarray, Tuple]) -> "Path": - """Return a new Path with `point` appended.""" - if isinstance(point, Vector): - p = point.data - else: - p = np.asarray(point, dtype=float) - - if len(self._points) == 0: - new_pts = p.reshape(1, -1) - else: - new_pts = np.vstack((self._points, p)) - return self.__class__(new_pts) - - def iclip_tail(self, max_len: int) -> "Path": - """Return a new Path containing only the last `max_len` points.""" - if max_len < 0: - raise ValueError("max_len must be ≥ 0") - return self.__class__(self._points[-max_len:]) - - def __add__(self, point): - """path + vec -> path.pushed(vec)""" - return self.pushed(point) - - -if __name__ == "__main__": - # Test vectors in various directions - print( - Path( - [ - Vector(1, 0), # Right - Vector(1, 1), # Up-Right - Vector(0, 1), # Up - Vector(-1, 1), # Up-Left - Vector(-1, 0), # Left - Vector(-1, -1), # Down-Left - Vector(0, -1), # Down - Vector(1, -1), # Down-Right - Vector(0.5, 0.5), # Up-Right (shorter) - Vector(-3, 4), # Up-Left (longer) - ] - ) - ) - - print(Path()) diff --git a/build/lib/dimos/types/pose.py b/build/lib/dimos/types/pose.py deleted file mode 100644 index 455f22c189..0000000000 --- a/build/lib/dimos/types/pose.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TypeVar, Union, Sequence -import numpy as np -from plum import dispatch -import math - -from dimos.types.vector import Vector, to_vector, to_numpy, VectorLike - - -T = TypeVar("T", bound="Pose") - -PoseLike = Union["Pose", VectorLike, Sequence[VectorLike]] - - -def yaw_to_matrix(yaw: float) -> np.ndarray: - """2-D yaw (about Z) to a 3×3 rotation matrix.""" - c, s = math.cos(yaw), math.sin(yaw) - return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) - - -class Pose(Vector): - """A pose in 3D space, consisting of a position vector and a rotation vector. - - Pose inherits from Vector and behaves like a vector for the position component. - The rotation vector is stored separately and can be accessed via the rot property. - """ - - _rot: Vector = None - - @dispatch - def __init__(self, pos: VectorLike): - self._data = to_numpy(pos) - self._rot = None - - @dispatch - def __init__(self, pos: VectorLike, rot: VectorLike): - self._data = to_numpy(pos) - self._rot = to_vector(rot) - - def __repr__(self) -> str: - return f"Pose({self.pos.__repr__()}, {self.rot.__repr__()})" - - def __str__(self) -> str: - return self.__repr__() - - def is_zero(self) -> bool: - return super().is_zero() and self.rot.is_zero() - - def __bool__(self) -> bool: - return not self.is_zero() - - def serialize(self): - """Serialize the pose to a dictionary.""" - return {"type": "pose", "pos": self.to_list(), "rot": self.rot.to_list()} - - def vector_to(self, target: Vector) -> Vector: - direction = target - self.pos.to_2d() - - cos_y = math.cos(-self.yaw) - sin_y = math.sin(-self.yaw) - - x = cos_y * direction.x - sin_y * direction.y - y = sin_y * direction.x + cos_y * direction.y - - return Vector(x, y) - - def __eq__(self, other) -> bool: - """Check if two poses are equal using numpy's allclose for floating point comparison.""" - if not isinstance(other, Pose): - return False - return np.allclose(self.pos._data, other.pos._data) and np.allclose( - self.rot._data, other.rot._data - ) - - @property - def rot(self) -> Vector: - if self._rot: - return self._rot - else: - return Vector(0, 0, 0) - - @property - def pos(self) -> Vector: - """Get the position vector (self).""" - return to_vector(self._data) - - def __add__(self: T, other) -> T: - """Override Vector's __add__ to handle Pose objects specially. - - When adding two Pose objects, both position and rotation components are added. - """ - if isinstance(other, Pose): - # Add both position and rotation components - result = super().__add__(other) - result._rot = self.rot + other.rot - return result - else: - # For other types, just use Vector's addition - return Pose(super().__add__(other), self.rot) - - @property - def yaw(self) -> float: - """Get the yaw (rotation around Z-axis) from the rotation vector.""" - return self.rot.z - - def __sub__(self: T, other) -> T: - """Override Vector's __sub__ to handle Pose objects specially. - - When subtracting two Pose objects, both position and rotation components are subtracted. - """ - if isinstance(other, Pose): - # Subtract both position and rotation components - result = super().__sub__(other) - result._rot = self.rot - other.rot - return result - else: - # For other types, just use Vector's subtraction - return super().__sub__(other) - - def __mul__(self: T, scalar: float) -> T: - return Pose(self.pos * scalar, self.rot) - - -@dispatch -def to_pose(pos: Pose) -> Pose: - return pos - - -@dispatch -def to_pose(pos: VectorLike) -> Pose: - return Pose(pos) - - -@dispatch -def to_pose(pos: Sequence[VectorLike]) -> Pose: - return Pose(*pos) diff --git a/build/lib/dimos/types/robot_capabilities.py b/build/lib/dimos/types/robot_capabilities.py deleted file mode 100644 index 8c9a7fcd41..0000000000 --- a/build/lib/dimos/types/robot_capabilities.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Robot capabilities module for defining robot functionality.""" - -from enum import Enum, auto - - -class RobotCapability(Enum): - """Enum defining possible robot capabilities.""" - - MANIPULATION = auto() - VISION = auto() - AUDIO = auto() - SPEECH = auto() - LOCOMOTION = auto() diff --git a/build/lib/dimos/types/robot_location.py b/build/lib/dimos/types/robot_location.py deleted file mode 100644 index c69d131a04..0000000000 --- a/build/lib/dimos/types/robot_location.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -RobotLocation type definition for storing and managing robot location data. -""" - -from dataclasses import dataclass, field -from typing import Dict, Any, Optional, Tuple -import time -import uuid - - -@dataclass -class RobotLocation: - """ - Represents a named location in the robot's spatial memory. - - This class stores the position, rotation, and descriptive metadata for - locations that the robot can remember and navigate to. - - Attributes: - name: Human-readable name of the location (e.g., "kitchen", "office") - position: 3D position coordinates (x, y, z) - rotation: 3D rotation angles in radians (roll, pitch, yaw) - frame_id: ID of the associated video frame if available - timestamp: Time when the location was recorded - location_id: Unique identifier for this location - metadata: Additional metadata for the location - """ - - name: str - position: Tuple[float, float, float] - rotation: Tuple[float, float, float] - frame_id: Optional[str] = None - timestamp: float = field(default_factory=time.time) - location_id: str = field(default_factory=lambda: f"loc_{uuid.uuid4().hex[:8]}") - metadata: Dict[str, Any] = field(default_factory=dict) - - def __post_init__(self): - """Validate and normalize the position and rotation tuples.""" - # Ensure position is a tuple of 3 floats - if len(self.position) == 2: - self.position = (self.position[0], self.position[1], 0.0) - else: - self.position = tuple(float(x) for x in self.position) - - # Ensure rotation is a tuple of 3 floats - if len(self.rotation) == 1: - self.rotation = (0.0, 0.0, self.rotation[0]) - else: - self.rotation = tuple(float(x) for x in self.rotation) - - def to_vector_metadata(self) -> Dict[str, Any]: - """ - Convert the location to metadata format for storing in a vector database. - - Returns: - Dictionary with metadata fields compatible with vector DB storage - """ - return { - "pos_x": float(self.position[0]), - "pos_y": float(self.position[1]), - "pos_z": float(self.position[2]), - "rot_x": float(self.rotation[0]), - "rot_y": float(self.rotation[1]), - "rot_z": float(self.rotation[2]), - "timestamp": self.timestamp, - "location_id": self.location_id, - "frame_id": self.frame_id, - "location_name": self.name, - "description": self.name, # Makes it searchable by text - } - - @classmethod - def from_vector_metadata(cls, metadata: Dict[str, Any]) -> "RobotLocation": - """ - Create a RobotLocation object from vector database metadata. - - Args: - metadata: Dictionary with metadata from vector database - - Returns: - RobotLocation object - """ - return cls( - name=metadata.get("location_name", "unknown"), - position=( - metadata.get("pos_x", 0.0), - metadata.get("pos_y", 0.0), - metadata.get("pos_z", 0.0), - ), - rotation=( - metadata.get("rot_x", 0.0), - metadata.get("rot_y", 0.0), - metadata.get("rot_z", 0.0), - ), - frame_id=metadata.get("frame_id"), - timestamp=metadata.get("timestamp", time.time()), - location_id=metadata.get("location_id", f"loc_{uuid.uuid4().hex[:8]}"), - metadata={ - k: v - for k, v in metadata.items() - if k - not in [ - "pos_x", - "pos_y", - "pos_z", - "rot_x", - "rot_y", - "rot_z", - "timestamp", - "location_id", - "frame_id", - "location_name", - "description", - ] - }, - ) diff --git a/build/lib/dimos/types/ros_polyfill.py b/build/lib/dimos/types/ros_polyfill.py deleted file mode 100644 index b5c2bc1d64..0000000000 --- a/build/lib/dimos/types/ros_polyfill.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -try: - from geometry_msgs.msg import Vector3 -except ImportError: - - class Vector3: - def __init__(self, x: float = 0.0, y: float = 0.0, z: float = 0.0): - self.x = float(x) - self.y = float(y) - self.z = float(z) - - def __repr__(self) -> str: - return f"Vector3(x={self.x}, y={self.y}, z={self.z})" - - -try: - from nav_msgs.msg import OccupancyGrid, Odometry - from geometry_msgs.msg import Pose, Point, Quaternion, Twist - from std_msgs.msg import Header -except ImportError: - - class Header: - def __init__(self): - self.stamp = None - self.frame_id = "" - - class Point: - def __init__(self, x: float = 0.0, y: float = 0.0, z: float = 0.0): - self.x = float(x) - self.y = float(y) - self.z = float(z) - - def __repr__(self) -> str: - return f"Point(x={self.x}, y={self.y}, z={self.z})" - - class Quaternion: - def __init__(self, x: float = 0.0, y: float = 0.0, z: float = 0.0, w: float = 1.0): - self.x = float(x) - self.y = float(y) - self.z = float(z) - self.w = float(w) - - def __repr__(self) -> str: - return f"Quaternion(x={self.x}, y={self.y}, z={self.z}, w={self.w})" - - class Pose: - def __init__(self): - self.position = Point() - self.orientation = Quaternion() - - def __repr__(self) -> str: - return f"Pose(position={self.position}, orientation={self.orientation})" - - class MapMetaData: - def __init__(self): - self.map_load_time = None - self.resolution = 0.05 - self.width = 0 - self.height = 0 - self.origin = Pose() - - def __repr__(self) -> str: - return f"MapMetaData(resolution={self.resolution}, width={self.width}, height={self.height}, origin={self.origin})" - - class Twist: - def __init__(self): - self.linear = Vector3() - self.angular = Vector3() - - def __repr__(self) -> str: - return f"Twist(linear={self.linear}, angular={self.angular})" - - class OccupancyGrid: - def __init__(self): - self.header = Header() - self.info = MapMetaData() - self.data = [] - - def __repr__(self) -> str: - return f"OccupancyGrid(info={self.info}, data_length={len(self.data)})" - - class Odometry: - def __init__(self): - self.header = Header() - self.child_frame_id = "" - self.pose = Pose() - self.twist = Twist() - - def __repr__(self) -> str: - return f"Odometry(pose={self.pose}, twist={self.twist})" diff --git a/build/lib/dimos/types/sample.py b/build/lib/dimos/types/sample.py deleted file mode 100644 index 5665f7a640..0000000000 --- a/build/lib/dimos/types/sample.py +++ /dev/null @@ -1,572 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -from collections import OrderedDict -from enum import Enum -from pathlib import Path -from typing import Any, Dict, List, Literal, Sequence, Union, get_origin - -import numpy as np -from datasets import Dataset -from gymnasium import spaces -from jsonref import replace_refs -from pydantic import BaseModel, ConfigDict, ValidationError -from pydantic.fields import FieldInfo -from pydantic_core import from_json -from typing_extensions import Annotated - -from mbodied.data.utils import to_features -from mbodied.utils.import_utils import smart_import - -Flattenable = Annotated[Literal["dict", "np", "pt", "list"], "Numpy, PyTorch, list, or dict"] - - -class Sample(BaseModel): - """A base model class for serializing, recording, and manipulating arbitray data. - - It was designed to be extensible, flexible, yet strongly typed. In addition to - supporting any json API out of the box, it can be used to represent - arbitrary action and observation spaces in robotics and integrates seemlessly with H5, Gym, Arrow, - PyTorch, DSPY, numpy, and HuggingFace. - - Methods: - schema: Get a simplified json schema of your data. - to: Convert the Sample instance to a different container type: - - - default_value: Get the default value for the Sample instance. - unflatten: Unflatten a one-dimensional array or dictionary into a Sample instance. - flatten: Flatten the Sample instance into a one-dimensional array or dictionary. - space_for: Default Gym space generation for a given value. - init_from: Initialize a Sample instance from a given value. - from_space: Generate a Sample instance from a Gym space. - pack_from: Pack a list of samples into a single sample with lists for attributes. - unpack: Unpack the packed Sample object into a list of Sample objects or dictionaries. - dict: Return the Sample object as a dictionary with None values excluded. - model_field_info: Get the FieldInfo for a given attribute key. - space: Return the corresponding Gym space for the Sample instance based on its instance attributes. - random_sample: Generate a random Sample instance based on its instance attributes. - - Examples: - >>> sample = Sample(x=1, y=2, z={"a": 3, "b": 4}, extra_field=5) - >>> flat_list = sample.flatten() - >>> print(flat_list) - [1, 2, 3, 4, 5] - >>> schema = sample.schema() - {'type': 'object', 'properties': {'x': {'type': 'number'}, 'y': {'type': 'number'}, 'z': {'type': 'object', 'properties': {'a': {'type': 'number'}, 'b': {'type': 'number'}}}, 'extra_field': {'type': 'number'}}} - >>> unflattened_sample = Sample.unflatten(flat_list, schema) - >>> print(unflattened_sample) - Sample(x=1, y=2, z={'a': 3, 'b': 4}, extra_field=5) - """ - - __doc__ = "A base model class for serializing, recording, and manipulating arbitray data." - - model_config: ConfigDict = ConfigDict( - use_enum_values=False, - from_attributes=True, - validate_assignment=False, - extra="allow", - arbitrary_types_allowed=True, - ) - - def __init__(self, datum=None, **data): - """Accepts an arbitrary datum as well as keyword arguments.""" - if datum is not None: - if isinstance(datum, Sample): - data.update(datum.dict()) - elif isinstance(datum, dict): - data.update(datum) - else: - data["datum"] = datum - super().__init__(**data) - - def __hash__(self) -> int: - """Return a hash of the Sample instance.""" - return hash(tuple(self.dict().values())) - - def __str__(self) -> str: - """Return a string representation of the Sample instance.""" - return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.dict().items() if v is not None])})" - - def dict(self, exclude_none=True, exclude: set[str] = None) -> Dict[str, Any]: - """Return the Sample object as a dictionary with None values excluded. - - Args: - exclude_none (bool, optional): Whether to exclude None values. Defaults to True. - exclude (set[str], optional): Set of attribute names to exclude. Defaults to None. - - Returns: - Dict[str, Any]: Dictionary representation of the Sample object. - """ - return self.model_dump(exclude_none=exclude_none, exclude=exclude) - - @classmethod - def unflatten(cls, one_d_array_or_dict, schema=None) -> "Sample": - """Unflatten a one-dimensional array or dictionary into a Sample instance. - - If a dictionary is provided, its keys are ignored. - - Args: - one_d_array_or_dict: A one-dimensional array or dictionary to unflatten. - schema: A dictionary representing the JSON schema. Defaults to using the class's schema. - - Returns: - Sample: The unflattened Sample instance. - - Examples: - >>> sample = Sample(x=1, y=2, z={"a": 3, "b": 4}, extra_field=5) - >>> flat_list = sample.flatten() - >>> print(flat_list) - [1, 2, 3, 4, 5] - >>> Sample.unflatten(flat_list, sample.schema()) - Sample(x=1, y=2, z={'a': 3, 'b': 4}, extra_field=5) - """ - if schema is None: - schema = cls().schema() - - # Convert input to list if it's not already - if isinstance(one_d_array_or_dict, dict): - flat_data = list(one_d_array_or_dict.values()) - else: - flat_data = list(one_d_array_or_dict) - - def unflatten_recursive(schema_part, index=0): - if schema_part["type"] == "object": - result = {} - for prop, prop_schema in schema_part["properties"].items(): - value, index = unflatten_recursive(prop_schema, index) - result[prop] = value - return result, index - elif schema_part["type"] == "array": - items = [] - for _ in range(schema_part.get("maxItems", len(flat_data) - index)): - value, index = unflatten_recursive(schema_part["items"], index) - items.append(value) - return items, index - else: # Assuming it's a primitive type - return flat_data[index], index + 1 - - unflattened_dict, _ = unflatten_recursive(schema) - return cls(**unflattened_dict) - - def flatten( - self, - output_type: Flattenable = "dict", - non_numerical: Literal["ignore", "forbid", "allow"] = "allow", - ) -> Dict[str, Any] | np.ndarray | "torch.Tensor" | List: - accumulator = {} if output_type == "dict" else [] - - def flatten_recursive(obj, path=""): - if isinstance(obj, Sample): - for k, v in obj.dict().items(): - flatten_recursive(v, path + k + "/") - elif isinstance(obj, dict): - for k, v in obj.items(): - flatten_recursive(v, path + k + "/") - elif isinstance(obj, list | tuple): - for i, item in enumerate(obj): - flatten_recursive(item, path + str(i) + "/") - elif hasattr(obj, "__len__") and not isinstance(obj, str): - flat_list = obj.flatten().tolist() - if output_type == "dict": - # Convert to list for dict storage - accumulator[path[:-1]] = flat_list - else: - accumulator.extend(flat_list) - else: - if non_numerical == "ignore" and not isinstance(obj, int | float | bool): - return - final_key = path[:-1] # Remove trailing slash - if output_type == "dict": - accumulator[final_key] = obj - else: - accumulator.append(obj) - - flatten_recursive(self) - accumulator = accumulator.values() if output_type == "dict" else accumulator - if non_numerical == "forbid" and any( - not isinstance(v, int | float | bool) for v in accumulator - ): - raise ValueError("Non-numerical values found in flattened data.") - if output_type == "np": - return np.array(accumulator) - if output_type == "pt": - torch = smart_import("torch") - return torch.tensor(accumulator) - return accumulator - - @staticmethod - def obj_to_schema(value: Any) -> Dict: - """Generates a simplified JSON schema from a dictionary. - - Args: - value (Any): An object to generate a schema for. - - Returns: - dict: A simplified JSON schema representing the structure of the dictionary. - """ - if isinstance(value, dict): - return { - "type": "object", - "properties": {k: Sample.obj_to_schema(v) for k, v in value.items()}, - } - if isinstance(value, list | tuple | np.ndarray): - if len(value) > 0: - return {"type": "array", "items": Sample.obj_to_schema(value[0])} - return {"type": "array", "items": {}} - if isinstance(value, str): - return {"type": "string"} - if isinstance(value, int | np.integer): - return {"type": "integer"} - if isinstance(value, float | np.floating): - return {"type": "number"} - if isinstance(value, bool): - return {"type": "boolean"} - return {} - - def schema(self, resolve_refs: bool = True, include_descriptions=False) -> Dict: - """Returns a simplified json schema. - - Removing additionalProperties, - selecting the first type in anyOf, and converting numpy schema to the desired type. - Optionally resolves references. - - Args: - resolve_refs (bool): Whether to resolve references in the schema. Defaults to True. - include_descriptions (bool): Whether to include descriptions in the schema. Defaults to False. - - Returns: - dict: A simplified JSON schema. - """ - schema = self.model_json_schema() - if "additionalProperties" in schema: - del schema["additionalProperties"] - - if resolve_refs: - schema = replace_refs(schema) - - if not include_descriptions and "description" in schema: - del schema["description"] - - properties = schema.get("properties", {}) - for key, value in self.dict().items(): - if key not in properties: - properties[key] = Sample.obj_to_schema(value) - if isinstance(value, Sample): - properties[key] = value.schema( - resolve_refs=resolve_refs, include_descriptions=include_descriptions - ) - else: - properties[key] = Sample.obj_to_schema(value) - return schema - - @classmethod - def read(cls, data: Any) -> "Sample": - """Read a Sample instance from a JSON string or dictionary or path. - - Args: - data (Any): The JSON string or dictionary to read. - - Returns: - Sample: The read Sample instance. - """ - if isinstance(data, str): - try: - data = cls.model_validate(from_json(data)) - except Exception as e: - logging.info(f"Error reading data: {e}. Attempting to read as JSON.") - if isinstance(data, str): - if Path(data).exists(): - if hasattr(cls, "open"): - data = cls.open(data) - else: - data = Path(data).read_text() - data = json.loads(data) - else: - data = json.load(data) - - if isinstance(data, dict): - return cls(**data) - return cls(data) - - def to(self, container: Any) -> Any: - """Convert the Sample instance to a different container type. - - Args: - container (Any): The container type to convert to. Supported types are - 'dict', 'list', 'np', 'pt' (pytorch), 'space' (gym.space), - 'schema', 'json', 'hf' (datasets.Dataset) and any subtype of Sample. - - Returns: - Any: The converted container. - """ - if isinstance(container, Sample) and not issubclass(container, Sample): - return container(**self.dict()) - if isinstance(container, type) and issubclass(container, Sample): - return container.unflatten(self.flatten()) - - if container == "dict": - return self.dict() - if container == "list": - return self.flatten(output_type="list") - if container == "np": - return self.flatten(output_type="np") - if container == "pt": - return self.flatten(output_type="pt") - if container == "space": - return self.space() - if container == "schema": - return self.schema() - if container == "json": - return self.model_dump_json() - if container == "hf": - return Dataset.from_dict(self.dict()) - if container == "features": - return to_features(self.dict()) - raise ValueError(f"Unsupported container type: {container}") - - @classmethod - def default_value(cls) -> "Sample": - """Get the default value for the Sample instance. - - Returns: - Sample: The default value for the Sample instance. - """ - return cls() - - @classmethod - def space_for( - cls, - value: Any, - max_text_length: int = 1000, - info: Annotated = None, - ) -> spaces.Space: - """Default Gym space generation for a given value. - - Only used for subclasses that do not override the space method. - """ - if isinstance(value, Enum) or get_origin(value) == Literal: - return spaces.Discrete(len(value.__args__)) - if isinstance(value, bool): - return spaces.Discrete(2) - if isinstance(value, dict | Sample): - if isinstance(value, Sample): - value = value.dict() - return spaces.Dict( - {k: Sample.space_for(v, max_text_length, info) for k, v in value.items()}, - ) - if isinstance(value, str): - return spaces.Text(max_length=max_text_length) - if isinstance(value, int | float | list | tuple | np.ndarray): - shape = None - le = None - ge = None - dtype = None - if info is not None: - shape = info.metadata_lookup.get("shape") - le = info.metadata_lookup.get("le") - ge = info.metadata_lookup.get("ge") - dtype = info.metadata_lookup.get("dtype") - logging.debug( - "Generating space for value: %s, shape: %s, le: %s, ge: %s, dtype: %s", - value, - shape, - le, - ge, - dtype, - ) - try: - value = np.asfarray(value) - shape = shape or value.shape - dtype = dtype or value.dtype - le = le or -np.inf - ge = ge or np.inf - return spaces.Box(low=le, high=ge, shape=shape, dtype=dtype) - except Exception as e: - logging.info(f"Could not convert value {value} to numpy array: {e}") - if len(value) > 0 and isinstance(value[0], dict | Sample): - return spaces.Tuple( - [spaces.Dict(cls.space_for(v, max_text_length, info)) for v in value], - ) - return spaces.Tuple( - [cls.space_for(value[0], max_text_length, info) for value in value[:1]], - ) - raise ValueError(f"Unsupported object {value} of type: {type(value)} for space generation") - - @classmethod - def init_from(cls, d: Any, pack=False) -> "Sample": - if isinstance(d, spaces.Space): - return cls.from_space(d) - if isinstance(d, Union[Sequence, np.ndarray]): # noqa: UP007 - if pack: - return cls.pack_from(d) - return cls.unflatten(d) - if isinstance(d, dict): - try: - return cls.model_validate(d) - except ValidationError as e: - logging.info(f" Unable to validate {d} as {cls} {e}. Attempting to unflatten.") - - try: - return cls.unflatten(d) - except Exception as e: - logging.info(f" Unable to unflatten {d} as {cls} {e}. Attempting to read.") - return cls.read(d) - return cls(d) - - @classmethod - def from_flat_dict(cls, flat_dict: Dict[str, Any], schema: Dict = None) -> "Sample": - """Initialize a Sample instance from a flattened dictionary.""" - """ - Reconstructs the original JSON object from a flattened dictionary using the provided schema. - - Args: - flat_dict (dict): A flattened dictionary with keys like "key1.nestedkey1". - schema (dict): A dictionary representing the JSON schema. - - Returns: - dict: The reconstructed JSON object. - """ - schema = schema or replace_refs(cls.model_json_schema()) - reconstructed = {} - - for flat_key, value in flat_dict.items(): - keys = flat_key.split(".") - current = reconstructed - for key in keys[:-1]: - if key not in current: - current[key] = {} - current = current[key] - current[keys[-1]] = value - - return reconstructed - - @classmethod - def from_space(cls, space: spaces.Space) -> "Sample": - """Generate a Sample instance from a Gym space.""" - sampled = space.sample() - if isinstance(sampled, dict | OrderedDict): - return cls(**sampled) - if hasattr(sampled, "__len__") and not isinstance(sampled, str): - sampled = np.asarray(sampled) - if len(sampled.shape) > 0 and isinstance(sampled[0], dict | Sample): - return cls.pack_from(sampled) - return cls(sampled) - - @classmethod - def pack_from(cls, samples: List[Union["Sample", Dict]]) -> "Sample": - """Pack a list of samples into a single sample with lists for attributes. - - Args: - samples (List[Union[Sample, Dict]]): List of samples or dictionaries. - - Returns: - Sample: Packed sample with lists for attributes. - """ - if samples is None or len(samples) == 0: - return cls() - - first_sample = samples[0] - if isinstance(first_sample, dict): - attributes = list(first_sample.keys()) - elif hasattr(first_sample, "__dict__"): - attributes = list(first_sample.__dict__.keys()) - else: - attributes = ["item" + str(i) for i in range(len(samples))] - - aggregated = {attr: [] for attr in attributes} - for sample in samples: - for attr in attributes: - # Handle both Sample instances and dictionaries - if isinstance(sample, dict): - aggregated[attr].append(sample.get(attr, None)) - else: - aggregated[attr].append(getattr(sample, attr, None)) - return cls(**aggregated) - - def unpack(self, to_dicts=False) -> List[Union["Sample", Dict]]: - """Unpack the packed Sample object into a list of Sample objects or dictionaries.""" - attributes = list(self.model_extra.keys()) + list(self.model_fields.keys()) - attributes = [attr for attr in attributes if getattr(self, attr) is not None] - if not attributes or getattr(self, attributes[0]) is None: - return [] - - # Ensure all attributes are lists and have the same length - list_sizes = { - len(getattr(self, attr)) for attr in attributes if isinstance(getattr(self, attr), list) - } - if len(list_sizes) != 1: - raise ValueError("Not all attribute lists have the same length.") - list_size = list_sizes.pop() - - if to_dicts: - return [{key: getattr(self, key)[i] for key in attributes} for i in range(list_size)] - - return [ - self.__class__(**{key: getattr(self, key)[i] for key in attributes}) - for i in range(list_size) - ] - - @classmethod - def default_space(cls) -> spaces.Dict: - """Return the Gym space for the Sample class based on its class attributes.""" - return cls().space() - - @classmethod - def default_sample(cls, output_type="Sample") -> Union["Sample", Dict[str, Any]]: - """Generate a default Sample instance from its class attributes. Useful for padding. - - This is the "no-op" instance and should be overriden as needed. - """ - if output_type == "Sample": - return cls() - return cls().dict() - - def model_field_info(self, key: str) -> FieldInfo: - """Get the FieldInfo for a given attribute key.""" - if self.model_extra and self.model_extra.get(key) is not None: - info = FieldInfo(metadata=self.model_extra[key]) - if self.model_fields.get(key) is not None: - info = FieldInfo(metadata=self.model_fields[key]) - - if info and hasattr(info, "annotation"): - return info.annotation - return None - - def space(self) -> spaces.Dict: - """Return the corresponding Gym space for the Sample instance based on its instance attributes. Omits None values. - - Override this method in subclasses to customize the space generation. - """ - space_dict = {} - for key, value in self.dict().items(): - logging.debug("Generating space for key: '%s', value: %s", key, value) - info = self.model_field_info(key) - value = getattr(self, key) if hasattr(self, key) else value # noqa: PLW2901 - space_dict[key] = ( - value.space() if isinstance(value, Sample) else self.space_for(value, info=info) - ) - return spaces.Dict(space_dict) - - def random_sample(self) -> "Sample": - """Generate a random Sample instance based on its instance attributes. Omits None values. - - Override this method in subclasses to customize the sample generation. - """ - return self.__class__.model_validate(self.space().sample()) - - -if __name__ == "__main__": - sample = Sample(x=1, y=2, z={"a": 3, "b": 4}, extra_field=5) diff --git a/build/lib/dimos/types/segmentation.py b/build/lib/dimos/types/segmentation.py deleted file mode 100644 index 5995f302f9..0000000000 --- a/build/lib/dimos/types/segmentation.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Any -import numpy as np - - -class SegmentationType: - def __init__(self, masks: List[np.ndarray], metadata: Any = None): - """ - Initializes a standardized segmentation type. - - Args: - masks (List[np.ndarray]): A list of binary masks for segmentation. - metadata (Any, optional): Additional metadata related to the segmentations. - """ - self.masks = masks - self.metadata = metadata - - def combine_masks(self): - """Combine all masks into a single mask.""" - combined_mask = np.zeros_like(self.masks[0]) - for mask in self.masks: - combined_mask = np.logical_or(combined_mask, mask) - return combined_mask - - def save_masks(self, directory: str): - """Save each mask to a separate file.""" - import os - - os.makedirs(directory, exist_ok=True) - for i, mask in enumerate(self.masks): - np.save(os.path.join(directory, f"mask_{i}.npy"), mask) diff --git a/build/lib/dimos/types/test_pose.py b/build/lib/dimos/types/test_pose.py deleted file mode 100644 index e95133e035..0000000000 --- a/build/lib/dimos/types/test_pose.py +++ /dev/null @@ -1,323 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import math -from dimos.types.pose import Pose, to_pose -from dimos.types.vector import Vector - - -def test_pose_default_init(): - """Test that default initialization of Pose() has zero vectors for pos and rot.""" - pose = Pose() - - # Check that pos is a zero vector - assert isinstance(pose.pos, Vector) - assert pose.pos.is_zero() - assert pose.pos.x == 0.0 - assert pose.pos.y == 0.0 - assert pose.pos.z == 0.0 - - # Check that rot is a zero vector - assert isinstance(pose.rot, Vector) - assert pose.rot.is_zero() - assert pose.rot.x == 0.0 - assert pose.rot.y == 0.0 - assert pose.rot.z == 0.0 - - assert pose.is_zero() - - assert not pose - - -def test_pose_vector_init(): - """Test initialization with custom vectors.""" - pos = Vector(1.0, 2.0, 3.0) - rot = Vector(4.0, 5.0, 6.0) - - pose = Pose(pos, rot) - - # Check pos vector - assert pose.pos == pos - assert pose.pos.x == 1.0 - assert pose.pos.y == 2.0 - assert pose.pos.z == 3.0 - - # Check rot vector - assert pose.rot == rot - assert pose.rot.x == 4.0 - assert pose.rot.y == 5.0 - assert pose.rot.z == 6.0 - - # even if pos has the same xyz as pos vector - # it shouldn't accept equality comparisons - # as both are not the same type - assert not pose == pos - - -def test_pose_partial_init(): - """Test initialization with only one custom vector.""" - pos = Vector(1.0, 2.0, 3.0) - assert pos - - # Only specify pos - pose1 = Pose(pos) - assert pose1.pos == pos - assert pose1.pos.x == 1.0 - assert pose1.pos.y == 2.0 - assert pose1.pos.z == 3.0 - assert not pose1.pos.is_zero() - - assert isinstance(pose1.rot, Vector) - assert pose1.rot.is_zero() - assert pose1.rot.x == 0.0 - assert pose1.rot.y == 0.0 - assert pose1.rot.z == 0.0 - - -def test_pose_equality(): - """Test equality comparison between positions.""" - pos1 = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) - pos2 = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) - pos3 = Pose(Vector(1.0, 2.0, 3.0), Vector(7.0, 8.0, 9.0)) - pos4 = Pose(Vector(7.0, 8.0, 9.0), Vector(4.0, 5.0, 6.0)) - - # Same pos and rot values should be equal - assert pos1 == pos2 - - # Different rot values should not be equal - assert pos1 != pos3 - - # Different pos values should not be equal - assert pos1 != pos4 - - # Pose should not equal a vector even if values match - assert pos1 != Vector(1.0, 2.0, 3.0) - - -def test_pose_vector_operations(): - """Test that Pose inherits Vector operations.""" - pos1 = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) - pos2 = Pose(Vector(2.0, 3.0, 4.0), Vector(7.0, 8.0, 9.0)) - - # Addition should work on both position and rotation components - sum_pos = pos1 + pos2 - assert isinstance(sum_pos, Pose) - assert sum_pos.x == 3.0 - assert sum_pos.y == 5.0 - assert sum_pos.z == 7.0 - # Rotation should be added as well - assert sum_pos.rot.x == 11.0 # 4.0 + 7.0 - assert sum_pos.rot.y == 13.0 # 5.0 + 8.0 - assert sum_pos.rot.z == 15.0 # 6.0 + 9.0 - - # Subtraction should work on both position and rotation components - diff_pos = pos2 - pos1 - assert isinstance(diff_pos, Pose) - assert diff_pos.x == 1.0 - assert diff_pos.y == 1.0 - assert diff_pos.z == 1.0 - # Rotation should be subtracted as well - assert diff_pos.rot.x == 3.0 # 7.0 - 4.0 - assert diff_pos.rot.y == 3.0 # 8.0 - 5.0 - assert diff_pos.rot.z == 3.0 # 9.0 - 6.0 - - # Scalar multiplication - scaled_pos = pos1 * 2.0 - assert isinstance(scaled_pos, Pose) - assert scaled_pos.x == 2.0 - assert scaled_pos.y == 4.0 - assert scaled_pos.z == 6.0 - assert scaled_pos.rot == pos1.rot # Rotation not affected by scalar multiplication - - # Adding a Vector to a Pose (only affects position component) - vec = Vector(5.0, 6.0, 7.0) - pos_plus_vec = pos1 + vec - assert isinstance(pos_plus_vec, Pose) - assert pos_plus_vec.x == 6.0 - assert pos_plus_vec.y == 8.0 - assert pos_plus_vec.z == 10.0 - assert pos_plus_vec.rot == pos1.rot # Rotation unchanged - - -def test_pose_serialization(): - """Test pose serialization.""" - pos = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) - serialized = pos.serialize() - - assert serialized["type"] == "pose" - assert serialized["pos"] == [1.0, 2.0, 3.0] - assert serialized["rot"] == [4.0, 5.0, 6.0] - - -def test_pose_initialization_with_arrays(): - """Test initialization with numpy arrays, lists and tuples.""" - # Test with numpy arrays - np_pos = np.array([1.0, 2.0, 3.0]) - np_rot = np.array([4.0, 5.0, 6.0]) - - pos1 = Pose(np_pos, np_rot) - - assert pos1.x == 1.0 - assert pos1.y == 2.0 - assert pos1.z == 3.0 - assert pos1.rot.x == 4.0 - assert pos1.rot.y == 5.0 - assert pos1.rot.z == 6.0 - - # Test with lists - list_pos = [7.0, 8.0, 9.0] - list_rot = [10.0, 11.0, 12.0] - pos2 = Pose(list_pos, list_rot) - - assert pos2.x == 7.0 - assert pos2.y == 8.0 - assert pos2.z == 9.0 - assert pos2.rot.x == 10.0 - assert pos2.rot.y == 11.0 - assert pos2.rot.z == 12.0 - - # Test with tuples - tuple_pos = (13.0, 14.0, 15.0) - tuple_rot = (16.0, 17.0, 18.0) - pos3 = Pose(tuple_pos, tuple_rot) - - assert pos3.x == 13.0 - assert pos3.y == 14.0 - assert pos3.z == 15.0 - assert pos3.rot.x == 16.0 - assert pos3.rot.y == 17.0 - assert pos3.rot.z == 18.0 - - -def test_to_pose_with_pose(): - """Test to_pose with Pose input.""" - # Create a pose - original_pos = Pose(Vector(1.0, 2.0, 3.0), Vector(4.0, 5.0, 6.0)) - - # Convert using to_pose - converted_pos = to_pose(original_pos) - - # Should return the exact same object - assert converted_pos is original_pos - assert converted_pos == original_pos - - # Check values - assert converted_pos.x == 1.0 - assert converted_pos.y == 2.0 - assert converted_pos.z == 3.0 - assert converted_pos.rot.x == 4.0 - assert converted_pos.rot.y == 5.0 - assert converted_pos.rot.z == 6.0 - - -def test_to_pose_with_vector(): - """Test to_pose with Vector input.""" - # Create a vector - vec = Vector(1.0, 2.0, 3.0) - - # Convert using to_pose - pos = to_pose(vec) - - # Should return a Pose with the vector as position and zero rotation - assert isinstance(pos, Pose) - assert pos.pos == vec - assert pos.x == 1.0 - assert pos.y == 2.0 - assert pos.z == 3.0 - - # Rotation should be zero - assert isinstance(pos.rot, Vector) - assert pos.rot.is_zero() - assert pos.rot.x == 0.0 - assert pos.rot.y == 0.0 - assert pos.rot.z == 0.0 - - -def test_to_pose_with_vectorlike(): - """Test to_pose with VectorLike inputs (arrays, lists, tuples).""" - # Test with numpy arrays - np_arr = np.array([1.0, 2.0, 3.0]) - pos1 = to_pose(np_arr) - - assert isinstance(pos1, Pose) - assert pos1.x == 1.0 - assert pos1.y == 2.0 - assert pos1.z == 3.0 - assert pos1.rot.is_zero() - - # Test with lists - list_val = [4.0, 5.0, 6.0] - pos2 = to_pose(list_val) - - assert isinstance(pos2, Pose) - assert pos2.x == 4.0 - assert pos2.y == 5.0 - assert pos2.z == 6.0 - assert pos2.rot.is_zero() - - # Test with tuples - tuple_val = (7.0, 8.0, 9.0) - pos3 = to_pose(tuple_val) - - assert isinstance(pos3, Pose) - assert pos3.x == 7.0 - assert pos3.y == 8.0 - assert pos3.z == 9.0 - assert pos3.rot.is_zero() - - -def test_to_pose_with_sequence(): - """Test to_pose with Sequence of VectorLike inputs.""" - # Test with sequence of two vectors - pos_vec = Vector(1.0, 2.0, 3.0) - rot_vec = Vector(4.0, 5.0, 6.0) - pos1 = to_pose([pos_vec, rot_vec]) - - assert isinstance(pos1, Pose) - assert pos1.pos == pos_vec - assert pos1.rot == rot_vec - assert pos1.x == 1.0 - assert pos1.y == 2.0 - assert pos1.z == 3.0 - assert pos1.rot.x == 4.0 - assert pos1.rot.y == 5.0 - assert pos1.rot.z == 6.0 - - # Test with sequence of lists - pos2 = to_pose([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]) - - assert isinstance(pos2, Pose) - assert pos2.x == 7.0 - assert pos2.y == 8.0 - assert pos2.z == 9.0 - assert pos2.rot.x == 10.0 - assert pos2.rot.y == 11.0 - assert pos2.rot.z == 12.0 - - # Test with mixed sequence (tuple and numpy array) - pos3 = to_pose([(13.0, 14.0, 15.0), np.array([16.0, 17.0, 18.0])]) - - assert isinstance(pos3, Pose) - assert pos3.x == 13.0 - assert pos3.y == 14.0 - assert pos3.z == 15.0 - assert pos3.rot.x == 16.0 - assert pos3.rot.y == 17.0 - assert pos3.rot.z == 18.0 - - -def test_vector_transform(): - robot_pose = Pose(Vector(4.0, 2.0, 0.5), Vector(0.0, 0.0, math.pi / 2)) - target = Vector(1.0, 3.0, 0.0) - print(robot_pose.vector_to(target)) diff --git a/build/lib/dimos/types/test_timestamped.py b/build/lib/dimos/types/test_timestamped.py deleted file mode 100644 index bf7962371e..0000000000 --- a/build/lib/dimos/types/test_timestamped.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime - -from dimos.types.timestamped import Timestamped - - -def test_timestamped_dt_method(): - ts = 1751075203.4120464 - timestamped = Timestamped(ts) - dt = timestamped.dt() - assert isinstance(dt, datetime) - assert abs(dt.timestamp() - ts) < 1e-6 - assert dt.tzinfo is not None, "datetime should be timezone-aware" diff --git a/build/lib/dimos/types/test_vector.py b/build/lib/dimos/types/test_vector.py deleted file mode 100644 index 6a93d37afd..0000000000 --- a/build/lib/dimos/types/test_vector.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import pytest - -from dimos.types.vector import Vector - - -def test_vector_default_init(): - """Test that default initialization of Vector() has x,y,z components all zero.""" - v = Vector() - assert v.x == 0.0 - assert v.y == 0.0 - assert v.z == 0.0 - assert v.dim == 0 - assert len(v.data) == 0 - assert v.to_list() == [] - assert v.is_zero() == True # Empty vector should be considered zero - - -def test_vector_specific_init(): - """Test initialization with specific values.""" - # 2D vector - v1 = Vector(1.0, 2.0) - assert v1.x == 1.0 - assert v1.y == 2.0 - assert v1.z == 0.0 - assert v1.dim == 2 - - # 3D vector - v2 = Vector(3.0, 4.0, 5.0) - assert v2.x == 3.0 - assert v2.y == 4.0 - assert v2.z == 5.0 - assert v2.dim == 3 - - # From list - v3 = Vector([6.0, 7.0, 8.0]) - assert v3.x == 6.0 - assert v3.y == 7.0 - assert v3.z == 8.0 - assert v3.dim == 3 - - # From numpy array - v4 = Vector(np.array([9.0, 10.0, 11.0])) - assert v4.x == 9.0 - assert v4.y == 10.0 - assert v4.z == 11.0 - assert v4.dim == 3 - - -def test_vector_addition(): - """Test vector addition.""" - v1 = Vector(1.0, 2.0, 3.0) - v2 = Vector(4.0, 5.0, 6.0) - - v_add = v1 + v2 - assert v_add.x == 5.0 - assert v_add.y == 7.0 - assert v_add.z == 9.0 - - -def test_vector_subtraction(): - """Test vector subtraction.""" - v1 = Vector(1.0, 2.0, 3.0) - v2 = Vector(4.0, 5.0, 6.0) - - v_sub = v2 - v1 - assert v_sub.x == 3.0 - assert v_sub.y == 3.0 - assert v_sub.z == 3.0 - - -def test_vector_scalar_multiplication(): - """Test vector multiplication by a scalar.""" - v1 = Vector(1.0, 2.0, 3.0) - - v_mul = v1 * 2.0 - assert v_mul.x == 2.0 - assert v_mul.y == 4.0 - assert v_mul.z == 6.0 - - # Test right multiplication - v_rmul = 2.0 * v1 - assert v_rmul.x == 2.0 - assert v_rmul.y == 4.0 - assert v_rmul.z == 6.0 - - -def test_vector_scalar_division(): - """Test vector division by a scalar.""" - v2 = Vector(4.0, 5.0, 6.0) - - v_div = v2 / 2.0 - assert v_div.x == 2.0 - assert v_div.y == 2.5 - assert v_div.z == 3.0 - - -def test_vector_dot_product(): - """Test vector dot product.""" - v1 = Vector(1.0, 2.0, 3.0) - v2 = Vector(4.0, 5.0, 6.0) - - dot = v1.dot(v2) - assert dot == 32.0 - - -def test_vector_length(): - """Test vector length calculation.""" - # 2D vector with length 5 - v1 = Vector(3.0, 4.0) - assert v1.length() == 5.0 - - # 3D vector - v2 = Vector(2.0, 3.0, 6.0) - assert v2.length() == pytest.approx(7.0, 0.001) - - # Test length_squared - assert v1.length_squared() == 25.0 - assert v2.length_squared() == 49.0 - - -def test_vector_normalize(): - """Test vector normalization.""" - v = Vector(2.0, 3.0, 6.0) - assert v.is_zero() == False - - v_norm = v.normalize() - length = v.length() - expected_x = 2.0 / length - expected_y = 3.0 / length - expected_z = 6.0 / length - - assert np.isclose(v_norm.x, expected_x) - assert np.isclose(v_norm.y, expected_y) - assert np.isclose(v_norm.z, expected_z) - assert np.isclose(v_norm.length(), 1.0) - assert v_norm.is_zero() == False - - # Test normalizing a zero vector - v_zero = Vector(0.0, 0.0, 0.0) - assert v_zero.is_zero() == True - v_zero_norm = v_zero.normalize() - assert v_zero_norm.x == 0.0 - assert v_zero_norm.y == 0.0 - assert v_zero_norm.z == 0.0 - assert v_zero_norm.is_zero() == True - - -def test_vector_to_2d(): - """Test conversion to 2D vector.""" - v = Vector(2.0, 3.0, 6.0) - - v_2d = v.to_2d() - assert v_2d.x == 2.0 - assert v_2d.y == 3.0 - assert v_2d.z == 0.0 - assert v_2d.dim == 2 - - # Already 2D vector - v2 = Vector(4.0, 5.0) - v2_2d = v2.to_2d() - assert v2_2d.x == 4.0 - assert v2_2d.y == 5.0 - assert v2_2d.dim == 2 - - -def test_vector_distance(): - """Test distance calculations between vectors.""" - v1 = Vector(1.0, 2.0, 3.0) - v2 = Vector(4.0, 6.0, 8.0) - - # Distance - dist = v1.distance(v2) - expected_dist = np.sqrt(9.0 + 16.0 + 25.0) # sqrt((4-1)² + (6-2)² + (8-3)²) - assert dist == pytest.approx(expected_dist) - - # Distance squared - dist_sq = v1.distance_squared(v2) - assert dist_sq == 50.0 # 9 + 16 + 25 - - -def test_vector_cross_product(): - """Test vector cross product.""" - v1 = Vector(1.0, 0.0, 0.0) # Unit x vector - v2 = Vector(0.0, 1.0, 0.0) # Unit y vector - - # v1 × v2 should be unit z vector - cross = v1.cross(v2) - assert cross.x == 0.0 - assert cross.y == 0.0 - assert cross.z == 1.0 - - # Test with more complex vectors - a = Vector(2.0, 3.0, 4.0) - b = Vector(5.0, 6.0, 7.0) - c = a.cross(b) - - # Cross product manually calculated: - # (3*7-4*6, 4*5-2*7, 2*6-3*5) - assert c.x == -3.0 - assert c.y == 6.0 - assert c.z == -3.0 - - # Test with 2D vectors (should raise error) - v_2d = Vector(1.0, 2.0) - with pytest.raises(ValueError): - v_2d.cross(v2) - - -def test_vector_zeros(): - """Test Vector.zeros class method.""" - # 3D zero vector - v_zeros = Vector.zeros(3) - assert v_zeros.x == 0.0 - assert v_zeros.y == 0.0 - assert v_zeros.z == 0.0 - assert v_zeros.dim == 3 - assert v_zeros.is_zero() == True - - # 2D zero vector - v_zeros_2d = Vector.zeros(2) - assert v_zeros_2d.x == 0.0 - assert v_zeros_2d.y == 0.0 - assert v_zeros_2d.z == 0.0 - assert v_zeros_2d.dim == 2 - assert v_zeros_2d.is_zero() == True - - -def test_vector_ones(): - """Test Vector.ones class method.""" - # 3D ones vector - v_ones = Vector.ones(3) - assert v_ones.x == 1.0 - assert v_ones.y == 1.0 - assert v_ones.z == 1.0 - assert v_ones.dim == 3 - - # 2D ones vector - v_ones_2d = Vector.ones(2) - assert v_ones_2d.x == 1.0 - assert v_ones_2d.y == 1.0 - assert v_ones_2d.z == 0.0 - assert v_ones_2d.dim == 2 - - -def test_vector_conversion_methods(): - """Test vector conversion methods (to_list, to_tuple, to_numpy).""" - v = Vector(1.0, 2.0, 3.0) - - # to_list - assert v.to_list() == [1.0, 2.0, 3.0] - - # to_tuple - assert v.to_tuple() == (1.0, 2.0, 3.0) - - # to_numpy - np_array = v.to_numpy() - assert isinstance(np_array, np.ndarray) - assert np.array_equal(np_array, np.array([1.0, 2.0, 3.0])) - - -def test_vector_equality(): - """Test vector equality.""" - v1 = Vector(1, 2, 3) - v2 = Vector(1, 2, 3) - v3 = Vector(4, 5, 6) - - assert v1 == v2 - assert v1 != v3 - assert v1 != Vector(1, 2) # Different dimensions - assert v1 != Vector(1.1, 2, 3) # Different values - assert v1 != [1, 2, 3] - - -def test_vector_is_zero(): - """Test is_zero method for vectors.""" - # Default empty vector - v0 = Vector() - assert v0.is_zero() == True - - # Explicit zero vector - v1 = Vector(0.0, 0.0, 0.0) - assert v1.is_zero() == True - - # Zero vector with different dimensions - v2 = Vector(0.0, 0.0) - assert v2.is_zero() == True - - # Non-zero vectors - v3 = Vector(1.0, 0.0, 0.0) - assert v3.is_zero() == False - - v4 = Vector(0.0, 2.0, 0.0) - assert v4.is_zero() == False - - v5 = Vector(0.0, 0.0, 3.0) - assert v5.is_zero() == False - - # Almost zero (within tolerance) - v6 = Vector(1e-10, 1e-10, 1e-10) - assert v6.is_zero() == True - - # Almost zero (outside tolerance) - v7 = Vector(1e-6, 1e-6, 1e-6) - assert v7.is_zero() == False - - -def test_vector_bool_conversion(): - """Test boolean conversion of vectors.""" - # Zero vectors should be False - v0 = Vector() - assert bool(v0) == False - - v1 = Vector(0.0, 0.0, 0.0) - assert bool(v1) == False - - # Almost zero vectors should be False - v2 = Vector(1e-10, 1e-10, 1e-10) - assert bool(v2) == False - - # Non-zero vectors should be True - v3 = Vector(1.0, 0.0, 0.0) - assert bool(v3) == True - - v4 = Vector(0.0, 2.0, 0.0) - assert bool(v4) == True - - v5 = Vector(0.0, 0.0, 3.0) - assert bool(v5) == True - - # Direct use in if statements - if v0: - assert False, "Zero vector should be False in boolean context" - else: - pass # Expected path - - if v3: - pass # Expected path - else: - assert False, "Non-zero vector should be True in boolean context" - - -def test_vector_add(): - """Test vector addition operator.""" - v1 = Vector(1.0, 2.0, 3.0) - v2 = Vector(4.0, 5.0, 6.0) - - # Using __add__ method - v_add = v1.__add__(v2) - assert v_add.x == 5.0 - assert v_add.y == 7.0 - assert v_add.z == 9.0 - - # Using + operator - v_add_op = v1 + v2 - assert v_add_op.x == 5.0 - assert v_add_op.y == 7.0 - assert v_add_op.z == 9.0 - - # Adding zero vector should return original vector - v_zero = Vector.zeros(3) - assert (v1 + v_zero) == v1 - - -def test_vector_add_dim_mismatch(): - """Test vector addition operator.""" - v1 = Vector(1.0, 2.0) - v2 = Vector(4.0, 5.0, 6.0) - - # Using + operator - v_add_op = v1 + v2 diff --git a/build/lib/dimos/types/timestamped.py b/build/lib/dimos/types/timestamped.py deleted file mode 100644 index 189bf7eaec..0000000000 --- a/build/lib/dimos/types/timestamped.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime, timezone -from typing import Generic, Iterable, Tuple, TypedDict, TypeVar, Union - -# any class that carries a timestamp should inherit from this -# this allows us to work with timeseries in consistent way, allign messages, replay etc -# aditional functionality will come to this class soon - - -class RosStamp(TypedDict): - sec: int - nanosec: int - - -EpochLike = Union[int, float, datetime, RosStamp] - - -def to_timestamp(ts: EpochLike) -> float: - """Convert EpochLike to a timestamp in seconds.""" - if isinstance(ts, datetime): - return ts.timestamp() - if isinstance(ts, (int, float)): - return float(ts) - if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: - return ts["sec"] + ts["nanosec"] / 1e9 - raise TypeError("unsupported timestamp type") - - -class Timestamped: - ts: float - - def __init__(self, ts: float): - self.ts = ts - - def dt(self) -> datetime: - return datetime.fromtimestamp(self.ts, tz=timezone.utc).astimezone() - - def ros_timestamp(self) -> dict[str, int]: - """Convert timestamp to ROS-style dictionary.""" - sec = int(self.ts) - nanosec = int((self.ts - sec) * 1_000_000_000) - return [sec, nanosec] diff --git a/build/lib/dimos/types/vector.py b/build/lib/dimos/types/vector.py deleted file mode 100644 index d980e28105..0000000000 --- a/build/lib/dimos/types/vector.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple, TypeVar, Union, Sequence - -import numpy as np -from dimos.types.ros_polyfill import Vector3 - -T = TypeVar("T", bound="Vector") - -# Vector-like types that can be converted to/from Vector -VectorLike = Union[Sequence[Union[int, float]], Vector3, "Vector", np.ndarray] - - -class Vector: - """A wrapper around numpy arrays for vector operations with intuitive syntax.""" - - def __init__(self, *args: VectorLike): - """Initialize a vector from components or another iterable. - - Examples: - Vector(1, 2) # 2D vector - Vector(1, 2, 3) # 3D vector - Vector([1, 2, 3]) # From list - Vector(np.array([1, 2, 3])) # From numpy array - """ - if len(args) == 1 and hasattr(args[0], "__iter__"): - self._data = np.array(args[0], dtype=float) - - elif len(args) == 1: - self._data = np.array([args[0].x, args[0].y, args[0].z], dtype=float) - - else: - self._data = np.array(args, dtype=float) - - @property - def yaw(self) -> float: - return self.x - - @property - def tuple(self) -> Tuple[float, ...]: - """Tuple representation of the vector.""" - return tuple(self._data) - - @property - def x(self) -> float: - """X component of the vector.""" - return self._data[0] if len(self._data) > 0 else 0.0 - - @property - def y(self) -> float: - """Y component of the vector.""" - return self._data[1] if len(self._data) > 1 else 0.0 - - @property - def z(self) -> float: - """Z component of the vector.""" - return self._data[2] if len(self._data) > 2 else 0.0 - - @property - def dim(self) -> int: - """Dimensionality of the vector.""" - return len(self._data) - - @property - def data(self) -> np.ndarray: - """Get the underlying numpy array.""" - return self._data - - def __getitem__(self, idx): - return self._data[idx] - - def __repr__(self) -> str: - return f"Vector({self.data})" - - def __str__(self) -> str: - if self.dim < 2: - return self.__repr__() - - def getArrow(): - repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] - - if self.x == 0 and self.y == 0: - return "·" - - # Calculate angle in radians and convert to directional index - angle = np.arctan2(self.y, self.x) - # Map angle to 0-7 index (8 directions) with proper orientation - dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) - # Get directional arrow symbol - return repr[dir_index] - - return f"{getArrow()} Vector {self.__repr__()}" - - def serialize(self) -> Tuple: - """Serialize the vector to a tuple.""" - return {"type": "vector", "c": self._data.tolist()} - - def __eq__(self, other) -> bool: - """Check if two vectors are equal using numpy's allclose for floating point comparison.""" - if not isinstance(other, Vector): - return False - if len(self._data) != len(other._data): - return False - return np.allclose(self._data, other._data) - - def __add__(self: T, other: VectorLike) -> T: - other = to_vector(other) - if self.dim != other.dim: - max_dim = max(self.dim, other.dim) - return self.pad(max_dim) + other.pad(max_dim) - return self.__class__(self._data + other._data) - - def __sub__(self: T, other: VectorLike) -> T: - other = to_vector(other) - if self.dim != other.dim: - max_dim = max(self.dim, other.dim) - return self.pad(max_dim) - other.pad(max_dim) - return self.__class__(self._data - other._data) - - def __mul__(self: T, scalar: float) -> T: - return self.__class__(self._data * scalar) - - def __rmul__(self: T, scalar: float) -> T: - return self.__mul__(scalar) - - def __truediv__(self: T, scalar: float) -> T: - return self.__class__(self._data / scalar) - - def __neg__(self: T) -> T: - return self.__class__(-self._data) - - def dot(self, other: VectorLike) -> float: - """Compute dot product.""" - other = to_vector(other) - return float(np.dot(self._data, other._data)) - - def cross(self: T, other: VectorLike) -> T: - """Compute cross product (3D vectors only).""" - if self.dim != 3: - raise ValueError("Cross product is only defined for 3D vectors") - - other = to_vector(other) - if other.dim != 3: - raise ValueError("Cross product requires two 3D vectors") - - return self.__class__(np.cross(self._data, other._data)) - - def length(self) -> float: - """Compute the Euclidean length (magnitude) of the vector.""" - return float(np.linalg.norm(self._data)) - - def length_squared(self) -> float: - """Compute the squared length of the vector (faster than length()).""" - return float(np.sum(self._data * self._data)) - - def normalize(self: T) -> T: - """Return a normalized unit vector in the same direction.""" - length = self.length() - if length < 1e-10: # Avoid division by near-zero - return self.__class__(np.zeros_like(self._data)) - return self.__class__(self._data / length) - - def to_2d(self: T) -> T: - """Convert a vector to a 2D vector by taking only the x and y components.""" - return self.__class__(self._data[:2]) - - def pad(self: T, dim: int) -> T: - """Pad a vector with zeros to reach the specified dimension. - - If vector already has dimension >= dim, it is returned unchanged. - """ - if self.dim >= dim: - return self - - padded = np.zeros(dim, dtype=float) - padded[: len(self._data)] = self._data - return self.__class__(padded) - - def distance(self, other: VectorLike) -> float: - """Compute Euclidean distance to another vector.""" - other = to_vector(other) - return float(np.linalg.norm(self._data - other._data)) - - def distance_squared(self, other: VectorLike) -> float: - """Compute squared Euclidean distance to another vector (faster than distance()).""" - other = to_vector(other) - diff = self._data - other._data - return float(np.sum(diff * diff)) - - def angle(self, other: VectorLike) -> float: - """Compute the angle (in radians) between this vector and another.""" - other = to_vector(other) - if self.length() < 1e-10 or other.length() < 1e-10: - return 0.0 - - cos_angle = np.clip( - np.dot(self._data, other._data) - / (np.linalg.norm(self._data) * np.linalg.norm(other._data)), - -1.0, - 1.0, - ) - return float(np.arccos(cos_angle)) - - def project(self: T, onto: VectorLike) -> T: - """Project this vector onto another vector.""" - onto = to_vector(onto) - onto_length_sq = np.sum(onto._data * onto._data) - if onto_length_sq < 1e-10: - return self.__class__(np.zeros_like(self._data)) - - scalar_projection = np.dot(self._data, onto._data) / onto_length_sq - return self.__class__(scalar_projection * onto._data) - - # this is here to test ros_observable_topic - # doesn't happen irl afaik that we want a vector from ros message - @classmethod - def from_msg(cls: type[T], msg) -> T: - return cls(*msg) - - @classmethod - def zeros(cls: type[T], dim: int) -> T: - """Create a zero vector of given dimension.""" - return cls(np.zeros(dim)) - - @classmethod - def ones(cls: type[T], dim: int) -> T: - """Create a vector of ones with given dimension.""" - return cls(np.ones(dim)) - - @classmethod - def unit_x(cls: type[T], dim: int = 3) -> T: - """Create a unit vector in the x direction.""" - v = np.zeros(dim) - v[0] = 1.0 - return cls(v) - - @classmethod - def unit_y(cls: type[T], dim: int = 3) -> T: - """Create a unit vector in the y direction.""" - v = np.zeros(dim) - v[1] = 1.0 - return cls(v) - - @classmethod - def unit_z(cls: type[T], dim: int = 3) -> T: - """Create a unit vector in the z direction.""" - v = np.zeros(dim) - if dim > 2: - v[2] = 1.0 - return cls(v) - - def to_list(self) -> List[float]: - """Convert the vector to a list.""" - return self._data.tolist() - - def to_tuple(self) -> Tuple[float, ...]: - """Convert the vector to a tuple.""" - return tuple(self._data) - - def to_numpy(self) -> np.ndarray: - """Convert the vector to a numpy array.""" - return self._data - - def is_zero(self) -> bool: - """Check if this is a zero vector (all components are zero). - - Returns: - True if all components are zero, False otherwise - """ - return np.allclose(self._data, 0.0) - - def __bool__(self) -> bool: - """Boolean conversion for Vector. - - A Vector is considered False if it's a zero vector (all components are zero), - and True otherwise. - - Returns: - False if vector is zero, True otherwise - """ - return not self.is_zero() - - -def to_numpy(value: VectorLike) -> np.ndarray: - """Convert a vector-compatible value to a numpy array. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Numpy array representation - """ - if isinstance(value, Vector3): - return np.array([value.x, value.y, value.z], dtype=float) - if isinstance(value, Vector): - return value.data - elif isinstance(value, np.ndarray): - return value - else: - return np.array(value, dtype=float) - - -def to_vector(value: VectorLike) -> Vector: - """Convert a vector-compatible value to a Vector object. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Vector object - """ - if isinstance(value, Vector): - return value - else: - return Vector(value) - - -def to_tuple(value: VectorLike) -> Tuple[float, ...]: - """Convert a vector-compatible value to a tuple. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Tuple of floats - """ - if isinstance(value, Vector3): - return tuple([value.x, value.y, value.z]) - if isinstance(value, Vector): - return tuple(value.data) - elif isinstance(value, np.ndarray): - return tuple(value.tolist()) - elif isinstance(value, tuple): - return value - else: - return tuple(value) - - -def to_list(value: VectorLike) -> List[float]: - """Convert a vector-compatible value to a list. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - List of floats - """ - if isinstance(value, Vector): - return value.data.tolist() - elif isinstance(value, np.ndarray): - return value.tolist() - elif isinstance(value, list): - return value - else: - return list(value) - - -# Helper functions to check dimensionality -def is_2d(value: VectorLike) -> bool: - """Check if a vector-compatible value is 2D. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - True if the value is 2D - """ - if isinstance(value, Vector3): - return False - elif isinstance(value, Vector): - return len(value) == 2 - elif isinstance(value, np.ndarray): - return value.shape[-1] == 2 or value.size == 2 - else: - return len(value) == 2 - - -def is_3d(value: VectorLike) -> bool: - """Check if a vector-compatible value is 3D. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - True if the value is 3D - """ - if isinstance(value, Vector): - return len(value) == 3 - elif isinstance(value, Vector3): - return True - elif isinstance(value, np.ndarray): - return value.shape[-1] == 3 or value.size == 3 - else: - return len(value) == 3 - - -# Extraction functions for XYZ components -def x(value: VectorLike) -> float: - """Get the X component of a vector-compatible value. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - X component as a float - """ - if isinstance(value, Vector): - return value.x - elif isinstance(value, Vector3): - return value.x - else: - return float(to_numpy(value)[0]) - - -def y(value: VectorLike) -> float: - """Get the Y component of a vector-compatible value. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Y component as a float - """ - if isinstance(value, Vector): - return value.y - elif isinstance(value, Vector3): - return value.y - else: - arr = to_numpy(value) - return float(arr[1]) if len(arr) > 1 else 0.0 - - -def z(value: VectorLike) -> float: - """Get the Z component of a vector-compatible value. - - Args: - value: Any vector-like object (Vector, numpy array, tuple, list) - - Returns: - Z component as a float - """ - if isinstance(value, Vector): - return value.z - elif isinstance(value, Vector3): - return value.z - else: - arr = to_numpy(value) - return float(arr[2]) if len(arr) > 2 else 0.0 diff --git a/build/lib/dimos/web/__init__.py b/build/lib/dimos/web/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/web/dimos_interface/__init__.py b/build/lib/dimos/web/dimos_interface/__init__.py deleted file mode 100644 index 5ca28b30e5..0000000000 --- a/build/lib/dimos/web/dimos_interface/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Dimensional Interface package -""" - -from .api.server import FastAPIServer - -__all__ = ["FastAPIServer"] diff --git a/build/lib/dimos/web/dimos_interface/api/__init__.py b/build/lib/dimos/web/dimos_interface/api/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/build/lib/dimos/web/dimos_interface/api/server.py b/build/lib/dimos/web/dimos_interface/api/server.py deleted file mode 100644 index bcc590ab46..0000000000 --- a/build/lib/dimos/web/dimos_interface/api/server.py +++ /dev/null @@ -1,362 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Working FastAPI/Uvicorn Impl. - -# Notes: Do not use simultaneously with Flask, this includes imports. -# Workers are not yet setup, as this requires a much more intricate -# reorganization. There appears to be possible signalling issues when -# opening up streams on multiple windows/reloading which will need to -# be fixed. Also note, Chrome only supports 6 simultaneous web streams, -# and its advised to test threading/worker performance with another -# browser like Safari. - -# Fast Api & Uvicorn -import cv2 -from dimos.web.edge_io import EdgeIO -from fastapi import FastAPI, Request, Form, HTTPException, UploadFile, File -from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse -from sse_starlette.sse import EventSourceResponse -from fastapi.templating import Jinja2Templates -import uvicorn -from threading import Lock -from pathlib import Path -from queue import Queue, Empty -import asyncio - -from reactivex.disposable import SingleAssignmentDisposable -from reactivex import operators as ops -import reactivex as rx -from fastapi.middleware.cors import CORSMiddleware - -# For audio processing -import io -import time -import numpy as np -import ffmpeg -import soundfile as sf -from dimos.stream.audio.base import AudioEvent - -# TODO: Resolve threading, start/stop stream functionality. - - -class FastAPIServer(EdgeIO): - def __init__( - self, - dev_name="FastAPI Server", - edge_type="Bidirectional", - host="0.0.0.0", - port=5555, - text_streams=None, - audio_subject=None, - **streams, - ): - print("Starting FastAPIServer initialization...") # Debug print - super().__init__(dev_name, edge_type) - self.app = FastAPI() - - # Add CORS middleware with more permissive settings for development - self.app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # More permissive for development - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - expose_headers=["*"], - ) - - self.port = port - self.host = host - BASE_DIR = Path(__file__).resolve().parent - self.templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) - self.streams = streams - self.active_streams = {} - self.stream_locks = {key: Lock() for key in self.streams} - self.stream_queues = {} - self.stream_disposables = {} - - # Initialize text streams - self.text_streams = text_streams or {} - self.text_queues = {} - self.text_disposables = {} - self.text_clients = set() - - # Create a Subject for text queries - self.query_subject = rx.subject.Subject() - self.query_stream = self.query_subject.pipe(ops.share()) - self.audio_subject = audio_subject - - for key in self.streams: - if self.streams[key] is not None: - self.active_streams[key] = self.streams[key].pipe( - ops.map(self.process_frame_fastapi), ops.share() - ) - - # Set up text stream subscriptions - for key, stream in self.text_streams.items(): - if stream is not None: - self.text_queues[key] = Queue(maxsize=100) - disposable = stream.subscribe( - lambda text, k=key: self.text_queues[k].put(text) if text is not None else None, - lambda e, k=key: self.text_queues[k].put(None), - lambda k=key: self.text_queues[k].put(None), - ) - self.text_disposables[key] = disposable - self.disposables.add(disposable) - - print("Setting up routes...") # Debug print - self.setup_routes() - print("FastAPIServer initialization complete") # Debug print - - def process_frame_fastapi(self, frame): - """Convert frame to JPEG format for streaming.""" - _, buffer = cv2.imencode(".jpg", frame) - return buffer.tobytes() - - def stream_generator(self, key): - """Generate frames for a given video stream.""" - - def generate(): - if key not in self.stream_queues: - self.stream_queues[key] = Queue(maxsize=10) - - frame_queue = self.stream_queues[key] - - # Clear any existing disposable for this stream - if key in self.stream_disposables: - self.stream_disposables[key].dispose() - - disposable = SingleAssignmentDisposable() - self.stream_disposables[key] = disposable - self.disposables.add(disposable) - - if key in self.active_streams: - with self.stream_locks[key]: - # Clear the queue before starting new subscription - while not frame_queue.empty(): - try: - frame_queue.get_nowait() - except Empty: - break - - disposable.disposable = self.active_streams[key].subscribe( - lambda frame: frame_queue.put(frame) if frame is not None else None, - lambda e: frame_queue.put(None), - lambda: frame_queue.put(None), - ) - - try: - while True: - try: - frame = frame_queue.get(timeout=1) - if frame is None: - break - yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") - except Empty: - # Instead of breaking, continue waiting for new frames - continue - finally: - if key in self.stream_disposables: - self.stream_disposables[key].dispose() - - return generate - - def create_video_feed_route(self, key): - """Create a video feed route for a specific stream.""" - - async def video_feed(): - return StreamingResponse( - self.stream_generator(key)(), media_type="multipart/x-mixed-replace; boundary=frame" - ) - - return video_feed - - async def text_stream_generator(self, key): - """Generate SSE events for text stream.""" - client_id = id(object()) - self.text_clients.add(client_id) - - try: - while True: - if key not in self.text_queues: - yield {"event": "ping", "data": ""} - await asyncio.sleep(0.1) - continue - - try: - text = self.text_queues[key].get_nowait() - if text is not None: - yield {"event": "message", "id": key, "data": text} - else: - break - except Empty: - yield {"event": "ping", "data": ""} - await asyncio.sleep(0.1) - finally: - self.text_clients.remove(client_id) - - @staticmethod - def _decode_audio(raw: bytes) -> tuple[np.ndarray, int]: - """Convert the webm/opus blob sent by the browser into mono 16-kHz PCM.""" - try: - # Use ffmpeg to convert to 16-kHz mono 16-bit PCM WAV in memory - out, _ = ( - ffmpeg.input("pipe:0") - .output( - "pipe:1", - format="wav", - acodec="pcm_s16le", - ac=1, - ar="16000", - loglevel="quiet", - ) - .run(input=raw, capture_stdout=True, capture_stderr=True) - ) - # Load with soundfile (returns float32 by default) - audio, sr = sf.read(io.BytesIO(out), dtype="float32") - # Ensure 1-D array (mono) - if audio.ndim > 1: - audio = audio[:, 0] - return np.array(audio), sr - except Exception as exc: - print(f"ffmpeg decoding failed: {exc}") - return None, None - - def setup_routes(self): - """Set up FastAPI routes.""" - - @self.app.get("/streams") - async def get_streams(): - """Get list of available video streams""" - return {"streams": list(self.streams.keys())} - - @self.app.get("/text_streams") - async def get_text_streams(): - """Get list of available text streams""" - return {"streams": list(self.text_streams.keys())} - - @self.app.get("/", response_class=HTMLResponse) - async def index(request: Request): - stream_keys = list(self.streams.keys()) - text_stream_keys = list(self.text_streams.keys()) - return self.templates.TemplateResponse( - "index_fastapi.html", - { - "request": request, - "stream_keys": stream_keys, - "text_stream_keys": text_stream_keys, - "has_voice": self.audio_subject is not None, - }, - ) - - @self.app.post("/submit_query") - async def submit_query(query: str = Form(...)): - # Using Form directly as a dependency ensures proper form handling - try: - if query: - # Emit the query through our Subject - self.query_subject.on_next(query) - return JSONResponse({"success": True, "message": "Query received"}) - return JSONResponse({"success": False, "message": "No query provided"}) - except Exception as e: - # Ensure we always return valid JSON even on error - return JSONResponse( - status_code=500, - content={"success": False, "message": f"Server error: {str(e)}"}, - ) - - @self.app.post("/upload_audio") - async def upload_audio(file: UploadFile = File(...)): - """Handle audio upload from the browser.""" - if self.audio_subject is None: - return JSONResponse( - status_code=400, - content={"success": False, "message": "Voice input not configured"}, - ) - - try: - data = await file.read() - audio_np, sr = self._decode_audio(data) - if audio_np is None: - return JSONResponse( - status_code=400, - content={"success": False, "message": "Unable to decode audio"}, - ) - - event = AudioEvent( - data=audio_np, - sample_rate=sr, - timestamp=time.time(), - channels=1 if audio_np.ndim == 1 else audio_np.shape[1], - ) - - # Push to reactive stream - self.audio_subject.on_next(event) - print(f"Received audio – {event.data.shape[0] / sr:.2f} s, {sr} Hz") - return {"success": True} - except Exception as e: - print(f"Failed to process uploaded audio: {e}") - return JSONResponse(status_code=500, content={"success": False, "message": str(e)}) - - # Unitree API endpoints - @self.app.get("/unitree/status") - async def unitree_status(): - """Check the status of the Unitree API server""" - return JSONResponse({"status": "online", "service": "unitree"}) - - @self.app.post("/unitree/command") - async def unitree_command(request: Request): - """Process commands sent from the terminal frontend""" - try: - data = await request.json() - command_text = data.get("command", "") - - # Emit the command through the query_subject - self.query_subject.on_next(command_text) - - response = { - "success": True, - "command": command_text, - "result": f"Processed command: {command_text}", - } - - return JSONResponse(response) - except Exception as e: - print(f"Error processing command: {str(e)}") - return JSONResponse( - status_code=500, - content={"success": False, "message": f"Error processing command: {str(e)}"}, - ) - - @self.app.get("/text_stream/{key}") - async def text_stream(key: str): - if key not in self.text_streams: - raise HTTPException(status_code=404, detail=f"Text stream '{key}' not found") - return EventSourceResponse(self.text_stream_generator(key)) - - for key in self.streams: - self.app.get(f"/video_feed/{key}")(self.create_video_feed_route(key)) - - def run(self): - """Run the FastAPI server.""" - uvicorn.run( - self.app, host=self.host, port=self.port - ) # TODO: Translate structure to enable in-built workers' - - -if __name__ == "__main__": - server = FastAPIServer() - server.run() diff --git a/build/lib/dimos/web/edge_io.py b/build/lib/dimos/web/edge_io.py deleted file mode 100644 index 8511df2ce3..0000000000 --- a/build/lib/dimos/web/edge_io.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from reactivex.disposable import CompositeDisposable - - -class EdgeIO: - def __init__(self, dev_name: str = "NA", edge_type: str = "Base"): - self.dev_name = dev_name - self.edge_type = edge_type - self.disposables = CompositeDisposable() - - def dispose_all(self): - """Disposes of all active subscriptions managed by this agent.""" - self.disposables.dispose() diff --git a/build/lib/dimos/web/fastapi_server.py b/build/lib/dimos/web/fastapi_server.py deleted file mode 100644 index 7dcd0f6d73..0000000000 --- a/build/lib/dimos/web/fastapi_server.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Working FastAPI/Uvicorn Impl. - -# Notes: Do not use simultaneously with Flask, this includes imports. -# Workers are not yet setup, as this requires a much more intricate -# reorganization. There appears to be possible signalling issues when -# opening up streams on multiple windows/reloading which will need to -# be fixed. Also note, Chrome only supports 6 simultaneous web streams, -# and its advised to test threading/worker performance with another -# browser like Safari. - -# Fast Api & Uvicorn -import cv2 -from dimos.web.edge_io import EdgeIO -from fastapi import FastAPI, Request, Form, HTTPException -from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse -from sse_starlette.sse import EventSourceResponse -from fastapi.templating import Jinja2Templates -import uvicorn -from threading import Lock -from pathlib import Path -from queue import Queue, Empty -import asyncio - -from reactivex.disposable import SingleAssignmentDisposable -from reactivex import operators as ops -import reactivex as rx - -# TODO: Resolve threading, start/stop stream functionality. - - -class FastAPIServer(EdgeIO): - def __init__( - self, - dev_name="FastAPI Server", - edge_type="Bidirectional", - host="0.0.0.0", - port=5555, - text_streams=None, - **streams, - ): - super().__init__(dev_name, edge_type) - self.app = FastAPI() - self.port = port - self.host = host - BASE_DIR = Path(__file__).resolve().parent - self.templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) - self.streams = streams - self.active_streams = {} - self.stream_locks = {key: Lock() for key in self.streams} - self.stream_queues = {} - self.stream_disposables = {} - - # Initialize text streams - self.text_streams = text_streams or {} - self.text_queues = {} - self.text_disposables = {} - self.text_clients = set() - - # Create a Subject for text queries - self.query_subject = rx.subject.Subject() - self.query_stream = self.query_subject.pipe(ops.share()) - - for key in self.streams: - if self.streams[key] is not None: - self.active_streams[key] = self.streams[key].pipe( - ops.map(self.process_frame_fastapi), ops.share() - ) - - # Set up text stream subscriptions - for key, stream in self.text_streams.items(): - if stream is not None: - self.text_queues[key] = Queue(maxsize=100) - disposable = stream.subscribe( - lambda text, k=key: self.text_queues[k].put(text) if text is not None else None, - lambda e, k=key: self.text_queues[k].put(None), - lambda k=key: self.text_queues[k].put(None), - ) - self.text_disposables[key] = disposable - self.disposables.add(disposable) - - self.setup_routes() - - def process_frame_fastapi(self, frame): - """Convert frame to JPEG format for streaming.""" - _, buffer = cv2.imencode(".jpg", frame) - return buffer.tobytes() - - def stream_generator(self, key): - """Generate frames for a given video stream.""" - - def generate(): - if key not in self.stream_queues: - self.stream_queues[key] = Queue(maxsize=10) - - frame_queue = self.stream_queues[key] - - # Clear any existing disposable for this stream - if key in self.stream_disposables: - self.stream_disposables[key].dispose() - - disposable = SingleAssignmentDisposable() - self.stream_disposables[key] = disposable - self.disposables.add(disposable) - - if key in self.active_streams: - with self.stream_locks[key]: - # Clear the queue before starting new subscription - while not frame_queue.empty(): - try: - frame_queue.get_nowait() - except Empty: - break - - disposable.disposable = self.active_streams[key].subscribe( - lambda frame: frame_queue.put(frame) if frame is not None else None, - lambda e: frame_queue.put(None), - lambda: frame_queue.put(None), - ) - - try: - while True: - try: - frame = frame_queue.get(timeout=1) - if frame is None: - break - yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") - except Empty: - # Instead of breaking, continue waiting for new frames - continue - finally: - if key in self.stream_disposables: - self.stream_disposables[key].dispose() - - return generate - - def create_video_feed_route(self, key): - """Create a video feed route for a specific stream.""" - - async def video_feed(): - return StreamingResponse( - self.stream_generator(key)(), media_type="multipart/x-mixed-replace; boundary=frame" - ) - - return video_feed - - async def text_stream_generator(self, key): - """Generate SSE events for text stream.""" - client_id = id(object()) - self.text_clients.add(client_id) - - try: - while True: - if key in self.text_queues: - try: - text = self.text_queues[key].get(timeout=1) - if text is not None: - yield {"event": "message", "id": key, "data": text} - except Empty: - # Send a keep-alive comment - yield {"event": "ping", "data": ""} - await asyncio.sleep(0.1) - finally: - self.text_clients.remove(client_id) - - def setup_routes(self): - """Set up FastAPI routes.""" - - @self.app.get("/", response_class=HTMLResponse) - async def index(request: Request): - stream_keys = list(self.streams.keys()) - text_stream_keys = list(self.text_streams.keys()) - return self.templates.TemplateResponse( - "index_fastapi.html", - { - "request": request, - "stream_keys": stream_keys, - "text_stream_keys": text_stream_keys, - }, - ) - - @self.app.post("/submit_query") - async def submit_query(query: str = Form(...)): - # Using Form directly as a dependency ensures proper form handling - try: - if query: - # Emit the query through our Subject - self.query_subject.on_next(query) - return JSONResponse({"success": True, "message": "Query received"}) - return JSONResponse({"success": False, "message": "No query provided"}) - except Exception as e: - # Ensure we always return valid JSON even on error - return JSONResponse( - status_code=500, - content={"success": False, "message": f"Server error: {str(e)}"}, - ) - - @self.app.get("/text_stream/{key}") - async def text_stream(key: str): - if key not in self.text_streams: - raise HTTPException(status_code=404, detail=f"Text stream '{key}' not found") - return EventSourceResponse(self.text_stream_generator(key)) - - for key in self.streams: - self.app.get(f"/video_feed/{key}")(self.create_video_feed_route(key)) - - def run(self): - """Run the FastAPI server.""" - uvicorn.run( - self.app, host=self.host, port=self.port - ) # TODO: Translate structure to enable in-built workers' diff --git a/build/lib/dimos/web/flask_server.py b/build/lib/dimos/web/flask_server.py deleted file mode 100644 index 01d79f63cd..0000000000 --- a/build/lib/dimos/web/flask_server.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from flask import Flask, Response, render_template -import cv2 -from reactivex import operators as ops -from reactivex.disposable import SingleAssignmentDisposable -from queue import Queue - -from dimos.web.edge_io import EdgeIO - - -class FlaskServer(EdgeIO): - def __init__(self, dev_name="Flask Server", edge_type="Bidirectional", port=5555, **streams): - super().__init__(dev_name, edge_type) - self.app = Flask(__name__) - self.port = port - self.streams = streams - self.active_streams = {} - - # Initialize shared stream references with ref_count - for key in self.streams: - if self.streams[key] is not None: - # Apply share and ref_count to manage subscriptions - self.active_streams[key] = self.streams[key].pipe( - ops.map(self.process_frame_flask), ops.share() - ) - - self.setup_routes() - - def process_frame_flask(self, frame): - """Convert frame to JPEG format for streaming.""" - _, buffer = cv2.imencode(".jpg", frame) - return buffer.tobytes() - - def setup_routes(self): - @self.app.route("/") - def index(): - stream_keys = list(self.streams.keys()) # Get the keys from the streams dictionary - return render_template("index_flask.html", stream_keys=stream_keys) - - # Function to create a streaming response - def stream_generator(key): - def generate(): - frame_queue = Queue() - disposable = SingleAssignmentDisposable() - - # Subscribe to the shared, ref-counted stream - if key in self.active_streams: - disposable.disposable = self.active_streams[key].subscribe( - lambda frame: frame_queue.put(frame) if frame is not None else None, - lambda e: frame_queue.put(None), - lambda: frame_queue.put(None), - ) - - try: - while True: - frame = frame_queue.get() - if frame is None: - break - yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") - finally: - disposable.dispose() - - return generate - - def make_response_generator(key): - def response_generator(): - return Response( - stream_generator(key)(), mimetype="multipart/x-mixed-replace; boundary=frame" - ) - - return response_generator - - # Dynamically adding routes using add_url_rule - for key in self.streams: - endpoint = f"video_feed_{key}" - self.app.add_url_rule( - f"/video_feed/{key}", endpoint, view_func=make_response_generator(key) - ) - - def run(self, host="0.0.0.0", port=5555, threaded=True): - self.port = port - self.app.run(host=host, port=self.port, debug=False, threaded=threaded) diff --git a/build/lib/dimos/web/robot_web_interface.py b/build/lib/dimos/web/robot_web_interface.py deleted file mode 100644 index 33847c0056..0000000000 --- a/build/lib/dimos/web/robot_web_interface.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Robot Web Interface wrapper for DIMOS. -Provides a clean interface to the dimensional-interface FastAPI server. -""" - -from dimos.web.dimos_interface.api.server import FastAPIServer - - -class RobotWebInterface(FastAPIServer): - """Wrapper class for the dimos-interface FastAPI server.""" - - def __init__(self, port=5555, text_streams=None, audio_subject=None, **streams): - super().__init__( - dev_name="Robot Web Interface", - edge_type="Bidirectional", - host="0.0.0.0", - port=port, - text_streams=text_streams, - audio_subject=audio_subject, - **streams, - ) diff --git a/build/lib/tests/__init__.py b/build/lib/tests/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/build/lib/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/build/lib/tests/agent_manip_flow_fastapi_test.py b/build/lib/tests/agent_manip_flow_fastapi_test.py deleted file mode 100644 index c7dec66f74..0000000000 --- a/build/lib/tests/agent_manip_flow_fastapi_test.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module initializes and manages the video processing pipeline integrated with a web server. -It handles video capture, frame processing, and exposes the processed video streams via HTTP endpoints. -""" - -import tests.test_header -import os - -# ----- - -# Standard library imports -import multiprocessing -from dotenv import load_dotenv - -# Third-party imports -from fastapi import FastAPI -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler, ImmediateScheduler - -# Local application imports -from dimos.agents.agent import OpenAIAgent -from dimos.stream.frame_processor import FrameProcessor -from dimos.stream.video_operators import VideoOperators as vops -from dimos.stream.video_provider import VideoProvider -from dimos.web.fastapi_server import FastAPIServer - -# Load environment variables -load_dotenv() - - -def main(): - """ - Initializes and runs the video processing pipeline with web server output. - - This function orchestrates a video processing system that handles capture, processing, - and visualization of video streams. It demonstrates parallel processing capabilities - and various video manipulation techniques across multiple stages including capture - and processing at different frame rates, edge detection, and optical flow analysis. - - Raises: - RuntimeError: If video sources are unavailable or processing fails. - """ - disposables = CompositeDisposable() - - processor = FrameProcessor( - output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True - ) - - optimal_thread_count = multiprocessing.cpu_count() # Gets number of CPU cores - thread_pool_scheduler = ThreadPoolScheduler(optimal_thread_count) - - VIDEO_SOURCES = [ - f"{os.getcwd()}/assets/ldru.mp4", - f"{os.getcwd()}/assets/ldru_480p.mp4", - f"{os.getcwd()}/assets/trimmed_video_480p.mov", - f"{os.getcwd()}/assets/video-f30-480p.mp4", - "rtsp://192.168.50.207:8080/h264.sdp", - "rtsp://10.0.0.106:8080/h264.sdp", - ] - - VIDEO_SOURCE_INDEX = 3 - VIDEO_SOURCE_INDEX_2 = 2 - - my_video_provider = VideoProvider("Video File", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX]) - my_video_provider_2 = VideoProvider( - "Video File 2", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX_2] - ) - - video_stream_obs = my_video_provider.capture_video_as_observable(fps=120).pipe( - ops.subscribe_on(thread_pool_scheduler), - # Move downstream operations to thread pool for parallel processing - # Disabled: Evaluating performance impact - # ops.observe_on(thread_pool_scheduler), - vops.with_jpeg_export(processor, suffix="raw"), - vops.with_fps_sampling(fps=30), - vops.with_jpeg_export(processor, suffix="raw_slowed"), - ) - - video_stream_obs_2 = my_video_provider_2.capture_video_as_observable(fps=120).pipe( - ops.subscribe_on(thread_pool_scheduler), - # Move downstream operations to thread pool for parallel processing - # Disabled: Evaluating performance impact - # ops.observe_on(thread_pool_scheduler), - vops.with_jpeg_export(processor, suffix="raw_2"), - vops.with_fps_sampling(fps=30), - vops.with_jpeg_export(processor, suffix="raw_2_slowed"), - ) - - edge_detection_stream_obs = processor.process_stream_edge_detection(video_stream_obs).pipe( - vops.with_jpeg_export(processor, suffix="edge"), - ) - - optical_flow_relevancy_stream_obs = processor.process_stream_optical_flow_with_relevancy( - video_stream_obs - ) - - optical_flow_stream_obs = optical_flow_relevancy_stream_obs.pipe( - ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), - vops.with_optical_flow_filtering(threshold=2.0), - ops.do_action(lambda _: print(f"Optical Flow Passed Threshold.")), - vops.with_jpeg_export(processor, suffix="optical"), - ) - - # - # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== - # - - # Agent 1 - # my_agent = OpenAIAgent( - # "Agent 1", - # query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.") - # my_agent.subscribe_to_image_processing(slowed_video_stream_obs) - # disposables.add(my_agent.disposables) - - # # Agent 2 - # my_agent_two = OpenAIAgent( - # "Agent 2", - # query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.") - # my_agent_two.subscribe_to_image_processing(optical_flow_stream_obs) - # disposables.add(my_agent_two.disposables) - - # - # ====== Create and start the FastAPI server ====== - # - - # Will be visible at http://[host]:[port]/video_feed/[key] - streams = { - "video_one": video_stream_obs, - "video_two": video_stream_obs_2, - "edge_detection": edge_detection_stream_obs, - "optical_flow": optical_flow_stream_obs, - } - fast_api_server = FastAPIServer(port=5555, **streams) - fast_api_server.run() - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/agent_manip_flow_flask_test.py b/build/lib/tests/agent_manip_flow_flask_test.py deleted file mode 100644 index 2356eb74ae..0000000000 --- a/build/lib/tests/agent_manip_flow_flask_test.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module initializes and manages the video processing pipeline integrated with a web server. -It handles video capture, frame processing, and exposes the processed video streams via HTTP endpoints. -""" - -import tests.test_header -import os - -# ----- - -# Standard library imports -import multiprocessing -from dotenv import load_dotenv - -# Third-party imports -from flask import Flask -from reactivex import operators as ops -from reactivex import of, interval, zip -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler, ImmediateScheduler - -# Local application imports -from dimos.agents.agent import PromptBuilder, OpenAIAgent -from dimos.stream.frame_processor import FrameProcessor -from dimos.stream.video_operators import VideoOperators as vops -from dimos.stream.video_provider import VideoProvider -from dimos.web.flask_server import FlaskServer - -# Load environment variables -load_dotenv() - -app = Flask(__name__) - - -def main(): - """ - Initializes and runs the video processing pipeline with web server output. - - This function orchestrates a video processing system that handles capture, processing, - and visualization of video streams. It demonstrates parallel processing capabilities - and various video manipulation techniques across multiple stages including capture - and processing at different frame rates, edge detection, and optical flow analysis. - - Raises: - RuntimeError: If video sources are unavailable or processing fails. - """ - disposables = CompositeDisposable() - - processor = FrameProcessor( - output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True - ) - - optimal_thread_count = multiprocessing.cpu_count() # Gets number of CPU cores - thread_pool_scheduler = ThreadPoolScheduler(optimal_thread_count) - - VIDEO_SOURCES = [ - f"{os.getcwd()}/assets/ldru.mp4", - f"{os.getcwd()}/assets/ldru_480p.mp4", - f"{os.getcwd()}/assets/trimmed_video_480p.mov", - f"{os.getcwd()}/assets/video-f30-480p.mp4", - f"{os.getcwd()}/assets/video.mov", - "rtsp://192.168.50.207:8080/h264.sdp", - "rtsp://10.0.0.106:8080/h264.sdp", - f"{os.getcwd()}/assets/people_1080p_24fps.mp4", - ] - - VIDEO_SOURCE_INDEX = 4 - - my_video_provider = VideoProvider("Video File", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX]) - - video_stream_obs = my_video_provider.capture_video_as_observable(fps=120).pipe( - ops.subscribe_on(thread_pool_scheduler), - # Move downstream operations to thread pool for parallel processing - # Disabled: Evaluating performance impact - # ops.observe_on(thread_pool_scheduler), - # vops.with_jpeg_export(processor, suffix="raw"), - vops.with_fps_sampling(fps=30), - # vops.with_jpeg_export(processor, suffix="raw_slowed"), - ) - - edge_detection_stream_obs = processor.process_stream_edge_detection(video_stream_obs).pipe( - # vops.with_jpeg_export(processor, suffix="edge"), - ) - - optical_flow_relevancy_stream_obs = processor.process_stream_optical_flow(video_stream_obs) - - optical_flow_stream_obs = optical_flow_relevancy_stream_obs.pipe( - # ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), - # vops.with_optical_flow_filtering(threshold=2.0), - # ops.do_action(lambda _: print(f"Optical Flow Passed Threshold.")), - # vops.with_jpeg_export(processor, suffix="optical") - ) - - # - # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== - # - - # Observable that emits every 2 seconds - secondly_emission = interval(2, scheduler=thread_pool_scheduler).pipe( - ops.map(lambda x: f"Second {x + 1}"), - # ops.take(30) - ) - - # Agent 1 - my_agent = OpenAIAgent( - "Agent 1", - query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.", - json_mode=False, - ) - - # Create an agent for each subset of questions that it would be theroized to handle. - # Set std. template/blueprints, and devs will add to that likely. - - ai_1_obs = video_stream_obs.pipe( - # vops.with_fps_sampling(fps=30), - # ops.throttle_first(1), - vops.with_jpeg_export(processor, suffix="open_ai_agent_1"), - ops.take(30), - ops.replay(buffer_size=30, scheduler=thread_pool_scheduler), - ) - ai_1_obs.connect() - - ai_1_repeat_obs = ai_1_obs.pipe(ops.repeat()) - - my_agent.subscribe_to_image_processing(ai_1_obs) - disposables.add(my_agent.disposables) - - # Agent 2 - my_agent_two = OpenAIAgent( - "Agent 2", - query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.", - max_input_tokens_per_request=1000, - max_output_tokens_per_request=300, - json_mode=False, - model_name="gpt-4o-2024-08-06", - ) - - ai_2_obs = optical_flow_stream_obs.pipe( - # vops.with_fps_sampling(fps=30), - # ops.throttle_first(1), - vops.with_jpeg_export(processor, suffix="open_ai_agent_2"), - ops.take(30), - ops.replay(buffer_size=30, scheduler=thread_pool_scheduler), - ) - ai_2_obs.connect() - - ai_2_repeat_obs = ai_2_obs.pipe(ops.repeat()) - - # Combine emissions using zip - ai_1_secondly_repeating_obs = zip(secondly_emission, ai_1_repeat_obs).pipe( - # ops.do_action(lambda s: print(f"AI 1 - Emission Count: {s[0]}")), - ops.map(lambda r: r[1]), - ) - - # Combine emissions using zip - ai_2_secondly_repeating_obs = zip(secondly_emission, ai_2_repeat_obs).pipe( - # ops.do_action(lambda s: print(f"AI 2 - Emission Count: {s[0]}")), - ops.map(lambda r: r[1]), - ) - - my_agent_two.subscribe_to_image_processing(ai_2_obs) - disposables.add(my_agent_two.disposables) - - # - # ====== Create and start the Flask server ====== - # - - # Will be visible at http://[host]:[port]/video_feed/[key] - flask_server = FlaskServer( - # video_one=video_stream_obs, - # edge_detection=edge_detection_stream_obs, - # optical_flow=optical_flow_stream_obs, - OpenAIAgent_1=ai_1_secondly_repeating_obs, - OpenAIAgent_2=ai_2_secondly_repeating_obs, - ) - - flask_server.run(threaded=True) - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/agent_memory_test.py b/build/lib/tests/agent_memory_test.py deleted file mode 100644 index b662af18bd..0000000000 --- a/build/lib/tests/agent_memory_test.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -# ----- - -from dotenv import load_dotenv -import os - -load_dotenv() - -from dimos.agents.memory.chroma_impl import OpenAISemanticMemory - -agent_memory = OpenAISemanticMemory() -print("Initialization done.") - -agent_memory.add_vector("id0", "Food") -agent_memory.add_vector("id1", "Cat") -agent_memory.add_vector("id2", "Mouse") -agent_memory.add_vector("id3", "Bike") -agent_memory.add_vector("id4", "Dog") -agent_memory.add_vector("id5", "Tricycle") -agent_memory.add_vector("id6", "Car") -agent_memory.add_vector("id7", "Horse") -agent_memory.add_vector("id8", "Vehicle") -agent_memory.add_vector("id6", "Red") -agent_memory.add_vector("id7", "Orange") -agent_memory.add_vector("id8", "Yellow") -print("Adding vectors done.") - -print(agent_memory.get_vector("id1")) -print("Done retrieving sample vector.") - -results = agent_memory.query("Colors") -print(results) -print("Done querying agent memory (basic).") - -results = agent_memory.query("Colors", similarity_threshold=0.2) -print(results) -print("Done querying agent memory (similarity_threshold=0.2).") - -results = agent_memory.query("Colors", n_results=2) -print(results) -print("Done querying agent memory (n_results=2).") - -results = agent_memory.query("Colors", n_results=19, similarity_threshold=0.45) -print(results) -print("Done querying agent memory (n_results=19, similarity_threshold=0.45).") diff --git a/build/lib/tests/colmap_test.py b/build/lib/tests/colmap_test.py deleted file mode 100644 index e1f451a7dc..0000000000 --- a/build/lib/tests/colmap_test.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os -import sys - -# ----- - -# Now try to import -from dimos.environment.colmap_environment import COLMAPEnvironment - -env = COLMAPEnvironment() -env.initialize_from_video("data/IMG_1525.MOV", "data/frames") diff --git a/build/lib/tests/run.py b/build/lib/tests/run.py deleted file mode 100644 index 9ae6f81398..0000000000 --- a/build/lib/tests/run.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -import time -from dotenv import load_dotenv -from dimos.agents.cerebras_agent import CerebrasAgent -from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 - -# from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.observe import Observe -from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal, Explore -from dimos.skills.visual_navigation_skills import FollowHuman -import reactivex as rx -import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -import threading -import json -from dimos.types.vector import Vector -from dimos.skills.unitree.unitree_speak import UnitreeSpeak - -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.utils.reactive import backpressure -import asyncio -import atexit -import signal -import sys -import warnings -import logging - -# Filter out known WebRTC warnings that don't affect functionality -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message=".*RTCSctpTransport.*") - -# Set up logging to reduce asyncio noise -logging.getLogger("asyncio").setLevel(logging.ERROR) - -# Load API key from environment -load_dotenv() - -# Allow command line arguments to control spatial memory parameters -import argparse - - -def parse_arguments(): - parser = argparse.ArgumentParser( - description="Run the robot with optional spatial memory parameters" - ) - parser.add_argument( - "--new-memory", action="store_true", help="Create a new spatial memory from scratch" - ) - parser.add_argument( - "--spatial-memory-dir", type=str, help="Directory for storing spatial memory data" - ) - return parser.parse_args() - - -args = parse_arguments() - -# Initialize robot with spatial memory parameters - using WebRTC mode instead of "ai" -robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - mode="normal", -) - - -# Add graceful shutdown handling to prevent WebRTC task destruction errors -def cleanup_robot(): - print("Cleaning up robot connection...") - try: - # Make cleanup non-blocking to avoid hangs - def quick_cleanup(): - try: - robot.liedown() - except: - pass - - # Run cleanup in a separate thread with timeout - cleanup_thread = threading.Thread(target=quick_cleanup) - cleanup_thread.daemon = True - cleanup_thread.start() - cleanup_thread.join(timeout=3.0) # Max 3 seconds for cleanup - - # Force stop the robot's WebRTC connection - try: - robot.stop() - except: - pass - - except Exception as e: - print(f"Error during cleanup: {e}") - # Continue anyway - - -atexit.register(cleanup_robot) - - -def signal_handler(signum, frame): - print("Received shutdown signal, cleaning up...") - try: - cleanup_robot() - except: - pass - # Force exit if cleanup hangs - os._exit(0) - - -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - -# Initialize WebSocket visualization -websocket_vis = WebsocketVis() -websocket_vis.start() -websocket_vis.connect(robot.global_planner.vis_stream()) - - -def msg_handler(msgtype, data): - if msgtype == "click": - print(f"Received click at position: {data['position']}") - - try: - print("Setting goal...") - - # Instead of disabling visualization, make it timeout-safe - original_vis = robot.global_planner.vis - - def safe_vis(name, drawable): - """Visualization wrapper that won't block on timeouts""" - try: - # Use a separate thread for visualization to avoid blocking - def vis_update(): - try: - original_vis(name, drawable) - except Exception as e: - print(f"Visualization update failed (non-critical): {e}") - - vis_thread = threading.Thread(target=vis_update) - vis_thread.daemon = True - vis_thread.start() - # Don't wait for completion - let it run asynchronously - except Exception as e: - print(f"Visualization setup failed (non-critical): {e}") - - robot.global_planner.vis = safe_vis - robot.global_planner.set_goal(Vector(data["position"])) - robot.global_planner.vis = original_vis - - print("Goal set successfully") - except Exception as e: - print(f"Error setting goal: {e}") - import traceback - - traceback.print_exc() - - -def threaded_msg_handler(msgtype, data): - print(f"Processing message: {msgtype}") - - # Create a dedicated event loop for goal setting to avoid conflicts - def run_with_dedicated_loop(): - try: - # Use asyncio.run which creates and manages its own event loop - # This won't conflict with the robot's or websocket's event loops - async def async_msg_handler(): - msg_handler(msgtype, data) - - asyncio.run(async_msg_handler()) - print("Goal setting completed successfully") - except Exception as e: - print(f"Error in goal setting thread: {e}") - import traceback - - traceback.print_exc() - - thread = threading.Thread(target=run_with_dedicated_loop) - thread.daemon = True - thread.start() - - -websocket_vis.msg_handler = threaded_msg_handler - - -def newmap(msg): - return ["costmap", robot.map.costmap.smudge()] - - -websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) -websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) - -# Create a subject for agent responses -agent_response_subject = rx.subject.Subject() -agent_response_stream = agent_response_subject.pipe(ops.share()) -local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) -audio_subject = rx.subject.Subject() - -# Initialize object detection stream -min_confidence = 0.6 -class_filter = None # No class filtering - -# Create video stream from robot's camera -video_stream = backpressure(robot.get_video_stream()) # WebRTC doesn't use ROS video stream - -# # Initialize ObjectDetectionStream with robot -object_detector = ObjectDetectionStream( - camera_intrinsics=robot.camera_intrinsics, - class_filter=class_filter, - get_pose=robot.get_pose, - video_stream=video_stream, - draw_masks=True, -) - -# # Create visualization stream for web interface -viz_stream = backpressure(object_detector.get_stream()).pipe( - ops.share(), - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), -) - -# # Get the formatted detection stream -formatted_detection_stream = object_detector.get_formatted_stream().pipe( - ops.filter(lambda x: x is not None) -) - - -# Create a direct mapping that combines detection data with locations -def combine_with_locations(object_detections): - # Get locations from spatial memory - try: - spatial_memory = robot.get_spatial_memory() - if spatial_memory is None: - # If spatial memory is disabled, just return the object detections - return object_detections - - locations = spatial_memory.get_robot_locations() - - # Format the locations section - locations_text = "\n\nSaved Robot Locations:\n" - if locations: - for loc in locations: - locations_text += f"- {loc.name}: Position ({loc.position[0]:.2f}, {loc.position[1]:.2f}, {loc.position[2]:.2f}), " - locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" - else: - locations_text += "None\n" - - # Simply concatenate the strings - return object_detections + locations_text - except Exception as e: - print(f"Error adding locations: {e}") - return object_detections - - -# Create the combined stream with a simple pipe operation -enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) - -streams = { - "unitree_video": robot.get_video_stream(), # Changed from get_ros_video_stream to get_video_stream for WebRTC - "local_planner_viz": local_planner_viz_stream, - "object_detection": viz_stream, # Uncommented object detection -} -text_streams = { - "agent_responses": agent_response_stream, -} - -web_interface = RobotWebInterface( - port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams -) - -stt_node = stt() -stt_node.consume_audio(audio_subject.pipe(ops.share())) - -# Read system query from prompt.txt file -with open( - os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets/agent/prompt.txt"), "r" -) as f: - system_query = f.read() - -# Create a ClaudeAgent instance -agent = ClaudeAgent( - dev_name="test_agent", - input_query_stream=stt_node.emit_text(), - # input_query_stream=web_interface.query_stream, - input_data_stream=enhanced_data_stream, - skills=robot.get_skills(), - system_query=system_query, - model_name="claude-3-5-haiku-latest", - thinking_budget_tokens=0, - max_output_tokens_per_request=8192, - # model_name="llama-4-scout-17b-16e-instruct", -) - -# tts_node = tts() -# tts_node.consume_text(agent.get_response_observable()) - -robot_skills = robot.get_skills() -robot_skills.add(ObserveStream) -robot_skills.add(Observe) -robot_skills.add(KillSkill) -robot_skills.add(NavigateWithText) -# robot_skills.add(FollowHuman) # TODO: broken -robot_skills.add(GetPose) -robot_skills.add(UnitreeSpeak) # Re-enable Speak skill -robot_skills.add(NavigateToGoal) -robot_skills.add(Explore) - -robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) -robot_skills.create_instance("Observe", robot=robot, agent=agent) -robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) -robot_skills.create_instance("NavigateWithText", robot=robot) -# robot_skills.create_instance("FollowHuman", robot=robot) -robot_skills.create_instance("GetPose", robot=robot) -robot_skills.create_instance("NavigateToGoal", robot=robot) -robot_skills.create_instance("Explore", robot=robot) -robot_skills.create_instance("UnitreeSpeak", robot=robot) # Now only needs robot instance - -# Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - -print("ObserveStream and Kill skills registered and ready for use") -print("Created memory.txt file") - -# Start web interface in a separate thread to avoid blocking -web_thread = threading.Thread(target=web_interface.run) -web_thread.daemon = True -web_thread.start() - -try: - while True: - # Main loop - can add robot movement or other logic here - time.sleep(0.01) - -except KeyboardInterrupt: - print("Stopping robot") - robot.liedown() -except Exception as e: - print(f"Unexpected error in main loop: {e}") - import traceback - - traceback.print_exc() -finally: - print("Cleaning up...") - cleanup_robot() diff --git a/build/lib/tests/run_go2_ros.py b/build/lib/tests/run_go2_ros.py deleted file mode 100644 index 6bba1c1797..0000000000 --- a/build/lib/tests/run_go2_ros.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -import os -import time - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2, WebRTCConnectionMethod -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl - - -def get_env_var(var_name, default=None, required=False): - """Get environment variable with validation.""" - value = os.getenv(var_name, default) - if value == "": - value = default - if required and not value: - raise ValueError(f"{var_name} environment variable is required") - return value - - -if __name__ == "__main__": - # Get configuration from environment variables - robot_ip = get_env_var("ROBOT_IP") - connection_method = get_env_var("CONNECTION_METHOD", "LocalSTA") - serial_number = get_env_var("SERIAL_NUMBER", None) - output_dir = get_env_var("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) - - # Ensure output directory exists - os.makedirs(output_dir, exist_ok=True) - print(f"Ensuring output directory exists: {output_dir}") - - use_ros = True - use_webrtc = False - # Convert connection method string to enum - connection_method = getattr(WebRTCConnectionMethod, connection_method) - - print("Initializing UnitreeGo2...") - print(f"Configuration:") - print(f" IP: {robot_ip}") - print(f" Connection Method: {connection_method}") - print(f" Serial Number: {serial_number if serial_number else 'Not provided'}") - print(f" Output Directory: {output_dir}") - - if use_ros: - ros_control = UnitreeROSControl(node_name="unitree_go2", use_raw=True) - else: - ros_control = None - - robot = UnitreeGo2( - ip=robot_ip, - connection_method=connection_method, - serial_number=serial_number, - output_dir=output_dir, - ros_control=ros_control, - use_ros=use_ros, - use_webrtc=use_webrtc, - ) - time.sleep(5) - try: - # Start perception - print("\nStarting perception system...") - - # Get the processed stream - processed_stream = robot.get_ros_video_stream(fps=30) - - # Create frame counter for unique filenames - frame_count = 0 - - # Create a subscriber to handle the frames - def handle_frame(frame): - global frame_count - frame_count += 1 - - try: - # Save frame to output directory if desired for debugging frame streaming - # MAKE SURE TO CHANGE OUTPUT DIR depending on if running in ROS or local - # frame_path = os.path.join(output_dir, f"frame_{frame_count:04d}.jpg") - # success = cv2.imwrite(frame_path, frame) - # print(f"Frame #{frame_count} {'saved successfully' if success else 'failed to save'} to {frame_path}") - pass - - except Exception as e: - print(f"Error in handle_frame: {e}") - import traceback - - print(traceback.format_exc()) - - def handle_error(error): - print(f"Error in stream: {error}") - - def handle_completion(): - print("Stream completed") - - # Subscribe to the stream - print("Creating subscription...") - try: - subscription = processed_stream.subscribe( - on_next=handle_frame, - on_error=lambda e: print(f"Subscription error: {e}"), - on_completed=lambda: print("Subscription completed"), - ) - print("Subscription created successfully") - except Exception as e: - print(f"Error creating subscription: {e}") - - time.sleep(5) - - # First put the robot in a good starting state - print("Running recovery stand...") - robot.webrtc_req(api_id=1006) # RecoveryStand - - # Queue 20 WebRTC requests back-to-back - print("\n🤖 QUEUEING WEBRTC COMMANDS BACK-TO-BACK FOR TESTING UnitreeGo2🤖\n") - - # Dance 1 - robot.webrtc_req(api_id=1033) - print("Queued: WiggleHips (1033)") - - robot.reverse(distance=0.2, speed=0.5) - print("Queued: Reverse 0.5m at 0.5m/s") - - # Wiggle Hips - robot.webrtc_req(api_id=1033) - print("Queued: WiggleHips (1033)") - - robot.move(distance=0.2, speed=0.5) - print("Queued: Move forward 1.0m at 0.5m/s") - - robot.webrtc_req(api_id=1017) - print("Queued: Stretch (1017)") - - robot.move(distance=0.2, speed=0.5) - print("Queued: Move forward 1.0m at 0.5m/s") - - robot.webrtc_req(api_id=1017) - print("Queued: Stretch (1017)") - - robot.reverse(distance=0.2, speed=0.5) - print("Queued: Reverse 0.5m at 0.5m/s") - - robot.webrtc_req(api_id=1017) - print("Queued: Stretch (1017)") - robot.spin(degrees=-90.0, speed=45.0) - print("Queued: Spin right 90 degrees at 45 degrees/s") - - robot.spin(degrees=90.0, speed=45.0) - print("Queued: Spin left 90 degrees at 45 degrees/s") - - # To prevent termination - while True: - time.sleep(0.1) - - except KeyboardInterrupt: - print("\nStopping perception...") - if "subscription" in locals(): - subscription.dispose() - except Exception as e: - print(f"Error in main loop: {e}") - finally: - # Cleanup - print("Cleaning up resources...") - if "subscription" in locals(): - subscription.dispose() - del robot - print("Cleanup complete.") diff --git a/build/lib/tests/run_navigation_only.py b/build/lib/tests/run_navigation_only.py deleted file mode 100644 index 2995750e2b..0000000000 --- a/build/lib/tests/run_navigation_only.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from dotenv import load_dotenv -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.types.vector import Vector -import reactivex.operators as ops -import time -import threading -import asyncio -import atexit -import signal -import sys -import warnings -import logging -# logging.basicConfig(level=logging.DEBUG) - -# Filter out known WebRTC warnings that don't affect functionality -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message=".*RTCSctpTransport.*") - -# Set up logging to reduce asyncio noise -logging.getLogger("asyncio").setLevel(logging.ERROR) - -load_dotenv() -robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="normal", enable_perception=False) - - -# Add graceful shutdown handling to prevent WebRTC task destruction errors -def cleanup_robot(): - print("Cleaning up robot connection...") - try: - # Make cleanup non-blocking to avoid hangs - def quick_cleanup(): - try: - robot.liedown() - except: - pass - - # Run cleanup in a separate thread with timeout - cleanup_thread = threading.Thread(target=quick_cleanup) - cleanup_thread.daemon = True - cleanup_thread.start() - cleanup_thread.join(timeout=3.0) # Max 3 seconds for cleanup - - # Force stop the robot's WebRTC connection - try: - robot.stop() - except: - pass - - except Exception as e: - print(f"Error during cleanup: {e}") - # Continue anyway - - -atexit.register(cleanup_robot) - - -def signal_handler(signum, frame): - print("Received shutdown signal, cleaning up...") - try: - cleanup_robot() - except: - pass - # Force exit if cleanup hangs - os._exit(0) - - -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - -websocket_vis = WebsocketVis() -websocket_vis.start() -websocket_vis.connect(robot.global_planner.vis_stream()) - - -def msg_handler(msgtype, data): - if msgtype == "click": - print(f"Received click at position: {data['position']}") - - try: - print("Setting goal...") - - # Instead of disabling visualization, make it timeout-safe - original_vis = robot.global_planner.vis - - def safe_vis(name, drawable): - """Visualization wrapper that won't block on timeouts""" - try: - # Use a separate thread for visualization to avoid blocking - def vis_update(): - try: - original_vis(name, drawable) - except Exception as e: - print(f"Visualization update failed (non-critical): {e}") - - vis_thread = threading.Thread(target=vis_update) - vis_thread.daemon = True - vis_thread.start() - # Don't wait for completion - let it run asynchronously - except Exception as e: - print(f"Visualization setup failed (non-critical): {e}") - - robot.global_planner.vis = safe_vis - robot.global_planner.set_goal(Vector(data["position"])) - robot.global_planner.vis = original_vis - - print("Goal set successfully") - except Exception as e: - print(f"Error setting goal: {e}") - import traceback - - traceback.print_exc() - - -def threaded_msg_handler(msgtype, data): - print(f"Processing message: {msgtype}") - - # Create a dedicated event loop for goal setting to avoid conflicts - def run_with_dedicated_loop(): - try: - # Use asyncio.run which creates and manages its own event loop - # This won't conflict with the robot's or websocket's event loops - async def async_msg_handler(): - msg_handler(msgtype, data) - - asyncio.run(async_msg_handler()) - print("Goal setting completed successfully") - except Exception as e: - print(f"Error in goal setting thread: {e}") - import traceback - - traceback.print_exc() - - thread = threading.Thread(target=run_with_dedicated_loop) - thread.daemon = True - thread.start() - - -websocket_vis.msg_handler = threaded_msg_handler - -print("standing up") -robot.standup() -print("robot is up") - - -def newmap(msg): - return ["costmap", robot.map.costmap.smudge()] - - -websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) -websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) - -local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) - -# Add RobotWebInterface with video stream -streams = {"unitree_video": robot.get_video_stream(), "local_planner_viz": local_planner_viz_stream} -web_interface = RobotWebInterface(port=5555, **streams) -web_interface.run() - -try: - while True: - # robot.move_vel(Vector(0.1, 0.1, 0.1)) - time.sleep(0.01) - -except KeyboardInterrupt: - print("Stopping robot") - robot.liedown() -except Exception as e: - print(f"Unexpected error in main loop: {e}") - import traceback - - traceback.print_exc() -finally: - print("Cleaning up...") - cleanup_robot() diff --git a/build/lib/tests/simple_agent_test.py b/build/lib/tests/simple_agent_test.py deleted file mode 100644 index 2534eac31b..0000000000 --- a/build/lib/tests/simple_agent_test.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.agents.agent import OpenAIAgent -import os - -# Initialize robot -robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() -) - -# Initialize agent -agent = OpenAIAgent( - dev_name="UnitreeExecutionAgent", - input_video_stream=robot.get_ros_video_stream(), - skills=robot.get_skills(), - system_query="Wiggle when you see a person! Jump when you see a person waving!", -) - -try: - input("Press ESC to exit...") -except KeyboardInterrupt: - print("\nExiting...") diff --git a/build/lib/tests/test_agent.py b/build/lib/tests/test_agent.py deleted file mode 100644 index e2c8f89f8e..0000000000 --- a/build/lib/tests/test_agent.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import os -import tests.test_header - -# ----- - -from dotenv import load_dotenv - - -# Sanity check for dotenv -def test_dotenv(): - print("test_dotenv:") - load_dotenv() - openai_api_key = os.getenv("OPENAI_API_KEY") - print("\t\tOPENAI_API_KEY: ", openai_api_key) - - -# Sanity check for openai connection -def test_openai_connection(): - from openai import OpenAI - - client = OpenAI() - print("test_openai_connection:") - response = client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - }, - }, - ], - } - ], - max_tokens=300, - ) - print("\t\tOpenAI Response: ", response.choices[0]) - - -test_dotenv() -test_openai_connection() diff --git a/build/lib/tests/test_agent_alibaba.py b/build/lib/tests/test_agent_alibaba.py deleted file mode 100644 index 9519387b7b..0000000000 --- a/build/lib/tests/test_agent_alibaba.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -import os -from dimos.agents.agent import OpenAIAgent -from openai import OpenAI -from dimos.stream.video_provider import VideoProvider -from dimos.utils.threadpool import get_scheduler -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills - -# Initialize video stream -video_stream = VideoProvider( - dev_name="VideoProvider", - # video_source=f"{os.getcwd()}/assets/framecount.mp4", - video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", - pool_scheduler=get_scheduler(), -).capture_video_as_observable(realtime=False, fps=1) - -# Specify the OpenAI client for Alibaba -qwen_client = OpenAI( - base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", - api_key=os.getenv("ALIBABA_API_KEY"), -) - -# Initialize Unitree skills -myUnitreeSkills = MyUnitreeSkills() -myUnitreeSkills.initialize_skills() - -# Initialize agent -agent = OpenAIAgent( - dev_name="AlibabaExecutionAgent", - openai_client=qwen_client, - model_name="qwen2.5-vl-72b-instruct", - tokenizer=HuggingFaceTokenizer(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), - max_output_tokens_per_request=8192, - input_video_stream=video_stream, - # system_query="Tell me the number in the video. Find me the center of the number spotted, and print the coordinates to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", - system_query="Tell me about any objects seen. Print the coordinates for center of the objects seen to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", - skills=myUnitreeSkills, -) - -try: - input("Press ESC to exit...") -except KeyboardInterrupt: - print("\nExiting...") diff --git a/build/lib/tests/test_agent_ctransformers_gguf.py b/build/lib/tests/test_agent_ctransformers_gguf.py deleted file mode 100644 index 6cd3405239..0000000000 --- a/build/lib/tests/test_agent_ctransformers_gguf.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -from dimos.agents.agent_ctransformers_gguf import CTransformersGGUFAgent - -system_query = "You are a robot with the following functions. Move(), Reverse(), Left(), Right(), Stop(). Given the following user comands return the correct function." - -# Initialize agent -agent = CTransformersGGUFAgent( - dev_name="GGUF-Agent", - model_name="TheBloke/Llama-2-7B-GGUF", - model_file="llama-2-7b.Q4_K_M.gguf", - model_type="llama", - system_query=system_query, - gpu_layers=50, - max_input_tokens_per_request=250, - max_output_tokens_per_request=10, -) - -test_query = "User: Travel forward 10 meters" - -agent.run_observable_query(test_query).subscribe( - on_next=lambda response: print(f"One-off query response: {response}"), - on_error=lambda error: print(f"Error: {error}"), - on_completed=lambda: print("Query completed"), -) - -try: - input("Press ESC to exit...") -except KeyboardInterrupt: - print("\nExiting...") diff --git a/build/lib/tests/test_agent_huggingface_local.py b/build/lib/tests/test_agent_huggingface_local.py deleted file mode 100644 index 4c4536a197..0000000000 --- a/build/lib/tests/test_agent_huggingface_local.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.data_provider import QueryDataProvider -import tests.test_header - -import os -from dimos.stream.video_provider import VideoProvider -from dimos.utils.threadpool import get_scheduler -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer -from dimos.agents.agent_huggingface_local import HuggingFaceLocalAgent -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills - -# Initialize video stream -video_stream = VideoProvider( - dev_name="VideoProvider", - # video_source=f"{os.getcwd()}/assets/framecount.mp4", - video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", - pool_scheduler=get_scheduler(), -).capture_video_as_observable(realtime=False, fps=1) - -# Initialize Unitree skills -myUnitreeSkills = MyUnitreeSkills() -myUnitreeSkills.initialize_skills() - -# Initialize query stream -query_provider = QueryDataProvider() - -system_query = "You are a robot with the following functions. Move(), Reverse(), Left(), Right(), Stop(). Given the following user comands return ONLY the correct function." - -# Initialize agent -agent = HuggingFaceLocalAgent( - dev_name="HuggingFaceLLMAgent", - model_name="Qwen/Qwen2.5-3B", - agent_type="HF-LLM", - system_query=system_query, - input_query_stream=query_provider.data_stream, - process_all_inputs=False, - max_input_tokens_per_request=250, - max_output_tokens_per_request=20, - # output_dir=self.output_dir, - # skills=skills_instance, - # frame_processor=frame_processor, -) - -# Start the query stream. -# Queries will be pushed every 1 second, in a count from 100 to 5000. -# This will cause listening agents to consume the queries and respond -# to them via skill execution and provide 1-shot responses. -query_provider.start_query_stream( - query_template="{query}; User: travel forward by 10 meters", - frequency=10, - start_count=1, - end_count=10000, - step=1, -) - -try: - input("Press ESC to exit...") -except KeyboardInterrupt: - print("\nExiting...") diff --git a/build/lib/tests/test_agent_huggingface_local_jetson.py b/build/lib/tests/test_agent_huggingface_local_jetson.py deleted file mode 100644 index 6d29b3903f..0000000000 --- a/build/lib/tests/test_agent_huggingface_local_jetson.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.data_provider import QueryDataProvider -import tests.test_header - -import os -from dimos.stream.video_provider import VideoProvider -from dimos.utils.threadpool import get_scheduler -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer -from dimos.agents.agent_huggingface_local import HuggingFaceLocalAgent -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills - -# Initialize video stream -video_stream = VideoProvider( - dev_name="VideoProvider", - # video_source=f"{os.getcwd()}/assets/framecount.mp4", - video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", - pool_scheduler=get_scheduler(), -).capture_video_as_observable(realtime=False, fps=1) - -# Initialize Unitree skills -myUnitreeSkills = MyUnitreeSkills() -myUnitreeSkills.initialize_skills() - -# Initialize query stream -query_provider = QueryDataProvider() - -system_query = "You are a helpful assistant." - -# Initialize agent -agent = HuggingFaceLocalAgent( - dev_name="HuggingFaceLLMAgent", - model_name="Qwen/Qwen2.5-0.5B", - # model_name="HuggingFaceTB/SmolLM2-135M", - agent_type="HF-LLM", - system_query=system_query, - input_query_stream=query_provider.data_stream, - process_all_inputs=False, - max_input_tokens_per_request=250, - max_output_tokens_per_request=20, - # output_dir=self.output_dir, - # skills=skills_instance, - # frame_processor=frame_processor, -) - -# Start the query stream. -# Queries will be pushed every 1 second, in a count from 100 to 5000. -# This will cause listening agents to consume the queries and respond -# to them via skill execution and provide 1-shot responses. -query_provider.start_query_stream( - query_template="{query}; User: Hello how are you!", - frequency=30, - start_count=1, - end_count=10000, - step=1, -) - -try: - input("Press ESC to exit...") -except KeyboardInterrupt: - print("\nExiting...") diff --git a/build/lib/tests/test_agent_huggingface_remote.py b/build/lib/tests/test_agent_huggingface_remote.py deleted file mode 100644 index 7129523bf0..0000000000 --- a/build/lib/tests/test_agent_huggingface_remote.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.data_provider import QueryDataProvider -import tests.test_header - -import os -from dimos.stream.video_provider import VideoProvider -from dimos.utils.threadpool import get_scheduler -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer -from dimos.agents.agent_huggingface_remote import HuggingFaceRemoteAgent -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills - -# Initialize video stream -# video_stream = VideoProvider( -# dev_name="VideoProvider", -# # video_source=f"{os.getcwd()}/assets/framecount.mp4", -# video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", -# pool_scheduler=get_scheduler(), -# ).capture_video_as_observable(realtime=False, fps=1) - -# Initialize Unitree skills -# myUnitreeSkills = MyUnitreeSkills() -# myUnitreeSkills.initialize_skills() - -# Initialize query stream -query_provider = QueryDataProvider() - -# Initialize agent -agent = HuggingFaceRemoteAgent( - dev_name="HuggingFaceRemoteAgent", - model_name="meta-llama/Meta-Llama-3-8B-Instruct", - tokenizer=HuggingFaceTokenizer(model_name="meta-llama/Meta-Llama-3-8B-Instruct"), - max_output_tokens_per_request=8192, - input_query_stream=query_provider.data_stream, - # input_video_stream=video_stream, - system_query="You are a helpful assistant that can answer questions and help with tasks.", -) - -# Start the query stream. -# Queries will be pushed every 1 second, in a count from 100 to 5000. -query_provider.start_query_stream( - query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response.", - frequency=5, - start_count=1, - end_count=10000, - step=1, -) - -try: - input("Press ESC to exit...") -except KeyboardInterrupt: - print("\nExiting...") diff --git a/build/lib/tests/test_audio_agent.py b/build/lib/tests/test_audio_agent.py deleted file mode 100644 index 6caf24b9eb..0000000000 --- a/build/lib/tests/test_audio_agent.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.audio.utils import keepalive -from dimos.stream.audio.pipelines import tts, stt -from dimos.utils.threadpool import get_scheduler -from dimos.agents.agent import OpenAIAgent - - -def main(): - stt_node = stt() - - agent = OpenAIAgent( - dev_name="UnitreeExecutionAgent", - input_query_stream=stt_node.emit_text(), - system_query="You are a helpful robot named daneel that does my bidding", - pool_scheduler=get_scheduler(), - ) - - tts_node = tts() - tts_node.consume_text(agent.get_response_observable()) - - # Keep the main thread alive - keepalive() - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_audio_robot_agent.py b/build/lib/tests/test_audio_robot_agent.py deleted file mode 100644 index 411e4a56c1..0000000000 --- a/build/lib/tests/test_audio_robot_agent.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.utils.threadpool import get_scheduler -import os -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.agents.agent import OpenAIAgent -from dimos.stream.audio.pipelines import tts, stt -from dimos.stream.audio.utils import keepalive - - -def main(): - stt_node = stt() - tts_node = tts() - - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills(), - ) - - # Initialize agent with main thread pool scheduler - agent = OpenAIAgent( - dev_name="UnitreeExecutionAgent", - input_query_stream=stt_node.emit_text(), - system_query="You are a helpful robot named daneel that does my bidding", - pool_scheduler=get_scheduler(), - skills=robot.get_skills(), - ) - - tts_node.consume_text(agent.get_response_observable()) - - # Keep the main thread alive - keepalive() - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_cerebras_unitree_ros.py b/build/lib/tests/test_cerebras_unitree_ros.py deleted file mode 100644 index cbb7c130db..0000000000 --- a/build/lib/tests/test_cerebras_unitree_ros.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import os -from dimos.robot.robot import MockRobot -import tests.test_header - -import time -from dotenv import load_dotenv -from dimos.agents.cerebras_agent import CerebrasAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal -from dimos.skills.visual_navigation_skills import FollowHuman -import reactivex as rx -import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -from dimos.web.websocket_vis.server import WebsocketVis -import threading -from dimos.types.vector import Vector -from dimos.skills.speak import Speak - -# Load API key from environment -load_dotenv() - -# robot = MockRobot() -robot_skills = MyUnitreeSkills() - -robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - ros_control=UnitreeROSControl(), - skills=robot_skills, - mock_connection=False, - new_memory=True, -) - -# Create a subject for agent responses -agent_response_subject = rx.subject.Subject() -agent_response_stream = agent_response_subject.pipe(ops.share()) - -streams = { - "unitree_video": robot.get_ros_video_stream(), -} -text_streams = { - "agent_responses": agent_response_stream, -} - -web_interface = RobotWebInterface( - port=5555, - text_streams=text_streams, - **streams, -) - -stt_node = stt() - -# Create a CerebrasAgent instance -agent = CerebrasAgent( - dev_name="test_cerebras_agent", - input_query_stream=stt_node.emit_text(), - # input_query_stream=web_interface.query_stream, - skills=robot_skills, - system_query="""You are an agent controlling a virtual robot. When given a query, respond by using the appropriate tool calls if needed to execute commands on the robot. - -IMPORTANT INSTRUCTIONS: -1. Each tool call must include the exact function name and appropriate parameters -2. If a function needs parameters like 'distance' or 'angle', be sure to include them -3. If you're unsure which tool to use, choose the most appropriate one based on the user's query -4. Parse the user's instructions carefully to determine correct parameter values - -When you need to call a skill or tool, ALWAYS respond ONLY with a JSON object in this exact format: {"name": "SkillName", "arguments": {"arg1": "value1", "arg2": "value2"}} - -Example: If the user asks to spin right by 90 degrees, output ONLY the following: {"name": "SpinRight", "arguments": {"degrees": 90}}""", - model_name="llama-4-scout-17b-16e-instruct", -) - -tts_node = tts() -tts_node.consume_text(agent.get_response_observable()) - -robot_skills.add(ObserveStream) -robot_skills.add(KillSkill) -robot_skills.add(NavigateWithText) -robot_skills.add(FollowHuman) -robot_skills.add(GetPose) -robot_skills.add(Speak) -robot_skills.add(NavigateToGoal) -robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) -robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) -robot_skills.create_instance("NavigateWithText", robot=robot) -robot_skills.create_instance("FollowHuman", robot=robot) -robot_skills.create_instance("GetPose", robot=robot) -robot_skills.create_instance("NavigateToGoal", robot=robot) - - -robot_skills.create_instance("Speak", tts_node=tts_node) - -# Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - -# print(f"Registered skills: {', '.join([skill.__name__ for skill in robot_skills.skills])}") -print("Cerebras agent demo initialized. You can now interact with the agent via the web interface.") - -web_interface.run() diff --git a/build/lib/tests/test_claude_agent_query.py b/build/lib/tests/test_claude_agent_query.py deleted file mode 100644 index aabd85bc12..0000000000 --- a/build/lib/tests/test_claude_agent_query.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -from dotenv import load_dotenv -from dimos.agents.claude_agent import ClaudeAgent - -# Load API key from environment -load_dotenv() - -# Create a ClaudeAgent instance -agent = ClaudeAgent(dev_name="test_agent", query="What is the capital of France?") - -# Use the stream_query method to get a response -response = agent.run_observable_query("What is the capital of France?").run() - -print(f"Response from Claude Agent: {response}") diff --git a/build/lib/tests/test_claude_agent_skills_query.py b/build/lib/tests/test_claude_agent_skills_query.py deleted file mode 100644 index 1aaeb795f1..0000000000 --- a/build/lib/tests/test_claude_agent_skills_query.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -import time -from dotenv import load_dotenv -from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import Navigate, BuildSemanticMap, GetPose, NavigateToGoal -from dimos.skills.visual_navigation_skills import NavigateToObject, FollowHuman -import reactivex as rx -import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -from dimos.web.websocket_vis.server import WebsocketVis -import threading -from dimos.types.vector import Vector -from dimos.skills.speak import Speak - -# Load API key from environment -load_dotenv() - -robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills(), - mock_connection=False, -) - -# Create a subject for agent responses -agent_response_subject = rx.subject.Subject() -agent_response_stream = agent_response_subject.pipe(ops.share()) -local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) - -streams = { - "unitree_video": robot.get_ros_video_stream(), - "local_planner_viz": local_planner_viz_stream, -} -text_streams = { - "agent_responses": agent_response_stream, -} - -web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) - -stt_node = stt() - -# Create a ClaudeAgent instance -agent = ClaudeAgent( - dev_name="test_agent", - input_query_stream=stt_node.emit_text(), - # input_query_stream=web_interface.query_stream, - skills=robot.get_skills(), - system_query="""You are an agent controlling a virtual robot. When given a query, respond by using the appropriate tool calls if needed to execute commands on the robot. - -IMPORTANT INSTRUCTIONS: -1. Each tool call must include the exact function name and appropriate parameters -2. If a function needs parameters like 'distance' or 'angle', be sure to include them -3. If you're unsure which tool to use, choose the most appropriate one based on the user's query -4. Parse the user's instructions carefully to determine correct parameter values - -Example: If the user asks to move forward 1 meter, call the Move function with distance=1""", - model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=2000, -) - -tts_node = tts() -# tts_node.consume_text(agent.get_response_observable()) - -robot_skills = robot.get_skills() -robot_skills.add(ObserveStream) -robot_skills.add(KillSkill) -robot_skills.add(Navigate) -robot_skills.add(BuildSemanticMap) -robot_skills.add(NavigateToObject) -robot_skills.add(FollowHuman) -robot_skills.add(GetPose) -robot_skills.add(Speak) -robot_skills.add(NavigateToGoal) -robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) -robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) -robot_skills.create_instance("Navigate", robot=robot) -robot_skills.create_instance("BuildSemanticMap", robot=robot) -robot_skills.create_instance("NavigateToObject", robot=robot) -robot_skills.create_instance("FollowHuman", robot=robot) -robot_skills.create_instance("GetPose", robot=robot) -robot_skills.create_instance("NavigateToGoal", robot=robot) -robot_skills.create_instance("Speak", tts_node=tts_node) - -# Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - -print("ObserveStream and Kill skills registered and ready for use") -print("Created memory.txt file") - -websocket_vis = WebsocketVis() -websocket_vis.start() -websocket_vis.connect(robot.global_planner.vis_stream()) - - -def msg_handler(msgtype, data): - if msgtype == "click": - target = Vector(data["position"]) - try: - robot.global_planner.set_goal(target) - except Exception as e: - print(f"Error setting goal: {e}") - return - - -def threaded_msg_handler(msgtype, data): - thread = threading.Thread(target=msg_handler, args=(msgtype, data)) - thread.daemon = True - thread.start() - - -websocket_vis.msg_handler = threaded_msg_handler - -web_interface.run() diff --git a/build/lib/tests/test_command_pose_unitree.py b/build/lib/tests/test_command_pose_unitree.py deleted file mode 100644 index 22cf0e82ed..0000000000 --- a/build/lib/tests/test_command_pose_unitree.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -# Add the parent directory to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -import os -import time -import math - -# Initialize robot -robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() -) - - -# Helper function to send pose commands continuously for a duration -def send_pose_for_duration(roll, pitch, yaw, duration, hz=10): - """Send the same pose command repeatedly at specified frequency for the given duration""" - start_time = time.time() - while time.time() - start_time < duration: - robot.pose_command(roll=roll, pitch=pitch, yaw=yaw) - time.sleep(1.0 / hz) # Sleep to achieve the desired frequency - - -# Test pose commands - -# First, make sure the robot is in a stable position -print("Setting default pose...") -send_pose_for_duration(0.0, 0.0, 0.0, 1) - -# Test roll angle (lean left/right) -print("Testing roll angle - lean right...") -send_pose_for_duration(0.5, 0.0, 0.0, 1.5) # Lean right - -print("Testing roll angle - lean left...") -send_pose_for_duration(-0.5, 0.0, 0.0, 1.5) # Lean left - -# Test pitch angle (lean forward/backward) -print("Testing pitch angle - lean forward...") -send_pose_for_duration(0.0, 0.5, 0.0, 1.5) # Lean forward - -print("Testing pitch angle - lean backward...") -send_pose_for_duration(0.0, -0.5, 0.0, 1.5) # Lean backward - -# Test yaw angle (rotate body without moving feet) -print("Testing yaw angle - rotate clockwise...") -send_pose_for_duration(0.0, 0.0, 0.5, 1.5) # Rotate body clockwise - -print("Testing yaw angle - rotate counterclockwise...") -send_pose_for_duration(0.0, 0.0, -0.5, 1.5) # Rotate body counterclockwise - -# Reset to default pose -print("Resetting to default pose...") -send_pose_for_duration(0.0, 0.0, 0.0, 2) - -print("Pose command test completed") - -# Keep the program running (optional) -print("Press Ctrl+C to exit") -try: - while True: - time.sleep(1) -except KeyboardInterrupt: - print("Test terminated by user") diff --git a/build/lib/tests/test_header.py b/build/lib/tests/test_header.py deleted file mode 100644 index 48ea6dd509..0000000000 --- a/build/lib/tests/test_header.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test utilities for identifying caller information and path setup. - -This module provides functionality to determine which file called the current -script and sets up the Python path to include the parent directory, allowing -tests to import from the main application. -""" - -import sys -import os -import inspect - -# Add the parent directory of 'tests' to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - - -def get_caller_info(): - """Identify the filename of the caller in the stack. - - Examines the call stack to find the first non-internal file that called - this module. Skips the current file and Python internal files. - - Returns: - str: The basename of the caller's filename, or "unknown" if not found. - """ - current_file = os.path.abspath(__file__) - - # Look through the call stack to find the first file that's not this one - for frame in inspect.stack()[1:]: - filename = os.path.abspath(frame.filename) - # Skip this file and Python internals - if filename != current_file and " 0: - best_score = max(grasp.get("score", 0.0) for grasp in grasps) - print(f" Best grasp score: {best_score:.3f}") - last_grasp_count = current_count - last_update_time = current_time - else: - # Show periodic "still waiting" message - if current_time - last_update_time > 10.0: - print(f" Still waiting for grasps... ({time.strftime('%H:%M:%S')})") - last_update_time = current_time - - time.sleep(1.0) # Check every second - - except Exception as e: - print(f" Error in grasp monitor: {e}") - time.sleep(2.0) - - -def main(): - """Test point cloud filtering with grasp generation using ManipulationPipeline.""" - print(" Testing point cloud filtering + grasp generation with ManipulationPipeline...") - - # Configuration - min_confidence = 0.6 - web_port = 5555 - grasp_server_url = "ws://18.224.39.74:8000/ws/grasp" - - try: - # Initialize ZED camera stream - zed_stream = ZEDCameraStream(resolution=sl.RESOLUTION.HD1080, fps=10) - - # Get camera intrinsics - camera_intrinsics_dict = zed_stream.get_camera_info() - camera_intrinsics = [ - camera_intrinsics_dict["fx"], - camera_intrinsics_dict["fy"], - camera_intrinsics_dict["cx"], - camera_intrinsics_dict["cy"], - ] - - # Create the concurrent manipulation pipeline WITH grasp generation - pipeline = ManipulationPipeline( - camera_intrinsics=camera_intrinsics, - min_confidence=min_confidence, - max_objects=10, - grasp_server_url=grasp_server_url, - enable_grasp_generation=True, # Enable grasp generation - ) - - # Create ZED stream - zed_frame_stream = zed_stream.create_stream().pipe(ops.share()) - - # Create concurrent processing streams - streams = pipeline.create_streams(zed_frame_stream) - detection_viz_stream = streams["detection_viz"] - pointcloud_viz_stream = streams["pointcloud_viz"] - grasps_stream = streams.get("grasps") # Get grasp stream if available - grasp_overlay_stream = streams.get("grasp_overlay") # Get grasp overlay stream if available - - except ImportError: - print("Error: ZED SDK not installed. Please install pyzed package.") - sys.exit(1) - except RuntimeError as e: - print(f"Error: Failed to open ZED camera: {e}") - sys.exit(1) - - try: - # Set up web interface with concurrent visualization streams - print("Initializing web interface...") - web_interface = RobotWebInterface( - port=web_port, - object_detection=detection_viz_stream, - pointcloud_stream=pointcloud_viz_stream, - grasp_overlay_stream=grasp_overlay_stream, - ) - - # Start grasp monitoring in background thread - grasp_monitor_thread = threading.Thread( - target=monitor_grasps, args=(pipeline,), daemon=True - ) - grasp_monitor_thread.start() - - print(f"\n Point Cloud + Grasp Generation Test Running:") - print(f" Web Interface: http://localhost:{web_port}") - print(f" Object Detection View: RGB with bounding boxes") - print(f" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") - print(f" Confidence threshold: {min_confidence}") - print(f" Grasp server: {grasp_server_url}") - print(f" Available streams: {list(streams.keys())}") - print("\nPress Ctrl+C to stop the test\n") - - # Start web server (blocking call) - web_interface.run() - - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Error during test: {e}") - finally: - print("Cleaning up resources...") - if "zed_stream" in locals(): - zed_stream.cleanup() - if "pipeline" in locals(): - pipeline.cleanup() - print("Test completed") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_manipulation_pipeline_single_frame.py b/build/lib/tests/test_manipulation_pipeline_single_frame.py deleted file mode 100644 index fa7187f948..0000000000 --- a/build/lib/tests/test_manipulation_pipeline_single_frame.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test manipulation processor with direct visualization and grasp data output.""" - -import os -import sys -import cv2 -import numpy as np -import time -import argparse -import matplotlib - -# Try to use TkAgg backend for live display, fallback to Agg if not available -try: - matplotlib.use("TkAgg") -except: - try: - matplotlib.use("Qt5Agg") - except: - matplotlib.use("Agg") # Fallback to non-interactive -import matplotlib.pyplot as plt -import open3d as o3d -from typing import Dict, List - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid -from dimos.manipulation.manip_aio_processer import ManipulationProcessor -from dimos.perception.pointcloud.utils import ( - load_camera_matrix_from_yaml, - visualize_pcd, - combine_object_pointclouds, -) -from dimos.utils.logging_config import setup_logger - -from dimos.perception.grasp_generation.utils import visualize_grasps_3d, create_grasp_overlay - -logger = setup_logger("test_pipeline_viz") - - -def load_first_frame(data_dir: str): - """Load first RGB-D frame and camera intrinsics.""" - # Load images - color_img = cv2.imread(os.path.join(data_dir, "color", "00000.png")) - color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) - - depth_img = cv2.imread(os.path.join(data_dir, "depth", "00000.png"), cv2.IMREAD_ANYDEPTH) - if depth_img.dtype == np.uint16: - depth_img = depth_img.astype(np.float32) / 1000.0 - # Load intrinsics - camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, "color_camera_info.yaml")) - intrinsics = [ - camera_matrix[0, 0], - camera_matrix[1, 1], - camera_matrix[0, 2], - camera_matrix[1, 2], - ] - - return color_img, depth_img, intrinsics - - -def create_point_cloud(color_img, depth_img, intrinsics): - """Create Open3D point cloud.""" - fx, fy, cx, cy = intrinsics - height, width = depth_img.shape - - o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) - color_o3d = o3d.geometry.Image(color_img) - depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) - - rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( - color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False - ) - - return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) - - -def run_processor(color_img, depth_img, intrinsics, grasp_server_url=None): - """Run processor and collect results.""" - processor_kwargs = { - "camera_intrinsics": intrinsics, - "enable_grasp_generation": True, - "enable_segmentation": True, - } - - if grasp_server_url: - processor_kwargs["grasp_server_url"] = grasp_server_url - - processor = ManipulationProcessor(**processor_kwargs) - - # Process frame without grasp generation - results = processor.process_frame(color_img, depth_img, generate_grasps=False) - - # Run grasp generation separately - grasps = processor.run_grasp_generation(results["all_objects"], results["full_pointcloud"]) - results["grasps"] = grasps - results["grasp_overlay"] = create_grasp_overlay(color_img, grasps, intrinsics) - - processor.cleanup() - return results - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--data-dir", default="assets/rgbd_data") - parser.add_argument("--wait-time", type=float, default=5.0) - parser.add_argument( - "--grasp-server-url", - default="ws://18.224.39.74:8000/ws/grasp", - help="WebSocket URL for AnyGrasp server", - ) - args = parser.parse_args() - - # Load data - color_img, depth_img, intrinsics = load_first_frame(args.data_dir) - logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") - - # Run processor - results = run_processor(color_img, depth_img, intrinsics, args.grasp_server_url) - - # Print results summary - print(f"Processing time: {results.get('processing_time', 0):.3f}s") - print(f"Detection objects: {len(results.get('detected_objects', []))}") - print(f"All objects processed: {len(results.get('all_objects', []))}") - - # Print grasp summary - grasp_data = results["grasps"] - total_grasps = len(grasp_data) if isinstance(grasp_data, list) else 0 - best_score = max(grasp["score"] for grasp in grasp_data) if grasp_data else 0 - - print(f"AnyGrasp grasps: {total_grasps} total (best score: {best_score:.3f})") - - # Create visualizations - plot_configs = [] - if results["detection_viz"] is not None: - plot_configs.append(("detection_viz", "Object Detection")) - if results["segmentation_viz"] is not None: - plot_configs.append(("segmentation_viz", "Semantic Segmentation")) - if results["pointcloud_viz"] is not None: - plot_configs.append(("pointcloud_viz", "All Objects Point Cloud")) - if results["detected_pointcloud_viz"] is not None: - plot_configs.append(("detected_pointcloud_viz", "Detection Objects Point Cloud")) - if results["misc_pointcloud_viz"] is not None: - plot_configs.append(("misc_pointcloud_viz", "Misc/Background Points")) - if results["grasp_overlay"] is not None: - plot_configs.append(("grasp_overlay", "Grasp Overlay")) - - # Create subplot layout - num_plots = len(plot_configs) - if num_plots <= 3: - fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5)) - else: - rows = 2 - cols = (num_plots + 1) // 2 - fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) - - if num_plots == 1: - axes = [axes] - elif num_plots > 2: - axes = axes.flatten() - - # Plot each result - for i, (key, title) in enumerate(plot_configs): - axes[i].imshow(results[key]) - axes[i].set_title(title) - axes[i].axis("off") - - # Hide unused subplots - if num_plots > 3: - for i in range(num_plots, len(axes)): - axes[i].axis("off") - - plt.tight_layout() - plt.savefig("manipulation_results.png", dpi=150, bbox_inches="tight") - plt.show(block=True) - plt.close() - - point_clouds = [obj["point_cloud"] for obj in results["all_objects"]] - colors = [obj["color"] for obj in results["all_objects"]] - combined_pcd = combine_object_pointclouds(point_clouds, colors) - - # 3D Grasp visualization - if grasp_data: - # Convert grasp format to visualization format for 3D display - viz_grasps = [] - for grasp in grasp_data: - translation = grasp.get("translation", [0, 0, 0]) - rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3).tolist())) - score = grasp.get("score", 0.0) - width = grasp.get("width", 0.08) - - viz_grasp = { - "translation": translation, - "rotation_matrix": rotation_matrix, - "width": width, - "score": score, - } - viz_grasps.append(viz_grasp) - - # Use unified 3D visualization - visualize_grasps_3d(combined_pcd, viz_grasps) - - # Visualize full point cloud - visualize_pcd( - results["full_pointcloud"], - window_name="Full Scene Point Cloud", - point_size=2.0, - show_coordinate_frame=True, - ) - - # Visualize all objects point cloud - visualize_pcd( - combined_pcd, - window_name="All Objects Point Cloud", - point_size=3.0, - show_coordinate_frame=True, - ) - - # Visualize misc clusters - visualize_clustered_point_clouds( - results["misc_clusters"], - window_name="Misc/Background Clusters (DBSCAN)", - point_size=3.0, - show_coordinate_frame=True, - ) - - # Visualize voxel grid - visualize_voxel_grid( - results["misc_voxel_grid"], - window_name="Misc/Background Voxel Grid", - show_coordinate_frame=True, - ) - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_manipulation_pipeline_single_frame_lcm.py b/build/lib/tests/test_manipulation_pipeline_single_frame_lcm.py deleted file mode 100644 index 62898816fa..0000000000 --- a/build/lib/tests/test_manipulation_pipeline_single_frame_lcm.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test manipulation processor with LCM topic subscription.""" - -import os -import sys -import cv2 -import numpy as np -import time -import argparse -import threading -import pickle -import matplotlib -import json -import copy - -# Try to use TkAgg backend for live display, fallback to Agg if not available -try: - matplotlib.use("TkAgg") -except: - try: - matplotlib.use("Qt5Agg") - except: - matplotlib.use("Agg") # Fallback to non-interactive -import matplotlib.pyplot as plt -import open3d as o3d -from typing import Dict, List, Optional - -# LCM imports -import lcm -from lcm_msgs.sensor_msgs import Image as LCMImage -from lcm_msgs.sensor_msgs import CameraInfo as LCMCameraInfo - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid -from dimos.manipulation.manip_aio_processer import ManipulationProcessor -from dimos.perception.grasp_generation.utils import visualize_grasps_3d -from dimos.perception.pointcloud.utils import visualize_pcd -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("test_pipeline_lcm") - - -class LCMDataCollector: - """Collects one message from each required LCM topic.""" - - def __init__(self, lcm_url: str = "udpm://239.255.76.67:7667?ttl=1"): - self.lcm = lcm.LCM(lcm_url) - - # Data storage - self.rgb_data: Optional[np.ndarray] = None - self.depth_data: Optional[np.ndarray] = None - self.camera_intrinsics: Optional[List[float]] = None - - # Synchronization - self.data_lock = threading.Lock() - self.data_ready_event = threading.Event() - - # Flags to track received messages - self.rgb_received = False - self.depth_received = False - self.camera_info_received = False - - # Subscribe to topics - self.lcm.subscribe("head_cam_rgb#sensor_msgs.Image", self._handle_rgb_message) - self.lcm.subscribe("head_cam_depth#sensor_msgs.Image", self._handle_depth_message) - self.lcm.subscribe("head_cam_info#sensor_msgs.CameraInfo", self._handle_camera_info_message) - - logger.info("LCM Data Collector initialized") - logger.info("Subscribed to topics:") - logger.info(" - head_cam_rgb#sensor_msgs.Image") - logger.info(" - head_cam_depth#sensor_msgs.Image") - logger.info(" - head_cam_info#sensor_msgs.CameraInfo") - - def _handle_rgb_message(self, channel: str, data: bytes): - """Handle RGB image message.""" - if self.rgb_received: - return # Already got one, ignore subsequent messages - - try: - msg = LCMImage.decode(data) - - # Convert message data to numpy array - if msg.encoding == "rgb8": - # RGB8 format: 3 bytes per pixel - rgb_array = np.frombuffer(msg.data[: msg.data_length], dtype=np.uint8) - rgb_image = rgb_array.reshape((msg.height, msg.width, 3)) - - with self.data_lock: - self.rgb_data = rgb_image - self.rgb_received = True - logger.info( - f"RGB message received: {msg.width}x{msg.height}, encoding: {msg.encoding}" - ) - self._check_all_data_received() - - else: - logger.warning(f"Unsupported RGB encoding: {msg.encoding}") - - except Exception as e: - logger.error(f"Error processing RGB message: {e}") - - def _handle_depth_message(self, channel: str, data: bytes): - """Handle depth image message.""" - if self.depth_received: - return # Already got one, ignore subsequent messages - - try: - msg = LCMImage.decode(data) - - # Convert message data to numpy array - if msg.encoding == "32FC1": - # 32FC1 format: 4 bytes (float32) per pixel - depth_array = np.frombuffer(msg.data[: msg.data_length], dtype=np.float32) - depth_image = depth_array.reshape((msg.height, msg.width)) - - with self.data_lock: - self.depth_data = depth_image - self.depth_received = True - logger.info( - f"Depth message received: {msg.width}x{msg.height}, encoding: {msg.encoding}" - ) - logger.info( - f"Depth range: {depth_image.min():.3f} - {depth_image.max():.3f} meters" - ) - self._check_all_data_received() - - else: - logger.warning(f"Unsupported depth encoding: {msg.encoding}") - - except Exception as e: - logger.error(f"Error processing depth message: {e}") - - def _handle_camera_info_message(self, channel: str, data: bytes): - """Handle camera info message.""" - if self.camera_info_received: - return # Already got one, ignore subsequent messages - - try: - msg = LCMCameraInfo.decode(data) - - # Extract intrinsics from K matrix: [fx, 0, cx, 0, fy, cy, 0, 0, 1] - K = msg.K - fx = K[0] # K[0,0] - fy = K[4] # K[1,1] - cx = K[2] # K[0,2] - cy = K[5] # K[1,2] - - intrinsics = [fx, fy, cx, cy] - - with self.data_lock: - self.camera_intrinsics = intrinsics - self.camera_info_received = True - logger.info(f"Camera info received: {msg.width}x{msg.height}") - logger.info(f"Intrinsics: fx={fx:.1f}, fy={fy:.1f}, cx={cx:.1f}, cy={cy:.1f}") - self._check_all_data_received() - - except Exception as e: - logger.error(f"Error processing camera info message: {e}") - - def _check_all_data_received(self): - """Check if all required data has been received.""" - if self.rgb_received and self.depth_received and self.camera_info_received: - logger.info("✅ All required data received!") - self.data_ready_event.set() - - def wait_for_data(self, timeout: float = 30.0) -> bool: - """Wait for all data to be received.""" - logger.info("Waiting for RGB, depth, and camera info messages...") - - # Start LCM handling in a separate thread - lcm_thread = threading.Thread(target=self._lcm_handle_loop, daemon=True) - lcm_thread.start() - - # Wait for data with timeout - return self.data_ready_event.wait(timeout) - - def _lcm_handle_loop(self): - """LCM message handling loop.""" - try: - while not self.data_ready_event.is_set(): - self.lcm.handle_timeout(100) # 100ms timeout - except Exception as e: - logger.error(f"Error in LCM handling loop: {e}") - - def get_data(self): - """Get the collected data.""" - with self.data_lock: - return self.rgb_data, self.depth_data, self.camera_intrinsics - - -def create_point_cloud(color_img, depth_img, intrinsics): - """Create Open3D point cloud.""" - fx, fy, cx, cy = intrinsics - height, width = depth_img.shape - - o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) - color_o3d = o3d.geometry.Image(color_img) - depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) - - rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( - color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False - ) - - return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) - - -def run_processor(color_img, depth_img, intrinsics): - """Run processor and collect results.""" - # Create processor - processor = ManipulationProcessor( - camera_intrinsics=intrinsics, - grasp_server_url="ws://18.224.39.74:8000/ws/grasp", - enable_grasp_generation=False, - enable_segmentation=True, - ) - - # Process single frame directly - results = processor.process_frame(color_img, depth_img) - - # Debug: print available results - print(f"Available results: {list(results.keys())}") - - processor.cleanup() - - return results - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lcm-url", default="udpm://239.255.76.67:7667?ttl=1", help="LCM URL for subscription" - ) - parser.add_argument( - "--timeout", type=float, default=30.0, help="Timeout in seconds to wait for messages" - ) - parser.add_argument( - "--save-images", action="store_true", help="Save received RGB and depth images to files" - ) - args = parser.parse_args() - - # Create data collector - collector = LCMDataCollector(args.lcm_url) - - # Wait for data - if not collector.wait_for_data(args.timeout): - logger.error(f"Timeout waiting for data after {args.timeout} seconds") - logger.error("Make sure Unity is running and publishing to the LCM topics") - return - - # Get the collected data - color_img, depth_img, intrinsics = collector.get_data() - - logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") - logger.info(f"Intrinsics: {intrinsics}") - - # Save images if requested - if args.save_images: - try: - cv2.imwrite("received_rgb.png", cv2.cvtColor(color_img, cv2.COLOR_RGB2BGR)) - # Save depth as 16-bit for visualization - depth_viz = (np.clip(depth_img * 1000, 0, 65535)).astype(np.uint16) - cv2.imwrite("received_depth.png", depth_viz) - logger.info("Saved received_rgb.png and received_depth.png") - except Exception as e: - logger.warning(f"Failed to save images: {e}") - - # Run processor - results = run_processor(color_img, depth_img, intrinsics) - - # Debug: Print what we received - print(f"\n✅ Processor Results:") - print(f" Available results: {list(results.keys())}") - print(f" Processing time: {results.get('processing_time', 0):.3f}s") - - # Show timing breakdown if available - if "timing_breakdown" in results: - breakdown = results["timing_breakdown"] - print(f" Timing breakdown:") - print(f" - Detection: {breakdown.get('detection', 0):.3f}s") - print(f" - Segmentation: {breakdown.get('segmentation', 0):.3f}s") - print(f" - Point cloud: {breakdown.get('pointcloud', 0):.3f}s") - print(f" - Misc extraction: {breakdown.get('misc_extraction', 0):.3f}s") - - # Print object information - detected_count = len(results.get("detected_objects", [])) - all_count = len(results.get("all_objects", [])) - - print(f" Detection objects: {detected_count}") - print(f" All objects processed: {all_count}") - - # Print misc clusters information - if "misc_clusters" in results and results["misc_clusters"]: - cluster_count = len(results["misc_clusters"]) - total_misc_points = sum( - len(np.asarray(cluster.points)) for cluster in results["misc_clusters"] - ) - print(f" Misc clusters: {cluster_count} clusters with {total_misc_points} total points") - else: - print(f" Misc clusters: None") - - # Print grasp summary - if "grasps" in results and results["grasps"]: - total_grasps = 0 - best_score = 0 - for grasp in results["grasps"]: - score = grasp.get("score", 0) - if score > best_score: - best_score = score - total_grasps += 1 - print(f" Grasps generated: {total_grasps} (best score: {best_score:.3f})") - else: - print(" Grasps: None generated") - - # Save results to pickle file - pickle_path = "manipulation_results.pkl" - print(f"\nSaving results to pickle file: {pickle_path}") - - def serialize_point_cloud(pcd): - """Convert Open3D PointCloud to serializable format.""" - if pcd is None: - return None - data = { - "points": np.asarray(pcd.points).tolist() if hasattr(pcd, "points") else [], - "colors": np.asarray(pcd.colors).tolist() - if hasattr(pcd, "colors") and pcd.colors - else [], - } - return data - - def serialize_voxel_grid(voxel_grid): - """Convert Open3D VoxelGrid to serializable format.""" - if voxel_grid is None: - return None - - # Extract voxel data - voxels = voxel_grid.get_voxels() - data = { - "voxel_size": voxel_grid.voxel_size, - "origin": np.asarray(voxel_grid.origin).tolist(), - "voxels": [ - ( - v.grid_index[0], - v.grid_index[1], - v.grid_index[2], - v.color[0], - v.color[1], - v.color[2], - ) - for v in voxels - ], - } - return data - - # Create a copy of results with non-picklable objects converted - pickle_data = { - "color_img": color_img, - "depth_img": depth_img, - "intrinsics": intrinsics, - "results": {}, - } - - # Convert and store all results, properly handling Open3D objects - for key, value in results.items(): - if key.endswith("_viz") or key in [ - "processing_time", - "timing_breakdown", - "detection2d_objects", - "segmentation2d_objects", - ]: - # These are already serializable - pickle_data["results"][key] = value - elif key == "full_pointcloud": - # Serialize PointCloud object - pickle_data["results"][key] = serialize_point_cloud(value) - print(f"Serialized {key}") - elif key == "misc_voxel_grid": - # Serialize VoxelGrid object - pickle_data["results"][key] = serialize_voxel_grid(value) - print(f"Serialized {key}") - elif key == "misc_clusters": - # List of PointCloud objects - if value: - serialized_clusters = [serialize_point_cloud(cluster) for cluster in value] - pickle_data["results"][key] = serialized_clusters - print(f"Serialized {key} ({len(serialized_clusters)} clusters)") - elif key == "detected_objects" or key == "all_objects": - # Objects with PointCloud attributes - serialized_objects = [] - for obj in value: - obj_dict = {k: v for k, v in obj.items() if k != "point_cloud"} - if "point_cloud" in obj: - obj_dict["point_cloud"] = serialize_point_cloud(obj.get("point_cloud")) - serialized_objects.append(obj_dict) - pickle_data["results"][key] = serialized_objects - print(f"Serialized {key} ({len(serialized_objects)} objects)") - else: - try: - # Try to pickle as is - pickle_data["results"][key] = value - print(f"Preserved {key} as is") - except (TypeError, ValueError): - print(f"Warning: Could not serialize {key}, skipping") - - with open(pickle_path, "wb") as f: - pickle.dump(pickle_data, f) - - print(f"Results saved successfully with all 3D data serialized!") - print(f"Pickled data keys: {list(pickle_data['results'].keys())}") - - # Visualization code has been moved to visualization_script.py - # The results have been pickled and can be loaded from there - print("\nVisualization code has been moved to visualization_script.py") - print("Run 'python visualization_script.py' to visualize the results") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_move_vel_unitree.py b/build/lib/tests/test_move_vel_unitree.py deleted file mode 100644 index fe4d09a8e1..0000000000 --- a/build/lib/tests/test_move_vel_unitree.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -import os -import time - -# Initialize robot -robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() -) - -# Move the robot forward -robot.move_vel(x=0.5, y=0, yaw=0, duration=5) - -while True: - time.sleep(1) diff --git a/build/lib/tests/test_navigate_to_object_robot.py b/build/lib/tests/test_navigate_to_object_robot.py deleted file mode 100644 index eb2767d6ca..0000000000 --- a/build/lib/tests/test_navigate_to_object_robot.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -import sys -import argparse -import threading -from reactivex import Subject, operators as RxOps - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger -from dimos.skills.navigation import Navigate -import tests.test_header - - -def parse_args(): - parser = argparse.ArgumentParser(description="Navigate to an object using Qwen vision.") - parser.add_argument( - "--object", - type=str, - default="chair", - help="Name of the object to navigate to (default: chair)", - ) - parser.add_argument( - "--distance", - type=float, - default=1.0, - help="Desired distance to maintain from object in meters (default: 0.8)", - ) - parser.add_argument( - "--timeout", - type=float, - default=60.0, - help="Maximum navigation time in seconds (default: 30.0)", - ) - return parser.parse_args() - - -def main(): - # Get command line arguments - args = parse_args() - object_name = args.object # Object to navigate to - distance = args.distance # Desired distance to object - timeout = args.timeout # Maximum navigation time - - print(f"Initializing Unitree Go2 robot for navigating to a {object_name}...") - - # Initialize the robot with ROS control and skills - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills(), - ) - - # Add and create instance of NavigateToObject skill - robot_skills = robot.get_skills() - robot_skills.add(Navigate) - robot_skills.create_instance("Navigate", robot=robot) - - # Set up tracking and visualization streams - object_tracking_stream = robot.object_tracking_stream - viz_stream = object_tracking_stream.pipe( - RxOps.share(), - RxOps.map(lambda x: x["viz_frame"] if x is not None else None), - RxOps.filter(lambda x: x is not None), - ) - - # The local planner visualization stream is created during robot initialization - local_planner_stream = robot.local_planner_viz_stream - - local_planner_stream = local_planner_stream.pipe( - RxOps.share(), - RxOps.map(lambda x: x if x is not None else None), - RxOps.filter(lambda x: x is not None), - ) - - try: - # Set up web interface - logger.info("Initializing web interface") - streams = { - # "robot_video": video_stream, - "object_tracking": viz_stream, - "local_planner": local_planner_stream, - } - - web_interface = RobotWebInterface(port=5555, **streams) - - # Wait for camera and tracking to initialize - print("Waiting for camera and tracking to initialize...") - time.sleep(3) - - def navigate_to_object(): - try: - result = robot_skills.call( - "Navigate", robot=robot, query=object_name, timeout=timeout - ) - print(f"Navigation result: {result}") - except Exception as e: - print(f"Error during navigation: {e}") - - navigate_thread = threading.Thread(target=navigate_to_object, daemon=True) - navigate_thread.start() - - print( - f"Navigating to {object_name} with desired distance {distance}m and timeout {timeout}s..." - ) - print("Web interface available at http://localhost:5555") - - # Start web server (blocking call) - web_interface.run() - - except KeyboardInterrupt: - print("\nInterrupted by user") - except Exception as e: - print(f"Error during navigation test: {e}") - finally: - print("Test completed") - robot.cleanup() - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_navigation_skills.py b/build/lib/tests/test_navigation_skills.py deleted file mode 100644 index 9a91d1aba5..0000000000 --- a/build/lib/tests/test_navigation_skills.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Simple test script for semantic / spatial memory skills. - -This script is a simplified version that focuses only on making the workflow work. - -Usage: - # Build and query in one run: - python simple_navigation_test.py --query "kitchen" - - # Skip build and just query: - python simple_navigation_test.py --skip-build --query "kitchen" -""" - -import os -import sys -import time -import logging -import argparse -import threading -from reactivex import Subject, operators as RxOps -import os - -import tests.test_header - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.skills.navigation import BuildSemanticMap, Navigate -from dimos.utils.logging_config import setup_logger -from dimos.web.robot_web_interface import RobotWebInterface - -# Setup logging -logger = setup_logger("simple_navigation_test") - - -def parse_args(): - spatial_memory_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../assets/spatial_memory_vegas") - ) - - parser = argparse.ArgumentParser(description="Simple test for semantic map skills.") - parser.add_argument( - "--skip-build", - action="store_true", - help="Skip building the map and run navigation with existing semantic and visual memory", - ) - parser.add_argument( - "--query", type=str, default="kitchen", help="Text query for navigation (default: kitchen)" - ) - parser.add_argument( - "--db-path", - type=str, - default=os.path.join(spatial_memory_dir, "chromadb_data"), - help="Path to ChromaDB database", - ) - parser.add_argument("--justgo", type=str, help="Globally navigate to location") - parser.add_argument( - "--visual-memory-dir", - type=str, - default=spatial_memory_dir, - help="Directory for visual memory", - ) - parser.add_argument( - "--visual-memory-file", - type=str, - default="visual_memory.pkl", - help="Filename for visual memory", - ) - parser.add_argument( - "--port", type=int, default=5555, help="Port for web visualization interface" - ) - return parser.parse_args() - - -def build_map(robot, args): - logger.info("Starting to build spatial memory...") - - # Create the BuildSemanticMap skill - build_skill = BuildSemanticMap( - robot=robot, - db_path=args.db_path, - visual_memory_dir=args.visual_memory_dir, - visual_memory_file=args.visual_memory_file, - ) - - # Start the skill - build_skill() - - # Wait for user to press Ctrl+C - logger.info("Press Ctrl+C to stop mapping and proceed to navigation...") - - try: - while True: - time.sleep(0.5) - except KeyboardInterrupt: - logger.info("Stopping map building...") - - # Stop the skill - build_skill.stop() - logger.info("Map building complete.") - - -def query_map(robot, args): - logger.info(f"Querying spatial memory for: '{args.query}'") - - # Create the Navigate skill - nav_skill = Navigate( - robot=robot, - query=args.query, - db_path=args.db_path, - visual_memory_path=os.path.join(args.visual_memory_dir, args.visual_memory_file), - ) - - # Query the map - result = nav_skill() - - # Display the result - if isinstance(result, dict) and result.get("success", False): - position = result.get("position", (0, 0, 0)) - similarity = result.get("similarity", 0) - logger.info(f"Found '{args.query}' at position: {position}") - logger.info(f"Similarity score: {similarity:.4f}") - return position - - else: - logger.error(f"Navigation query failed: {result}") - return False - - -def setup_visualization(robot, port=5555): - """Set up visualization streams for the web interface""" - logger.info(f"Setting up visualization streams on port {port}") - - # Get video stream from robot - video_stream = robot.video_stream_ros.pipe( - RxOps.share(), - RxOps.map(lambda frame: frame), - RxOps.filter(lambda frame: frame is not None), - ) - - # Get local planner visualization stream - local_planner_stream = robot.local_planner_viz_stream.pipe( - RxOps.share(), - RxOps.map(lambda frame: frame), - RxOps.filter(lambda frame: frame is not None), - ) - - # Create web interface with streams - streams = {"robot_video": video_stream, "local_planner": local_planner_stream} - - web_interface = RobotWebInterface(port=port, **streams) - - return web_interface - - -def run_navigation(robot, target): - """Run navigation in a separate thread""" - logger.info(f"Starting navigation to target: {target}") - return robot.global_planner.set_goal(target) - - -def main(): - args = parse_args() - - # Ensure directories exist - if not args.justgo: - os.makedirs(args.db_path, exist_ok=True) - os.makedirs(args.visual_memory_dir, exist_ok=True) - - # Initialize robot - logger.info("Initializing robot...") - ros_control = UnitreeROSControl(node_name="simple_nav_test", mock_connection=False) - robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP"), skills=MyUnitreeSkills()) - - # Set up visualization - web_interface = None - try: - # Set up visualization first if the robot has video capabilities - if hasattr(robot, "video_stream_ros") and robot.video_stream_ros is not None: - web_interface = setup_visualization(robot, port=args.port) - # Start web interface in a separate thread - viz_thread = threading.Thread(target=web_interface.run, daemon=True) - viz_thread.start() - logger.info(f"Web visualization available at http://localhost:{args.port}") - # Wait a moment for the web interface to initialize - time.sleep(2) - - if args.justgo: - # Just go to the specified location - coords = list(map(float, args.justgo.split(","))) - logger.info(f"Navigating to coordinates: {coords}") - - # Run navigation - navigate_thread = threading.Thread( - target=lambda: run_navigation(robot, coords), daemon=True - ) - navigate_thread.start() - - # Wait for navigation to complete or user to interrupt - try: - while navigate_thread.is_alive(): - time.sleep(0.5) - logger.info("Navigation completed") - except KeyboardInterrupt: - logger.info("Navigation interrupted by user") - else: - # Build map if not skipped - if not args.skip_build: - build_map(robot, args) - - # Query the map - target = query_map(robot, args) - - if not target: - logger.error("No target found for navigation.") - return - - # Run navigation - navigate_thread = threading.Thread( - target=lambda: run_navigation(robot, target), daemon=True - ) - navigate_thread.start() - - # Wait for navigation to complete or user to interrupt - try: - while navigate_thread.is_alive(): - time.sleep(0.5) - logger.info("Navigation completed") - except KeyboardInterrupt: - logger.info("Navigation interrupted by user") - - # If web interface is running, keep the main thread alive - if web_interface: - logger.info( - "Navigation completed. Visualization still available. Press Ctrl+C to exit." - ) - try: - while True: - time.sleep(0.5) - except KeyboardInterrupt: - logger.info("Exiting...") - - finally: - # Clean up - logger.info("Cleaning up resources...") - try: - robot.cleanup() - except Exception as e: - logger.error(f"Error during cleanup: {e}") - - logger.info("Test completed successfully") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_object_detection_agent_data_query_stream.py b/build/lib/tests/test_object_detection_agent_data_query_stream.py deleted file mode 100644 index 00e5625119..0000000000 --- a/build/lib/tests/test_object_detection_agent_data_query_stream.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -import sys -import argparse -import threading -from typing import List, Dict, Any -from reactivex import Subject, operators as ops - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger -from dimos.stream.video_provider import VideoProvider -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.types.vector import Vector -from dimos.utils.reactive import backpressure -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.agents.claude_agent import ClaudeAgent - -from dotenv import load_dotenv - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Test ObjectDetectionStream for object detection and position estimation" - ) - parser.add_argument( - "--mode", - type=str, - default="webcam", - choices=["robot", "webcam"], - help='Mode to run: "robot" or "webcam" (default: webcam)', - ) - return parser.parse_args() - - -load_dotenv() - - -def main(): - # Get command line arguments - args = parse_args() - - # Set default parameters - min_confidence = 0.6 - class_filter = None # No class filtering - web_port = 5555 - - # Initialize detector - detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) - - # Initialize based on mode - if args.mode == "robot": - print("Initializing in robot mode...") - - # Get robot IP from environment - robot_ip = os.getenv("ROBOT_IP") - if not robot_ip: - print("Error: ROBOT_IP environment variable not set.") - sys.exit(1) - - # Initialize robot - robot = UnitreeGo2( - ip=robot_ip, - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills(), - ) - # Create video stream from robot's camera - video_stream = robot.video_stream_ros - - # Initialize ObjectDetectionStream with robot and transform function - object_detector = ObjectDetectionStream( - camera_intrinsics=robot.camera_intrinsics, - min_confidence=min_confidence, - class_filter=class_filter, - transform_to_map=robot.ros_control.transform_pose, - detector=detector, - video_stream=video_stream, - ) - - else: # webcam mode - print("Initializing in webcam mode...") - - # Define camera intrinsics for the webcam - # These are approximate values for a typical 640x480 webcam - width, height = 640, 480 - focal_length_mm = 3.67 # mm (typical webcam) - sensor_width_mm = 4.8 # mm (1/4" sensor) - - # Calculate focal length in pixels - focal_length_x_px = width * focal_length_mm / sensor_width_mm - focal_length_y_px = height * focal_length_mm / sensor_width_mm - - # Principal point (center of image) - cx, cy = width / 2, height / 2 - - # Camera intrinsics in [fx, fy, cx, cy] format - camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] - - # Initialize video provider and ObjectDetectionStream - video_provider = VideoProvider("test_camera", video_source=0) # Default camera - # Create video stream - video_stream = backpressure( - video_provider.capture_video_as_observable(realtime=True, fps=30) - ) - - object_detector = ObjectDetectionStream( - camera_intrinsics=camera_intrinsics, - min_confidence=min_confidence, - class_filter=class_filter, - detector=detector, - video_stream=video_stream, - ) - - # Set placeholder robot for cleanup - robot = None - - # Create visualization stream for web interface - viz_stream = object_detector.get_stream().pipe( - ops.share(), - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), - ) - - # Create object data observable for Agent using the formatted stream - object_data_stream = object_detector.get_formatted_stream().pipe( - ops.share(), ops.filter(lambda x: x is not None) - ) - - # Create stop event for clean shutdown - stop_event = threading.Event() - - try: - # Set up web interface - print("Initializing web interface...") - web_interface = RobotWebInterface(port=web_port, object_detection=viz_stream) - - agent = ClaudeAgent( - dev_name="test_agent", - # input_query_stream=stt_node.emit_text(), - input_query_stream=web_interface.query_stream, - input_data_stream=object_data_stream, - system_query="Tell me what you see", - model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=0, - ) - - # Print configuration information - print("\nObjectDetectionStream Test Running:") - print(f"Mode: {args.mode}") - print(f"Web Interface: http://localhost:{web_port}") - print("\nPress Ctrl+C to stop the test\n") - - # Start web server (blocking call) - web_interface.run() - - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Error during test: {e}") - finally: - # Clean up resources - print("Cleaning up resources...") - stop_event.set() - - if args.mode == "robot" and robot: - robot.cleanup() - elif args.mode == "webcam": - if "video_provider" in locals(): - video_provider.dispose_all() - - print("Test completed") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_object_detection_stream.py b/build/lib/tests/test_object_detection_stream.py deleted file mode 100644 index 1cf8aeab01..0000000000 --- a/build/lib/tests/test_object_detection_stream.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -import sys -import argparse -import threading -from typing import List, Dict, Any -from reactivex import Subject, operators as ops - -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger -from dimos.stream.video_provider import VideoProvider -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.types.vector import Vector -from dimos.utils.reactive import backpressure -from dotenv import load_dotenv - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Test ObjectDetectionStream for object detection and position estimation" - ) - parser.add_argument( - "--mode", - type=str, - default="webcam", - choices=["robot", "webcam"], - help='Mode to run: "robot" or "webcam" (default: webcam)', - ) - return parser.parse_args() - - -load_dotenv() - - -class ResultPrinter: - def __init__(self, print_interval: float = 1.0): - """ - Initialize a result printer that limits console output frequency. - - Args: - print_interval: Minimum time between console prints in seconds - """ - self.print_interval = print_interval - self.last_print_time = 0 - - def print_results(self, objects: List[Dict[str, Any]]): - """Print object detection results to console with rate limiting.""" - current_time = time.time() - - # Only print results at the specified interval - if current_time - self.last_print_time >= self.print_interval: - self.last_print_time = current_time - - if not objects: - print("\n[No objects detected]") - return - - print("\n" + "=" * 50) - print(f"Detected {len(objects)} objects at {time.strftime('%H:%M:%S')}:") - print("=" * 50) - - for i, obj in enumerate(objects): - pos = obj["position"] - rot = obj["rotation"] - size = obj["size"] - - print( - f"{i + 1}. {obj['label']} (ID: {obj['object_id']}, Conf: {obj['confidence']:.2f})" - ) - print(f" Position: x={pos.x:.2f}, y={pos.y:.2f}, z={pos.z:.2f} m") - print(f" Rotation: yaw={rot.z:.2f} rad") - print(f" Size: width={size['width']:.2f}, height={size['height']:.2f} m") - print(f" Depth: {obj['depth']:.2f} m") - print("-" * 30) - - -def main(): - # Get command line arguments - args = parse_args() - - # Set up the result printer for console output - result_printer = ResultPrinter(print_interval=1.0) - - # Set default parameters - min_confidence = 0.6 - class_filter = None # No class filtering - web_port = 5555 - - # Initialize based on mode - if args.mode == "robot": - print("Initializing in robot mode...") - - # Get robot IP from environment - robot_ip = os.getenv("ROBOT_IP") - if not robot_ip: - print("Error: ROBOT_IP environment variable not set.") - sys.exit(1) - - # Initialize robot - robot = UnitreeGo2( - ip=robot_ip, - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills(), - ) - # Create video stream from robot's camera - video_stream = robot.video_stream_ros - - # Initialize ObjectDetectionStream with robot and transform function - object_detector = ObjectDetectionStream( - camera_intrinsics=robot.camera_intrinsics, - min_confidence=min_confidence, - class_filter=class_filter, - transform_to_map=robot.ros_control.transform_pose, - detector=detector, - video_stream=video_stream, - disable_depth=False, - ) - - else: # webcam mode - print("Initializing in webcam mode...") - - # Define camera intrinsics for the webcam - # These are approximate values for a typical 640x480 webcam - width, height = 640, 480 - focal_length_mm = 3.67 # mm (typical webcam) - sensor_width_mm = 4.8 # mm (1/4" sensor) - - # Calculate focal length in pixels - focal_length_x_px = width * focal_length_mm / sensor_width_mm - focal_length_y_px = height * focal_length_mm / sensor_width_mm - - # Principal point (center of image) - cx, cy = width / 2, height / 2 - - # Camera intrinsics in [fx, fy, cx, cy] format - camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] - - # Initialize video provider and ObjectDetectionStream - video_provider = VideoProvider("test_camera", video_source=0) # Default camera - # Create video stream - video_stream = backpressure( - video_provider.capture_video_as_observable(realtime=True, fps=30) - ) - - object_detector = ObjectDetectionStream( - camera_intrinsics=camera_intrinsics, - min_confidence=min_confidence, - class_filter=class_filter, - video_stream=video_stream, - disable_depth=False, - draw_masks=True, - ) - - # Set placeholder robot for cleanup - robot = None - - # Create visualization stream for web interface - viz_stream = object_detector.get_stream().pipe( - ops.share(), - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), - ) - - # Create stop event for clean shutdown - stop_event = threading.Event() - - # Define subscription callback to print results - def on_next(result): - if stop_event.is_set(): - return - - # Print detected objects to console - if "objects" in result: - result_printer.print_results(result["objects"]) - - def on_error(error): - print(f"Error in detection stream: {error}") - stop_event.set() - - def on_completed(): - print("Detection stream completed") - stop_event.set() - - try: - # Subscribe to the detection stream - subscription = object_detector.get_stream().subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - # Set up web interface - print("Initializing web interface...") - web_interface = RobotWebInterface(port=web_port, object_detection=viz_stream) - - # Print configuration information - print("\nObjectDetectionStream Test Running:") - print(f"Mode: {args.mode}") - print(f"Web Interface: http://localhost:{web_port}") - print("\nPress Ctrl+C to stop the test\n") - - # Start web server (blocking call) - web_interface.run() - - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Error during test: {e}") - finally: - # Clean up resources - print("Cleaning up resources...") - stop_event.set() - - if subscription: - subscription.dispose() - - if args.mode == "robot" and robot: - robot.cleanup() - elif args.mode == "webcam": - if "video_provider" in locals(): - video_provider.dispose_all() - - print("Test completed") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_object_tracking_webcam.py b/build/lib/tests/test_object_tracking_webcam.py deleted file mode 100644 index a9d792d51b..0000000000 --- a/build/lib/tests/test_object_tracking_webcam.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -import os -import sys -import queue -import threading -import tests.test_header - -from dimos.stream.video_provider import VideoProvider -from dimos.perception.object_tracker import ObjectTrackingStream - -# Global variables for bounding box selection -selecting_bbox = False -bbox_points = [] -current_bbox = None -tracker_initialized = False -object_size = 0.30 # Hardcoded object size in meters (adjust based on your tracking target) - - -def mouse_callback(event, x, y, flags, param): - global selecting_bbox, bbox_points, current_bbox, tracker_initialized, tracker_stream - - if event == cv2.EVENT_LBUTTONDOWN: - # Start bbox selection - selecting_bbox = True - bbox_points = [(x, y)] - current_bbox = None - tracker_initialized = False - - elif event == cv2.EVENT_MOUSEMOVE and selecting_bbox: - # Update current selection for visualization - current_bbox = [bbox_points[0][0], bbox_points[0][1], x, y] - - elif event == cv2.EVENT_LBUTTONUP: - # End bbox selection - selecting_bbox = False - if bbox_points: - bbox_points.append((x, y)) - x1, y1 = bbox_points[0] - x2, y2 = bbox_points[1] - # Ensure x1,y1 is top-left and x2,y2 is bottom-right - current_bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] - # Add the bbox to the tracking queue - if param.get("bbox_queue") and not tracker_initialized: - param["bbox_queue"].put((current_bbox, object_size)) - tracker_initialized = True - - -def main(): - global tracker_initialized - - # Create queues for thread communication - frame_queue = queue.Queue(maxsize=5) - bbox_queue = queue.Queue() - stop_event = threading.Event() - - # Logitech C920e camera parameters at 480p - # Convert physical parameters to pixel-based intrinsics - width, height = 640, 480 - focal_length_mm = 3.67 # mm - sensor_width_mm = 4.8 # mm (1/4" sensor) - sensor_height_mm = 3.6 # mm - - # Calculate focal length in pixels - focal_length_x_px = width * focal_length_mm / sensor_width_mm - focal_length_y_px = height * focal_length_mm / sensor_height_mm - - # Principal point (assuming center of image) - cx = width / 2 - cy = height / 2 - - # Final camera intrinsics in [fx, fy, cx, cy] format - camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] - - # Initialize video provider and object tracking stream - video_provider = VideoProvider("test_camera", video_source=0) - tracker_stream = ObjectTrackingStream( - camera_intrinsics=camera_intrinsics, - camera_pitch=0.0, # Adjust if your camera is tilted - camera_height=0.5, # Height of camera from ground in meters (adjust as needed) - ) - - # Create video stream - video_stream = video_provider.capture_video_as_observable(realtime=True, fps=30) - tracking_stream = tracker_stream.create_stream(video_stream) - - # Define callbacks for the tracking stream - def on_next(result): - if stop_event.is_set(): - return - - # Get the visualization frame - viz_frame = result["viz_frame"] - - # If we're selecting a bbox, draw the current selection - if selecting_bbox and current_bbox is not None: - x1, y1, x2, y2 = current_bbox - cv2.rectangle(viz_frame, (x1, y1), (x2, y2), (0, 255, 255), 2) - - # Add instructions - cv2.putText( - viz_frame, - "Click and drag to select object", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - (255, 255, 255), - 2, - ) - cv2.putText( - viz_frame, - f"Object size: {object_size:.2f}m", - (10, 60), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - (255, 255, 255), - 2, - ) - - # Show tracking status - status = "Tracking" if tracker_initialized else "Not tracking" - cv2.putText( - viz_frame, - f"Status: {status}", - (10, 90), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - (0, 255, 0) if tracker_initialized else (0, 0, 255), - 2, - ) - - # Put frame in queue for main thread to display - try: - frame_queue.put_nowait(viz_frame) - except queue.Full: - # Skip frame if queue is full - pass - - def on_error(error): - print(f"Error: {error}") - stop_event.set() - - def on_completed(): - print("Stream completed") - stop_event.set() - - # Start the subscription - subscription = None - - try: - # Subscribe to start processing in background thread - subscription = tracking_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - print("Object tracking started. Click and drag to select an object. Press 'q' to exit.") - - # Create window and set mouse callback - cv2.namedWindow("Object Tracker") - cv2.setMouseCallback("Object Tracker", mouse_callback, {"bbox_queue": bbox_queue}) - - # Main thread loop for displaying frames and handling bbox selection - while not stop_event.is_set(): - # Check if there's a new bbox to track - try: - new_bbox, size = bbox_queue.get_nowait() - print(f"New object selected: {new_bbox}, size: {size}m") - # Initialize tracker with the new bbox and size - tracker_stream.track(new_bbox, size=size) - except queue.Empty: - pass - - try: - # Get frame with timeout - viz_frame = frame_queue.get(timeout=1.0) - - # Display the frame - cv2.imshow("Object Tracker", viz_frame) - # Check for exit key - if cv2.waitKey(1) & 0xFF == ord("q"): - print("Exit key pressed") - break - - except queue.Empty: - # No frame available, check if we should continue - if cv2.waitKey(1) & 0xFF == ord("q"): - print("Exit key pressed") - break - continue - - except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping...") - finally: - # Signal threads to stop - stop_event.set() - - # Clean up resources - if subscription: - subscription.dispose() - - video_provider.dispose_all() - tracker_stream.cleanup() - cv2.destroyAllWindows() - print("Cleanup complete") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_object_tracking_with_qwen.py b/build/lib/tests/test_object_tracking_with_qwen.py deleted file mode 100644 index 959565ae55..0000000000 --- a/build/lib/tests/test_object_tracking_with_qwen.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -import time -import cv2 -import numpy as np -import queue -import threading -import json -from reactivex import Subject, operators as RxOps -from openai import OpenAI -import tests.test_header - -from dimos.stream.video_provider import VideoProvider -from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.models.qwen.video_query import get_bbox_from_qwen -from dimos.utils.logging_config import logger - -# Global variables for tracking control -object_size = 0.30 # Hardcoded object size in meters (adjust based on your tracking target) -tracking_object_name = "object" # Will be updated by Qwen -object_name = "hairbrush" # Example object name for Qwen - -global tracker_initialized, detection_in_progress - -# Create queues for thread communication -frame_queue = queue.Queue(maxsize=5) -stop_event = threading.Event() - -# Logitech C920e camera parameters at 480p -width, height = 640, 480 -focal_length_mm = 3.67 # mm -sensor_width_mm = 4.8 # mm (1/4" sensor) -sensor_height_mm = 3.6 # mm - -# Calculate focal length in pixels -focal_length_x_px = width * focal_length_mm / sensor_width_mm -focal_length_y_px = height * focal_length_mm / sensor_height_mm -cx, cy = width / 2, height / 2 - -# Final camera intrinsics in [fx, fy, cx, cy] format -camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] - -# Initialize video provider and object tracking stream -video_provider = VideoProvider("webcam", video_source=0) -tracker_stream = ObjectTrackingStream( - camera_intrinsics=camera_intrinsics, camera_pitch=0.0, camera_height=0.5 -) - -# Create video streams -video_stream = video_provider.capture_video_as_observable(realtime=True, fps=10) -tracking_stream = tracker_stream.create_stream(video_stream) - -# Check if display is available -if "DISPLAY" not in os.environ: - raise RuntimeError( - "No display available. Please set DISPLAY environment variable or run in headless mode." - ) - - -# Define callbacks for the tracking stream -def on_next(result): - global tracker_initialized, detection_in_progress - if stop_event.is_set(): - return - - # Get the visualization frame - viz_frame = result["viz_frame"] - - # Add information to the visualization - cv2.putText( - viz_frame, - f"Tracking {tracking_object_name}", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - (255, 255, 255), - 2, - ) - cv2.putText( - viz_frame, - f"Object size: {object_size:.2f}m", - (10, 60), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - (255, 255, 255), - 2, - ) - - # Show tracking status - status = "Tracking" if tracker_initialized else "Waiting for detection" - color = (0, 255, 0) if tracker_initialized else (0, 0, 255) - cv2.putText(viz_frame, f"Status: {status}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) - - # If detection is in progress, show a message - if detection_in_progress: - cv2.putText( - viz_frame, "Querying Qwen...", (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2 - ) - - # Put frame in queue for main thread to display - try: - frame_queue.put_nowait(viz_frame) - except queue.Full: - pass - - -def on_error(error): - print(f"Error: {error}") - stop_event.set() - - -def on_completed(): - print("Stream completed") - stop_event.set() - - -# Start the subscription -subscription = None - -try: - # Initialize global flags - tracker_initialized = False - detection_in_progress = False - # Subscribe to start processing in background thread - subscription = tracking_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - print("Object tracking with Qwen started. Press 'q' to exit.") - print("Waiting for initial object detection...") - - # Main thread loop for displaying frames and updating tracking - while not stop_event.is_set(): - # Check if we need to update tracking - - if not detection_in_progress: - detection_in_progress = True - print("Requesting object detection from Qwen...") - - print("detection_in_progress: ", detection_in_progress) - print("tracker_initialized: ", tracker_initialized) - - def detection_task(): - global detection_in_progress, tracker_initialized, tracking_object_name, object_size - try: - result = get_bbox_from_qwen(video_stream, object_name=object_name) - print(f"Got result from Qwen: {result}") - - if result: - bbox, size = result - print(f"Detected object at {bbox} with size {size}") - tracker_stream.track(bbox, size=size) - tracker_initialized = True - return - - print("No object detected by Qwen") - tracker_initialized = False - tracker_stream.stop_track() - - except Exception as e: - print(f"Error in update_tracking: {e}") - tracker_initialized = False - tracker_stream.stop_track() - finally: - detection_in_progress = False - - # Run detection task in a separate thread - threading.Thread(target=detection_task, daemon=True).start() - - try: - # Get frame with timeout - viz_frame = frame_queue.get(timeout=0.1) - - # Display the frame - cv2.imshow("Object Tracking with Qwen", viz_frame) - - # Check for exit key - if cv2.waitKey(1) & 0xFF == ord("q"): - print("Exit key pressed") - break - - except queue.Empty: - # No frame available, check if we should continue - if cv2.waitKey(1) & 0xFF == ord("q"): - print("Exit key pressed") - break - continue - -except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping...") -finally: - # Signal threads to stop - stop_event.set() - - # Clean up resources - if subscription: - subscription.dispose() - - video_provider.dispose_all() - tracker_stream.cleanup() - cv2.destroyAllWindows() - print("Cleanup complete") diff --git a/build/lib/tests/test_observe_stream_skill.py b/build/lib/tests/test_observe_stream_skill.py deleted file mode 100644 index 7f18789fb0..0000000000 --- a/build/lib/tests/test_observe_stream_skill.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test for the monitor skill and kill skill. - -This script demonstrates how to use the monitor skill to periodically -send images from the robot's video stream to a Claude agent, and how -to use the kill skill to terminate the monitor skill. -""" - -import os -import time -import threading -from dotenv import load_dotenv -import reactivex as rx -from reactivex import operators as ops -import logging - -from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.kill_skill import KillSkill -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import setup_logger -import tests.test_header - -logger = setup_logger("tests.test_observe_stream_skill") - -load_dotenv() - - -def main(): - # Initialize the robot with mock connection for testing - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP", "192.168.123.161"), skills=MyUnitreeSkills(), mock_connection=True - ) - - agent_response_subject = rx.subject.Subject() - agent_response_stream = agent_response_subject.pipe(ops.share()) - - streams = {"unitree_video": robot.get_ros_video_stream()} - text_streams = { - "agent_responses": agent_response_stream, - } - - web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) - - agent = ClaudeAgent( - dev_name="test_agent", - input_query_stream=web_interface.query_stream, - skills=robot.get_skills(), - system_query="""You are an agent monitoring a robot's environment. - When you see an image, describe what you see and alert if you notice any people or important changes. - Be concise but thorough in your observations.""", - model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=10000, - ) - - agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - - robot_skills = robot.get_skills() - - robot_skills.add(ObserveStream) - robot_skills.add(KillSkill) - - robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) - robot_skills.create_instance("KillSkill", skill_library=robot_skills) - - web_interface_thread = threading.Thread(target=web_interface.run) - web_interface_thread.daemon = True - web_interface_thread.start() - - logger.info("Starting monitor skill...") - - memory_file = os.path.join(agent.output_dir, "memory.txt") - with open(memory_file, "a") as f: - f.write( - "SKILL CALL: ObserveStream(timestep=10.0, query_text='What do you see in this image? Alert me if you see any people.', max_duration=120.0)" - ) - - result = robot_skills.call( - "ObserveStream", - timestep=10.0, # 20 seconds between monitoring queries - query_text="What do you see in this image? Alert me if you see any people.", - max_duration=120.0, - ) # Run for 120 seconds - logger.info(f"Monitor skill result: {result}") - - logger.info(f"Running skills: {robot_skills.get_running_skills().keys()}") - - try: - logger.info("Observer running. Will stop after 35 seconds...") - time.sleep(20.0) - - logger.info(f"Running skills before kill: {robot_skills.get_running_skills().keys()}") - logger.info("Killing the observer skill...") - - memory_file = os.path.join(agent.output_dir, "memory.txt") - with open(memory_file, "a") as f: - f.write("\n\nSKILL CALL: KillSkill(skill_name='observer')\n\n") - - kill_result = robot_skills.call("KillSkill", skill_name="observer") - logger.info(f"Kill skill result: {kill_result}") - - logger.info(f"Running skills after kill: {robot_skills.get_running_skills().keys()}") - - # Keep test running until user interrupts - while True: - time.sleep(1.0) - except KeyboardInterrupt: - logger.info("Test interrupted by user") - - logger.info("Test completed") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_person_following_robot.py b/build/lib/tests/test_person_following_robot.py deleted file mode 100644 index 46f91cc7a3..0000000000 --- a/build/lib/tests/test_person_following_robot.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -import sys -from reactivex import operators as RxOps -import tests.test_header - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger -from dimos.models.qwen.video_query import query_single_frame_observable - - -def main(): - # Hardcoded parameters - timeout = 60.0 # Maximum time to follow a person (seconds) - distance = 0.5 # Desired distance to maintain from target (meters) - - print("Initializing Unitree Go2 robot...") - - # Initialize the robot with ROS control and skills - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - ros_control=UnitreeROSControl(), - skills=MyUnitreeSkills(), - ) - - tracking_stream = robot.person_tracking_stream - viz_stream = tracking_stream.pipe( - RxOps.share(), - RxOps.map(lambda x: x["viz_frame"] if x is not None else None), - RxOps.filter(lambda x: x is not None), - ) - video_stream = robot.get_ros_video_stream() - - try: - # Set up web interface - logger.info("Initializing web interface") - streams = {"unitree_video": video_stream, "person_tracking": viz_stream} - - web_interface = RobotWebInterface(port=5555, **streams) - - # Wait for camera and tracking to initialize - print("Waiting for camera and tracking to initialize...") - time.sleep(5) - # Get initial point from Qwen - - max_retries = 5 - delay = 3 - - for attempt in range(max_retries): - try: - qwen_point = eval( - query_single_frame_observable( - video_stream, - "Look at this frame and point to the person shirt. Return ONLY their center coordinates as a tuple (x,y).", - ) - .pipe(RxOps.take(1)) - .run() - ) # Get first response and convert string tuple to actual tuple - logger.info(f"Found person at coordinates {qwen_point}") - break # If successful, break out of retry loop - except Exception as e: - if attempt < max_retries - 1: - logger.error( - f"Person not found. Attempt {attempt + 1}/{max_retries} failed. Retrying in {delay}s... Error: {e}" - ) - time.sleep(delay) - else: - logger.error(f"Person not found after {max_retries} attempts. Last error: {e}") - return - - # Start following human in a separate thread - import threading - - follow_thread = threading.Thread( - target=lambda: robot.follow_human(timeout=timeout, distance=distance, point=qwen_point), - daemon=True, - ) - follow_thread.start() - - print(f"Following human at point {qwen_point} for {timeout} seconds...") - print("Web interface available at http://localhost:5555") - - # Start web server (blocking call) - web_interface.run() - - except KeyboardInterrupt: - print("\nInterrupted by user") - except Exception as e: - print(f"Error during test: {e}") - finally: - print("Test completed") - robot.cleanup() - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_person_following_webcam.py b/build/lib/tests/test_person_following_webcam.py deleted file mode 100644 index 2108c4cf95..0000000000 --- a/build/lib/tests/test_person_following_webcam.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -import os -import sys -import queue -import threading -import tests.test_header - - -from dimos.stream.video_provider import VideoProvider -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.perception.visual_servoing import VisualServoing - - -def main(): - # Create a queue for thread communication (limit to prevent memory issues) - frame_queue = queue.Queue(maxsize=5) - result_queue = queue.Queue(maxsize=5) # For tracking results - stop_event = threading.Event() - - # Logitech C920e camera parameters at 480p - # Convert physical parameters to intrinsics [fx, fy, cx, cy] - resolution = (640, 480) # 480p resolution - focal_length_mm = 3.67 # mm - sensor_size_mm = (4.8, 3.6) # mm (1/4" sensor) - - # Calculate focal length in pixels - fx = (resolution[0] * focal_length_mm) / sensor_size_mm[0] - fy = (resolution[1] * focal_length_mm) / sensor_size_mm[1] - - # Principal point (typically at image center) - cx = resolution[0] / 2 - cy = resolution[1] / 2 - - # Camera intrinsics in [fx, fy, cx, cy] format - camera_intrinsics = [fx, fy, cx, cy] - - # Camera mounted parameters - camera_pitch = np.deg2rad(-5) # negative for downward pitch - camera_height = 1.4 # meters - - # Initialize video provider and person tracking stream - video_provider = VideoProvider("test_camera", video_source=0) - person_tracker = PersonTrackingStream( - camera_intrinsics=camera_intrinsics, camera_pitch=camera_pitch, camera_height=camera_height - ) - - # Create streams - video_stream = video_provider.capture_video_as_observable(realtime=False, fps=20) - person_tracking_stream = person_tracker.create_stream(video_stream) - - # Create visual servoing object - visual_servoing = VisualServoing( - tracking_stream=person_tracking_stream, - max_linear_speed=0.5, - max_angular_speed=0.75, - desired_distance=2.5, - ) - - # Track if we have selected a person to follow - selected_point = None - tracking_active = False - - # Define callbacks for the tracking stream - def on_next(result): - if stop_event.is_set(): - return - - # Get the visualization frame which already includes person detections - # with bounding boxes, tracking IDs, and distance/angle information - viz_frame = result["viz_frame"] - - # Store the result for the main thread to use with visual servoing - try: - result_queue.put_nowait(result) - except queue.Full: - # Skip if queue is full - pass - - # Put frame in queue for main thread to display (non-blocking) - try: - frame_queue.put_nowait(viz_frame) - except queue.Full: - # Skip frame if queue is full - pass - - def on_error(error): - print(f"Error: {error}") - stop_event.set() - - def on_completed(): - print("Stream completed") - stop_event.set() - - # Mouse callback for selecting a person to track - def mouse_callback(event, x, y, flags, param): - nonlocal selected_point, tracking_active - - if event == cv2.EVENT_LBUTTONDOWN: - # Store the clicked point - selected_point = (x, y) - tracking_active = False # Will be set to True if start_tracking succeeds - print(f"Selected point: {selected_point}") - - # Start the subscription - subscription = None - - try: - # Subscribe to start processing in background thread - subscription = person_tracking_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - print("Person tracking visualization started.") - print("Click on a person to start visual servoing. Press 'q' to exit.") - - # Set up mouse callback - cv2.namedWindow("Person Tracking") - cv2.setMouseCallback("Person Tracking", mouse_callback) - - # Main thread loop for displaying frames - while not stop_event.is_set(): - try: - # Get frame with timeout (allows checking stop_event periodically) - frame = frame_queue.get(timeout=1.0) - - # Call the visual servoing if we have a selected point - if selected_point is not None: - # If not actively tracking, try to start tracking - if not tracking_active: - tracking_active = visual_servoing.start_tracking(point=selected_point) - if not tracking_active: - print("Failed to start tracking") - selected_point = None - - # If tracking is active, update tracking - if tracking_active: - servoing_result = visual_servoing.updateTracking() - - # Display visual servoing output on the frame - linear_vel = servoing_result.get("linear_vel", 0.0) - angular_vel = servoing_result.get("angular_vel", 0.0) - running = visual_servoing.running - - status_color = ( - (0, 255, 0) if running else (0, 0, 255) - ) # Green if running, red if not - - # Add velocity text to frame - cv2.putText( - frame, - f"Linear: {linear_vel:.2f} m/s", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - status_color, - 2, - ) - cv2.putText( - frame, - f"Angular: {angular_vel:.2f} rad/s", - (10, 60), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - status_color, - 2, - ) - cv2.putText( - frame, - f"Tracking: {'ON' if running else 'OFF'}", - (10, 90), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - status_color, - 2, - ) - - # If tracking is lost, reset selected_point and tracking_active - if not running: - selected_point = None - tracking_active = False - - # Display the frame in main thread - cv2.imshow("Person Tracking", frame) - - # Check for exit key - if cv2.waitKey(1) & 0xFF == ord("q"): - print("Exit key pressed") - break - - except queue.Empty: - # No frame available, check if we should continue - if cv2.waitKey(1) & 0xFF == ord("q"): - print("Exit key pressed") - break - continue - - except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping...") - finally: - # Signal threads to stop - stop_event.set() - - # Clean up resources - if subscription: - subscription.dispose() - - visual_servoing.cleanup() - video_provider.dispose_all() - person_tracker.cleanup() - cv2.destroyAllWindows() - print("Cleanup complete") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_planning_agent_web_interface.py b/build/lib/tests/test_planning_agent_web_interface.py deleted file mode 100644 index 1d1e3fcd87..0000000000 --- a/build/lib/tests/test_planning_agent_web_interface.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Planning agent demo with FastAPI server and robot integration. - -Connects a planning agent, execution agent, and robot with a web interface. - -Environment Variables: - OPENAI_API_KEY: Required. OpenAI API key. - ROBOT_IP: Required. IP address of the robot. - CONN_TYPE: Required. Connection method to the robot. - ROS_OUTPUT_DIR: Optional. Directory for ROS output files. -""" - -import tests.test_header -import os -import sys - -# ----- - -from textwrap import dedent -import threading -import time -import reactivex as rx -import reactivex.operators as ops - -# Local application imports -from dimos.agents.agent import OpenAIAgent -from dimos.agents.planning_agent import PlanningAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.utils.logging_config import logger - -# from dimos.web.fastapi_server import FastAPIServer -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.threadpool import make_single_thread_scheduler - - -def main(): - # Get environment variables - robot_ip = os.getenv("ROBOT_IP") - if not robot_ip: - raise ValueError("ROBOT_IP environment variable is required") - connection_method = os.getenv("CONN_TYPE") or "webrtc" - output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) - - # Initialize components as None for proper cleanup - robot = None - web_interface = None - planner = None - executor = None - - try: - # Initialize robot - logger.info("Initializing Unitree Robot") - robot = UnitreeGo2( - ip=robot_ip, - connection_method=connection_method, - output_dir=output_dir, - mock_connection=False, - skills=MyUnitreeSkills(), - ) - # Set up video stream - logger.info("Starting video stream") - video_stream = robot.get_ros_video_stream() - - # Initialize robot skills - logger.info("Initializing robot skills") - - # Create subjects for planner and executor responses - logger.info("Creating response streams") - planner_response_subject = rx.subject.Subject() - planner_response_stream = planner_response_subject.pipe(ops.share()) - - executor_response_subject = rx.subject.Subject() - executor_response_stream = executor_response_subject.pipe(ops.share()) - - # Web interface mode with FastAPI server - logger.info("Initializing FastAPI server") - streams = {"unitree_video": video_stream} - text_streams = { - "planner_responses": planner_response_stream, - "executor_responses": executor_response_stream, - } - - web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) - - logger.info("Starting planning agent with web interface") - planner = PlanningAgent( - dev_name="TaskPlanner", - model_name="gpt-4o", - input_query_stream=web_interface.query_stream, - skills=robot.get_skills(), - ) - - # Get planner's response observable - logger.info("Setting up agent response streams") - planner_responses = planner.get_response_observable() - - # Connect planner to its subject - planner_responses.subscribe(lambda x: planner_response_subject.on_next(x)) - - planner_responses.subscribe( - on_next=lambda x: logger.info(f"Planner response: {x}"), - on_error=lambda e: logger.error(f"Planner error: {e}"), - on_completed=lambda: logger.info("Planner completed"), - ) - - # Initialize execution agent with robot skills - logger.info("Starting execution agent") - system_query = dedent( - """ - You are a robot execution agent that can execute tasks on a virtual - robot. The sole text you will be given is the task to execute. - You will be given a list of skills that you can use to execute the task. - ONLY OUTPUT THE SKILLS TO EXECUTE, NOTHING ELSE. - """ - ) - executor = OpenAIAgent( - dev_name="StepExecutor", - input_query_stream=planner_responses, - output_dir=output_dir, - skills=robot.get_skills(), - system_query=system_query, - pool_scheduler=make_single_thread_scheduler(), - ) - - # Get executor's response observable - executor_responses = executor.get_response_observable() - - # Subscribe to responses for logging - executor_responses.subscribe( - on_next=lambda x: logger.info(f"Executor response: {x}"), - on_error=lambda e: logger.error(f"Executor error: {e}"), - on_completed=lambda: logger.info("Executor completed"), - ) - - # Connect executor to its subject - executor_responses.subscribe(lambda x: executor_response_subject.on_next(x)) - - # Start web server (blocking call) - logger.info("Starting FastAPI server") - web_interface.run() - - except KeyboardInterrupt: - print("Stopping demo...") - except Exception as e: - logger.error(f"Error: {e}") - return 1 - finally: - # Clean up all components - logger.info("Cleaning up components") - if executor: - executor.dispose_all() - if planner: - planner.dispose_all() - if web_interface: - web_interface.dispose_all() - if robot: - robot.cleanup() - # Halt execution forever - while True: - time.sleep(1) - - -if __name__ == "__main__": - sys.exit(main()) - -# Example Task: Move the robot forward by 1 meter, then turn 90 degrees clockwise, then move backward by 1 meter, then turn a random angle counterclockwise, then repeat this sequence 5 times. diff --git a/build/lib/tests/test_planning_robot_agent.py b/build/lib/tests/test_planning_robot_agent.py deleted file mode 100644 index 6e55e5de71..0000000000 --- a/build/lib/tests/test_planning_robot_agent.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Planning agent demo with FastAPI server and robot integration. - -Connects a planning agent, execution agent, and robot with a web interface. - -Environment Variables: - OPENAI_API_KEY: Required. OpenAI API key. - ROBOT_IP: Required. IP address of the robot. - CONN_TYPE: Required. Connection method to the robot. - ROS_OUTPUT_DIR: Optional. Directory for ROS output files. - USE_TERMINAL: Optional. If set to "true", use terminal interface instead of web. -""" - -import tests.test_header -import os -import sys - -# ----- - -from textwrap import dedent -import threading -import time - -# Local application imports -from dimos.agents.agent import OpenAIAgent -from dimos.agents.planning_agent import PlanningAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.utils.logging_config import logger -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.threadpool import make_single_thread_scheduler - - -def main(): - # Get environment variables - robot_ip = os.getenv("ROBOT_IP") - if not robot_ip: - raise ValueError("ROBOT_IP environment variable is required") - connection_method = os.getenv("CONN_TYPE") or "webrtc" - output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) - use_terminal = os.getenv("USE_TERMINAL", "").lower() == "true" - - use_terminal = True - # Initialize components as None for proper cleanup - robot = None - web_interface = None - planner = None - executor = None - - try: - # Initialize robot - logger.info("Initializing Unitree Robot") - robot = UnitreeGo2( - ip=robot_ip, - connection_method=connection_method, - output_dir=output_dir, - mock_connection=True, - ) - - # Set up video stream - logger.info("Starting video stream") - video_stream = robot.get_ros_video_stream() - - # Initialize robot skills - logger.info("Initializing robot skills") - skills_instance = MyUnitreeSkills(robot=robot) - - if use_terminal: - # Terminal mode - no web interface needed - logger.info("Starting planning agent in terminal mode") - planner = PlanningAgent( - dev_name="TaskPlanner", - model_name="gpt-4o", - use_terminal=True, - skills=skills_instance, - ) - else: - # Web interface mode - logger.info("Initializing FastAPI server") - streams = {"unitree_video": video_stream} - web_interface = RobotWebInterface(port=5555, **streams) - - logger.info("Starting planning agent with web interface") - planner = PlanningAgent( - dev_name="TaskPlanner", - model_name="gpt-4o", - input_query_stream=web_interface.query_stream, - skills=skills_instance, - ) - - # Get planner's response observable - logger.info("Setting up agent response streams") - planner_responses = planner.get_response_observable() - - # Initialize execution agent with robot skills - logger.info("Starting execution agent") - system_query = dedent( - """ - You are a robot execution agent that can execute tasks on a virtual - robot. You are given a task to execute and a list of skills that - you can use to execute the task. ONLY OUTPUT THE SKILLS TO EXECUTE, - NOTHING ELSE. - """ - ) - executor = OpenAIAgent( - dev_name="StepExecutor", - input_query_stream=planner_responses, - output_dir=output_dir, - skills=skills_instance, - system_query=system_query, - pool_scheduler=make_single_thread_scheduler(), - ) - - # Get executor's response observable - executor_responses = executor.get_response_observable() - - # Subscribe to responses for logging - executor_responses.subscribe( - on_next=lambda x: logger.info(f"Executor response: {x}"), - on_error=lambda e: logger.error(f"Executor error: {e}"), - on_completed=lambda: logger.info("Executor completed"), - ) - - if use_terminal: - # In terminal mode, just wait for the planning session to complete - logger.info("Waiting for planning session to complete") - while not planner.plan_confirmed: - pass - logger.info("Planning session completed") - else: - # Start web server (blocking call) - logger.info("Starting FastAPI server") - web_interface.run() - - # Keep the main thread alive - logger.error("NOTE: Keeping main thread alive") - while True: - time.sleep(1) - - except KeyboardInterrupt: - print("Stopping demo...") - except Exception as e: - logger.error(f"Error: {e}") - return 1 - finally: - # Clean up all components - logger.info("Cleaning up components") - if executor: - executor.dispose_all() - if planner: - planner.dispose_all() - if web_interface: - web_interface.dispose_all() - if robot: - robot.cleanup() - # Halt execution forever - while True: - time.sleep(1) - - -if __name__ == "__main__": - sys.exit(main()) - -# Example Task: Move the robot forward by 1 meter, then turn 90 degrees clockwise, then move backward by 1 meter, then turn a random angle counterclockwise, then repeat this sequence 5 times. diff --git a/build/lib/tests/test_pointcloud_filtering.py b/build/lib/tests/test_pointcloud_filtering.py deleted file mode 100644 index 57a1cb5b00..0000000000 --- a/build/lib/tests/test_pointcloud_filtering.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import time -import threading -from reactivex import operators as ops - -import tests.test_header - -from pyzed import sl -from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger -from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline - - -def main(): - """Test point cloud filtering using the concurrent stream-based ManipulationPipeline.""" - print("Testing point cloud filtering with ManipulationPipeline...") - - # Configuration - min_confidence = 0.6 - web_port = 5555 - - try: - # Initialize ZED camera stream - zed_stream = ZEDCameraStream(resolution=sl.RESOLUTION.HD1080, fps=10) - - # Get camera intrinsics - camera_intrinsics_dict = zed_stream.get_camera_info() - camera_intrinsics = [ - camera_intrinsics_dict["fx"], - camera_intrinsics_dict["fy"], - camera_intrinsics_dict["cx"], - camera_intrinsics_dict["cy"], - ] - - # Create the concurrent manipulation pipeline - pipeline = ManipulationPipeline( - camera_intrinsics=camera_intrinsics, - min_confidence=min_confidence, - max_objects=10, - ) - - # Create ZED stream - zed_frame_stream = zed_stream.create_stream().pipe(ops.share()) - - # Create concurrent processing streams - streams = pipeline.create_streams(zed_frame_stream) - detection_viz_stream = streams["detection_viz"] - pointcloud_viz_stream = streams["pointcloud_viz"] - - except ImportError: - print("Error: ZED SDK not installed. Please install pyzed package.") - sys.exit(1) - except RuntimeError as e: - print(f"Error: Failed to open ZED camera: {e}") - sys.exit(1) - - try: - # Set up web interface with concurrent visualization streams - print("Initializing web interface...") - web_interface = RobotWebInterface( - port=web_port, - object_detection=detection_viz_stream, - pointcloud_stream=pointcloud_viz_stream, - ) - - print(f"\nPoint Cloud Filtering Test Running:") - print(f"Web Interface: http://localhost:{web_port}") - print(f"Object Detection View: RGB with bounding boxes") - print(f"Point Cloud View: Depth with colored point clouds and 3D bounding boxes") - print(f"Confidence threshold: {min_confidence}") - print("\nPress Ctrl+C to stop the test\n") - - # Start web server (blocking call) - web_interface.run() - - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Error during test: {e}") - finally: - print("Cleaning up resources...") - if "zed_stream" in locals(): - zed_stream.cleanup() - if "pipeline" in locals(): - pipeline.cleanup() - print("Test completed") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_qwen_image_query.py b/build/lib/tests/test_qwen_image_query.py deleted file mode 100644 index 13feaf7eb3..0000000000 --- a/build/lib/tests/test_qwen_image_query.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test the Qwen image query functionality.""" - -import os -from PIL import Image -from dimos.models.qwen.video_query import query_single_frame - - -def test_qwen_image_query(): - """Test querying Qwen with a single image.""" - # Skip if no API key - if not os.getenv("ALIBABA_API_KEY"): - print("ALIBABA_API_KEY not set") - return - - # Load test image - image_path = os.path.join(os.getcwd(), "assets", "test_spatial_memory", "frame_038.jpg") - image = Image.open(image_path) - - # Test basic object detection query - response = query_single_frame( - image=image, - query="What objects do you see in this image? Return as a comma-separated list.", - ) - print(response) - - # Test coordinate query - response = query_single_frame( - image=image, - query="Return the center coordinates of any person in the image as a tuple (x,y)", - ) - print(response) - - -if __name__ == "__main__": - test_qwen_image_query() diff --git a/build/lib/tests/test_robot.py b/build/lib/tests/test_robot.py deleted file mode 100644 index 76289273f7..0000000000 --- a/build/lib/tests/test_robot.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time -import threading -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.robot.local_planner.local_planner import navigate_to_goal_local -from dimos.web.robot_web_interface import RobotWebInterface -from reactivex import operators as RxOps -import tests.test_header - - -def main(): - print("Initializing Unitree Go2 robot with local planner visualization...") - - # Initialize the robot with webrtc interface - robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") - - # Get the camera stream - video_stream = robot.get_video_stream() - - # The local planner visualization stream is created during robot initialization - local_planner_stream = robot.local_planner_viz_stream - - local_planner_stream = local_planner_stream.pipe( - RxOps.share(), - RxOps.map(lambda x: x if x is not None else None), - RxOps.filter(lambda x: x is not None), - ) - - goal_following_thread = None - try: - # Set up web interface with both streams - streams = {"camera": video_stream, "local_planner": local_planner_stream} - - # Create and start the web interface - web_interface = RobotWebInterface(port=5555, **streams) - - # Wait for initialization - print("Waiting for camera and systems to initialize...") - time.sleep(2) - - # Start the goal following test in a separate thread - print("Starting navigation to local goal (2m ahead) in a separate thread...") - goal_following_thread = threading.Thread( - target=navigate_to_goal_local, - kwargs={"robot": robot, "goal_xy_robot": (3.0, 0.0), "distance": 0.0, "timeout": 300}, - daemon=True, - ) - goal_following_thread.start() - - print("Robot streams running") - print("Web interface available at http://localhost:5555") - print("Press Ctrl+C to exit") - - # Start web server (blocking call) - web_interface.run() - - except KeyboardInterrupt: - print("\nInterrupted by user") - except Exception as e: - print(f"Error during test: {e}") - finally: - print("Cleaning up...") - # Make sure the robot stands down safely - try: - robot.liedown() - except: - pass - print("Test completed") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_rtsp_video_provider.py b/build/lib/tests/test_rtsp_video_provider.py deleted file mode 100644 index e3824740a6..0000000000 --- a/build/lib/tests/test_rtsp_video_provider.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.stream.rtsp_video_provider import RtspVideoProvider -from dimos.web.robot_web_interface import RobotWebInterface -import tests.test_header - -import logging -import time - -import numpy as np -import reactivex as rx -from reactivex import operators as ops - -from dimos.stream.frame_processor import FrameProcessor -from dimos.stream.video_operators import VideoOperators as vops -from dimos.stream.video_provider import get_scheduler -from dimos.utils.logging_config import setup_logger - - -logger = setup_logger("tests.test_rtsp_video_provider") - -import sys -import os - -# Load environment variables from .env file -from dotenv import load_dotenv - -load_dotenv() - -# RTSP URL must be provided as a command-line argument or environment variable -RTSP_URL = os.environ.get("TEST_RTSP_URL", "") -if len(sys.argv) > 1: - RTSP_URL = sys.argv[1] # Allow overriding with command-line argument -elif RTSP_URL == "": - print("Please provide an RTSP URL for testing.") - print( - "You can set the TEST_RTSP_URL environment variable or pass it as a command-line argument." - ) - print("Example: python -m dimos.stream.rtsp_video_provider rtsp://...") - sys.exit(1) - -logger.info(f"Attempting to connect to provided RTSP URL.") -provider = RtspVideoProvider(dev_name="TestRtspCam", rtsp_url=RTSP_URL) - -logger.info("Creating observable...") -video_stream_observable = provider.capture_video_as_observable() - -logger.info("Subscribing to observable...") -frame_counter = 0 -start_time = time.monotonic() # Re-initialize start_time -last_log_time = start_time # Keep this for interval timing - -# Create a subject for ffmpeg responses -ffmpeg_response_subject = rx.subject.Subject() -ffmpeg_response_stream = ffmpeg_response_subject.pipe(ops.observe_on(get_scheduler()), ops.share()) - - -def process_frame(frame: np.ndarray): - """Callback function executed for each received frame.""" - global frame_counter, last_log_time, start_time # Add start_time to global - frame_counter += 1 - current_time = time.monotonic() - # Log stats periodically (e.g., every 5 seconds) - if current_time - last_log_time >= 5.0: - total_elapsed_time = current_time - start_time # Calculate total elapsed time - avg_fps = frame_counter / total_elapsed_time if total_elapsed_time > 0 else 0 - logger.info(f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}") - ffmpeg_response_subject.on_next( - f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}" - ) - last_log_time = current_time # Update log time for the next interval - - -def handle_error(error: Exception): - """Callback function executed if the observable stream errors.""" - logger.error(f"Stream error: {error}", exc_info=True) # Log with traceback - - -def handle_completion(): - """Callback function executed when the observable stream completes.""" - logger.info("Stream completed.") - - -# Subscribe to the observable stream -processor = FrameProcessor() -subscription = video_stream_observable.pipe( - # ops.subscribe_on(get_scheduler()), - ops.observe_on(get_scheduler()), - ops.share(), - vops.with_jpeg_export(processor, suffix="reolink_", save_limit=30, loop=True), -).subscribe(on_next=process_frame, on_error=handle_error, on_completed=handle_completion) - -streams = {"reolink_video": video_stream_observable} -text_streams = { - "ffmpeg_responses": ffmpeg_response_stream, -} - -web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) - -web_interface.run() # This may block the main thread - -# TODO: Redo disposal / keep-alive loop - -# Keep the main thread alive to receive frames (e.g., for 60 seconds) -print("Stream running. Press Ctrl+C to stop...") -try: - # Keep running indefinitely until interrupted - while True: - time.sleep(1) - # Optional: Check if subscription is still active - # if not subscription.is_disposed: - # # logger.debug("Subscription active...") - # pass - # else: - # logger.warning("Subscription was disposed externally.") - # break - -except KeyboardInterrupt: - print("KeyboardInterrupt received. Shutting down...") -finally: - # Ensure resources are cleaned up regardless of how the loop exits - print("Disposing subscription...") - # subscription.dispose() - print("Disposing provider resources...") - provider.dispose_all() - print("Cleanup finished.") - -# Final check (optional, for debugging) -time.sleep(1) # Give background threads a moment -final_process = provider._ffmpeg_process -if final_process and final_process.poll() is None: - print(f"WARNING: ffmpeg process (PID: {final_process.pid}) may still be running after cleanup!") -else: - print("ffmpeg process appears terminated.") diff --git a/build/lib/tests/test_semantic_seg_robot.py b/build/lib/tests/test_semantic_seg_robot.py deleted file mode 100644 index eb5beb88e2..0000000000 --- a/build/lib/tests/test_semantic_seg_robot.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -import os -import sys -import queue -import threading - -# Add the parent directory to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from dimos.stream.video_provider import VideoProvider -from dimos.perception.semantic_seg import SemanticSegmentationStream -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.stream.video_operators import VideoOperators as MyVideoOps, Operators as MyOps -from dimos.stream.frame_processor import FrameProcessor -from reactivex import operators as RxOps - - -def main(): - # Create a queue for thread communication (limit to prevent memory issues) - frame_queue = queue.Queue(maxsize=5) - stop_event = threading.Event() - - # Unitree Go2 camera parameters at 1080p - camera_params = { - "resolution": (1920, 1080), # 1080p resolution - "focal_length": 3.2, # mm - "sensor_size": (4.8, 3.6), # mm (1/4" sensor) - } - - # Initialize video provider and segmentation stream - # video_provider = VideoProvider("test_camera", video_source=0) - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - ros_control=UnitreeROSControl(), - ) - - seg_stream = SemanticSegmentationStream( - enable_mono_depth=False, camera_params=camera_params, gt_depth_scale=512.0 - ) - - # Create streams - video_stream = robot.get_ros_video_stream(fps=5) - segmentation_stream = seg_stream.create_stream(video_stream) - - # Define callbacks for the segmentation stream - def on_next(segmentation): - if stop_event.is_set(): - return - # Get the frame and visualize - vis_frame = segmentation.metadata["viz_frame"] - depth_viz = segmentation.metadata["depth_viz"] - # Get the image dimensions - height, width = vis_frame.shape[:2] - depth_height, depth_width = depth_viz.shape[:2] - - # Resize depth visualization to match segmentation height - # (maintaining aspect ratio if needed) - depth_resized = cv2.resize(depth_viz, (int(depth_width * height / depth_height), height)) - - # Create a combined frame for side-by-side display - combined_viz = np.hstack((vis_frame, depth_resized)) - - # Add labels - font = cv2.FONT_HERSHEY_SIMPLEX - cv2.putText(combined_viz, "Semantic Segmentation", (10, 30), font, 0.8, (255, 255, 255), 2) - cv2.putText( - combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2 - ) - - # Put frame in queue for main thread to display (non-blocking) - try: - frame_queue.put_nowait(combined_viz) - except queue.Full: - # Skip frame if queue is full - pass - - def on_error(error): - print(f"Error: {error}") - stop_event.set() - - def on_completed(): - print("Stream completed") - stop_event.set() - - # Start the subscription - subscription = None - - try: - # Subscribe to start processing in background thread - print_emission_args = { - "enabled": True, - "dev_name": "SemanticSegmentation", - "counts": {}, - } - - frame_processor = FrameProcessor(delete_on_init=True) - subscription = segmentation_stream.pipe( - MyOps.print_emission(id="A", **print_emission_args), - RxOps.share(), - MyOps.print_emission(id="B", **print_emission_args), - RxOps.map(lambda x: x.metadata["viz_frame"] if x is not None else None), - MyOps.print_emission(id="C", **print_emission_args), - RxOps.filter(lambda x: x is not None), - MyOps.print_emission(id="D", **print_emission_args), - # MyVideoOps.with_jpeg_export(frame_processor=frame_processor, suffix="_frame_"), - MyOps.print_emission(id="E", **print_emission_args), - ) - - print("Semantic segmentation visualization started. Press 'q' to exit.") - - streams = { - "segmentation_stream": subscription, - } - fast_api_server = RobotWebInterface(port=5555, **streams) - fast_api_server.run() - - except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping...") - finally: - # Signal threads to stop - stop_event.set() - - # Clean up resources - if subscription: - subscription.dispose() - - seg_stream.cleanup() - cv2.destroyAllWindows() - print("Cleanup complete") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_semantic_seg_robot_agent.py b/build/lib/tests/test_semantic_seg_robot_agent.py deleted file mode 100644 index 8007e700a0..0000000000 --- a/build/lib/tests/test_semantic_seg_robot_agent.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -import os -import sys - -from dimos.stream.video_provider import VideoProvider -from dimos.perception.semantic_seg import SemanticSegmentationStream -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.stream.video_operators import VideoOperators as MyVideoOps, Operators as MyOps -from dimos.stream.frame_processor import FrameProcessor -from reactivex import Subject, operators as RxOps -from dimos.agents.agent import OpenAIAgent -from dimos.utils.threadpool import get_scheduler - - -def main(): - # Unitree Go2 camera parameters at 1080p - camera_params = { - "resolution": (1920, 1080), # 1080p resolution - "focal_length": 3.2, # mm - "sensor_size": (4.8, 3.6), # mm (1/4" sensor) - } - - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() - ) - - seg_stream = SemanticSegmentationStream( - enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0 - ) - - # Create streams - video_stream = robot.get_ros_video_stream(fps=5) - segmentation_stream = seg_stream.create_stream( - video_stream.pipe(MyVideoOps.with_fps_sampling(fps=0.5)) - ) - # Throttling to slowdown SegmentationAgent calls - # TODO: add Agent parameter to handle this called api_call_interval - - frame_processor = FrameProcessor(delete_on_init=True) - seg_stream = segmentation_stream.pipe( - RxOps.share(), - RxOps.map(lambda x: x.metadata["viz_frame"] if x is not None else None), - RxOps.filter(lambda x: x is not None), - # MyVideoOps.with_jpeg_export(frame_processor=frame_processor, suffix="_frame_"), # debugging - ) - - depth_stream = segmentation_stream.pipe( - RxOps.share(), - RxOps.map(lambda x: x.metadata["depth_viz"] if x is not None else None), - RxOps.filter(lambda x: x is not None), - ) - - object_stream = segmentation_stream.pipe( - RxOps.share(), - RxOps.map(lambda x: x.metadata["objects"] if x is not None else None), - RxOps.filter(lambda x: x is not None), - RxOps.map( - lambda objects: "\n".join( - f"Object {obj['object_id']}: {obj['label']} (confidence: {obj['prob']:.2f})" - + (f", depth: {obj['depth']:.2f}m" if "depth" in obj else "") - for obj in objects - ) - if objects - else "No objects detected." - ), - ) - - text_query_stream = Subject() - - # Combine text query with latest object data when a new text query arrives - enriched_query_stream = text_query_stream.pipe( - RxOps.with_latest_from(object_stream), - RxOps.map( - lambda combined: { - "query": combined[0], - "objects": combined[1] if len(combined) > 1 else "No object data available", - } - ), - RxOps.map(lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}"), - RxOps.do_action( - lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m") - or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]] - ), - ) - - segmentation_agent = OpenAIAgent( - dev_name="SemanticSegmentationAgent", - model_name="gpt-4o", - system_query="You are a helpful assistant that can control a virtual robot with semantic segmentation / distnace data as a guide. Only output skill calls, no other text", - input_query_stream=enriched_query_stream, - process_all_inputs=False, - pool_scheduler=get_scheduler(), - skills=robot.get_skills(), - ) - agent_response_stream = segmentation_agent.get_response_observable() - - print("Semantic segmentation visualization started. Press 'q' to exit.") - - streams = { - "raw_stream": video_stream, - "depth_stream": depth_stream, - "seg_stream": seg_stream, - } - text_streams = { - "object_stream": object_stream, - "enriched_query_stream": enriched_query_stream, - "agent_response_stream": agent_response_stream, - } - - try: - fast_api_server = RobotWebInterface(port=5555, text_streams=text_streams, **streams) - fast_api_server.query_stream.subscribe(lambda x: text_query_stream.on_next(x)) - fast_api_server.run() - except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping...") - finally: - seg_stream.cleanup() - cv2.destroyAllWindows() - print("Cleanup complete") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_semantic_seg_webcam.py b/build/lib/tests/test_semantic_seg_webcam.py deleted file mode 100644 index 083d1a0090..0000000000 --- a/build/lib/tests/test_semantic_seg_webcam.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np -import os -import sys -import queue -import threading - -# Add the parent directory to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from dimos.stream.video_provider import VideoProvider -from dimos.perception.semantic_seg import SemanticSegmentationStream - - -def main(): - # Create a queue for thread communication (limit to prevent memory issues) - frame_queue = queue.Queue(maxsize=5) - stop_event = threading.Event() - - # Logitech C920e camera parameters at 480p - camera_params = { - "resolution": (640, 480), # 480p resolution - "focal_length": 3.67, # mm - "sensor_size": (4.8, 3.6), # mm (1/4" sensor) - } - - # Initialize video provider and segmentation stream - video_provider = VideoProvider("test_camera", video_source=0) - seg_stream = SemanticSegmentationStream( - enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0 - ) - - # Create streams - video_stream = video_provider.capture_video_as_observable(realtime=False, fps=5) - segmentation_stream = seg_stream.create_stream(video_stream) - - # Define callbacks for the segmentation stream - def on_next(segmentation): - if stop_event.is_set(): - return - - # Get the frame and visualize - vis_frame = segmentation.metadata["viz_frame"] - depth_viz = segmentation.metadata["depth_viz"] - # Get the image dimensions - height, width = vis_frame.shape[:2] - depth_height, depth_width = depth_viz.shape[:2] - - # Resize depth visualization to match segmentation height - # (maintaining aspect ratio if needed) - depth_resized = cv2.resize(depth_viz, (int(depth_width * height / depth_height), height)) - - # Create a combined frame for side-by-side display - combined_viz = np.hstack((vis_frame, depth_resized)) - - # Add labels - font = cv2.FONT_HERSHEY_SIMPLEX - cv2.putText(combined_viz, "Semantic Segmentation", (10, 30), font, 0.8, (255, 255, 255), 2) - cv2.putText( - combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2 - ) - - # Put frame in queue for main thread to display (non-blocking) - try: - frame_queue.put_nowait(combined_viz) - except queue.Full: - # Skip frame if queue is full - pass - - def on_error(error): - print(f"Error: {error}") - stop_event.set() - - def on_completed(): - print("Stream completed") - stop_event.set() - - # Start the subscription - subscription = None - - try: - # Subscribe to start processing in background thread - subscription = segmentation_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - print("Semantic segmentation visualization started. Press 'q' to exit.") - - # Main thread loop for displaying frames - while not stop_event.is_set(): - try: - # Get frame with timeout (allows checking stop_event periodically) - combined_viz = frame_queue.get(timeout=1.0) - - # Display the frame in main thread - cv2.imshow("Semantic Segmentation", combined_viz) - # Check for exit key - if cv2.waitKey(1) & 0xFF == ord("q"): - print("Exit key pressed") - break - - except queue.Empty: - # No frame available, check if we should continue - if cv2.waitKey(1) & 0xFF == ord("q"): - print("Exit key pressed") - break - continue - - except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping...") - finally: - # Signal threads to stop - stop_event.set() - - # Clean up resources - if subscription: - subscription.dispose() - - video_provider.dispose_all() - seg_stream.cleanup() - cv2.destroyAllWindows() - print("Cleanup complete") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_skills.py b/build/lib/tests/test_skills.py deleted file mode 100644 index 0d4b7f2ff8..0000000000 --- a/build/lib/tests/test_skills.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for the skills module in the dimos package.""" - -import unittest -from unittest import mock - -import tests.test_header - -from dimos.skills.skills import AbstractSkill, SkillLibrary -from dimos.robot.robot import MockRobot -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.types.constants import Colors -from dimos.agents.agent import OpenAIAgent - - -class TestSkill(AbstractSkill): - """A test skill that tracks its execution for testing purposes.""" - - _called: bool = False - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._called = False - - def __call__(self): - self._called = True - return "TestSkill executed successfully" - - -class SkillLibraryTest(unittest.TestCase): - """Tests for the SkillLibrary functionality.""" - - def setUp(self): - """Set up test fixtures before each test method.""" - self.robot = MockRobot() - self.skill_library = MyUnitreeSkills(robot=self.robot) - self.skill_library.initialize_skills() - - def test_skill_iteration(self): - """Test that skills can be properly iterated in the skill library.""" - skills_count = 0 - for skill in self.skill_library: - skills_count += 1 - self.assertTrue(hasattr(skill, "__name__")) - self.assertTrue(issubclass(skill, AbstractSkill)) - - self.assertGreater(skills_count, 0, "Skill library should contain at least one skill") - - def test_skill_registration(self): - """Test that skills can be properly registered in the skill library.""" - # Clear existing skills for isolated test - self.skill_library = MyUnitreeSkills(robot=self.robot) - original_count = len(list(self.skill_library)) - - # Add a custom test skill - test_skill = TestSkill - self.skill_library.add(test_skill) - - # Verify the skill was added - new_count = len(list(self.skill_library)) - self.assertEqual(new_count, original_count + 1) - - # Check if the skill can be found by name - found = False - for skill in self.skill_library: - if skill.__name__ == "TestSkill": - found = True - break - self.assertTrue(found, "Added skill should be found in skill library") - - def test_skill_direct_execution(self): - """Test that a skill can be executed directly.""" - test_skill = TestSkill() - self.assertFalse(test_skill._called) - result = test_skill() - self.assertTrue(test_skill._called) - self.assertEqual(result, "TestSkill executed successfully") - - def test_skill_library_execution(self): - """Test that a skill can be executed through the skill library.""" - # Add our test skill to the library - test_skill = TestSkill - self.skill_library.add(test_skill) - - # Create an instance to confirm it was executed - with mock.patch.object(TestSkill, "__call__", return_value="Success") as mock_call: - result = self.skill_library.call("TestSkill") - mock_call.assert_called_once() - self.assertEqual(result, "Success") - - def test_skill_not_found(self): - """Test that calling a non-existent skill raises an appropriate error.""" - with self.assertRaises(ValueError): - self.skill_library.call("NonExistentSkill") - - -class SkillWithAgentTest(unittest.TestCase): - """Tests for skills used with an agent.""" - - def setUp(self): - """Set up test fixtures before each test method.""" - self.robot = MockRobot() - self.skill_library = MyUnitreeSkills(robot=self.robot) - self.skill_library.initialize_skills() - - # Add a test skill - self.skill_library.add(TestSkill) - - # Create the agent - self.agent = OpenAIAgent( - dev_name="SkillTestAgent", - system_query="You are a skill testing agent. When prompted to perform an action, use the appropriate skill.", - skills=self.skill_library, - ) - - @mock.patch("dimos.agents.agent.OpenAIAgent.run_observable_query") - def test_agent_skill_identification(self, mock_query): - """Test that the agent can identify skills based on natural language.""" - # Mock the agent response - mock_response = mock.MagicMock() - mock_response.run.return_value = "I found the TestSkill and executed it." - mock_query.return_value = mock_response - - # Run the test - response = self.agent.run_observable_query("Please run the test skill").run() - - # Assertions - mock_query.assert_called_once_with("Please run the test skill") - self.assertEqual(response, "I found the TestSkill and executed it.") - - @mock.patch.object(TestSkill, "__call__") - @mock.patch("dimos.agents.agent.OpenAIAgent.run_observable_query") - def test_agent_skill_execution(self, mock_query, mock_skill_call): - """Test that the agent can execute skills properly.""" - # Mock the agent and skill call - mock_skill_call.return_value = "TestSkill executed successfully" - mock_response = mock.MagicMock() - mock_response.run.return_value = "Executed TestSkill successfully." - mock_query.return_value = mock_response - - # Run the test - response = self.agent.run_observable_query("Execute the TestSkill skill").run() - - # We can't directly verify the skill was called since our mocking setup - # doesn't capture the internal skill execution of the agent, but we can - # verify the agent was properly called - mock_query.assert_called_once_with("Execute the TestSkill skill") - self.assertEqual(response, "Executed TestSkill successfully.") - - def test_agent_multi_skill_registration(self): - """Test that multiple skills can be registered with an agent.""" - - # Create a new skill - class AnotherTestSkill(AbstractSkill): - def __call__(self): - return "Another test skill executed" - - # Register the new skill - initial_count = len(list(self.skill_library)) - self.skill_library.add(AnotherTestSkill) - - # Verify two distinct skills now exist - self.assertEqual(len(list(self.skill_library)), initial_count + 1) - - # Verify both skills are found by name - skill_names = [skill.__name__ for skill in self.skill_library] - self.assertIn("TestSkill", skill_names) - self.assertIn("AnotherTestSkill", skill_names) - - -if __name__ == "__main__": - unittest.main() diff --git a/build/lib/tests/test_skills_rest.py b/build/lib/tests/test_skills_rest.py deleted file mode 100644 index 70a15fcfd5..0000000000 --- a/build/lib/tests/test_skills_rest.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -from textwrap import dedent -from dimos.skills.skills import SkillLibrary - -from dotenv import load_dotenv -from dimos.agents.claude_agent import ClaudeAgent -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.rest.rest import GenericRestSkill -import reactivex as rx -import reactivex.operators as ops - -# Load API key from environment -load_dotenv() - -# Create a skill library and add the GenericRestSkill -skills = SkillLibrary() -skills.add(GenericRestSkill) - -# Create a subject for agent responses -agent_response_subject = rx.subject.Subject() -agent_response_stream = agent_response_subject.pipe(ops.share()) - -# Create a text stream for agent responses in the web interface -text_streams = { - "agent_responses": agent_response_stream, -} -web_interface = RobotWebInterface(port=5555, text_streams=text_streams) - -# Create a ClaudeAgent instance -agent = ClaudeAgent( - dev_name="test_agent", - input_query_stream=web_interface.query_stream, - skills=skills, - system_query=dedent( - """ - You are a virtual agent. When given a query, respond by using - the appropriate tool calls if needed to execute commands on the robot. - - IMPORTANT: - Only return the response directly asked of the user. E.G. if the user asks for the time, - only return the time. If the user asks for the weather, only return the weather. - """ - ), - model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=2000, -) - -# Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - -# Start the web interface -web_interface.run() - -# Run this query in the web interface: -# -# Make a web request to nist to get the current time. -# You should use http://worldclockapi.com/api/json/utc/now -# diff --git a/build/lib/tests/test_spatial_memory.py b/build/lib/tests/test_spatial_memory.py deleted file mode 100644 index b400749cb4..0000000000 --- a/build/lib/tests/test_spatial_memory.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -import time -import pickle -import numpy as np -import cv2 -import matplotlib.pyplot as plt -from matplotlib.patches import Circle -import reactivex -from reactivex import operators as ops -import chromadb - -from dimos.agents.memory.visual_memory import VisualMemory - -import tests.test_header - -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.perception.spatial_perception import SpatialMemory - - -def extract_position(transform): - """Extract position coordinates from a transform message""" - if transform is None: - return (0, 0, 0) - - pos = transform.transform.translation - return (pos.x, pos.y, pos.z) - - -def setup_persistent_chroma_db(db_path="chromadb_data"): - """ - Set up a persistent ChromaDB database at the specified path. - - Args: - db_path: Path to store the ChromaDB database - - Returns: - The ChromaDB client instance - """ - # Create a persistent ChromaDB client - full_db_path = os.path.join("/home/stash/dimensional/dimos/assets/test_spatial_memory", db_path) - print(f"Setting up persistent ChromaDB at: {full_db_path}") - - # Ensure the directory exists - os.makedirs(full_db_path, exist_ok=True) - - return chromadb.PersistentClient(path=full_db_path) - - -def main(): - print("Starting spatial memory test...") - - # Initialize ROS control and robot - ros_control = UnitreeROSControl(node_name="spatial_memory_test", mock_connection=False) - - robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP")) - - # Create counters for tracking - frame_count = 0 - transform_count = 0 - stored_count = 0 - - print("Setting up video stream...") - video_stream = robot.get_ros_video_stream() - - # Create transform stream at 1 Hz - print("Setting up transform stream...") - transform_stream = ros_control.get_transform_stream( - child_frame="map", - parent_frame="base_link", - rate_hz=1.0, # 1 transform per second - ) - - # Setup output directory for visual memory - visual_memory_dir = "/home/stash/dimensional/dimos/assets/test_spatial_memory" - os.makedirs(visual_memory_dir, exist_ok=True) - - # Setup persistent storage path for visual memory - visual_memory_path = os.path.join(visual_memory_dir, "visual_memory.pkl") - - # Try to load existing visual memory if it exists - if os.path.exists(visual_memory_path): - try: - print(f"Loading existing visual memory from {visual_memory_path}...") - visual_memory = VisualMemory.load(visual_memory_path, output_dir=visual_memory_dir) - print(f"Loaded {visual_memory.count()} images from previous runs") - except Exception as e: - print(f"Error loading visual memory: {e}") - visual_memory = VisualMemory(output_dir=visual_memory_dir) - else: - print("No existing visual memory found. Starting with empty visual memory.") - visual_memory = VisualMemory(output_dir=visual_memory_dir) - - # Setup a persistent database for ChromaDB - db_client = setup_persistent_chroma_db() - - # Create spatial perception instance with persistent storage - print("Creating SpatialMemory with persistent vector database...") - spatial_memory = SpatialMemory( - collection_name="test_spatial_memory", - min_distance_threshold=1, # Store frames every 1 meter - min_time_threshold=1, # Store frames at least every 1 second - chroma_client=db_client, # Use the persistent client - visual_memory=visual_memory, # Use the visual memory we loaded or created - ) - - # Combine streams using combine_latest - # This will pair up items properly without buffering - combined_stream = reactivex.combine_latest(video_stream, transform_stream).pipe( - ops.map( - lambda pair: { - "frame": pair[0], # First element is the frame - "position": extract_position(pair[1]), # Second element is the transform - } - ) - ) - - # Process with spatial memory - result_stream = spatial_memory.process_stream(combined_stream) - - # Simple callback to track stored frames and save them to the assets directory - def on_stored_frame(result): - nonlocal stored_count - # Only count actually stored frames (not debug frames) - if not result.get("stored", True) == False: - stored_count += 1 - pos = result["position"] - print(f"\nStored frame #{stored_count} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})") - - # Save the frame to the assets directory - if "frame" in result: - frame_filename = f"/home/stash/dimensional/dimos/assets/test_spatial_memory/frame_{stored_count:03d}.jpg" - cv2.imwrite(frame_filename, result["frame"]) - print(f"Saved frame to {frame_filename}") - - # Subscribe to results - print("Subscribing to spatial perception results...") - result_subscription = result_stream.subscribe(on_stored_frame) - - print("\nRunning until interrupted...") - try: - while True: - time.sleep(1.0) - print(f"Running: {stored_count} frames stored so far", end="\r") - except KeyboardInterrupt: - print("\nTest interrupted by user") - finally: - # Clean up resources - print("\nCleaning up...") - if "result_subscription" in locals(): - result_subscription.dispose() - - # Visualize spatial memory with multiple object queries - visualize_spatial_memory_with_objects( - spatial_memory, - objects=[ - "kitchen", - "conference room", - "vacuum", - "office", - "bathroom", - "boxes", - "telephone booth", - ], - output_filename="spatial_memory_map.png", - ) - - # Save visual memory to disk for later use - saved_path = spatial_memory.vector_db.visual_memory.save("visual_memory.pkl") - print(f"Saved {spatial_memory.vector_db.visual_memory.count()} images to disk at {saved_path}") - - -def visualize_spatial_memory_with_objects( - spatial_memory, objects, output_filename="spatial_memory_map.png" -): - """ - Visualize a spatial memory map with multiple labeled objects. - - Args: - spatial_memory: SpatialMemory instance - objects: List of object names to query and visualize (e.g. ["kitchen", "office"]) - output_filename: Filename to save the visualization - """ - # Define colors for different objects - will cycle through these - colors = ["red", "green", "orange", "purple", "brown", "cyan", "magenta", "yellow"] - - # Get all stored locations for background - locations = spatial_memory.vector_db.get_all_locations() - if not locations: - print("No locations stored in spatial memory.") - return - - # Extract coordinates from all stored locations - if len(locations[0]) >= 3: - x_coords = [loc[0] for loc in locations] - y_coords = [loc[1] for loc in locations] - else: - x_coords, y_coords = zip(*locations) - - # Create figure - plt.figure(figsize=(12, 10)) - - # Plot all points in blue - plt.scatter(x_coords, y_coords, c="blue", s=50, alpha=0.5, label="All Frames") - - # Container for all object coordinates - object_coords = {} - - # Query for each object and store the result - for i, obj in enumerate(objects): - color = colors[i % len(colors)] # Cycle through colors - print(f"\nProcessing {obj} query for visualization...") - - # Get best match for this object - results = spatial_memory.query_by_text(obj, limit=1) - if not results: - print(f"No results found for '{obj}'") - continue - - # Get the first (best) result - result = results[0] - metadata = result["metadata"] - - # Extract coordinates from the first metadata item - if isinstance(metadata, list) and metadata: - metadata = metadata[0] - - if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: - x = metadata.get("x", 0) - y = metadata.get("y", 0) - - # Store coordinates for this object - object_coords[obj] = (x, y) - - # Plot this object's position - plt.scatter([x], [y], c=color, s=100, alpha=0.8, label=obj.title()) - - # Add annotation - obj_abbrev = obj[0].upper() if len(obj) > 0 else "X" - plt.annotate( - f"{obj_abbrev}", (x, y), textcoords="offset points", xytext=(0, 10), ha="center" - ) - - # Save the image to a file using the object name - if "image" in result and result["image"] is not None: - # Clean the object name to make it suitable for a filename - clean_name = obj.replace(" ", "_").lower() - output_img_filename = f"{clean_name}_result.jpg" - cv2.imwrite(output_img_filename, result["image"]) - print(f"Saved {obj} image to {output_img_filename}") - - # Finalize the plot - plt.title("Spatial Memory Map with Query Results") - plt.xlabel("X Position (m)") - plt.ylabel("Y Position (m)") - plt.grid(True) - plt.axis("equal") - plt.legend() - - # Add origin circle - plt.gca().add_patch(Circle((0, 0), 1.0, fill=False, color="blue", linestyle="--")) - - # Save the visualization - plt.savefig(output_filename, dpi=300) - print(f"Saved enhanced map visualization to {output_filename}") - - return object_coords - - # Final cleanup - print("Performing final cleanup...") - spatial_memory.cleanup() - - try: - robot.cleanup() - except Exception as e: - print(f"Error during robot cleanup: {e}") - - print("Test completed successfully") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_spatial_memory_query.py b/build/lib/tests/test_spatial_memory_query.py deleted file mode 100644 index a0e77e9444..0000000000 --- a/build/lib/tests/test_spatial_memory_query.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test script for querying an existing spatial memory database - -Usage: - python test_spatial_memory_query.py --query "kitchen table" --limit 5 --threshold 0.7 --save-all - python test_spatial_memory_query.py --query "robot" --limit 3 --save-one -""" - -import os -import sys -import argparse -import numpy as np -import cv2 -import matplotlib.pyplot as plt -import chromadb -from datetime import datetime - -import tests.test_header -from dimos.perception.spatial_perception import SpatialMemory -from dimos.agents.memory.visual_memory import VisualMemory - - -def setup_persistent_chroma_db(db_path): - """Set up a persistent ChromaDB client at the specified path.""" - print(f"Setting up persistent ChromaDB at: {db_path}") - os.makedirs(db_path, exist_ok=True) - return chromadb.PersistentClient(path=db_path) - - -def parse_args(): - """Parse command-line arguments.""" - parser = argparse.ArgumentParser(description="Query spatial memory database.") - parser.add_argument( - "--query", type=str, default=None, help="Text query to search for (e.g., 'kitchen table')" - ) - parser.add_argument("--limit", type=int, default=3, help="Maximum number of results to return") - parser.add_argument( - "--threshold", - type=float, - default=None, - help="Similarity threshold (0.0-1.0). Only return results above this threshold.", - ) - parser.add_argument("--save-all", action="store_true", help="Save all result images") - parser.add_argument("--save-one", action="store_true", help="Save only the best matching image") - parser.add_argument( - "--visualize", - action="store_true", - help="Create a visualization of all stored memory locations", - ) - parser.add_argument( - "--db-path", - type=str, - default="/home/stash/dimensional/dimos/assets/test_spatial_memory/chromadb_data", - help="Path to ChromaDB database", - ) - parser.add_argument( - "--visual-memory-path", - type=str, - default="/home/stash/dimensional/dimos/assets/test_spatial_memory/visual_memory.pkl", - help="Path to visual memory file", - ) - return parser.parse_args() - - -def main(): - args = parse_args() - print("Loading existing spatial memory database for querying...") - - # Setup the persistent ChromaDB client - db_client = setup_persistent_chroma_db(args.db_path) - - # Setup output directory for any saved results - output_dir = os.path.dirname(args.visual_memory_path) - - # Load the visual memory - print(f"Loading visual memory from {args.visual_memory_path}...") - if os.path.exists(args.visual_memory_path): - visual_memory = VisualMemory.load(args.visual_memory_path, output_dir=output_dir) - print(f"Loaded {visual_memory.count()} images from visual memory") - else: - visual_memory = VisualMemory(output_dir=output_dir) - print("No existing visual memory found. Query results won't include images.") - - # Create SpatialMemory with the existing database and visual memory - spatial_memory = SpatialMemory( - collection_name="test_spatial_memory", chroma_client=db_client, visual_memory=visual_memory - ) - - # Create a visualization if requested - if args.visualize: - print("\nCreating visualization of spatial memory...") - common_objects = [ - "kitchen", - "conference room", - "vacuum", - "office", - "bathroom", - "boxes", - "telephone booth", - ] - visualize_spatial_memory_with_objects( - spatial_memory, objects=common_objects, output_filename="spatial_memory_map.png" - ) - - # Handle query if provided - if args.query: - query = args.query - limit = args.limit - print(f"\nQuerying for: '{query}' (limit: {limit})...") - - # Run the query - results = spatial_memory.query_by_text(query, limit=limit) - - if not results: - print(f"No results found for query: '{query}'") - return - - # Filter by threshold if specified - if args.threshold is not None: - print(f"Filtering results with similarity threshold: {args.threshold}") - filtered_results = [] - for result in results: - # Distance is inverse of similarity (0 is perfect match) - # Convert to similarity score (1.0 is perfect match) - similarity = 1.0 - ( - result.get("distance", 0) if result.get("distance") is not None else 0 - ) - if similarity >= args.threshold: - filtered_results.append((result, similarity)) - - # Sort by similarity (highest first) - filtered_results.sort(key=lambda x: x[1], reverse=True) - - if not filtered_results: - print(f"No results met the similarity threshold of {args.threshold}") - return - - print(f"Found {len(filtered_results)} results above threshold") - results_with_scores = filtered_results - else: - # Add similarity scores for all results - results_with_scores = [] - for result in results: - similarity = 1.0 - ( - result.get("distance", 0) if result.get("distance") is not None else 0 - ) - results_with_scores.append((result, similarity)) - - # Process and display results - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - for i, (result, similarity) in enumerate(results_with_scores): - metadata = result.get("metadata", {}) - if isinstance(metadata, list) and metadata: - metadata = metadata[0] - - # Display result information - print(f"\nResult {i + 1} for '{query}':") - print(f"Similarity: {similarity:.4f} (distance: {1.0 - similarity:.4f})") - - # Extract and display position information - if isinstance(metadata, dict): - x = metadata.get("x", 0) - y = metadata.get("y", 0) - z = metadata.get("z", 0) - print(f"Position: ({x:.2f}, {y:.2f}, {z:.2f})") - if "timestamp" in metadata: - print(f"Timestamp: {metadata['timestamp']}") - if "frame_id" in metadata: - print(f"Frame ID: {metadata['frame_id']}") - - # Save image if requested and available - if "image" in result and result["image"] is not None: - # Only save first image, or all images based on flags - if args.save_one and i > 0: - continue - if not (args.save_all or args.save_one): - continue - - # Create a descriptive filename - clean_query = query.replace(" ", "_").replace("/", "_").lower() - output_filename = f"{clean_query}_result_{i + 1}_{timestamp}.jpg" - - # Save the image - cv2.imwrite(output_filename, result["image"]) - print(f"Saved image to {output_filename}") - elif "image" in result and result["image"] is None: - print("Image data not available for this result") - else: - print('No query specified. Use --query "text to search for" to run a query.') - print("Use --help to see all available options.") - - print("\nQuery completed successfully!") - - -def visualize_spatial_memory_with_objects( - spatial_memory, objects, output_filename="spatial_memory_map.png" -): - """Visualize spatial memory with labeled objects.""" - # Define colors for different objects - colors = ["red", "green", "orange", "purple", "brown", "cyan", "magenta", "yellow"] - - # Get all stored locations for background - locations = spatial_memory.vector_db.get_all_locations() - if not locations: - print("No locations stored in spatial memory.") - return - - # Extract coordinates - if len(locations[0]) >= 3: - x_coords = [loc[0] for loc in locations] - y_coords = [loc[1] for loc in locations] - else: - x_coords, y_coords = zip(*locations) - - # Create figure - plt.figure(figsize=(12, 10)) - plt.scatter(x_coords, y_coords, c="blue", s=50, alpha=0.5, label="All Frames") - - # Container for object coordinates - object_coords = {} - - # Query for each object - for i, obj in enumerate(objects): - color = colors[i % len(colors)] - print(f"Processing {obj} query for visualization...") - - # Get best match - results = spatial_memory.query_by_text(obj, limit=1) - if not results: - print(f"No results found for '{obj}'") - continue - - # Process result - result = results[0] - metadata = result["metadata"] - - if isinstance(metadata, list) and metadata: - metadata = metadata[0] - - if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: - x = metadata.get("x", 0) - y = metadata.get("y", 0) - - # Store coordinates - object_coords[obj] = (x, y) - - # Plot position - plt.scatter([x], [y], c=color, s=100, alpha=0.8, label=obj.title()) - - # Add annotation - obj_abbrev = obj[0].upper() if len(obj) > 0 else "X" - plt.annotate( - f"{obj_abbrev}", (x, y), textcoords="offset points", xytext=(0, 10), ha="center" - ) - - # Save image if available - if "image" in result and result["image"] is not None: - clean_name = obj.replace(" ", "_").lower() - output_img_filename = f"{clean_name}_result.jpg" - cv2.imwrite(output_img_filename, result["image"]) - print(f"Saved {obj} image to {output_img_filename}") - - # Finalize plot - plt.title("Spatial Memory Map with Query Results") - plt.xlabel("X Position (m)") - plt.ylabel("Y Position (m)") - plt.grid(True) - plt.axis("equal") - plt.legend() - - # Add origin marker - plt.gca().add_patch(plt.Circle((0, 0), 1.0, fill=False, color="blue", linestyle="--")) - - # Save visualization - plt.savefig(output_filename, dpi=300) - print(f"Saved visualization to {output_filename}") - - return object_coords - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_standalone_chromadb.py b/build/lib/tests/test_standalone_chromadb.py deleted file mode 100644 index a5dc0e9b73..0000000000 --- a/build/lib/tests/test_standalone_chromadb.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -# ----- - -import chromadb -from langchain_openai import OpenAIEmbeddings -from langchain_chroma import Chroma - -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") -if not OPENAI_API_KEY: - raise Exception("OpenAI key not specified.") - -collection_name = "my_collection" - -embeddings = OpenAIEmbeddings( - model="text-embedding-3-large", - dimensions=1024, - api_key=OPENAI_API_KEY, -) - -db_connection = Chroma( - collection_name=collection_name, - embedding_function=embeddings, -) - - -def add_vector(vector_id, vector_data): - """Add a vector to the ChromaDB collection.""" - if not db_connection: - raise Exception("Collection not initialized. Call connect() first.") - db_connection.add_texts( - ids=[vector_id], - texts=[vector_data], - metadatas=[{"name": vector_id}], - ) - - -add_vector("id0", "Food") -add_vector("id1", "Cat") -add_vector("id2", "Mouse") -add_vector("id3", "Bike") -add_vector("id4", "Dog") -add_vector("id5", "Tricycle") -add_vector("id6", "Car") -add_vector("id7", "Horse") -add_vector("id8", "Vehicle") -add_vector("id6", "Red") -add_vector("id7", "Orange") -add_vector("id8", "Yellow") - - -def get_vector(vector_id): - """Retrieve a vector from the ChromaDB by its identifier.""" - result = db_connection.get(include=["embeddings"], ids=[vector_id]) - return result - - -print(get_vector("id1")) -# print(get_vector("id3")) -# print(get_vector("id0")) -# print(get_vector("id2")) - - -def query(query_texts, n_results=2): - """Query the collection with a specific text and return up to n results.""" - if not db_connection: - raise Exception("Collection not initialized. Call connect() first.") - return db_connection.similarity_search(query=query_texts, k=n_results) - - -results = query("Colors") -print(results) diff --git a/build/lib/tests/test_standalone_fastapi.py b/build/lib/tests/test_standalone_fastapi.py deleted file mode 100644 index 6fac013546..0000000000 --- a/build/lib/tests/test_standalone_fastapi.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -import logging - -logging.basicConfig(level=logging.DEBUG) - -from fastapi import FastAPI, Response -import cv2 -import uvicorn -from starlette.responses import StreamingResponse - -app = FastAPI() - -# Note: Chrome does not allow for loading more than 6 simultaneous -# video streams. Use Safari or another browser for utilizing -# multiple simultaneous streams. Possibly build out functionality -# that will stop live streams. - - -@app.get("/") -async def root(): - pid = os.getpid() # Get the current process ID - return {"message": f"Video Streaming Server, PID: {pid}"} - - -def video_stream_generator(): - pid = os.getpid() - print(f"Stream initiated by worker with PID: {pid}") # Log the PID when the generator is called - - # Use the correct path for your video source - cap = cv2.VideoCapture( - f"{os.getcwd()}/assets/trimmed_video_480p.mov" - ) # Change 0 to a filepath for video files - - if not cap.isOpened(): - yield (b"--frame\r\nContent-Type: text/plain\r\n\r\n" + b"Could not open video source\r\n") - return - - try: - while True: - ret, frame = cap.read() - # If frame is read correctly ret is True - if not ret: - print(f"Reached the end of the video, restarting... PID: {pid}") - cap.set( - cv2.CAP_PROP_POS_FRAMES, 0 - ) # Set the position of the next video frame to 0 (the beginning) - continue - _, buffer = cv2.imencode(".jpg", frame) - yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + buffer.tobytes() + b"\r\n") - finally: - cap.release() - - -@app.get("/video") -async def video_endpoint(): - logging.debug("Attempting to open video stream.") - response = StreamingResponse( - video_stream_generator(), media_type="multipart/x-mixed-replace; boundary=frame" - ) - logging.debug("Streaming response set up.") - return response - - -if __name__ == "__main__": - uvicorn.run("__main__:app", host="0.0.0.0", port=5555, workers=20) diff --git a/build/lib/tests/test_standalone_hugging_face.py b/build/lib/tests/test_standalone_hugging_face.py deleted file mode 100644 index d0b2e68e61..0000000000 --- a/build/lib/tests/test_standalone_hugging_face.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -# from transformers import AutoModelForCausalLM, AutoTokenizer - -# model_name = "Qwen/QwQ-32B" - -# model = AutoModelForCausalLM.from_pretrained( -# model_name, -# torch_dtype="auto", -# device_map="auto" -# ) -# tokenizer = AutoTokenizer.from_pretrained(model_name) - -# prompt = "How many r's are in the word \"strawberry\"" -# messages = [ -# {"role": "user", "content": prompt} -# ] -# text = tokenizer.apply_chat_template( -# messages, -# tokenize=False, -# add_generation_prompt=True -# ) - -# model_inputs = tokenizer([text], return_tensors="pt").to(model.device) - -# generated_ids = model.generate( -# **model_inputs, -# max_new_tokens=32768 -# ) -# generated_ids = [ -# output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) -# ] - -# response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] -# print(response) - -# ----------------------------------------------------------------------------- - -# import requests -# import json - -# API_URL = "https://api-inference.huggingface.co/models/Qwen/QwQ-32B" -# api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') - -# HEADERS = {"Authorization": f"Bearer {api_key}"} - -# prompt = "How many r's are in the word \"strawberry\"" -# messages = [ -# {"role": "user", "content": prompt} -# ] - -# # Format the prompt in the desired chat format -# chat_template = ( -# f"{messages[0]['content']}\n" -# "Assistant:" -# ) - -# payload = { -# "inputs": chat_template, -# "parameters": { -# "max_new_tokens": 32768, -# "temperature": 0.7 -# } -# } - -# # API request -# response = requests.post(API_URL, headers=HEADERS, json=payload) - -# # Handle response -# if response.status_code == 200: -# output = response.json()[0]['generated_text'] -# print(output.strip()) -# else: -# print(f"Error {response.status_code}: {response.text}") - -# ----------------------------------------------------------------------------- - -# import os -# import requests -# import time - -# API_URL = "https://api-inference.huggingface.co/models/Qwen/QwQ-32B" -# api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') - -# HEADERS = {"Authorization": f"Bearer {api_key}"} - -# def query_with_retries(payload, max_retries=5, delay=15): -# for attempt in range(max_retries): -# response = requests.post(API_URL, headers=HEADERS, json=payload) -# if response.status_code == 200: -# return response.json()[0]['generated_text'] -# elif response.status_code == 500: # Service unavailable -# print(f"Attempt {attempt + 1}/{max_retries}: Model busy. Retrying in {delay} seconds...") -# time.sleep(delay) -# else: -# print(f"Error {response.status_code}: {response.text}") -# break -# return "Failed after multiple retries." - -# prompt = "How many r's are in the word \"strawberry\"" -# messages = [{"role": "user", "content": prompt}] -# chat_template = f"{messages[0]['content']}\nAssistant:" - -# payload = { -# "inputs": chat_template, -# "parameters": {"max_new_tokens": 32768, "temperature": 0.7} -# } - -# output = query_with_retries(payload) -# print(output.strip()) - -# ----------------------------------------------------------------------------- - -import os -from huggingface_hub import InferenceClient - -# Use environment variable for API key -api_key = os.getenv("HUGGINGFACE_ACCESS_TOKEN") - -client = InferenceClient( - provider="hf-inference", - api_key=api_key, -) - -messages = [{"role": "user", "content": 'How many r\'s are in the word "strawberry"'}] - -completion = client.chat.completions.create( - model="Qwen/QwQ-32B", - messages=messages, - max_tokens=150, -) - -print(completion.choices[0].message) diff --git a/build/lib/tests/test_standalone_openai_json.py b/build/lib/tests/test_standalone_openai_json.py deleted file mode 100644 index ef839ae85b..0000000000 --- a/build/lib/tests/test_standalone_openai_json.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -# ----- - -import dotenv - -dotenv.load_dotenv() - -import json -from textwrap import dedent -from openai import OpenAI -from pydantic import BaseModel - -MODEL = "gpt-4o-2024-08-06" - -math_tutor_prompt = """ - You are a helpful math tutor. You will be provided with a math problem, - and your goal will be to output a step by step solution, along with a final answer. - For each step, just provide the output as an equation use the explanation field to detail the reasoning. -""" - -bad_prompt = """ - Follow the instructions. -""" - -client = OpenAI() - - -class MathReasoning(BaseModel): - class Step(BaseModel): - explanation: str - output: str - - steps: list[Step] - final_answer: str - - -def get_math_solution(question: str): - completion = client.beta.chat.completions.parse( - model=MODEL, - messages=[ - {"role": "system", "content": dedent(bad_prompt)}, - {"role": "user", "content": question}, - ], - response_format=MathReasoning, - ) - return completion.choices[0].message - - -# Web Server -import http.server -import socketserver -import urllib.parse - -PORT = 5555 - - -class CustomHandler(http.server.SimpleHTTPRequestHandler): - def do_GET(self): - # Parse query parameters from the URL - parsed_path = urllib.parse.urlparse(self.path) - query_params = urllib.parse.parse_qs(parsed_path.query) - - # Check for a specific query parameter, e.g., 'problem' - problem = query_params.get("problem", [""])[ - 0 - ] # Default to an empty string if 'problem' isn't provided - - if problem: - print(f"Problem: {problem}") - solution = get_math_solution(problem) - - if solution.refusal: - print(f"Refusal: {solution.refusal}") - - print(f"Solution: {solution}") - self.send_response(200) - else: - solution = json.dumps( - {"error": "Please provide a math problem using the 'problem' query parameter."} - ) - self.send_response(400) - - self.send_header("Content-type", "application/json; charset=utf-8") - self.end_headers() - - # Write the message content - self.wfile.write(str(solution).encode()) - - -with socketserver.TCPServer(("", PORT), CustomHandler) as httpd: - print(f"Serving at port {PORT}") - httpd.serve_forever() diff --git a/build/lib/tests/test_standalone_openai_json_struct.py b/build/lib/tests/test_standalone_openai_json_struct.py deleted file mode 100644 index 1b49aed8a7..0000000000 --- a/build/lib/tests/test_standalone_openai_json_struct.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -# ----- - -from typing import List, Union, Dict - -import dotenv - -dotenv.load_dotenv() - -from textwrap import dedent -from openai import OpenAI -from pydantic import BaseModel - -MODEL = "gpt-4o-2024-08-06" - -math_tutor_prompt = """ - You are a helpful math tutor. You will be provided with a math problem, - and your goal will be to output a step by step solution, along with a final answer. - For each step, just provide the output as an equation use the explanation field to detail the reasoning. -""" - -general_prompt = """ - Follow the instructions. Output a step by step solution, along with a final answer. Use the explanation field to detail the reasoning. -""" - -client = OpenAI() - - -class MathReasoning(BaseModel): - class Step(BaseModel): - explanation: str - output: str - - steps: list[Step] - final_answer: str - - -def get_math_solution(question: str): - prompt = general_prompt - completion = client.beta.chat.completions.parse( - model=MODEL, - messages=[ - {"role": "system", "content": dedent(prompt)}, - {"role": "user", "content": question}, - ], - response_format=MathReasoning, - ) - return completion.choices[0].message - - -# Define Problem -problem = "What is the derivative of 3x^2" -print(f"Problem: {problem}") - -# Query for result -solution = get_math_solution(problem) - -# If the query was refused -if solution.refusal: - print(f"Refusal: {solution.refusal}") - exit() - -# If we were able to successfully parse the response back -parsed_solution = solution.parsed -if not parsed_solution: - print(f"Unable to Parse Solution") - exit() - -# Print solution from class definitions -print(f"Parsed: {parsed_solution}") - -steps = parsed_solution.steps -print(f"Steps: {steps}") - -final_answer = parsed_solution.final_answer -print(f"Final Answer: {final_answer}") diff --git a/build/lib/tests/test_standalone_openai_json_struct_func.py b/build/lib/tests/test_standalone_openai_json_struct_func.py deleted file mode 100644 index dcea40ffff..0000000000 --- a/build/lib/tests/test_standalone_openai_json_struct_func.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -# ----- - -from typing import List, Union, Dict - -import dotenv - -dotenv.load_dotenv() - -import json -import requests -from textwrap import dedent -from openai import OpenAI, pydantic_function_tool -from pydantic import BaseModel, Field - -MODEL = "gpt-4o-2024-08-06" - -math_tutor_prompt = """ - You are a helpful math tutor. You will be provided with a math problem, - and your goal will be to output a step by step solution, along with a final answer. - For each step, just provide the output as an equation use the explanation field to detail the reasoning. -""" - -general_prompt = """ - Follow the instructions. Output a step by step solution, along with a final answer. Use the explanation field to detail the reasoning. -""" - -client = OpenAI() - - -class MathReasoning(BaseModel): - class Step(BaseModel): - explanation: str - output: str - - steps: list[Step] - final_answer: str - - -# region Function Calling -class GetWeather(BaseModel): - latitude: str = Field(..., description="latitude e.g. Bogotá, Colombia") - longitude: str = Field(..., description="longitude e.g. Bogotá, Colombia") - - -def get_weather(latitude, longitude): - response = requests.get( - f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit" - ) - data = response.json() - return data["current"]["temperature_2m"] - - -def get_tools(): - return [pydantic_function_tool(GetWeather)] - - -tools = get_tools() - - -def call_function(name, args): - if name == "get_weather": - print(f"Running function: {name}") - print(f"Arguments are: {args}") - return get_weather(**args) - elif name == "GetWeather": - print(f"Running function: {name}") - print(f"Arguments are: {args}") - return get_weather(**args) - else: - return f"Local function not found: {name}" - - -def callback(message, messages, response_message, tool_calls): - if message is None or message.tool_calls is None: - print("No message or tools were called.") - return - - has_called_tools = False - for tool_call in message.tool_calls: - messages.append(response_message) - - has_called_tools = True - name = tool_call.function.name - args = json.loads(tool_call.function.arguments) - - result = call_function(name, args) - print(f"Function Call Results: {result}") - - messages.append( - {"role": "tool", "tool_call_id": tool_call.id, "content": str(result), "name": name} - ) - - # Complete the second call, after the functions have completed. - if has_called_tools: - print("Sending Second Query.") - completion_2 = client.beta.chat.completions.parse( - model=MODEL, - messages=messages, - response_format=MathReasoning, - tools=tools, - ) - print(f"Message: {completion_2.choices[0].message}") - return completion_2.choices[0].message - else: - print("No Need for Second Query.") - return None - - -# endregion Function Calling - - -def get_math_solution(question: str): - prompt = general_prompt - messages = [ - {"role": "system", "content": dedent(prompt)}, - {"role": "user", "content": question}, - ] - response = client.beta.chat.completions.parse( - model=MODEL, messages=messages, response_format=MathReasoning, tools=tools - ) - - response_message = response.choices[0].message - tool_calls = response_message.tool_calls - - new_response = callback(response.choices[0].message, messages, response_message, tool_calls) - - return new_response or response.choices[0].message - - -# Define Problem -problems = ["What is the derivative of 3x^2", "What's the weather like in San Fran today?"] -problem = problems[0] - -for problem in problems: - print("================") - print(f"Problem: {problem}") - - # Query for result - solution = get_math_solution(problem) - - # If the query was refused - if solution.refusal: - print(f"Refusal: {solution.refusal}") - break - - # If we were able to successfully parse the response back - parsed_solution = solution.parsed - if not parsed_solution: - print(f"Unable to Parse Solution") - print(f"Solution: {solution}") - break - - # Print solution from class definitions - print(f"Parsed: {parsed_solution}") - - steps = parsed_solution.steps - print(f"Steps: {steps}") - - final_answer = parsed_solution.final_answer - print(f"Final Answer: {final_answer}") diff --git a/build/lib/tests/test_standalone_openai_json_struct_func_playground.py b/build/lib/tests/test_standalone_openai_json_struct_func_playground.py deleted file mode 100644 index f4554de6be..0000000000 --- a/build/lib/tests/test_standalone_openai_json_struct_func_playground.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -# ----- -# # Milestone 1 - - -# from typing import List, Dict, Optional -# import requests -# import json -# from pydantic import BaseModel, Field -# from openai import OpenAI, pydantic_function_tool - -# # Environment setup -# import dotenv -# dotenv.load_dotenv() - -# # Constants and prompts -# MODEL = "gpt-4o-2024-08-06" -# GENERAL_PROMPT = ''' -# Follow the instructions. Output a step by step solution, along with a final answer. -# Use the explanation field to detail the reasoning. -# ''' - -# # Initialize OpenAI client -# client = OpenAI() - -# # Models and functions -# class Step(BaseModel): -# explanation: str -# output: str - -# class MathReasoning(BaseModel): -# steps: List[Step] -# final_answer: str - -# class GetWeather(BaseModel): -# latitude: str = Field(..., description="Latitude e.g., Bogotá, Colombia") -# longitude: str = Field(..., description="Longitude e.g., Bogotá, Colombia") - -# def fetch_weather(latitude: str, longitude: str) -> Dict: -# url = f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit" -# response = requests.get(url) -# return response.json().get('current', {}) - -# # Tool management -# def get_tools() -> List[BaseModel]: -# return [pydantic_function_tool(GetWeather)] - -# def handle_function_call(tool_call: Dict) -> Optional[str]: -# if tool_call['name'] == "get_weather": -# result = fetch_weather(**tool_call['args']) -# return f"Temperature is {result['temperature_2m']}°F" -# return None - -# # Communication and processing with OpenAI -# def process_message_with_openai(question: str) -> MathReasoning: -# messages = [ -# {"role": "system", "content": GENERAL_PROMPT.strip()}, -# {"role": "user", "content": question} -# ] -# response = client.beta.chat.completions.parse( -# model=MODEL, -# messages=messages, -# response_format=MathReasoning, -# tools=get_tools() -# ) -# return response.choices[0].message - -# def get_math_solution(question: str) -> MathReasoning: -# solution = process_message_with_openai(question) -# return solution - -# # Example usage -# def main(): -# problems = [ -# "What is the derivative of 3x^2", -# "What's the weather like in San Francisco today?" -# ] -# problem = problems[1] -# print(f"Problem: {problem}") - -# solution = get_math_solution(problem) -# if not solution: -# print("Failed to get a solution.") -# return - -# if not solution.parsed: -# print("Failed to get a parsed solution.") -# print(f"Solution: {solution}") -# return - -# print(f"Steps: {solution.parsed.steps}") -# print(f"Final Answer: {solution.parsed.final_answer}") - -# if __name__ == "__main__": -# main() - - -# # Milestone 1 - -# Milestone 2 -import json -import os -import requests - -from dotenv import load_dotenv - -load_dotenv() - -from openai import OpenAI - -client = OpenAI() - - -def get_current_weather(latitude, longitude): - """Get the current weather in a given latitude and longitude using the 7Timer API""" - base = "http://www.7timer.info/bin/api.pl" - request_url = f"{base}?lon={longitude}&lat={latitude}&product=civillight&output=json" - response = requests.get(request_url) - - # Parse response to extract the main weather data - weather_data = response.json() - current_data = weather_data.get("dataseries", [{}])[0] - - result = { - "latitude": latitude, - "longitude": longitude, - "temp": current_data.get("temp2m", {"max": "Unknown", "min": "Unknown"}), - "humidity": "Unknown", - } - - # Convert the dictionary to JSON string to match the given structure - return json.dumps(result) - - -def run_conversation(content): - messages = [{"role": "user", "content": content}] - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given latitude and longitude", - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "string", - "description": "The latitude of a place", - }, - "longitude": { - "type": "string", - "description": "The longitude of a place", - }, - }, - "required": ["latitude", "longitude"], - }, - }, - } - ] - response = client.chat.completions.create( - model="gpt-3.5-turbo-0125", - messages=messages, - tools=tools, - tool_choice="auto", - ) - response_message = response.choices[0].message - tool_calls = response_message.tool_calls - - if tool_calls: - messages.append(response_message) - - available_functions = { - "get_current_weather": get_current_weather, - } - for tool_call in tool_calls: - print(f"Function: {tool_call.function.name}") - print(f"Params:{tool_call.function.arguments}") - function_name = tool_call.function.name - function_to_call = available_functions[function_name] - function_args = json.loads(tool_call.function.arguments) - function_response = function_to_call( - latitude=function_args.get("latitude"), - longitude=function_args.get("longitude"), - ) - print(f"API: {function_response}") - messages.append( - { - "tool_call_id": tool_call.id, - "role": "tool", - "name": function_name, - "content": function_response, - } - ) - - second_response = client.chat.completions.create( - model="gpt-3.5-turbo-0125", messages=messages, stream=True - ) - return second_response - - -if __name__ == "__main__": - question = "What's the weather like in Paris and San Francisco?" - response = run_conversation(question) - for chunk in response: - print(chunk.choices[0].delta.content or "", end="", flush=True) -# Milestone 2 diff --git a/build/lib/tests/test_standalone_project_out.py b/build/lib/tests/test_standalone_project_out.py deleted file mode 100644 index 22aec63bae..0000000000 --- a/build/lib/tests/test_standalone_project_out.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import sys -import os - -# ----- - -import ast -import inspect -import types -import sys - - -def extract_function_info(filename): - with open(filename, "r") as f: - source = f.read() - tree = ast.parse(source, filename=filename) - - function_info = [] - - # Use a dictionary to track functions - module_globals = {} - - # Add the source to the locals (useful if you use local functions) - exec(source, module_globals) - - for node in ast.walk(tree): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - docstring = ast.get_docstring(node) or "" - - # Attempt to get the callable object from the globals - try: - if node.name in module_globals: - func_obj = module_globals[node.name] - signature = inspect.signature(func_obj) - function_info.append( - {"name": node.name, "signature": str(signature), "docstring": docstring} - ) - else: - function_info.append( - { - "name": node.name, - "signature": "Could not get signature", - "docstring": docstring, - } - ) - except TypeError as e: - print( - f"Could not get function signature for {node.name} in {filename}: {e}", - file=sys.stderr, - ) - function_info.append( - { - "name": node.name, - "signature": "Could not get signature", - "docstring": docstring, - } - ) - - class_info = [] - for node in ast.walk(tree): - if isinstance(node, ast.ClassDef): - docstring = ast.get_docstring(node) or "" - methods = [] - for method in node.body: - if isinstance(method, (ast.FunctionDef, ast.AsyncFunctionDef)): - method_docstring = ast.get_docstring(method) or "" - try: - if node.name in module_globals: - class_obj = module_globals[node.name] - method_obj = getattr(class_obj, method.name) - signature = inspect.signature(method_obj) - methods.append( - { - "name": method.name, - "signature": str(signature), - "docstring": method_docstring, - } - ) - else: - methods.append( - { - "name": method.name, - "signature": "Could not get signature", - "docstring": method_docstring, - } - ) - except AttributeError as e: - print( - f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", - file=sys.stderr, - ) - methods.append( - { - "name": method.name, - "signature": "Could not get signature", - "docstring": method_docstring, - } - ) - except TypeError as e: - print( - f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", - file=sys.stderr, - ) - methods.append( - { - "name": method.name, - "signature": "Could not get signature", - "docstring": method_docstring, - } - ) - class_info.append({"name": node.name, "docstring": docstring, "methods": methods}) - - return {"function_info": function_info, "class_info": class_info} - - -# Usage: -file_path = "./dimos/agents/memory/base.py" -extracted_info = extract_function_info(file_path) -print(extracted_info) - -file_path = "./dimos/agents/memory/chroma_impl.py" -extracted_info = extract_function_info(file_path) -print(extracted_info) - -file_path = "./dimos/agents/agent.py" -extracted_info = extract_function_info(file_path) -print(extracted_info) diff --git a/build/lib/tests/test_standalone_rxpy_01.py b/build/lib/tests/test_standalone_rxpy_01.py deleted file mode 100644 index 733930d430..0000000000 --- a/build/lib/tests/test_standalone_rxpy_01.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -# ----- - -import reactivex -from reactivex import operators as ops -from reactivex.scheduler import ThreadPoolScheduler -import multiprocessing -from threading import Event - -which_test = 2 -if which_test == 1: - """ - Test 1: Periodic Emission Test - - This test creates a ThreadPoolScheduler that leverages as many threads as there are CPU - cores available, optimizing the execution across multiple threads. The core functionality - revolves around an observable, secondly_emission, which emits a value every second. - Each emission is an incrementing integer, which is then mapped to a message indicating - the number of seconds since the test began. The sequence is limited to 30 emissions, - each logged as it occurs, and accompanied by an additional message via the - emission_process function to indicate the value's emission. The test subscribes to the - observable to print each emitted value, handle any potential errors, and confirm - completion of the emissions after 30 seconds. - - Key Components: - • ThreadPoolScheduler: Manages concurrency with multiple threads. - • Observable Sequence: Emits every second, indicating progression with a specific - message format. - • Subscription: Monitors and logs emissions, errors, and the completion event. - """ - - # Create a scheduler that uses as many threads as there are CPUs available - optimal_thread_count = multiprocessing.cpu_count() - pool_scheduler = ThreadPoolScheduler(optimal_thread_count) - - def emission_process(value): - print(f"Emitting: {value}") - - # Create an observable that emits every second - secondly_emission = reactivex.interval(1.0, scheduler=pool_scheduler).pipe( - ops.map(lambda x: f"Value {x} emitted after {x + 1} second(s)"), - ops.do_action(emission_process), - ops.take(30), # Limit the emission to 30 times - ) - - # Subscribe to the observable to start emitting - secondly_emission.subscribe( - on_next=lambda x: print(x), - on_error=lambda e: print(e), - on_completed=lambda: print("Emission completed."), - scheduler=pool_scheduler, - ) - -elif which_test == 2: - """ - Test 2: Combined Emission Test - - In this test, a similar ThreadPoolScheduler setup is used to handle tasks across multiple - CPU cores efficiently. This setup includes two observables. The first, secondly_emission, - emits an incrementing integer every second, indicating the passage of time. The second - observable, immediate_emission, emits a predefined sequence of characters (['a', 'b', - 'c', 'd', 'e']) repeatedly and immediately. These two streams are combined using the zip - operator, which synchronizes their emissions into pairs. Each combined pair is formatted - and logged, indicating both the time elapsed and the immediate value emitted at that - second. - - A synchronization mechanism via an Event (completed_event) ensures that the main program - thread waits until all planned emissions are completed before exiting. This test not only - checks the functionality of zipping different rhythmic emissions but also demonstrates - handling of asynchronous task completion in Python using event-driven programming. - - Key Components: - • Combined Observable Emissions: Synchronizes periodic and immediate emissions into - a single stream. - • Event Synchronization: Uses a threading event to manage program lifecycle and - ensure that all emissions are processed before shutdown. - • Complex Subscription Management: Handles errors and completion, including - setting an event to signal the end of task processing. - """ - - # Create a scheduler with optimal threads - optimal_thread_count = multiprocessing.cpu_count() - pool_scheduler = ThreadPoolScheduler(optimal_thread_count) - - # Define an event to wait for the observable to complete - completed_event = Event() - - def emission_process(value): - print(f"Emitting: {value}") - - # Observable that emits every second - secondly_emission = reactivex.interval(1.0, scheduler=pool_scheduler).pipe( - ops.map(lambda x: f"Second {x + 1}"), ops.take(30) - ) - - # Observable that emits values immediately and repeatedly - immediate_emission = reactivex.from_(["a", "b", "c", "d", "e"]).pipe(ops.repeat()) - - # Combine emissions using zip - combined_emissions = reactivex.zip(secondly_emission, immediate_emission).pipe( - ops.map(lambda combined: f"{combined[0]} - Value: {combined[1]}"), - ops.do_action(lambda s: print(f"Combined emission: {s}")), - ) - - # Subscribe to the combined emissions - combined_emissions.subscribe( - on_next=lambda x: print(x), - on_error=lambda e: print(f"Error: {e}"), - on_completed=lambda: { - print("Combined emission completed."), - completed_event.set(), # Set the event to signal completion - }, - scheduler=pool_scheduler, - ) - - # Wait for the observable to complete - completed_event.wait() diff --git a/build/lib/tests/test_unitree_agent.py b/build/lib/tests/test_unitree_agent.py deleted file mode 100644 index 34c5aa335d..0000000000 --- a/build/lib/tests/test_unitree_agent.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os -import time - -from dimos.web.fastapi_server import FastAPIServer - -print(f"Current working directory: {os.getcwd()}") - -# ----- - -from dimos.agents.agent import OpenAIAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.stream.data_provider import QueryDataProvider - -MOCK_CONNECTION = True - - -class UnitreeAgentDemo: - def __init__(self): - self.robot_ip = None - self.connection_method = None - self.serial_number = None - self.output_dir = None - self._fetch_env_vars() - - def _fetch_env_vars(self): - print("Fetching environment variables") - - def get_env_var(var_name, default=None, required=False): - """Get environment variable with validation.""" - value = os.getenv(var_name, default) - if required and not value: - raise ValueError(f"{var_name} environment variable is required") - return value - - self.robot_ip = get_env_var("ROBOT_IP", required=True) - self.connection_method = get_env_var("CONN_TYPE") - self.serial_number = get_env_var("SERIAL_NUMBER") - self.output_dir = get_env_var( - "ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros") - ) - - def _initialize_robot(self, with_video_stream=True): - print( - f"Initializing Unitree Robot {'with' if with_video_stream else 'without'} Video Stream" - ) - self.robot = UnitreeGo2( - ip=self.robot_ip, - connection_method=self.connection_method, - serial_number=self.serial_number, - output_dir=self.output_dir, - disable_video_stream=(not with_video_stream), - mock_connection=MOCK_CONNECTION, - ) - print(f"Robot initialized: {self.robot}") - - # ----- - - def run_with_queries(self): - # Initialize robot - self._initialize_robot(with_video_stream=False) - - # Initialize query stream - query_provider = QueryDataProvider() - - # Create the skills available to the agent. - # By default, this will create all skills in this class and make them available. - skills_instance = MyUnitreeSkills(robot=self.robot) - - print("Starting Unitree Perception Agent") - self.UnitreePerceptionAgent = OpenAIAgent( - dev_name="UnitreePerceptionAgent", - agent_type="Perception", - input_query_stream=query_provider.data_stream, - output_dir=self.output_dir, - skills=skills_instance, - # frame_processor=frame_processor, - ) - - # Start the query stream. - # Queries will be pushed every 1 second, in a count from 100 to 5000. - # This will cause listening agents to consume the queries and respond - # to them via skill execution and provide 1-shot responses. - query_provider.start_query_stream( - query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", - frequency=0.01, - start_count=1, - end_count=10000, - step=1, - ) - - def run_with_test_video(self): - # Initialize robot - self._initialize_robot(with_video_stream=False) - - # Initialize test video stream - from dimos.stream.video_provider import VideoProvider - - self.video_stream = VideoProvider( - dev_name="UnitreeGo2", video_source=f"{os.getcwd()}/assets/framecount.mp4" - ).capture_video_as_observable(realtime=False, fps=1) - - # Get Skills - # By default, this will create all skills in this class and make them available to the agent. - skills_instance = MyUnitreeSkills(robot=self.robot) - - print("Starting Unitree Perception Agent (Test Video)") - self.UnitreePerceptionAgent = OpenAIAgent( - dev_name="UnitreePerceptionAgent", - agent_type="Perception", - input_video_stream=self.video_stream, - output_dir=self.output_dir, - query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", - image_detail="high", - skills=skills_instance, - # frame_processor=frame_processor, - ) - - def run_with_ros_video(self): - # Initialize robot - self._initialize_robot() - - # Initialize ROS video stream - print("Starting Unitree Perception Stream") - self.video_stream = self.robot.get_ros_video_stream() - - # Get Skills - # By default, this will create all skills in this class and make them available to the agent. - skills_instance = MyUnitreeSkills(robot=self.robot) - - # Run recovery stand - print("Running recovery stand") - self.robot.webrtc_req(api_id=1006) - - # Wait for 1 second - time.sleep(1) - - # Switch to sport mode - print("Switching to sport mode") - self.robot.webrtc_req(api_id=1011, parameter='{"gait_type": "sport"}') - - # Wait for 1 second - time.sleep(1) - - print("Starting Unitree Perception Agent (ROS Video)") - self.UnitreePerceptionAgent = OpenAIAgent( - dev_name="UnitreePerceptionAgent", - agent_type="Perception", - input_video_stream=self.video_stream, - output_dir=self.output_dir, - query="Based on the image, execute the command seen in the image AND ONLY THE COMMAND IN THE IMAGE. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", - # WORKING MOVEMENT DEMO VVV - # query="Move() 5 meters foward. Then spin 360 degrees to the right, and then Reverse() 5 meters, and then Move forward 3 meters", - image_detail="high", - skills=skills_instance, - # frame_processor=frame_processor, - ) - - def run_with_multiple_query_and_test_video_agents(self): - # Initialize robot - self._initialize_robot(with_video_stream=False) - - # Initialize query stream - query_provider = QueryDataProvider() - - # Initialize test video stream - from dimos.stream.video_provider import VideoProvider - - self.video_stream = VideoProvider( - dev_name="UnitreeGo2", video_source=f"{os.getcwd()}/assets/framecount.mp4" - ).capture_video_as_observable(realtime=False, fps=1) - - # Create the skills available to the agent. - # By default, this will create all skills in this class and make them available. - skills_instance = MyUnitreeSkills(robot=self.robot) - - print("Starting Unitree Perception Agent") - self.UnitreeQueryPerceptionAgent = OpenAIAgent( - dev_name="UnitreeQueryPerceptionAgent", - agent_type="Perception", - input_query_stream=query_provider.data_stream, - output_dir=self.output_dir, - skills=skills_instance, - # frame_processor=frame_processor, - ) - - print("Starting Unitree Perception Agent Two") - self.UnitreeQueryPerceptionAgentTwo = OpenAIAgent( - dev_name="UnitreeQueryPerceptionAgentTwo", - agent_type="Perception", - input_query_stream=query_provider.data_stream, - output_dir=self.output_dir, - skills=skills_instance, - # frame_processor=frame_processor, - ) - - print("Starting Unitree Perception Agent (Test Video)") - self.UnitreeVideoPerceptionAgent = OpenAIAgent( - dev_name="UnitreeVideoPerceptionAgent", - agent_type="Perception", - input_video_stream=self.video_stream, - output_dir=self.output_dir, - query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", - image_detail="high", - skills=skills_instance, - # frame_processor=frame_processor, - ) - - print("Starting Unitree Perception Agent Two (Test Video)") - self.UnitreeVideoPerceptionAgentTwo = OpenAIAgent( - dev_name="UnitreeVideoPerceptionAgentTwo", - agent_type="Perception", - input_video_stream=self.video_stream, - output_dir=self.output_dir, - query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", - image_detail="high", - skills=skills_instance, - # frame_processor=frame_processor, - ) - - # Start the query stream. - # Queries will be pushed every 1 second, in a count from 100 to 5000. - # This will cause listening agents to consume the queries and respond - # to them via skill execution and provide 1-shot responses. - query_provider.start_query_stream( - query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", - frequency=0.01, - start_count=1, - end_count=10000000, - step=1, - ) - - def run_with_queries_and_fast_api(self): - # Initialize robot - self._initialize_robot(with_video_stream=True) - - # Initialize ROS video stream - print("Starting Unitree Perception Stream") - self.video_stream = self.robot.get_ros_video_stream() - - # Initialize test video stream - # from dimos.stream.video_provider import VideoProvider - # self.video_stream = VideoProvider( - # dev_name="UnitreeGo2", - # video_source=f"{os.getcwd()}/assets/framecount.mp4" - # ).capture_video_as_observable(realtime=False, fps=1) - - # Will be visible at http://[host]:[port]/video_feed/[key] - streams = { - "unitree_video": self.video_stream, - } - fast_api_server = FastAPIServer(port=5555, **streams) - - # Create the skills available to the agent. - skills_instance = MyUnitreeSkills(robot=self.robot) - - print("Starting Unitree Perception Agent") - self.UnitreeQueryPerceptionAgent = OpenAIAgent( - dev_name="UnitreeQueryPerceptionAgent", - agent_type="Perception", - input_query_stream=fast_api_server.query_stream, - output_dir=self.output_dir, - skills=skills_instance, - ) - - # Run the FastAPI server (this will block) - fast_api_server.run() - - # ----- - - def stop(self): - print("Stopping Unitree Agent") - self.robot.cleanup() - - -if __name__ == "__main__": - myUnitreeAgentDemo = UnitreeAgentDemo() - - test_to_run = 4 - - if test_to_run == 0: - myUnitreeAgentDemo.run_with_queries() - elif test_to_run == 1: - myUnitreeAgentDemo.run_with_test_video() - elif test_to_run == 2: - myUnitreeAgentDemo.run_with_ros_video() - elif test_to_run == 3: - myUnitreeAgentDemo.run_with_multiple_query_and_test_video_agents() - elif test_to_run == 4: - myUnitreeAgentDemo.run_with_queries_and_fast_api() - elif test_to_run < 0 or test_to_run >= 5: - assert False, f"Invalid test number: {test_to_run}" - - # Keep the program running to allow the Unitree Agent Demo to operate continuously - try: - print("\nRunning Unitree Agent Demo (Press Ctrl+C to stop)...") - while True: - time.sleep(0.1) - except KeyboardInterrupt: - print("\nStopping Unitree Agent Demo") - myUnitreeAgentDemo.stop() - except Exception as e: - print(f"Error in main loop: {e}") diff --git a/build/lib/tests/test_unitree_agent_queries_fastapi.py b/build/lib/tests/test_unitree_agent_queries_fastapi.py deleted file mode 100644 index be95ea5de6..0000000000 --- a/build/lib/tests/test_unitree_agent_queries_fastapi.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unitree Go2 robot agent demo with FastAPI server integration. - -Connects a Unitree Go2 robot to an OpenAI agent with a web interface. - -Environment Variables: - OPENAI_API_KEY: Required. OpenAI API key. - ROBOT_IP: Required. IP address of the Unitree robot. - CONN_TYPE: Required. Connection method to the robot. - ROS_OUTPUT_DIR: Optional. Directory for ROS output files. -""" - -import tests.test_header -import os -import sys -import reactivex as rx -import reactivex.operators as ops - -# Local application imports -from dimos.agents.agent import OpenAIAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.utils.logging_config import logger -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.web.fastapi_server import FastAPIServer - - -def main(): - # Get environment variables - robot_ip = os.getenv("ROBOT_IP") - if not robot_ip: - raise ValueError("ROBOT_IP environment variable is required") - connection_method = os.getenv("CONN_TYPE") or "webrtc" - output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) - - try: - # Initialize robot - logger.info("Initializing Unitree Robot") - robot = UnitreeGo2( - ip=robot_ip, - connection_method=connection_method, - output_dir=output_dir, - skills=MyUnitreeSkills(), - ) - - # Set up video stream - logger.info("Starting video stream") - video_stream = robot.get_ros_video_stream() - - # Create FastAPI server with video stream and text streams - logger.info("Initializing FastAPI server") - streams = {"unitree_video": video_stream} - - # Create a subject for agent responses - agent_response_subject = rx.subject.Subject() - agent_response_stream = agent_response_subject.pipe(ops.share()) - - text_streams = { - "agent_responses": agent_response_stream, - } - - web_interface = FastAPIServer(port=5555, text_streams=text_streams, **streams) - - logger.info("Starting action primitive execution agent") - agent = OpenAIAgent( - dev_name="UnitreeQueryExecutionAgent", - input_query_stream=web_interface.query_stream, - output_dir=output_dir, - skills=robot.get_skills(), - ) - - # Subscribe to agent responses and send them to the subject - agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - - # Start server (blocking call) - logger.info("Starting FastAPI server") - web_interface.run() - - except KeyboardInterrupt: - print("Stopping demo...") - except Exception as e: - logger.error(f"Error: {e}") - return 1 - finally: - if robot: - robot.cleanup() - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/build/lib/tests/test_unitree_ros_v0.0.4.py b/build/lib/tests/test_unitree_ros_v0.0.4.py deleted file mode 100644 index e4086074cc..0000000000 --- a/build/lib/tests/test_unitree_ros_v0.0.4.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header -import os - -import time -from dotenv import load_dotenv -from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal -from dimos.skills.visual_navigation_skills import FollowHuman -import reactivex as rx -import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -import threading -import json -from dimos.types.vector import Vector -from dimos.skills.speak import Speak -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.utils.reactive import backpressure - -# Load API key from environment -load_dotenv() - -# Allow command line arguments to control spatial memory parameters -import argparse - - -def parse_arguments(): - parser = argparse.ArgumentParser( - description="Run the robot with optional spatial memory parameters" - ) - parser.add_argument( - "--voice", - action="store_true", - help="Use voice input from microphone instead of web interface", - ) - return parser.parse_args() - - -args = parse_arguments() - -# Initialize robot with spatial memory parameters -robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - skills=MyUnitreeSkills(), - mock_connection=False, - new_memory=True, -) - -# Create a subject for agent responses -agent_response_subject = rx.subject.Subject() -agent_response_stream = agent_response_subject.pipe(ops.share()) -local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) - -# Initialize object detection stream -min_confidence = 0.6 -class_filter = None # No class filtering -detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) - -# Create video stream from robot's camera -video_stream = backpressure(robot.get_ros_video_stream()) - -# Initialize ObjectDetectionStream with robot -object_detector = ObjectDetectionStream( - camera_intrinsics=robot.camera_intrinsics, - min_confidence=min_confidence, - class_filter=class_filter, - transform_to_map=robot.ros_control.transform_pose, - detector=detector, - video_stream=video_stream, -) - -# Create visualization stream for web interface -viz_stream = backpressure(object_detector.get_stream()).pipe( - ops.share(), - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), -) - -# Get the formatted detection stream -formatted_detection_stream = object_detector.get_formatted_stream().pipe( - ops.filter(lambda x: x is not None) -) - - -# Create a direct mapping that combines detection data with locations -def combine_with_locations(object_detections): - # Get locations from spatial memory - try: - locations = robot.get_spatial_memory().get_robot_locations() - - # Format the locations section - locations_text = "\n\nSaved Robot Locations:\n" - if locations: - for loc in locations: - locations_text += f"- {loc.name}: Position ({loc.position[0]:.2f}, {loc.position[1]:.2f}, {loc.position[2]:.2f}), " - locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" - else: - locations_text += "None\n" - - # Simply concatenate the strings - return object_detections + locations_text - except Exception as e: - print(f"Error adding locations: {e}") - return object_detections - - -# Create the combined stream with a simple pipe operation -enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) - -streams = { - "unitree_video": robot.get_ros_video_stream(), - "local_planner_viz": local_planner_viz_stream, - "object_detection": viz_stream, -} -text_streams = { - "agent_responses": agent_response_stream, -} - -web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) - -stt_node = stt() - -# Read system query from prompt.txt file -with open( - os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt"), "r" -) as f: - system_query = f.read() - -# Create a ClaudeAgent instance with either voice input or web interface input based on flag -input_stream = stt_node.emit_text() if args.voice else web_interface.query_stream -print(f"Using {'voice input' if args.voice else 'web interface input'} for queries") - -agent = ClaudeAgent( - dev_name="test_agent", - input_query_stream=input_stream, - input_data_stream=enhanced_data_stream, # Add the enhanced data stream - skills=robot.get_skills(), - system_query=system_query, - model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=0, -) - -# Initialize TTS node only if voice flag is set -tts_node = None -if args.voice: - print("Voice mode: Enabling TTS for speech output") - tts_node = tts() - tts_node.consume_text(agent.get_response_observable()) -else: - print("Web interface mode: Disabling TTS to avoid audio issues") - -robot_skills = robot.get_skills() -robot_skills.add(ObserveStream) -robot_skills.add(KillSkill) -robot_skills.add(NavigateWithText) -robot_skills.add(FollowHuman) -robot_skills.add(GetPose) -# Add Speak skill only if voice flag is set -if args.voice: - robot_skills.add(Speak) -# robot_skills.add(NavigateToGoal) -robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) -robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) -robot_skills.create_instance("NavigateWithText", robot=robot) -robot_skills.create_instance("FollowHuman", robot=robot) -robot_skills.create_instance("GetPose", robot=robot) -# robot_skills.create_instance("NavigateToGoal", robot=robot) -# Create Speak skill instance only if voice flag is set -if args.voice: - robot_skills.create_instance("Speak", tts_node=tts_node) - -# Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - -print("ObserveStream and Kill skills registered and ready for use") -print("Created memory.txt file") - -web_interface.run() diff --git a/build/lib/tests/test_webrtc_queue.py b/build/lib/tests/test_webrtc_queue.py deleted file mode 100644 index 11408df145..0000000000 --- a/build/lib/tests/test_webrtc_queue.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tests.test_header - -import time -from dimos.robot.unitree.unitree_go2 import UnitreeGo2, WebRTCConnectionMethod -import os -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl - - -def main(): - """Test WebRTC request queue with a sequence of 20 back-to-back commands""" - - print("Initializing UnitreeGo2...") - - # Get configuration from environment variables - - robot_ip = os.getenv("ROBOT_IP") - connection_method = getattr(WebRTCConnectionMethod, os.getenv("CONNECTION_METHOD", "LocalSTA")) - - # Initialize ROS control - ros_control = UnitreeROSControl(node_name="unitree_go2_test", use_raw=True) - - # Initialize robot - robot = UnitreeGo2( - ip=robot_ip, - connection_method=connection_method, - ros_control=ros_control, - use_ros=True, - use_webrtc=False, # Using queue instead of direct WebRTC - ) - - # Wait for initialization - print("Waiting for robot to initialize...") - time.sleep(5) - - # First put the robot in a good starting state - print("Running recovery stand...") - robot.webrtc_req(api_id=1006) # RecoveryStand - - # Queue 20 WebRTC requests back-to-back - print("\n🤖 QUEUEING 20 COMMANDS BACK-TO-BACK 🤖\n") - - # Dance 1 - robot.webrtc_req(api_id=1022) # Dance1 - print("Queued: Dance1 (1022)") - - # Wiggle Hips - robot.webrtc_req(api_id=1033) # WiggleHips - print("Queued: WiggleHips (1033)") - - # Stretch - robot.webrtc_req(api_id=1017) # Stretch - print("Queued: Stretch (1017)") - - # Hello - robot.webrtc_req(api_id=1016) # Hello - print("Queued: Hello (1016)") - - # Dance 2 - robot.webrtc_req(api_id=1023) # Dance2 - print("Queued: Dance2 (1023)") - - # Wallow - robot.webrtc_req(api_id=1021) # Wallow - print("Queued: Wallow (1021)") - - # Scrape - robot.webrtc_req(api_id=1029) # Scrape - print("Queued: Scrape (1029)") - - # Finger Heart - robot.webrtc_req(api_id=1036) # FingerHeart - print("Queued: FingerHeart (1036)") - - # Recovery Stand (base position) - robot.webrtc_req(api_id=1006) # RecoveryStand - print("Queued: RecoveryStand (1006)") - - # Hello again - robot.webrtc_req(api_id=1016) # Hello - print("Queued: Hello (1016)") - - # Wiggle Hips again - robot.webrtc_req(api_id=1033) # WiggleHips - print("Queued: WiggleHips (1033)") - - # Front Pounce - robot.webrtc_req(api_id=1032) # FrontPounce - print("Queued: FrontPounce (1032)") - - # Dance 1 again - robot.webrtc_req(api_id=1022) # Dance1 - print("Queued: Dance1 (1022)") - - # Stretch again - robot.webrtc_req(api_id=1017) # Stretch - print("Queued: Stretch (1017)") - - # Front Jump - robot.webrtc_req(api_id=1031) # FrontJump - print("Queued: FrontJump (1031)") - - # Finger Heart again - robot.webrtc_req(api_id=1036) # FingerHeart - print("Queued: FingerHeart (1036)") - - # Scrape again - robot.webrtc_req(api_id=1029) # Scrape - print("Queued: Scrape (1029)") - - # Hello one more time - robot.webrtc_req(api_id=1016) # Hello - print("Queued: Hello (1016)") - - # Dance 2 again - robot.webrtc_req(api_id=1023) # Dance2 - print("Queued: Dance2 (1023)") - - # Finish with recovery stand - robot.webrtc_req(api_id=1006) # RecoveryStand - print("Queued: RecoveryStand (1006)") - - print("\nAll 20 commands queued successfully! Watch the robot perform them in sequence.") - print("The WebRTC queue manager will process them one by one when the robot is ready.") - print("Press Ctrl+C to stop the program when you've seen enough.\n") - - try: - # Keep the program running so the queue can be processed - while True: - time.sleep(1) - except KeyboardInterrupt: - print("\nStopping the test...") - finally: - # Cleanup - print("Cleaning up resources...") - robot.cleanup() - print("Test completed.") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_websocketvis.py b/build/lib/tests/test_websocketvis.py deleted file mode 100644 index a400bd9d14..0000000000 --- a/build/lib/tests/test_websocketvis.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import os -import time -import threading -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.web.websocket_vis.helpers import vector_stream -from dimos.robot.global_planner.planner import AstarPlanner -from dimos.types.costmap import Costmap -from dimos.types.vector import Vector -from reactivex import operators as ops -import argparse -import pickle -import reactivex as rx -from dimos.web.robot_web_interface import RobotWebInterface - - -def parse_args(): - parser = argparse.ArgumentParser(description="Simple test for vis.") - parser.add_argument( - "--live", - action="store_true", - ) - parser.add_argument( - "--port", type=int, default=5555, help="Port for web visualization interface" - ) - return parser.parse_args() - - -def setup_web_interface(robot, port=5555): - """Set up web interface with robot video and local planner visualization""" - print(f"Setting up web interface on port {port}") - - # Get video stream from robot - video_stream = robot.video_stream_ros.pipe( - ops.share(), - ops.map(lambda frame: frame), - ops.filter(lambda frame: frame is not None), - ) - - # Get local planner visualization stream - local_planner_stream = robot.local_planner_viz_stream.pipe( - ops.share(), - ops.map(lambda frame: frame), - ops.filter(lambda frame: frame is not None), - ) - - # Create web interface with streams - web_interface = RobotWebInterface( - port=port, robot_video=video_stream, local_planner=local_planner_stream - ) - - return web_interface - - -def main(): - args = parse_args() - - websocket_vis = WebsocketVis() - websocket_vis.start() - - web_interface = None - - if args.live: - ros_control = UnitreeROSControl(node_name="web_nav_test", mock_connection=False) - robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP")) - planner = robot.global_planner - - websocket_vis.connect( - vector_stream("robot", lambda: robot.ros_control.transform_euler_pos("base_link")) - ) - websocket_vis.connect( - robot.ros_control.topic("map", Costmap).pipe(ops.map(lambda x: ["costmap", x])) - ) - - # Also set up the web interface with both streams - if hasattr(robot, "video_stream_ros") and hasattr(robot, "local_planner_viz_stream"): - web_interface = setup_web_interface(robot, port=args.port) - - # Start web interface in a separate thread - viz_thread = threading.Thread(target=web_interface.run, daemon=True) - viz_thread.start() - print(f"Web interface available at http://localhost:{args.port}") - - else: - pickle_path = f"{__file__.rsplit('/', 1)[0]}/mockdata/vegas.pickle" - print(f"Loading costmap from {pickle_path}") - planner = AstarPlanner( - get_costmap=lambda: pickle.load(open(pickle_path, "rb")), - get_robot_pos=lambda: Vector(5.0, 5.0), - set_local_nav=lambda x: time.sleep(1) and True, - ) - - def msg_handler(msgtype, data): - if msgtype == "click": - target = Vector(data["position"]) - try: - planner.set_goal(target) - except Exception as e: - print(f"Error setting goal: {e}") - return - - def threaded_msg_handler(msgtype, data): - thread = threading.Thread(target=msg_handler, args=(msgtype, data)) - thread.daemon = True - thread.start() - - websocket_vis.connect(planner.vis_stream()) - websocket_vis.msg_handler = threaded_msg_handler - - print(f"WebSocket server started on port {websocket_vis.port}") - print(planner.get_costmap()) - - planner.plan(Vector(-4.8, -1.0)) # plan a path to the origin - - def fakepos(): - # Simulate a fake vector position change (to test realtime rendering) - vec = Vector(math.sin(time.time()) * 2, math.cos(time.time()) * 2, 0) - print(vec) - return vec - - # if not args.live: - # websocket_vis.connect(rx.interval(0.05).pipe(ops.map(lambda _: ["fakepos", fakepos()]))) - - try: - # Keep the server running - while True: - time.sleep(0.1) - pass - except KeyboardInterrupt: - print("Stopping WebSocket server...") - websocket_vis.stop() - print("WebSocket server stopped") - - -if __name__ == "__main__": - main() diff --git a/build/lib/tests/test_zed_setup.py b/build/lib/tests/test_zed_setup.py deleted file mode 100644 index ca50bb63fb..0000000000 --- a/build/lib/tests/test_zed_setup.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Simple test script to verify ZED camera setup and basic functionality. -""" - -import sys -from pathlib import Path - - -def test_imports(): - """Test that all required modules can be imported.""" - print("Testing imports...") - - try: - import numpy as np - - print("✓ NumPy imported successfully") - except ImportError as e: - print(f"✗ NumPy import failed: {e}") - return False - - try: - import cv2 - - print("✓ OpenCV imported successfully") - except ImportError as e: - print(f"✗ OpenCV import failed: {e}") - return False - - try: - from PIL import Image, ImageDraw, ImageFont - - print("✓ PIL imported successfully") - except ImportError as e: - print(f"✗ PIL import failed: {e}") - return False - - try: - import pyzed.sl as sl - - print("✓ ZED SDK (pyzed) imported successfully") - # Note: SDK version method varies between versions - except ImportError as e: - print(f"✗ ZED SDK import failed: {e}") - print(" Please install ZED SDK and pyzed package") - return False - - try: - from dimos.hardware.zed_camera import ZEDCamera - - print("✓ ZEDCamera class imported successfully") - except ImportError as e: - print(f"✗ ZEDCamera import failed: {e}") - return False - - try: - from dimos.perception.zed_visualizer import ZEDVisualizer - - print("✓ ZEDVisualizer class imported successfully") - except ImportError as e: - print(f"✗ ZEDVisualizer import failed: {e}") - return False - - return True - - -def test_camera_detection(): - """Test if ZED cameras are detected.""" - print("\nTesting camera detection...") - - try: - import pyzed.sl as sl - - # List available cameras - cameras = sl.Camera.get_device_list() - print(f"Found {len(cameras)} ZED camera(s):") - - for i, camera_info in enumerate(cameras): - print(f" Camera {i}:") - print(f" Model: {camera_info.camera_model}") - print(f" Serial: {camera_info.serial_number}") - print(f" State: {camera_info.camera_state}") - - return len(cameras) > 0 - - except Exception as e: - print(f"Error detecting cameras: {e}") - return False - - -def test_basic_functionality(): - """Test basic ZED camera functionality without actually opening the camera.""" - print("\nTesting basic functionality...") - - try: - import pyzed.sl as sl - from dimos.hardware.zed_camera import ZEDCamera - from dimos.perception.zed_visualizer import ZEDVisualizer - - # Test camera initialization (without opening) - camera = ZEDCamera( - camera_id=0, - resolution=sl.RESOLUTION.HD720, - depth_mode=sl.DEPTH_MODE.NEURAL, - ) - print("✓ ZEDCamera instance created successfully") - - # Test visualizer initialization - visualizer = ZEDVisualizer(max_depth=10.0) - print("✓ ZEDVisualizer instance created successfully") - - # Test creating a dummy visualization - dummy_rgb = np.zeros((480, 640, 3), dtype=np.uint8) - dummy_depth = np.ones((480, 640), dtype=np.float32) * 2.0 - - vis = visualizer.create_side_by_side_image(dummy_rgb, dummy_depth) - print("✓ Dummy visualization created successfully") - - return True - - except Exception as e: - print(f"✗ Basic functionality test failed: {e}") - return False - - -def main(): - """Run all tests.""" - print("ZED Camera Setup Test") - print("=" * 50) - - # Test imports - if not test_imports(): - print("\n❌ Import tests failed. Please install missing dependencies.") - return False - - # Test camera detection - cameras_found = test_camera_detection() - if not cameras_found: - print( - "\n⚠️ No ZED cameras detected. Please connect a ZED camera to test capture functionality." - ) - - # Test basic functionality - if not test_basic_functionality(): - print("\n❌ Basic functionality tests failed.") - return False - - print("\n" + "=" * 50) - if cameras_found: - print("✅ All tests passed! You can now run the ZED demo:") - print(" python examples/zed_neural_depth_demo.py --display-time 10") - else: - print("✅ Setup is ready, but no camera detected.") - print(" Connect a ZED camera and run:") - print(" python examples/zed_neural_depth_demo.py --display-time 10") - - return True - - -if __name__ == "__main__": - # Add the project root to Python path - sys.path.append(str(Path(__file__).parent)) - - # Import numpy after path setup - import numpy as np - - success = main() - sys.exit(0 if success else 1) diff --git a/build/lib/tests/visualization_script.py b/build/lib/tests/visualization_script.py deleted file mode 100644 index d0c4c6af84..0000000000 --- a/build/lib/tests/visualization_script.py +++ /dev/null @@ -1,1041 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Visualize pickled manipulation pipeline results.""" - -import os -import sys -import pickle -import numpy as np -import json -import matplotlib - -# Try to use TkAgg backend for live display, fallback to Agg if not available -try: - matplotlib.use("TkAgg") -except: - try: - matplotlib.use("Qt5Agg") - except: - matplotlib.use("Agg") # Fallback to non-interactive -import matplotlib.pyplot as plt -import open3d as o3d - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid -from dimos.perception.grasp_generation.utils import visualize_grasps_3d -from dimos.perception.pointcloud.utils import visualize_pcd -from dimos.utils.logging_config import setup_logger -import trimesh - -import tf_lcm_py -import cv2 -from contextlib import contextmanager -import lcm_msgs -from lcm_msgs.sensor_msgs import JointState, PointCloud2, CameraInfo, PointCloud2, PointField -from lcm_msgs.std_msgs import Header -from typing import List, Tuple, Optional -import atexit -from datetime import datetime -import time - -from pydrake.all import ( - AddMultibodyPlantSceneGraph, - CoulombFriction, - Diagram, - DiagramBuilder, - InverseKinematics, - MeshcatVisualizer, - MeshcatVisualizerParams, - MultibodyPlant, - Parser, - RigidTransform, - RollPitchYaw, - RotationMatrix, - JointIndex, - Solve, - StartMeshcat, -) -from pydrake.geometry import ( - CollisionFilterDeclaration, - Mesh, - ProximityProperties, - InMemoryMesh, - Box, - Cylinder, -) -from pydrake.math import RigidTransform as DrakeRigidTransform -from pydrake.common import MemoryFile - -from pydrake.all import ( - MinimumDistanceLowerBoundConstraint, - MultibodyPlant, - Parser, - DiagramBuilder, - AddMultibodyPlantSceneGraph, - MeshcatVisualizer, - StartMeshcat, - RigidTransform, - Role, - RollPitchYaw, - RotationMatrix, - Solve, - InverseKinematics, - MeshcatVisualizerParams, - MinimumDistanceLowerBoundConstraint, - DoDifferentialInverseKinematics, - DifferentialInverseKinematicsStatus, - DifferentialInverseKinematicsParameters, - DepthImageToPointCloud, -) -from manipulation.scenarios import AddMultibodyTriad -from manipulation.meshcat_utils import ( # TODO(russt): switch to pydrake version - _MeshcatPoseSliders, -) -from manipulation.scenarios import AddIiwa, AddShape, AddWsg - -logger = setup_logger("visualization_script") - - -def create_point_cloud(color_img, depth_img, intrinsics): - """Create Open3D point cloud from RGB and depth images.""" - fx, fy, cx, cy = intrinsics - height, width = depth_img.shape - - o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) - color_o3d = o3d.geometry.Image(color_img) - depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) - - rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( - color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False - ) - - return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) - - -def deserialize_point_cloud(data): - """Reconstruct Open3D PointCloud from serialized data.""" - if data is None: - return None - - pcd = o3d.geometry.PointCloud() - if "points" in data and data["points"]: - pcd.points = o3d.utility.Vector3dVector(np.array(data["points"])) - if "colors" in data and data["colors"]: - pcd.colors = o3d.utility.Vector3dVector(np.array(data["colors"])) - return pcd - - -def deserialize_voxel_grid(data): - """Reconstruct Open3D VoxelGrid from serialized data.""" - if data is None: - return None - - # Create a point cloud to convert to voxel grid - pcd = o3d.geometry.PointCloud() - voxel_size = data["voxel_size"] - origin = np.array(data["origin"]) - - # Create points from voxel indices - points = [] - colors = [] - for voxel in data["voxels"]: - # Each voxel is (i, j, k, r, g, b) - i, j, k, r, g, b = voxel - # Convert voxel grid index to 3D point - point = origin + np.array([i, j, k]) * voxel_size - points.append(point) - colors.append([r, g, b]) - - if points: - pcd.points = o3d.utility.Vector3dVector(np.array(points)) - pcd.colors = o3d.utility.Vector3dVector(np.array(colors)) - - # Convert to voxel grid - voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size) - return voxel_grid - - -def visualize_results(pickle_path="manipulation_results.pkl"): - """Load pickled results and visualize them.""" - print(f"Loading results from {pickle_path}...") - try: - with open(pickle_path, "rb") as f: - data = pickle.load(f) - - results = data["results"] - color_img = data["color_img"] - depth_img = data["depth_img"] - intrinsics = data["intrinsics"] - - print(f"Loaded results with keys: {list(results.keys())}") - - except FileNotFoundError: - print(f"Error: Pickle file {pickle_path} not found.") - print("Make sure to run test_manipulation_pipeline_single_frame_lcm.py first.") - return - except Exception as e: - print(f"Error loading pickle file: {e}") - return - - # Determine number of subplots based on what results we have - num_plots = 0 - plot_configs = [] - - if "detection_viz" in results and results["detection_viz"] is not None: - plot_configs.append(("detection_viz", "Object Detection")) - num_plots += 1 - - if "segmentation_viz" in results and results["segmentation_viz"] is not None: - plot_configs.append(("segmentation_viz", "Semantic Segmentation")) - num_plots += 1 - - if "pointcloud_viz" in results and results["pointcloud_viz"] is not None: - plot_configs.append(("pointcloud_viz", "All Objects Point Cloud")) - num_plots += 1 - - if "detected_pointcloud_viz" in results and results["detected_pointcloud_viz"] is not None: - plot_configs.append(("detected_pointcloud_viz", "Detection Objects Point Cloud")) - num_plots += 1 - - if "misc_pointcloud_viz" in results and results["misc_pointcloud_viz"] is not None: - plot_configs.append(("misc_pointcloud_viz", "Misc/Background Points")) - num_plots += 1 - - if "grasp_overlay" in results and results["grasp_overlay"] is not None: - plot_configs.append(("grasp_overlay", "Grasp Overlay")) - num_plots += 1 - - if num_plots == 0: - print("No visualization results to display") - return - - # Create subplot layout - if num_plots <= 3: - fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5)) - else: - rows = 2 - cols = (num_plots + 1) // 2 - fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) - - # Ensure axes is always a list for consistent indexing - if num_plots == 1: - axes = [axes] - elif num_plots > 2: - axes = axes.flatten() - - # Plot each result - for i, (key, title) in enumerate(plot_configs): - axes[i].imshow(results[key]) - axes[i].set_title(title) - axes[i].axis("off") - - # Hide unused subplots if any - if num_plots > 3: - for i in range(num_plots, len(axes)): - axes[i].axis("off") - - plt.tight_layout() - - # Save and show the plot - output_path = "visualization_results.png" - plt.savefig(output_path, dpi=150, bbox_inches="tight") - print(f"Results visualization saved to: {output_path}") - - # Show plot live as well - plt.show(block=True) - plt.close() - - # Deserialize and reconstruct 3D objects from the pickle file - print("\nReconstructing 3D visualization objects from serialized data...") - - # Reconstruct full point cloud if available - full_pcd = None - if "full_pointcloud" in results and results["full_pointcloud"] is not None: - full_pcd = deserialize_point_cloud(results["full_pointcloud"]) - print(f"Reconstructed full point cloud with {len(np.asarray(full_pcd.points))} points") - - # Visualize reconstructed full point cloud - try: - visualize_pcd( - full_pcd, - window_name="Reconstructed Full Scene Point Cloud", - point_size=2.0, - show_coordinate_frame=True, - ) - except (KeyboardInterrupt, EOFError): - print("\nSkipping full point cloud visualization") - except Exception as e: - print(f"Error in point cloud visualization: {e}") - else: - print("No full point cloud available for visualization") - - # Reconstruct misc clusters if available - if "misc_clusters" in results and results["misc_clusters"]: - misc_clusters = [deserialize_point_cloud(cluster) for cluster in results["misc_clusters"]] - cluster_count = len(misc_clusters) - total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in misc_clusters) - print(f"Reconstructed {cluster_count} misc clusters with {total_misc_points} total points") - - # Visualize reconstructed misc clusters - try: - visualize_clustered_point_clouds( - misc_clusters, - window_name="Reconstructed Misc/Background Clusters (DBSCAN)", - point_size=3.0, - show_coordinate_frame=True, - ) - except (KeyboardInterrupt, EOFError): - print("\nSkipping misc clusters visualization") - except Exception as e: - print(f"Error in misc clusters visualization: {e}") - else: - print("No misc clusters available for visualization") - - # Reconstruct voxel grid if available - if "misc_voxel_grid" in results and results["misc_voxel_grid"] is not None: - misc_voxel_grid = deserialize_voxel_grid(results["misc_voxel_grid"]) - if misc_voxel_grid: - voxel_count = len(misc_voxel_grid.get_voxels()) - print(f"Reconstructed voxel grid with {voxel_count} voxels") - - # Visualize reconstructed voxel grid - try: - visualize_voxel_grid( - misc_voxel_grid, - window_name="Reconstructed Misc/Background Voxel Grid", - show_coordinate_frame=True, - ) - except (KeyboardInterrupt, EOFError): - print("\nSkipping voxel grid visualization") - except Exception as e: - print(f"Error in voxel grid visualization: {e}") - else: - print("Failed to reconstruct voxel grid") - else: - print("No voxel grid available for visualization") - - -class DrakeKinematicsEnv: - def __init__( - self, - urdf_path: str, - kinematic_chain_joints: List[str], - links_to_ignore: Optional[List[str]] = None, - ): - self._resources_to_cleanup = [] - - # Register cleanup at exit - atexit.register(self.cleanup_resources) - - # Initialize tf resources once and reuse them - self.buffer = tf_lcm_py.Buffer(30.0) - self._resources_to_cleanup.append(self.buffer) - with self.safe_lcm_instance() as lcm_instance: - self.tf_lcm_instance = lcm_instance - self._resources_to_cleanup.append(self.tf_lcm_instance) - # Create TransformListener with our LCM instance and buffer - self.listener = tf_lcm_py.TransformListener(self.tf_lcm_instance, self.buffer) - self._resources_to_cleanup.append(self.listener) - - # Check if URDF file exists - if not os.path.exists(urdf_path): - raise FileNotFoundError(f"URDF file not found: {urdf_path}") - - # Drake utils initialization - self.meshcat = StartMeshcat() - print(f"Meshcat started at: {self.meshcat.web_url()}") - - self.urdf_path = urdf_path - self.builder = DiagramBuilder() - - self.plant, self.scene_graph = AddMultibodyPlantSceneGraph(self.builder, time_step=0.01) - self.parser = Parser(self.plant) - - # Load the robot URDF - print(f"Loading URDF from: {self.urdf_path}") - self.model_instances = self.parser.AddModelsFromUrl(f"file://{self.urdf_path}") - self.kinematic_chain_joints = kinematic_chain_joints - self.model_instance = self.model_instances[0] if self.model_instances else None - - if not self.model_instances: - raise RuntimeError("Failed to load any model instances from URDF") - - print(f"Loaded {len(self.model_instances)} model instances") - - # Set up collision filtering - if links_to_ignore: - bodies = [] - for link_name in links_to_ignore: - try: - body = self.plant.GetBodyByName(link_name) - if body is not None: - bodies.extend(self.plant.GetBodiesWeldedTo(body)) - except RuntimeError: - print(f"Warning: Link '{link_name}' not found in URDF") - - if bodies: - arm_geoms = self.plant.CollectRegisteredGeometries(bodies) - decl = CollisionFilterDeclaration().ExcludeWithin(arm_geoms) - manager = self.scene_graph.collision_filter_manager() - manager.Apply(decl) - - # Load and process point cloud data - self._load_and_process_point_clouds() - - # Finalize the plant before adding visualizer - self.plant.Finalize() - - # Print some debug info about the plant - print(f"Plant has {self.plant.num_bodies()} bodies") - print(f"Plant has {self.plant.num_joints()} joints") - for i in range(self.plant.num_joints()): - joint = self.plant.get_joint(JointIndex(i)) - print(f" Joint {i}: {joint.name()} (type: {joint.type_name()})") - - # Add visualizer - self.visualizer = MeshcatVisualizer.AddToBuilder( - self.builder, self.scene_graph, self.meshcat, params=MeshcatVisualizerParams() - ) - - # Build the diagram - self.diagram = self.builder.Build() - self.diagram_context = self.diagram.CreateDefaultContext() - self.plant_context = self.plant.GetMyContextFromRoot(self.diagram_context) - - # Set up joint indices - self.joint_indices = [] - for joint_name in self.kinematic_chain_joints: - try: - joint = self.plant.GetJointByName(joint_name) - if joint.num_positions() > 0: - start_index = joint.position_start() - for i in range(joint.num_positions()): - self.joint_indices.append(start_index + i) - print( - f"Added joint '{joint_name}' at indices {start_index} to {start_index + joint.num_positions() - 1}" - ) - except RuntimeError: - print(f"Warning: Joint '{joint_name}' not found in URDF.") - - # Get important frames/bodies - try: - self.end_effector_link = self.plant.GetBodyByName("link6") - self.end_effector_frame = self.plant.GetFrameByName("link6") - print("Found end effector link6") - except RuntimeError: - print("Warning: link6 not found") - self.end_effector_link = None - self.end_effector_frame = None - - try: - self.camera_link = self.plant.GetBodyByName("camera_center_link") - print("Found camera_center_link") - except RuntimeError: - print("Warning: camera_center_link not found") - self.camera_link = None - - # Set robot to a reasonable initial configuration - self._set_initial_configuration() - - # Force initial visualization update - self._update_visualization() - - print("Drake environment initialization complete!") - print(f"Visit {self.meshcat.web_url()} to see the visualization") - - def _load_and_process_point_clouds(self): - """Load point cloud data from pickle file and add to scene""" - pickle_path = "manipulation_results.pkl" - try: - with open(pickle_path, "rb") as f: - data = pickle.load(f) - - results = data["results"] - print(f"Loaded results with keys: {list(results.keys())}") - - except FileNotFoundError: - print(f"Warning: Pickle file {pickle_path} not found.") - print("Skipping point cloud loading.") - return - except Exception as e: - print(f"Warning: Error loading pickle file: {e}") - return - - full_detected_pcd = o3d.geometry.PointCloud() - for obj in results["detected_objects"]: - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(obj["point_cloud_numpy"]) - full_detected_pcd += pcd - - self.process_and_add_object_class("all_objects", results) - self.process_and_add_object_class("misc_clusters", results) - misc_clusters = results["misc_clusters"] - print(type(misc_clusters[0]["points"])) - print(np.asarray(misc_clusters[0]["points"]).shape) - - def process_and_add_object_class(self, object_key: str, results: dict): - # Process detected objects - if object_key in results: - detected_objects = results[object_key] - if detected_objects: - print(f"Processing {len(detected_objects)} {object_key}") - all_decomposed_meshes = [] - - transform = self.get_transform("world", "camera_center_link") - for i in range(len(detected_objects)): - try: - if object_key == "misc_clusters": - points = np.asarray(detected_objects[i]["points"]) - elif "point_cloud_numpy" in detected_objects[i]: - points = detected_objects[i]["point_cloud_numpy"] - elif ( - "point_cloud" in detected_objects[i] - and detected_objects[i]["point_cloud"] - ): - # Handle serialized point cloud - points = np.array(detected_objects[i]["point_cloud"]["points"]) - else: - print(f"Warning: No point cloud data found for object {i}") - continue - - if len(points) < 10: # Need more points for mesh reconstruction - print( - f"Warning: Object {i} has too few points ({len(points)}) for mesh reconstruction" - ) - continue - - # Swap y-z axes since this is a common problem - points = np.column_stack((points[:, 0], points[:, 2], -points[:, 1])) - # Transform points to world frame - points = self.transform_point_cloud_with_open3d(points, transform) - - # Use fast DBSCAN clustering + convex hulls approach - clustered_hulls = self._create_clustered_convex_hulls(points, i) - all_decomposed_meshes.extend(clustered_hulls) - - print( - f"Created {len(clustered_hulls)} clustered convex hulls for object {i}" - ) - - except Exception as e: - print(f"Warning: Failed to process object {i}: {e}") - - if all_decomposed_meshes: - self.register_convex_hulls_as_collision(all_decomposed_meshes, object_key) - print(f"Registered {len(all_decomposed_meshes)} total clustered convex hulls") - else: - print("Warning: No valid clustered convex hulls created from detected objects") - else: - print("No detected objects found") - - def _create_clustered_convex_hulls( - self, points: np.ndarray, object_id: int - ) -> List[o3d.geometry.TriangleMesh]: - """ - Create convex hulls from DBSCAN clusters of point cloud data. - Fast approach: cluster points, then convex hull each cluster. - - Args: - points: Nx3 numpy array of 3D points - object_id: ID for debugging/logging - - Returns: - List of Open3D triangle meshes (convex hulls of clusters) - """ - try: - # Create Open3D point cloud - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(points) - - # Quick outlier removal (optional, can skip for speed) - if len(points) > 50: # Only for larger point clouds - pcd, _ = pcd.remove_statistical_outlier(nb_neighbors=10, std_ratio=2.0) - points = np.asarray(pcd.points) - - if len(points) < 4: - print(f"Warning: Too few points after filtering for object {object_id}") - return [] - - # Try multiple DBSCAN parameter combinations to find clusters - clusters = [] - labels = None - - # Calculate some basic statistics for parameter estimation - if len(points) > 10: - # Compute nearest neighbor distances for better eps estimation - distances = pcd.compute_nearest_neighbor_distance() - avg_nn_distance = np.mean(distances) - std_nn_distance = np.std(distances) - - print( - f"Object {object_id}: {len(points)} points, avg_nn_dist={avg_nn_distance:.4f}" - ) - - for i in range(20): - try: - eps = avg_nn_distance * (2.0 + (i * 0.1)) - min_samples = 20 - labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_samples)) - unique_labels = np.unique(labels) - clusters = unique_labels[unique_labels >= 0] # Remove noise label (-1) - - noise_points = np.sum(labels == -1) - clustered_points = len(points) - noise_points - - print( - f" Try {i + 1}: eps={eps:.4f}, min_samples={min_samples} → {len(clusters)} clusters, {clustered_points}/{len(points)} points clustered" - ) - - # Accept if we found clusters and most points are clustered - if ( - len(clusters) > 0 and clustered_points >= len(points) * 0.95 - ): # At least 30% of points clustered - print(f" ✓ Accepted parameter set {i + 1}") - break - - except Exception as e: - print( - f" Try {i + 1}: Failed with eps={eps:.4f}, min_samples={min_samples}: {e}" - ) - continue - - if len(clusters) == 0 or labels is None: - print( - f"No clusters found for object {object_id} after all attempts, using entire point cloud" - ) - # Fallback: use entire point cloud as single convex hull - hull_mesh, _ = pcd.compute_convex_hull() - hull_mesh.compute_vertex_normals() - return [hull_mesh] - - print( - f"Found {len(clusters)} clusters for object {object_id} (eps={eps:.3f}, min_samples={min_samples})" - ) - - # Create convex hull for each cluster - convex_hulls = [] - for cluster_id in clusters: - try: - # Get points for this cluster - cluster_mask = labels == cluster_id - cluster_points = points[cluster_mask] - - if len(cluster_points) < 4: - print( - f"Skipping cluster {cluster_id} with only {len(cluster_points)} points" - ) - continue - - # Create point cloud for this cluster - cluster_pcd = o3d.geometry.PointCloud() - cluster_pcd.points = o3d.utility.Vector3dVector(cluster_points) - - # Compute convex hull - hull_mesh, _ = cluster_pcd.compute_convex_hull() - hull_mesh.compute_vertex_normals() - - # Validate hull - if ( - len(np.asarray(hull_mesh.vertices)) >= 4 - and len(np.asarray(hull_mesh.triangles)) >= 4 - ): - convex_hulls.append(hull_mesh) - print( - f" Cluster {cluster_id}: {len(cluster_points)} points → convex hull with {len(np.asarray(hull_mesh.vertices))} vertices" - ) - else: - print(f" Skipping degenerate hull for cluster {cluster_id}") - - except Exception as e: - print(f"Error processing cluster {cluster_id} for object {object_id}: {e}") - - if not convex_hulls: - print( - f"No valid convex hulls created for object {object_id}, using entire point cloud" - ) - # Fallback: use entire point cloud as single convex hull - hull_mesh, _ = pcd.compute_convex_hull() - hull_mesh.compute_vertex_normals() - return [hull_mesh] - - return convex_hulls - - except Exception as e: - print(f"Error in DBSCAN clustering for object {object_id}: {e}") - # Final fallback: single convex hull - try: - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(points) - hull_mesh, _ = pcd.compute_convex_hull() - hull_mesh.compute_vertex_normals() - return [hull_mesh] - except: - return [] - - def _set_initial_configuration(self): - """Set the robot to a reasonable initial joint configuration""" - # Set all joints to zero initially - if self.joint_indices: - q = np.zeros(len(self.joint_indices)) - - # You can customize these values for a better initial pose - # For example, if you know good default joint angles: - if len(q) >= 6: # Assuming at least 6 DOF arm - q[1] = 0.0 # joint1 - q[2] = 0.0 # joint2 - q[3] = 0.0 # joint3 - q[4] = 0.0 # joint4 - q[5] = 0.0 # joint5 - q[6] = 0.0 # joint6 - - # Set the joint positions in the plant context - positions = self.plant.GetPositions(self.plant_context) - for i, joint_idx in enumerate(self.joint_indices): - if joint_idx < len(positions): - positions[joint_idx] = q[i] - - self.plant.SetPositions(self.plant_context, positions) - print(f"Set initial joint configuration: {q}") - else: - print("Warning: No joint indices found, using default configuration") - - def _update_visualization(self): - """Force update the visualization""" - try: - # Get the visualizer's context from the diagram context - visualizer_context = self.visualizer.GetMyContextFromRoot(self.diagram_context) - self.visualizer.ForcedPublish(visualizer_context) - print("Visualization updated successfully") - except Exception as e: - print(f"Error updating visualization: {e}") - - def set_joint_positions(self, joint_positions): - """Set specific joint positions and update visualization""" - if len(joint_positions) != len(self.joint_indices): - raise ValueError( - f"Expected {len(self.joint_indices)} joint positions, got {len(joint_positions)}" - ) - - positions = self.plant.GetPositions(self.plant_context) - for i, joint_idx in enumerate(self.joint_indices): - if joint_idx < len(positions): - positions[joint_idx] = joint_positions[i] - - self.plant.SetPositions(self.plant_context, positions) - self._update_visualization() - print(f"Updated joint positions: {joint_positions}") - - def register_convex_hulls_as_collision( - self, meshes: List[o3d.geometry.TriangleMesh], hull_type: str - ): - """Register convex hulls as collision and visual geometry""" - if not meshes: - print("No meshes to register") - return - - world = self.plant.world_body() - proximity = ProximityProperties() - - for i, mesh in enumerate(meshes): - try: - # Convert Open3D → numpy arrays → trimesh.Trimesh - vertices = np.asarray(mesh.vertices) - faces = np.asarray(mesh.triangles) - - if len(vertices) == 0 or len(faces) == 0: - print(f"Warning: Mesh {i} is empty, skipping") - continue - - tmesh = trimesh.Trimesh(vertices=vertices, faces=faces) - - # Export to OBJ in memory - tmesh_obj_blob = tmesh.export(file_type="obj") - mem_file = MemoryFile( - contents=tmesh_obj_blob, extension=".obj", filename_hint=f"convex_hull_{i}.obj" - ) - in_memory_mesh = InMemoryMesh() - in_memory_mesh.mesh_file = mem_file - drake_mesh = Mesh(in_memory_mesh, scale=1.0) - - pos = np.array([0.0, 0.0, 0.0]) - rpy = RollPitchYaw(0.0, 0.0, 0.0) - X_WG = DrakeRigidTransform(RotationMatrix(rpy), pos) - - # Register collision and visual geometry - self.plant.RegisterCollisionGeometry( - body=world, - X_BG=X_WG, - shape=drake_mesh, - name=f"convex_hull_collision_{i}_{hull_type}", - properties=proximity, - ) - self.plant.RegisterVisualGeometry( - body=world, - X_BG=X_WG, - shape=drake_mesh, - name=f"convex_hull_visual_{i}_{hull_type}", - diffuse_color=np.array([0.7, 0.5, 0.3, 0.8]), # Orange-ish color - ) - - print( - f"Registered convex hull {i} with {len(vertices)} vertices and {len(faces)} faces" - ) - - except Exception as e: - print(f"Warning: Failed to register mesh {i}: {e}") - - # Add a simple table for reference - try: - table_shape = Box(1.0, 1.0, 0.1) # Thinner table - table_pose = RigidTransform(p=[0.5, 0.0, -0.05]) # In front of robot - self.plant.RegisterCollisionGeometry( - world, table_pose, table_shape, "table_collision", proximity - ) - self.plant.RegisterVisualGeometry( - world, table_pose, table_shape, "table_visual", [0.8, 0.6, 0.4, 1.0] - ) - print("Added reference table") - except Exception as e: - print(f"Warning: Failed to add table: {e}") - - def get_seeded_random_rgba(self, id: int): - np.random.seed(id) - return np.random.rand(4) - - @contextmanager - def safe_lcm_instance(self): - """Context manager for safely managing LCM instance lifecycle""" - lcm_instance = tf_lcm_py.LCM() - try: - yield lcm_instance - finally: - pass - - def cleanup_resources(self): - """Clean up resources before exiting""" - # Only clean up once when exiting - print("Cleaning up resources...") - # Force cleanup of resources in reverse order (last created first) - for resource in reversed(self._resources_to_cleanup): - try: - # For objects like TransformListener that might have a close or shutdown method - if hasattr(resource, "close"): - resource.close() - elif hasattr(resource, "shutdown"): - resource.shutdown() - - # Explicitly delete the resource - del resource - except Exception as e: - print(f"Error during cleanup: {e}") - - # Clear the resources list - self._resources_to_cleanup = [] - - def get_transform(self, target_frame, source_frame): - print("Getting transform from", source_frame, "to", target_frame) - attempts = 0 - max_attempts = 20 # Reduced from 120 to avoid long blocking - - while attempts < max_attempts: - try: - # Process LCM messages with error handling - if not self.tf_lcm_instance.handle_timeout(100): # 100ms timeout - # If handle_timeout returns false, we might need to re-check if LCM is still good - if not self.tf_lcm_instance.good(): - print("WARNING: LCM instance is no longer in a good state") - - # Get the most recent timestamp from the buffer instead of using current time - try: - timestamp = self.buffer.get_most_recent_timestamp() - if attempts % 10 == 0: - print(f"Using timestamp from buffer: {timestamp}") - except Exception as e: - # Fall back to current time if get_most_recent_timestamp fails - timestamp = datetime.now() - if not hasattr(timestamp, "timestamp"): - timestamp.timestamp = ( - lambda: time.mktime(timestamp.timetuple()) + timestamp.microsecond / 1e6 - ) - if attempts % 10 == 0: - print(f"Falling back to current time: {timestamp}") - - # Check if we can find the transform - if self.buffer.can_transform(target_frame, source_frame, timestamp): - # print(f"Found transform between '{target_frame}' and '{source_frame}'!") - - # Look up the transform with the timestamp from the buffer - transform = self.buffer.lookup_transform( - target_frame, - source_frame, - timestamp, - timeout=10.0, - time_tolerance=0.1, - lcm_module=lcm_msgs, - ) - - return transform - - # Increment counter and report status every 10 attempts - attempts += 1 - if attempts % 10 == 0: - print(f"Still waiting... (attempt {attempts}/{max_attempts})") - frames = self.buffer.get_all_frame_names() - if frames: - print(f"Frames received so far ({len(frames)} total):") - for frame in sorted(frames): - print(f" {frame}") - else: - print("No frames received yet") - - # Brief pause - time.sleep(0.5) - - except Exception as e: - print(f"Error during transform lookup: {e}") - attempts += 1 - time.sleep(1) # Longer pause after an error - - print(f"\nERROR: No transform found after {max_attempts} attempts") - return None - - def transform_point_cloud_with_open3d(self, points_np: np.ndarray, transform) -> np.ndarray: - """ - Transforms a point cloud using Open3D given a transform. - - Args: - points_np (np.ndarray): Nx3 array of 3D points. - transform: Transform from tf_lcm_py. - - Returns: - np.ndarray: Nx3 array of transformed 3D points. - """ - if points_np.shape[1] != 3: - print("Input point cloud must have shape Nx3.") - return points_np - - # Convert transform to 4x4 numpy matrix - tf_matrix = np.eye(4) - - # Extract rotation quaternion components - qw = transform.transform.rotation.w - qx = transform.transform.rotation.x - qy = transform.transform.rotation.y - qz = transform.transform.rotation.z - - # Convert quaternion to rotation matrix - # Formula from: https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Quaternion-derived_rotation_matrix - tf_matrix[0, 0] = 1 - 2 * qy * qy - 2 * qz * qz - tf_matrix[0, 1] = 2 * qx * qy - 2 * qz * qw - tf_matrix[0, 2] = 2 * qx * qz + 2 * qy * qw - - tf_matrix[1, 0] = 2 * qx * qy + 2 * qz * qw - tf_matrix[1, 1] = 1 - 2 * qx * qx - 2 * qz * qz - tf_matrix[1, 2] = 2 * qy * qz - 2 * qx * qw - - tf_matrix[2, 0] = 2 * qx * qz - 2 * qy * qw - tf_matrix[2, 1] = 2 * qy * qz + 2 * qx * qw - tf_matrix[2, 2] = 1 - 2 * qx * qx - 2 * qy * qy - - # Set translation - tf_matrix[0, 3] = transform.transform.translation.x - tf_matrix[1, 3] = transform.transform.translation.y - tf_matrix[2, 3] = transform.transform.translation.z - - # Create Open3D point cloud - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(points_np) - - # Apply transformation - pcd.transform(tf_matrix) - - # Return as NumPy array - return np.asarray(pcd.points) - - -# Updated main function -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Visualize manipulation results") - parser.add_argument("--visualize-only", action="store_true", help="Only visualize results") - args = parser.parse_args() - - if args.visualize_only: - visualize_results() - exit(0) - - try: - # Then set up Drake environment - kinematic_chain_joints = [ - "pillar_platform_joint", - "joint1", - "joint2", - "joint3", - "joint4", - "joint5", - "joint6", - ] - - links_to_ignore = [ - "devkit_base_link", - "pillar_platform", - "piper_angled_mount", - "pan_tilt_base", - "pan_tilt_head", - "pan_tilt_pan", - "base_link", - "link1", - "link2", - "link3", - "link4", - "link5", - "link6", - ] - - urdf_path = "./assets/devkit_base_descr.urdf" - urdf_path = os.path.abspath(urdf_path) - - print(f"Attempting to load URDF from: {urdf_path}") - - env = DrakeKinematicsEnv(urdf_path, kinematic_chain_joints, links_to_ignore) - env.set_joint_positions([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) - transform = env.get_transform("world", "camera_center_link") - print( - transform.transform.translation.x, - transform.transform.translation.y, - transform.transform.translation.z, - ) - print( - transform.transform.rotation.w, - transform.transform.rotation.x, - transform.transform.rotation.y, - transform.transform.rotation.z, - ) - - # Keep the visualization alive - print("\nVisualization is running. Press Ctrl+C to exit.") - while True: - time.sleep(1) - - except KeyboardInterrupt: - print("\nExiting...") - except Exception as e: - print(f"Error: {e}") - import traceback - - traceback.print_exc() diff --git a/build/lib/tests/zed_neural_depth_demo.py b/build/lib/tests/zed_neural_depth_demo.py deleted file mode 100644 index 5edce9633f..0000000000 --- a/build/lib/tests/zed_neural_depth_demo.py +++ /dev/null @@ -1,450 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -ZED Camera Neural Depth Demo - OpenCV Live Visualization with Data Saving - -This script demonstrates live visualization of ZED camera RGB and depth data using OpenCV. -Press SPACE to save RGB and depth images to rgbd_data2 folder. -Press ESC or 'q' to quit. -""" - -import os -import sys -import time -import argparse -import logging -from pathlib import Path -import numpy as np -import cv2 -import yaml -from datetime import datetime -import open3d as o3d - -# Add the project root to Python path -sys.path.append(str(Path(__file__).parent.parent)) - -try: - import pyzed.sl as sl -except ImportError: - print("ERROR: ZED SDK not found. Please install the ZED SDK and pyzed Python package.") - print("Download from: https://www.stereolabs.com/developers/release/") - sys.exit(1) - -from dimos.hardware.zed_camera import ZEDCamera -from dimos.perception.pointcloud.utils import visualize_pcd, visualize_clustered_point_clouds - -# Configure logging -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - - -class ZEDLiveVisualizer: - """Live OpenCV visualization for ZED camera data with saving functionality.""" - - def __init__(self, camera, max_depth=10.0, output_dir="assets/rgbd_data2"): - self.camera = camera - self.max_depth = max_depth - self.output_dir = Path(output_dir) - self.save_counter = 0 - - # Store captured pointclouds for later visualization - self.captured_pointclouds = [] - - # Display settings for 480p - self.display_width = 640 - self.display_height = 480 - - # Create output directory structure - self.setup_output_directory() - - # Get camera info for saving - self.camera_info = camera.get_camera_info() - - # Save camera info files once - self.save_camera_info() - - # OpenCV window name (single window) - self.window_name = "ZED Camera - RGB + Depth" - - # Create window - cv2.namedWindow(self.window_name, cv2.WINDOW_AUTOSIZE) - - def setup_output_directory(self): - """Create the output directory structure.""" - self.output_dir.mkdir(exist_ok=True) - (self.output_dir / "color").mkdir(exist_ok=True) - (self.output_dir / "depth").mkdir(exist_ok=True) - (self.output_dir / "pointclouds").mkdir(exist_ok=True) - logger.info(f"Created output directory: {self.output_dir}") - - def save_camera_info(self): - """Save camera info YAML files with ZED camera parameters.""" - # Get current timestamp - now = datetime.now() - timestamp_sec = int(now.timestamp()) - timestamp_nanosec = int((now.timestamp() % 1) * 1e9) - - # Get camera resolution - resolution = self.camera_info.get("resolution", {}) - width = int(resolution.get("width", 1280)) - height = int(resolution.get("height", 720)) - - # Extract left camera parameters (for RGB) from already available camera_info - left_cam = self.camera_info.get("left_cam", {}) - # Convert numpy values to Python floats - fx = float(left_cam.get("fx", 749.341552734375)) - fy = float(left_cam.get("fy", 748.5587768554688)) - cx = float(left_cam.get("cx", 639.4312744140625)) - cy = float(left_cam.get("cy", 357.2478942871094)) - - # Build distortion coefficients from ZED format - # ZED provides k1, k2, p1, p2, k3 - convert to rational_polynomial format - k1 = float(left_cam.get("k1", 0.0)) - k2 = float(left_cam.get("k2", 0.0)) - p1 = float(left_cam.get("p1", 0.0)) - p2 = float(left_cam.get("p2", 0.0)) - k3 = float(left_cam.get("k3", 0.0)) - distortion = [k1, k2, p1, p2, k3, 0.0, 0.0, 0.0] - - # Create camera info structure with plain Python types - camera_info = { - "D": distortion, - "K": [fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0], - "P": [fx, 0.0, cx, 0.0, 0.0, fy, cy, 0.0, 0.0, 0.0, 1.0, 0.0], - "R": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], - "binning_x": 0, - "binning_y": 0, - "distortion_model": "rational_polynomial", - "header": { - "frame_id": "camera_color_optical_frame", - "stamp": {"nanosec": timestamp_nanosec, "sec": timestamp_sec}, - }, - "height": height, - "roi": {"do_rectify": False, "height": 0, "width": 0, "x_offset": 0, "y_offset": 0}, - "width": width, - } - - # Save color camera info - color_info_path = self.output_dir / "color_camera_info.yaml" - with open(color_info_path, "w") as f: - yaml.dump(camera_info, f, default_flow_style=False) - - # Save depth camera info (same as color for ZED) - depth_info_path = self.output_dir / "depth_camera_info.yaml" - with open(depth_info_path, "w") as f: - yaml.dump(camera_info, f, default_flow_style=False) - - logger.info(f"Saved camera info files to {self.output_dir}") - - def normalize_depth_for_display(self, depth_map): - """Normalize depth map for OpenCV visualization.""" - # Handle invalid values - valid_mask = (depth_map > 0) & np.isfinite(depth_map) - - if not np.any(valid_mask): - return np.zeros_like(depth_map, dtype=np.uint8) - - # Normalize to 0-255 for display - depth_norm = np.zeros_like(depth_map, dtype=np.float32) - depth_clipped = np.clip(depth_map[valid_mask], 0, self.max_depth) - depth_norm[valid_mask] = depth_clipped / self.max_depth - - # Convert to 8-bit and apply colormap - depth_8bit = (depth_norm * 255).astype(np.uint8) - depth_colored = cv2.applyColorMap(depth_8bit, cv2.COLORMAP_JET) - - return depth_colored - - def save_frame(self, rgb_img, depth_map): - """Save RGB, depth images, and pointcloud with proper naming convention.""" - # Generate filename with 5-digit zero-padding - filename = f"{self.save_counter:05d}.png" - pcd_filename = f"{self.save_counter:05d}.ply" - - # Save RGB image - rgb_path = self.output_dir / "color" / filename - cv2.imwrite(str(rgb_path), rgb_img) - - # Save depth image (convert to 16-bit for proper depth storage) - depth_path = self.output_dir / "depth" / filename - # Convert meters to millimeters and save as 16-bit - depth_mm = (depth_map * 1000).astype(np.uint16) - cv2.imwrite(str(depth_path), depth_mm) - - # Capture and save pointcloud - pcd = self.camera.capture_pointcloud() - if pcd is not None and len(np.asarray(pcd.points)) > 0: - pcd_path = self.output_dir / "pointclouds" / pcd_filename - o3d.io.write_point_cloud(str(pcd_path), pcd) - - # Store pointcloud for later visualization - self.captured_pointclouds.append(pcd) - - logger.info( - f"Saved frame {self.save_counter}: {rgb_path}, {depth_path}, and {pcd_path}" - ) - else: - logger.warning(f"Failed to capture pointcloud for frame {self.save_counter}") - logger.info(f"Saved frame {self.save_counter}: {rgb_path} and {depth_path}") - - self.save_counter += 1 - - def visualize_captured_pointclouds(self): - """Visualize all captured pointclouds using Open3D, one by one.""" - if not self.captured_pointclouds: - logger.info("No pointclouds captured to visualize") - return - - logger.info( - f"Visualizing {len(self.captured_pointclouds)} captured pointclouds one by one..." - ) - logger.info("Close each pointcloud window to proceed to the next one") - - for i, pcd in enumerate(self.captured_pointclouds): - if len(np.asarray(pcd.points)) > 0: - logger.info(f"Displaying pointcloud {i + 1}/{len(self.captured_pointclouds)}") - visualize_pcd(pcd, window_name=f"ZED Pointcloud {i + 1:05d}", point_size=2.0) - else: - logger.warning(f"Pointcloud {i + 1} is empty, skipping...") - - logger.info("Finished displaying all pointclouds") - - def update_display(self): - """Update the live display with new frames.""" - # Capture frame - left_img, right_img, depth_map = self.camera.capture_frame() - - if left_img is None or depth_map is None: - return False, None, None - - # Resize RGB to 480p - rgb_resized = cv2.resize(left_img, (self.display_width, self.display_height)) - - # Create depth visualization - depth_colored = self.normalize_depth_for_display(depth_map) - - # Resize depth to 480p - depth_resized = cv2.resize(depth_colored, (self.display_width, self.display_height)) - - # Add text overlays - text_color = (255, 255, 255) - font = cv2.FONT_HERSHEY_SIMPLEX - font_scale = 0.6 - thickness = 2 - - # Add title and instructions to RGB - cv2.putText( - rgb_resized, "RGB Camera Feed", (10, 25), font, font_scale, text_color, thickness - ) - cv2.putText( - rgb_resized, - "SPACE: Save | ESC/Q: Quit", - (10, 50), - font, - font_scale - 0.1, - text_color, - thickness, - ) - - # Add title and stats to depth - cv2.putText( - depth_resized, - f"Depth Map (0-{self.max_depth}m)", - (10, 25), - font, - font_scale, - text_color, - thickness, - ) - cv2.putText( - depth_resized, - f"Saved: {self.save_counter} frames", - (10, 50), - font, - font_scale - 0.1, - text_color, - thickness, - ) - - # Stack images horizontally - combined_display = np.hstack((rgb_resized, depth_resized)) - - # Display combined image - cv2.imshow(self.window_name, combined_display) - - return True, left_img, depth_map - - def handle_key_events(self, rgb_img, depth_map): - """Handle keyboard input.""" - key = cv2.waitKey(1) & 0xFF - - if key == ord(" "): # Space key - save frame - if rgb_img is not None and depth_map is not None: - self.save_frame(rgb_img, depth_map) - return "save" - elif key == 27 or key == ord("q"): # ESC or 'q' - quit - return "quit" - - return "continue" - - def cleanup(self): - """Clean up OpenCV windows.""" - cv2.destroyAllWindows() - - -def main(): - parser = argparse.ArgumentParser( - description="ZED Camera Neural Depth Demo - OpenCV with Data Saving" - ) - parser.add_argument("--camera-id", type=int, default=0, help="ZED camera ID (default: 0)") - parser.add_argument( - "--resolution", - type=str, - default="HD1080", - choices=["HD2K", "HD1080", "HD720", "VGA"], - help="Camera resolution (default: HD1080)", - ) - parser.add_argument( - "--max-depth", - type=float, - default=10.0, - help="Maximum depth for visualization in meters (default: 10.0)", - ) - parser.add_argument( - "--camera-fps", type=int, default=15, help="Camera capture FPS (default: 30)" - ) - parser.add_argument( - "--depth-mode", - type=str, - default="NEURAL", - choices=["NEURAL", "NEURAL_PLUS"], - help="Depth mode (NEURAL=faster, NEURAL_PLUS=more accurate)", - ) - parser.add_argument( - "--output-dir", - type=str, - default="assets/rgbd_data2", - help="Output directory for saved data (default: rgbd_data2)", - ) - - args = parser.parse_args() - - # Map resolution string to ZED enum - resolution_map = { - "HD2K": sl.RESOLUTION.HD2K, - "HD1080": sl.RESOLUTION.HD1080, - "HD720": sl.RESOLUTION.HD720, - "VGA": sl.RESOLUTION.VGA, - } - - depth_mode_map = {"NEURAL": sl.DEPTH_MODE.NEURAL, "NEURAL_PLUS": sl.DEPTH_MODE.NEURAL_PLUS} - - try: - # Initialize ZED camera with neural depth - logger.info( - f"Initializing ZED camera with {args.depth_mode} depth processing at {args.camera_fps} FPS..." - ) - camera = ZEDCamera( - camera_id=args.camera_id, - resolution=resolution_map[args.resolution], - depth_mode=depth_mode_map[args.depth_mode], - fps=args.camera_fps, - ) - - # Open camera - with camera: - # Get camera information - info = camera.get_camera_info() - logger.info(f"Camera Model: {info.get('model', 'Unknown')}") - logger.info(f"Serial Number: {info.get('serial_number', 'Unknown')}") - logger.info(f"Firmware: {info.get('firmware', 'Unknown')}") - logger.info(f"Resolution: {info.get('resolution', {})}") - logger.info(f"Baseline: {info.get('baseline', 0):.3f}m") - - # Initialize visualizer - visualizer = ZEDLiveVisualizer( - camera, max_depth=args.max_depth, output_dir=args.output_dir - ) - - logger.info("Starting live visualization...") - logger.info("Controls:") - logger.info(" SPACE - Save current RGB and depth frame") - logger.info(" ESC/Q - Quit") - - frame_count = 0 - start_time = time.time() - - try: - while True: - loop_start = time.time() - - # Update display - success, rgb_img, depth_map = visualizer.update_display() - - if success: - frame_count += 1 - - # Handle keyboard events - action = visualizer.handle_key_events(rgb_img, depth_map) - - if action == "quit": - break - elif action == "save": - # Frame was saved, no additional action needed - pass - - # Print performance stats every 60 frames - if frame_count % 60 == 0: - elapsed = time.time() - start_time - fps = frame_count / elapsed - logger.info( - f"Frame {frame_count} | FPS: {fps:.1f} | Saved: {visualizer.save_counter}" - ) - - # Small delay to prevent CPU overload - elapsed = time.time() - loop_start - min_frame_time = 1.0 / 60.0 # Cap at 60 FPS - if elapsed < min_frame_time: - time.sleep(min_frame_time - elapsed) - - except KeyboardInterrupt: - logger.info("Stopped by user") - - # Final stats - total_time = time.time() - start_time - if total_time > 0: - avg_fps = frame_count / total_time - logger.info( - f"Final stats: {frame_count} frames in {total_time:.1f}s (avg {avg_fps:.1f} FPS)" - ) - logger.info(f"Total saved frames: {visualizer.save_counter}") - - # Visualize captured pointclouds - visualizer.visualize_captured_pointclouds() - - except Exception as e: - logger.error(f"Error during execution: {e}") - raise - finally: - if "visualizer" in locals(): - visualizer.cleanup() - logger.info("Demo completed") - - -if __name__ == "__main__": - main() From 1e8da3f6df40557d25dd4529cf3aab719daf21d1 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Fri, 18 Jul 2025 01:39:23 -0700 Subject: [PATCH 69/89] use piper from ours --- dimos/hardware/piper_arm.py | 247 ++++++++++++++++++++++++++++++------ 1 file changed, 206 insertions(+), 41 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index ee528792d1..c661ad37f6 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -27,9 +27,18 @@ import tty import select from scipy.spatial.transform import Rotation as R -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler +import random +import threading + +import pytest + +import dimos.core as core +import dimos.protocol.service.lcmservice as lcmservice +from dimos.core import In, Module, Out, rpc +from dimos_lcm.geometry_msgs import Pose, Vector3, Twist + class PiperArm: def __init__(self, arm_name: str = "arm"): @@ -40,6 +49,7 @@ def __init__(self, arm_name: str = "arm"): self.resetArm() time.sleep(0.1) self.enable() + self.enable_gripper() # Enable gripper after arm is enabled self.gotoZero() time.sleep(1) self.init_vel_controller() @@ -60,11 +70,17 @@ def enable(self): pass time.sleep(0.01) print(f"[PiperArm] Enabled") - self.arm.MotionCtrl_2(0x01, 0x01, 80, 0x00) + # self.arm.ModeCtrl( + # ctrl_mode=0x01, # CAN command mode + # move_mode=0x01, # “Move-J”, but ignored in MIT + # move_spd_rate_ctrl=100, # doesn’t matter in MIT + # is_mit_mode=0xAD # <-- the magic flag + # ) + self.arm.MotionCtrl_2(0x01, 0x01, 80, 0xAD) def gotoZero(self): factor = 1000 - position = [57.0, 0.0, 250.0, 0, 90.0, 0, 0] + position = [57.0, 0.0, 250.0, 0, 97.0, 0, 0] X = round(position[0] * factor) Y = round(position[1] * factor) Z = round(position[2] * factor) @@ -80,28 +96,44 @@ def gotoZero(self): def softStop(self): self.gotoZero() time.sleep(1) - self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + self.arm.MotionCtrl_2( + 0x01, + 0x00, + 100, + ) self.arm.MotionCtrl_1(0x01, 0, 0) time.sleep(3) - def cmd_ee_pose_values(self, x, y, z, r, p, y_): + def cmd_ee_pose_values(self, x, y, z, r, p, y_, line_mode=False): """Command end-effector to target pose in space (position + Euler angles)""" factor = 1000 - pose = [x * factor, y * factor, z * factor, r * factor, p * factor, y_ * factor] - self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + pose = [ + x * factor * factor, + y * factor * factor, + z * factor * factor, + r * factor, + p * factor, + y_ * factor, + ] + self.arm.MotionCtrl_2(0x01, 0x02 if line_mode else 0x00, 100, 0x00) self.arm.EndPoseCtrl( int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5]) ) - def cmd_ee_pose(self, pose: Pose): + def cmd_ee_pose(self, pose: Pose, line_mode=False): """Command end-effector to target pose using Pose message""" # Convert quaternion to euler angles euler = quaternion_to_euler(pose.orientation, degrees=True) - + # Command the pose self.cmd_ee_pose_values( - pose.position.x, pose.position.y, pose.position.z, - euler[0], euler[1], euler[2] + pose.position.x, + pose.position.y, + pose.position.z, + euler[0], + euler[1], + euler[2], + line_mode, ) def get_ee_pose(self): @@ -111,16 +143,16 @@ def get_ee_pose(self): # Extract individual pose values and convert to base units # Position values are divided by 1000 to convert from SDK units to meters # Rotation values are divided by 1000 to convert from SDK units to radians - x = pose.end_pose.X_axis / factor / factor # Convert mm to m - y = pose.end_pose.Y_axis / factor / factor # Convert mm to m - z = pose.end_pose.Z_axis / factor / factor # Convert mm to m - rx = pose.end_pose.RX_axis / factor - ry = pose.end_pose.RY_axis / factor + x = pose.end_pose.X_axis / factor / factor # Convert mm to m + y = pose.end_pose.Y_axis / factor / factor # Convert mm to m + z = pose.end_pose.Z_axis / factor / factor # Convert mm to m + rx = pose.end_pose.RX_axis / factor + ry = pose.end_pose.RY_axis / factor rz = pose.end_pose.RZ_axis / factor # Create position vector (already in meters) position = Vector3(x, y, z) - + orientation = euler_to_quaternion(Vector3(rx, ry, rz), degrees=True) return Pose(position, orientation) @@ -133,9 +165,23 @@ def cmd_gripper_ctrl(self, position): self.arm.GripperCtrl(abs(round(position)), 250, 0x01, 0) print(f"[PiperArm] Commanding gripper position: {position}") + def enable_gripper(self): + """Enable the gripper using the initialization sequence""" + print("[PiperArm] Enabling gripper...") + while not self.arm.EnablePiper(): + time.sleep(0.01) + self.arm.GripperCtrl(0, 1000, 0x02, 0) + self.arm.GripperCtrl(0, 1000, 0x01, 0) + print("[PiperArm] Gripper enabled") + + def release_gripper(self): + """Release gripper by opening to 100mm (10cm)""" + print("[PiperArm] Releasing gripper (opening to 100mm)...") + self.cmd_gripper_ctrl(0.1) # 0.1m = 100mm = 10cm + def resetArm(self): self.arm.MotionCtrl_1(0x02, 0, 0) - self.arm.MotionCtrl_2(0, 0, 0, 0x00) + self.arm.MotionCtrl_2(0, 0, 0, 0xAD) print(f"[PiperArm] Resetting arm") def init_vel_controller(self): @@ -147,15 +193,8 @@ def init_vel_controller(self): self.dt = 0.01 def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): - x_dot = x_dot * 1000 - y_dot = y_dot * 1000 - z_dot = z_dot * 1000 - R_dot = R_dot * 1000 - P_dot = P_dot * 1000 - Y_dot = Y_dot * 1000 - joint_state = self.arm.GetArmJointMsgs().joint_state - # print(f"[PiperArm] Current Joints: {joint_state}", type(joint_state)) + # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) joint_angles = np.array( [ joint_state.joint_1, @@ -168,8 +207,7 @@ def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): ) # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) factor = 57295.7795 # 1000*180/3.1415926 - joint_angles = joint_angles * factor # convert to radians - # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) + joint_angles = joint_angles / factor # convert to radians q = np.array( [ @@ -181,12 +219,14 @@ def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): joint_angles[5], ] ) - # print(f"[PiperArm] Current Joints: {q}") - time.sleep(0.005) + J = self.chain.jacobian(q) + self.J_pinv = np.linalg.pinv(J) dq = self.J_pinv @ np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt newq = q + dq - self.arm.MotionCtrl_2(0x01, 0x01, 100, 0x00) + newq = newq * factor + + self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) self.arm.JointCtrl( int(round(newq[0])), int(round(newq[1])), @@ -195,23 +235,46 @@ def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): int(round(newq[4])), int(round(newq[5])), ) + time.sleep(self.dt) # print(f"[PiperArm] Moving to Joints to : {newq}") - def cmd_vel_ee(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): + def cmd_vel_ee(self, x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot): factor = 1000 x_dot = x_dot * factor y_dot = y_dot * factor z_dot = z_dot * factor - R_dot = R_dot * factor - P_dot = P_dot * factor - Y_dot = Y_dot * factor + RX_dot = RX_dot * factor + PY_dot = PY_dot * factor + YZ_dot = YZ_dot * factor + + current_pose_msg = self.get_ee_pose() + + # Convert quaternion to euler angles + quat = [ + current_pose_msg.orientation.x, + current_pose_msg.orientation.y, + current_pose_msg.orientation.z, + current_pose_msg.orientation.w, + ] + rotation = R.from_quat(quat) + euler = rotation.as_euler("xyz") # Returns [rx, ry, rz] in radians + + # Create current pose array [x, y, z, rx, ry, rz] + current_pose = np.array( + [ + current_pose_msg.position.x, + current_pose_msg.position.y, + current_pose_msg.position.z, + euler[0], + euler[1], + euler[2], + ] + ) - current_pose = self.get_EE_pose() - current_pose = np.array(current_pose) - current_pose = current_pose + # Apply velocity increment current_pose = current_pose + np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt - current_pose = current_pose - self.cmd_EE_pose( + + self.cmd_ee_pose_values( current_pose[0], current_pose[1], current_pose[2], @@ -230,6 +293,108 @@ def disable(self): self.arm.DisconnectPort() +class VelocityController(Module): + cmd_vel: In[Twist] = None + + def __init__(self, arm, period=0.01, *args, **kwargs): + super().__init__(*args, **kwargs) + self.arm = arm + self.period = period + self.latest_cmd = None + self.last_cmd_time = None + + @rpc + def start(self): + self.cmd_vel.subscribe(self.handle_cmd_vel) + + def control_loop(): + while True: + # Check for timeout (1 second) + if self.last_cmd_time and (time.time() - self.last_cmd_time) > 1.0: + print("No velocity command received for 1 second, stopping control loop") + break + + cmd_vel = self.latest_cmd + + joint_state = self.arm.GetArmJointMsgs().joint_state + # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) + joint_angles = np.array( + [ + joint_state.joint_1, + joint_state.joint_2, + joint_state.joint_3, + joint_state.joint_4, + joint_state.joint_5, + joint_state.joint_6, + ] + ) + factor = 57295.7795 # 1000*180/3.1415926 + joint_angles = joint_angles / factor # convert to radians + q = np.array( + [ + joint_angles[0], + joint_angles[1], + joint_angles[2], + joint_angles[3], + joint_angles[4], + joint_angles[5], + ] + ) + + J = self.chain.jacobian(q) + self.J_pinv = np.linalg.pinv(J) + dq = ( + self.J_pinv + @ np.array( + [ + cmd_vel.linear.X, + cmd_vel.linear.y, + cmd_vel.linear.z, + cmd_vel.angular.x, + cmd_vel.angular.y, + cmd_vel.angular.z, + ] + ) + * self.dt + ) + newq = q + dq + + newq = newq * factor # convert radians to scaled degree units for joint control + + self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) + self.arm.JointCtrl( + int(round(newq[0])), + int(round(newq[1])), + int(round(newq[2])), + int(round(newq[3])), + int(round(newq[4])), + int(round(newq[5])), + ) + time.sleep(self.period) + + thread = threading.Thread(target=control_loop, daemon=True) + thread.start() + + def handle_cmd_vel(self, cmd_vel: Twist): + self.latest_cmd = cmd_vel + self.last_cmd_time = time.time() + + +@pytest.mark.tool +def run_velocity_controller(): + lcmservice.autoconf() + dimos = core.start(2) + + velocity_controller = dimos.deploy(VelocityController, arm=arm, period=0.01) + velocity_controller.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) + + velocity_controller.start() + + print("Velocity controller started") + while True: + time.sleep(1) + + if __name__ == "__main__": arm = PiperArm() @@ -287,4 +452,4 @@ def teleop_linear_vel(arm): f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s" ) - teleop_linear_vel(arm) + run_velocity_controller() From b1804a3998c359cc813d1a58adcfde0d2f878e93 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Fri, 18 Jul 2025 22:48:54 -0700 Subject: [PATCH 70/89] removed submodule --- dimos-lcm | 1 - 1 file changed, 1 deletion(-) delete mode 160000 dimos-lcm diff --git a/dimos-lcm b/dimos-lcm deleted file mode 160000 index 403afa2fdb..0000000000 --- a/dimos-lcm +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 403afa2fdba3232d98719f426fbd8d7d94e0e549 From d1199b38449edecaf30e3ac079bd550f414a7b32 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Fri, 18 Jul 2025 23:11:46 -0700 Subject: [PATCH 71/89] switch to using lcm types for transform utils --- dimos/hardware/piper_arm.py | 6 ++--- dimos/manipulation/visual_servoing/pbvs.py | 6 ++--- dimos/utils/transform_utils.py | 26 ++++++++++++---------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index c661ad37f6..9fa3e06503 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -130,9 +130,9 @@ def cmd_ee_pose(self, pose: Pose, line_mode=False): pose.position.x, pose.position.y, pose.position.z, - euler[0], - euler[1], - euler[2], + euler.x, + euler.y, + euler.z, line_mode, ) diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index e90f8e6996..dca5cbb0bb 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -290,9 +290,7 @@ def _update_target_grasp_pose(self, ee_pose: Pose): target_pos = self.current_target.bbox.center.position # Calculate orientation pointing from target towards EE - yaw_to_ee = yaw_towards_point( - Vector3(target_pos.x, target_pos.y, target_pos.z), ee_pose.position - ) + yaw_to_ee = yaw_towards_point(target_pos, ee_pose.position) # Create target pose with proper orientation # Convert grasp pitch from degrees to radians with mapping: @@ -438,7 +436,7 @@ def get_object_pose_camera_frame( Tuple of (position, rotation) in camera frame """ # Calculate orientation pointing at camera - yaw_to_camera = yaw_towards_point(Vector3(object_pos.x, object_pos.y, object_pos.z)) + yaw_to_camera = yaw_towards_point(object_pos) # Convert euler angles to quaternion using utility function euler = Vector3(0.0, 0.0, yaw_to_camera) # Level grasp diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 689091bc3b..143b74a33c 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -13,11 +13,11 @@ # limitations under the License. import numpy as np -from typing import Tuple, Dict, Any +from typing import Tuple import logging from scipy.spatial.transform import Rotation as R -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion +from dimos_lcm.geometry_msgs import Pose, Point, Vector3, Quaternion logger = logging.getLogger(__name__) @@ -69,7 +69,7 @@ def matrix_to_pose(T: np.ndarray) -> Pose: Pose object with position and orientation (quaternion) """ # Extract position - pos = Vector3(T[0, 3], T[1, 3], T[2, 3]) + pos = Point(T[0, 3], T[1, 3], T[2, 3]) # Extract rotation matrix and convert to quaternion Rot = T[:3, :3] @@ -149,7 +149,7 @@ def optical_to_robot_frame(pose: Pose) -> Pose: quat_robot = R.from_matrix(R_robot).as_quat() # [x, y, z, w] return Pose( - Vector3(robot_x, robot_y, robot_z), + Point(robot_x, robot_y, robot_z), Quaternion(quat_robot[0], quat_robot[1], quat_robot[2], quat_robot[3]), ) @@ -191,12 +191,12 @@ def robot_to_optical_frame(pose: Pose) -> Pose: quat_optical = R.from_matrix(R_optical).as_quat() # [x, y, z, w] return Pose( - Vector3(optical_x, optical_y, optical_z), + Point(optical_x, optical_y, optical_z), Quaternion(quat_optical[0], quat_optical[1], quat_optical[2], quat_optical[3]), ) -def yaw_towards_point(position: Vector3, target_point: Vector3 = Vector3(0.0, 0.0, 0.0)) -> float: +def yaw_towards_point(position: Point, target_point: Point = None) -> float: """ Calculate yaw angle from target point to position (away from target). This is commonly used for object orientation in grasping applications. @@ -209,29 +209,31 @@ def yaw_towards_point(position: Vector3, target_point: Vector3 = Vector3(0.0, 0. Returns: Yaw angle in radians pointing from target_point to position """ + if target_point is None: + target_point = Point(0.0, 0.0, 0.0) direction_x = position.x - target_point.x direction_y = position.y - target_point.y return np.arctan2(direction_y, direction_x) def transform_robot_to_map( - robot_position: Vector3, robot_rotation: Vector3, position: Vector3, rotation: Vector3 -) -> Tuple[Vector3, Vector3]: + robot_position: Point, robot_rotation: Vector3, position: Point, rotation: Vector3 +) -> Tuple[Point, Vector3]: """Transform position and rotation from robot frame to map frame. Args: robot_position: Current robot position in map frame robot_rotation: Current robot rotation in map frame - position: Position in robot frame as Vector3 (x, y, z) + position: Position in robot frame as Point (x, y, z) rotation: Rotation in robot frame as Vector3 (roll, pitch, yaw) in radians Returns: Tuple of (transformed_position, transformed_rotation) where: - - transformed_position: Vector3 (x, y, z) in map frame + - transformed_position: Point (x, y, z) in map frame - transformed_rotation: Vector3 (roll, pitch, yaw) in map frame Example: - obj_pos_robot = Vector3(1.0, 0.5, 0.0) # 1m forward, 0.5m left of robot + obj_pos_robot = Point(1.0, 0.5, 0.0) # 1m forward, 0.5m left of robot obj_rot_robot = Vector3(0.0, 0.0, 0.0) # No rotation relative to robot map_pos, map_rot = transform_robot_to_map(robot_position, robot_rotation, obj_pos_robot, obj_rot_robot) @@ -262,7 +264,7 @@ def transform_robot_to_map( map_pitch = robot_rot.y + rot_pitch # Add robot's pitch map_yaw_rot = normalize_angle(robot_yaw + rot_yaw) # Add robot's yaw and normalize - transformed_position = Vector3(map_x, map_y, map_z) + transformed_position = Point(map_x, map_y, map_z) transformed_rotation = Vector3(map_roll, map_pitch, map_yaw_rot) return transformed_position, transformed_rotation From e29b5636fb1613c234677d9c8779420ff9c3467c Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 21 Jul 2025 11:07:24 -0700 Subject: [PATCH 72/89] Perfected state machine, cleaned everything up --- dimos/hardware/piper_arm.py | 4 +- .../visual_servoing/manipulation.py | 742 ++++++++++++++++++ dimos/manipulation/visual_servoing/pbvs.py | 104 +-- dimos/utils/transform_utils.py | 7 + tests/test_ibvs.py | 269 +------ 5 files changed, 805 insertions(+), 321 deletions(-) create mode 100644 dimos/manipulation/visual_servoing/manipulation.py diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 9fa3e06503..3ec4f216f7 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -45,9 +45,9 @@ def __init__(self, arm_name: str = "arm"): self.init_can() self.arm = C_PiperInterface_V2() self.arm.ConnectPort() - time.sleep(0.1) + time.sleep(0.5) self.resetArm() - time.sleep(0.1) + time.sleep(0.5) self.enable() self.enable_gripper() # Enable gripper after arm is enabled self.gotoZero() diff --git a/dimos/manipulation/visual_servoing/manipulation.py b/dimos/manipulation/visual_servoing/manipulation.py new file mode 100644 index 0000000000..9db79595f7 --- /dev/null +++ b/dimos/manipulation/visual_servoing/manipulation.py @@ -0,0 +1,742 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Manipulation system for robotic grasping with visual servoing. +Handles grasping logic, state machine, and hardware coordination. +""" + +import cv2 +import time +from typing import Optional, Tuple, Any +from enum import Enum +from collections import deque + +import numpy as np + +from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor +from dimos.manipulation.visual_servoing.pbvs import PBVS +from dimos.perception.common.utils import ( + find_clicked_detection, + bbox2d_to_corners, +) +from dimos.manipulation.visual_servoing.utils import ( + match_detection_by_id, +) +from dimos.utils.transform_utils import ( + pose_to_matrix, + matrix_to_pose, + create_transform_from_6dof, + compose_transforms, +) +from dimos.utils.logging_config import setup_logger +from dimos_lcm.geometry_msgs import Vector3, Pose +from dimos_lcm.vision_msgs import Detection3DArray, Detection2DArray + +logger = setup_logger("dimos.manipulation.manipulation") + + +class GraspStage(Enum): + """Enum for different grasp stages.""" + + IDLE = "idle" # No target set + PRE_GRASP = "pre_grasp" # Target set, moving to pre-grasp position + GRASP = "grasp" # Executing final grasp + CLOSE_AND_LIFT = "close_and_lift" # Close gripper and lift + + +class Manipulation: + """ + High-level manipulation orchestrator for visual servoing and grasping. + + Handles: + - State machine for grasping sequences + - Grasp execution logic + - Coordination between perception and control + + This class is hardware-agnostic and accepts camera and arm objects. + """ + + def __init__( + self, + camera: Any, # Generic camera object with required interface + arm: Any, # Generic arm object with required interface + camera_intrinsics: list, # [fx, fy, cx, cy] + direct_ee_control: bool = True, + ee_to_camera_6dof: Optional[list] = None, + ): + """ + Initialize manipulation system. + + Args: + camera: Camera object with capture_frame_with_pose() method + arm: Robot arm object with get_ee_pose(), cmd_ee_pose(), cmd_vel_ee(), + cmd_gripper_ctrl(), release_gripper(), softStop(), gotoZero(), and disable() methods + camera_intrinsics: Camera intrinsics [fx, fy, cx, cy] + direct_ee_control: If True, use direct EE pose control; if False, use velocity control + ee_to_camera_6dof: EE to camera transform [x, y, z, rx, ry, rz] in meters and radians + """ + self.camera = camera + self.arm = arm + self.direct_ee_control = direct_ee_control + + # Default EE to camera transform if not provided + if ee_to_camera_6dof is None: + ee_to_camera_6dof = [-0.06, 0.03, -0.05, 0.0, -1.57, 0.0] + + # Create transform matrices + pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) + rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) + self.T_ee_to_camera = create_transform_from_6dof(pos, rot) + + # Initialize processors + self.detector = Detection3DProcessor(camera_intrinsics) + self.pbvs = PBVS( + position_gain=0.3, + rotation_gain=0.2, + target_tolerance=0.05, + direct_ee_control=direct_ee_control, + ) + + # Control state + self.last_valid_target = None + self.waiting_for_reach = False # True when waiting for robot to reach commanded pose + self.last_commanded_pose = None # Last pose sent to robot + self.target_updated = False # True when target has been updated with fresh detections + + # Grasp parameters + self.grasp_width_offset = 0.03 # Default grasp width offset + self.grasp_pitch_degrees = 30.0 # Default grasp pitch in degrees + self.pregrasp_distance = 0.3 # Distance to maintain before grasping (m) + self.grasp_distance = 0.01 # Distance for final grasp approach (m) + self.grasp_close_delay = 3.0 # Time to wait at grasp pose before closing (seconds) + self.grasp_reached_time = None # Time when grasp pose was reached + + # Grasp stage tracking + self.grasp_stage = GraspStage.IDLE + + # Pose stabilization tracking + self.pose_history_size = 4 # Number of poses to check for stabilization + self.pose_stabilization_threshold = 0.005 # 1cm threshold for stabilization + self.stabilization_timeout = 10.0 # Timeout in seconds before giving up + self.stabilization_start_time = None # Time when stabilization started + self.reached_poses = deque( + maxlen=self.pose_history_size + ) # Only stores poses that were reached + self.adjustment_count = 0 + + # State for visualization + self.current_visualization = None + self.last_rgb = None + self.last_detection_3d_array = None + self.last_detection_2d_array = None + self.last_camera_pose = None + self.last_target_tracked = False + + logger.info( + f"Initialized Manipulation system in {'Direct EE' if direct_ee_control else 'Velocity'} control mode" + ) + + def set_grasp_stage(self, stage: GraspStage): + """ + Set the grasp stage. + + Args: + stage: The new grasp stage + """ + self.grasp_stage = stage + logger.info(f"Set grasp stage to: {stage.value}") + + def set_grasp_pitch(self, pitch_degrees: float): + """ + Set the grasp pitch angle. + + Args: + pitch_degrees: Grasp pitch angle in degrees (0-90) + 0 = level grasp, 90 = top-down grasp + """ + # Clamp to valid range + pitch_degrees = max(0.0, min(90.0, pitch_degrees)) + self.grasp_pitch_degrees = pitch_degrees + self.pbvs.set_grasp_pitch(pitch_degrees) + logger.info(f"Set grasp pitch to: {pitch_degrees} degrees") + + def reset_to_idle(self): + """Reset the manipulation system to IDLE state.""" + self.pbvs.clear_target() + self.grasp_stage = GraspStage.IDLE + self.reached_poses.clear() + self.adjustment_count = 0 + self.waiting_for_reach = False + self.last_commanded_pose = None + self.target_updated = False + self.stabilization_start_time = None + self.grasp_reached_time = None + + def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: + """ + Execute pre-grasp stage: visual servoing to pre-grasp position. + + Args: + detection_3d_array: Current 3D detections + + Returns: + True if target is being tracked + """ + # Get EE pose + ee_pose = self.arm.get_ee_pose() + + # PBVS control with pre-grasp distance + vel_cmd, ang_vel_cmd, _, target_tracked, target_pose = self.pbvs.compute_control( + ee_pose, detection_3d_array, self.pregrasp_distance + ) + + if ( + self.stabilization_start_time + and (time.time() - self.stabilization_start_time) > self.stabilization_timeout + ): + logger.warning( + f"Failed to get stable grasp after {self.stabilization_timeout} seconds, resetting" + ) + self.arm.gotoZero() + time.sleep(1.0) + self.reset_to_idle() + return False + + # Set target_updated flag if target was successfully tracked + if target_tracked and target_pose: + self.target_updated = True + self.last_valid_target = self.pbvs.get_current_target() + + # Handle direct EE control + if self.direct_ee_control and target_pose and target_tracked: + # Check if we have enough reached poses and they're stable + if self.check_target_stabilized(): + logger.info("Target stabilized, transitioning to GRASP stage") + self.grasp_stage = GraspStage.GRASP + self.adjustment_count = 0 + self.waiting_for_reach = False + elif not self.waiting_for_reach and self.target_updated: + # Command the pose only if target has been updated + self.arm.cmd_ee_pose(target_pose) + self.last_commanded_pose = target_pose + self.waiting_for_reach = True + self.target_updated = False # Reset flag after commanding + self.adjustment_count += 1 + + elapsed_time = ( + time.time() - self.stabilization_start_time + if self.stabilization_start_time + else 0 + ) + logger.info( + f"Commanded target pose: pos=({target_pose.position.x:.3f}, " + f"{target_pose.position.y:.3f}, {target_pose.position.z:.3f}), " + f"attempt {self.adjustment_count} (elapsed: {elapsed_time:.1f}s)" + ) + + # Sleep for 200ms after commanding to avoid rapid commands + time.sleep(0.2) + + elif not self.direct_ee_control and vel_cmd and ang_vel_cmd: + # Velocity control + self.arm.cmd_vel_ee( + vel_cmd.x, vel_cmd.y, vel_cmd.z, ang_vel_cmd.x, ang_vel_cmd.y, ang_vel_cmd.z + ) + + return target_tracked + + def execute_grasp(self, detection_3d_array: Detection3DArray) -> bool: + """ + Execute grasp stage: move to final grasp position. + + Args: + detection_3d_array: Current 3D detections + + Returns: + True if target is being tracked + """ + if not self.waiting_for_reach and self.last_valid_target: + # Get EE pose + ee_pose = self.arm.get_ee_pose() + + # PBVS control with grasp distance + vel_cmd, ang_vel_cmd, _, target_tracked, target_pose = self.pbvs.compute_control( + ee_pose, detection_3d_array, self.grasp_distance + ) + + if self.direct_ee_control and target_pose and target_tracked: + # Get object size and calculate gripper opening + object_size = self.last_valid_target.bbox.size + object_width = object_size.x + gripper_opening = object_width + self.grasp_width_offset + gripper_opening = max(0.005, min(gripper_opening, 0.1)) + + logger.info(f"Executing grasp: opening gripper to {gripper_opening * 1000:.1f}mm") + print(f"Executing grasp: opening gripper to {gripper_opening * 1000:.1f}mm") + + # Command gripper to open and move to grasp pose + self.arm.cmd_gripper_ctrl(gripper_opening) + self.arm.cmd_ee_pose(target_pose, line_mode=True) + self.waiting_for_reach = True + logger.info("Grasp pose commanded") + + return target_tracked + + return False + + def execute_close_and_lift(self): + """Execute the close and lift sequence.""" + logger.info("Executing CLOSE_AND_LIFT sequence") + + # Close gripper + logger.info("Closing gripper") + self.arm.cmd_gripper_ctrl(0.0) # Close gripper completely + time.sleep(0.5) # Wait for gripper to close + + # Return to home position + logger.info("Returning to home position") + self.arm.gotoZero() + + # Reset to IDLE after completion + logger.info("Grasp sequence completed, returning to IDLE") + self.reset_to_idle() + + def capture_and_process( + self, + ) -> Tuple[ + Optional[np.ndarray], Optional[Detection3DArray], Optional[Detection2DArray], Optional[Pose] + ]: + """ + Capture frame from camera and process detections. + + Returns: + Tuple of (rgb_image, detection_3d_array, detection_2d_array, camera_pose) + Returns None values if capture fails + """ + # Capture frame + bgr, _, depth, _ = self.camera.capture_frame_with_pose() + if bgr is None or depth is None: + return None, None, None, None + + # Process + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + + # Get EE pose from robot (this serves as our odometry) + ee_pose = self.arm.get_ee_pose() + + # Transform EE pose to camera pose + ee_transform = pose_to_matrix(ee_pose) + camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) + camera_pose = matrix_to_pose(camera_transform) + + # Process detections using camera transform + detection_3d_array, detection_2d_array = self.detector.process_frame( + rgb, depth, camera_transform + ) + + return rgb, detection_3d_array, detection_2d_array, camera_pose + + def pick_target(self, x: int, y: int) -> bool: + """ + Select a target object at the given pixel coordinates. + + Args: + x: X coordinate in image + y: Y coordinate in image + + Returns: + True if a target was successfully selected + """ + if not self.last_detection_2d_array or not self.last_detection_3d_array: + logger.warning("No detections available for target selection") + return False + + clicked_3d = find_clicked_detection( + (x, y), self.last_detection_2d_array.detections, self.last_detection_3d_array.detections + ) + if clicked_3d: + self.pbvs.set_target(clicked_3d) + self.grasp_stage = GraspStage.PRE_GRASP # Transition from IDLE to PRE_GRASP + self.reached_poses.clear() # Clear pose history + self.adjustment_count = 0 # Reset adjustment counter + self.waiting_for_reach = False # Ensure we're not stuck in waiting state + self.last_commanded_pose = None + self.stabilization_start_time = time.time() # Start the timeout timer + logger.info(f"Target selected at ({x}, {y})") + return True + return False + + def create_visualization( + self, + rgb: np.ndarray, + detection_3d_array: Detection3DArray, + detection_2d_array: Detection2DArray, + camera_pose: Pose, + target_tracked: bool, + ) -> np.ndarray: + """ + Create visualization with detections and status overlays. + + Args: + rgb: RGB image + detection_3d_array: 3D detections + detection_2d_array: 2D detections + camera_pose: Current camera pose + target_tracked: Whether target is being tracked + + Returns: + BGR image with visualizations + """ + # Create visualization with position overlays + viz = self.detector.visualize_detections( + rgb, detection_3d_array.detections, detection_2d_array.detections + ) + + # Add PBVS status overlay + viz = self.pbvs.create_status_overlay(viz, self.grasp_stage) + + # Highlight target + current_target = self.pbvs.get_current_target() + if target_tracked and current_target: + det_2d = match_detection_by_id( + current_target, detection_3d_array.detections, detection_2d_array.detections + ) + if det_2d and det_2d.bbox: + x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + + cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) + cv2.putText( + viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 + ) + + # Convert back to BGR for OpenCV display + viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) + + # Add pose info + mode_text = "Direct EE" if self.direct_ee_control else "Velocity" + cv2.putText( + viz_bgr, + f"Eye-in-Hand ({mode_text})", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 255), + 1, + ) + + # Get EE pose for display + ee_pose = self.arm.get_ee_pose() + + camera_text = f"Camera: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" + cv2.putText(viz_bgr, camera_text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1) + + ee_text = ( + f"EE: ({ee_pose.position.x:.2f}, {ee_pose.position.y:.2f}, {ee_pose.position.z:.2f})m" + ) + cv2.putText(viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + # Add control status for direct EE mode + if self.direct_ee_control: + if self.grasp_stage == GraspStage.IDLE: + status_text = "IDLE - Click object to select target" + status_color = (100, 100, 100) + elif self.grasp_stage == GraspStage.PRE_GRASP: + if self.waiting_for_reach: + status_text = "PRE-GRASP - Waiting for robot to reach target..." + status_color = (255, 255, 0) + else: + poses_text = f" ({len(self.reached_poses)}/{self.pose_history_size} poses)" + elapsed_time = ( + time.time() - self.stabilization_start_time + if self.stabilization_start_time + else 0 + ) + time_text = f" [{elapsed_time:.1f}s/{self.stabilization_timeout:.0f}s]" + status_text = f"PRE-GRASP - Collecting stable poses{poses_text}{time_text}" + status_color = (0, 255, 255) + elif self.grasp_stage == GraspStage.GRASP: + if self.grasp_reached_time: + time_remaining = self.grasp_close_delay - ( + time.time() - self.grasp_reached_time + ) + status_text = f"GRASP - Waiting to close ({time_remaining:.1f}s)" + else: + status_text = "GRASP - Moving to grasp pose" + status_color = (0, 255, 0) + else: # CLOSE_AND_LIFT + status_text = "CLOSE_AND_LIFT - Closing gripper and lifting" + status_color = (255, 0, 255) + + cv2.putText( + viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 + ) + cv2.putText( + viz_bgr, + "s=STOP | h=HOME | SPACE=FORCE GRASP | g=RELEASE", + (10, 110), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) + + return viz_bgr + + def update(self) -> bool: + """ + Main update function that handles capture, processing, control, and visualization. + + Returns: + True if update was successful, False if capture failed + """ + # Always capture frame for visualization + bgr, _, depth, _ = self.camera.capture_frame_with_pose() + if bgr is None or depth is None: + return False + + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + + # If waiting for robot to reach target, check if reached + if self.waiting_for_reach and self.last_commanded_pose: + ee_pose = self.arm.get_ee_pose() + + if self.grasp_stage == GraspStage.GRASP: + # Check if grasp pose is reached + grasp_distance = self.grasp_distance + reached = self.pbvs.is_target_reached(ee_pose, grasp_distance) + + if reached and not self.grasp_reached_time: + # First time reaching grasp pose + self.grasp_reached_time = time.time() + logger.info( + f"Robot reached grasp pose, waiting {self.grasp_close_delay}s before closing gripper" + ) + + # Wait for delay then transition to CLOSE_AND_LIFT + if ( + self.grasp_reached_time + and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay + ): + logger.info( + f"Waited {self.grasp_close_delay}s at grasp pose, transitioning to CLOSE_AND_LIFT" + ) + self.grasp_stage = GraspStage.CLOSE_AND_LIFT + self.waiting_for_reach = False + else: + # For PRE_GRASP stage, check if reached + grasp_distance = ( + self.pregrasp_distance + if self.grasp_stage == GraspStage.PRE_GRASP + else self.grasp_distance + ) + reached = self.pbvs.is_target_reached(ee_pose, grasp_distance) + + if reached: + logger.info("Robot reached commanded pose") + self.waiting_for_reach = False + self.reached_poses.append(self.last_commanded_pose) + self.target_updated = False # Reset flag so we wait for fresh update + time.sleep(0.3) + + # Create basic visualization while waiting + self.current_visualization = self._create_waiting_visualization(rgb) + return True + + # Normal processing when not waiting + # Get EE pose and camera transform + ee_pose = self.arm.get_ee_pose() + ee_transform = pose_to_matrix(ee_pose) + camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) + camera_pose = matrix_to_pose(camera_transform) + + # Process detections + detection_3d_array, detection_2d_array = self.detector.process_frame( + rgb, depth, camera_transform + ) + + # Store for target selection + self.last_rgb = rgb + self.last_detection_3d_array = detection_3d_array + self.last_detection_2d_array = detection_2d_array + self.last_camera_pose = camera_pose + + # Execute stage-specific logic + target_tracked = False + + if self.grasp_stage == GraspStage.IDLE: + # Nothing to do in IDLE + pass + elif self.grasp_stage == GraspStage.PRE_GRASP: + if detection_3d_array: + target_tracked = self.execute_pre_grasp(detection_3d_array) + self.last_target_tracked = target_tracked + elif self.grasp_stage == GraspStage.GRASP: + if detection_3d_array: + target_tracked = self.execute_grasp(detection_3d_array) + self.last_target_tracked = target_tracked + elif self.grasp_stage == GraspStage.CLOSE_AND_LIFT: + # No visual servoing needed for close and lift + self.execute_close_and_lift() + + # Create full visualization + if detection_3d_array and detection_2d_array and camera_pose: + self.current_visualization = self.create_visualization( + rgb, detection_3d_array, detection_2d_array, camera_pose, target_tracked + ) + else: + # Basic visualization with just the RGB image + self.current_visualization = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + + return True + + def get_visualization(self) -> Optional[np.ndarray]: + """ + Get the current visualization image. + + Returns: + BGR image with visualizations, or None if no visualization available + """ + return self.current_visualization + + def handle_keyboard_command(self, key: int) -> str: + """ + Handle keyboard commands for robot control. + + Args: + key: Keyboard key code + + Returns: + Action taken as string, or empty string if no action + """ + if key == ord("r"): + self.reset_to_idle() + return "reset" + elif key == ord("s"): + print("SOFT STOP - Emergency stopping robot!") + self.arm.softStop() + return "stop" + elif key == ord("h"): + print("GO HOME - Returning to safe position...") + self.arm.gotoZero() + return "home" + elif key == ord(" ") and self.direct_ee_control and self.pbvs.target_grasp_pose: + # Manual override - immediately transition to GRASP if in PRE_GRASP + if self.grasp_stage == GraspStage.PRE_GRASP: + logger.info("Manual grasp execution requested") + self.set_grasp_stage(GraspStage.GRASP) + print("Executing target pose") + return "execute" + elif key == 82: # Up arrow - increase pitch + new_pitch = min(90.0, self.grasp_pitch_degrees + 15.0) + self.set_grasp_pitch(new_pitch) + print(f"Grasp pitch: {new_pitch:.0f} degrees") + return "pitch_up" + elif key == 84: # Down arrow - decrease pitch + new_pitch = max(0.0, self.grasp_pitch_degrees - 15.0) + self.set_grasp_pitch(new_pitch) + print(f"Grasp pitch: {new_pitch:.0f} degrees") + return "pitch_down" + elif key == ord("g"): + print("Opening gripper") + self.arm.release_gripper() + return "release" + + return "" + + def _create_waiting_visualization(self, rgb: np.ndarray) -> np.ndarray: + """ + Create a simple visualization while waiting for robot to reach pose. + + Args: + rgb: RGB image + + Returns: + BGR image with waiting status + """ + viz_bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + + # Add waiting status + cv2.putText( + viz_bgr, + "WAITING FOR ROBOT TO REACH TARGET...", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (0, 255, 255), + 2, + ) + + # Add current stage info + stage_text = f"Stage: {self.grasp_stage.value.upper()}" + cv2.putText( + viz_bgr, + stage_text, + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 0), + 1, + ) + + # Add progress info based on stage + if self.grasp_stage == GraspStage.PRE_GRASP: + progress_text = f"Reached poses: {len(self.reached_poses)}/{self.pose_history_size}" + elif self.grasp_stage == GraspStage.GRASP and self.grasp_reached_time: + time_remaining = max( + 0, self.grasp_close_delay - (time.time() - self.grasp_reached_time) + ) + progress_text = f"Closing gripper in: {time_remaining:.1f}s" + else: + progress_text = "" + + if progress_text: + cv2.putText( + viz_bgr, + progress_text, + (10, 90), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 255), + 1, + ) + + return viz_bgr + + def check_target_stabilized(self) -> bool: + """ + Check if the commanded poses have stabilized. + + Returns: + True if poses are stable, False otherwise + """ + if len(self.reached_poses) < self.reached_poses.maxlen: + return False # Not enough poses yet + + # Extract positions + positions = np.array( + [[p.position.x, p.position.y, p.position.z] for p in self.reached_poses] + ) + + # Calculate standard deviation for each axis + std_devs = np.std(positions, axis=0) + + # Check if all axes are below threshold + return np.all(std_devs < self.pose_stabilization_threshold) + + def cleanup(self): + """Clean up resources (detector only, hardware cleanup is caller's responsibility).""" + self.detector.cleanup() + logger.info("Cleaned up manipulation system resources") diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index dca5cbb0bb..ad49b49856 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -19,8 +19,6 @@ import numpy as np from typing import Optional, Tuple -from enum import Enum - from scipy.spatial.transform import Rotation as R from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point from dimos_lcm.vision_msgs import Detection3D, Detection3DArray @@ -39,13 +37,6 @@ logger = setup_logger("dimos.manipulation.pbvs") -class GraspStage(Enum): - """Enum for different grasp stages.""" - - PRE_GRASP = "pre_grasp" - GRASP = "grasp" - - class PBVS: """ High-level Position-Based Visual Servoing orchestrator. @@ -67,10 +58,8 @@ def __init__( max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s target_tolerance: float = 0.01, # 1cm - max_tracking_distance_threshold: float = 0.1, # Max distance for target tracking (m) - min_size_similarity: float = 0.7, # Min size similarity threshold (0.0-1.0) - pregrasp_distance: float = 0.15, # 15cm pregrasp distance - grasp_distance: float = 0.05, # 5cm grasp distance (final approach) + max_tracking_distance_threshold: float = 0.08, # Max distance for target tracking (m) + min_size_similarity: float = 0.6, # Min size similarity threshold (0.0-1.0) direct_ee_control: bool = False, # If True, output target poses instead of velocities ): """ @@ -84,8 +73,6 @@ def __init__( target_tolerance: Distance threshold for considering target reached (m) max_tracking_distance: Maximum distance for valid target tracking (m) min_size_similarity: Minimum size similarity for valid target tracking (0.0-1.0) - pregrasp_distance: Distance to maintain before grasping (m) - grasp_distance: Distance for final grasp approach (m) direct_ee_control: If True, output target poses instead of velocity commands """ # Initialize low-level controller only if not in direct control mode @@ -106,8 +93,6 @@ def __init__( # Target tracking parameters self.max_tracking_distance_threshold = max_tracking_distance_threshold self.min_size_similarity = min_size_similarity - self.pregrasp_distance = pregrasp_distance - self.grasp_distance = grasp_distance self.direct_ee_control = direct_ee_control self.grasp_pitch_degrees = ( 45.0 # Default grasp pitch in degrees (45° between level and top-down) @@ -116,7 +101,6 @@ def __init__( # Target state self.current_target = None self.target_grasp_pose = None - self.grasp_stage = GraspStage.PRE_GRASP # For direct control mode visualization self.last_position_error = None @@ -124,7 +108,6 @@ def __init__( logger.info( f"Initialized PBVS system with controller gains: pos={position_gain}, rot={rotation_gain}, " - f"pregrasp_distance={pregrasp_distance}m, grasp_distance={grasp_distance}m, " f"tracking_thresholds: distance={max_tracking_distance_threshold}m, size={min_size_similarity:.2f}" ) @@ -141,7 +124,6 @@ def set_target(self, target_object: Detection3D) -> bool: if target_object and target_object.bbox and target_object.bbox.center: self.current_target = target_object self.target_grasp_pose = None # Will be computed when needed - self.grasp_stage = GraspStage.PRE_GRASP # Reset to pre-grasp stage logger.info(f"New target set: ID {target_object.id}") return True return False @@ -150,7 +132,6 @@ def clear_target(self): """Clear the current target.""" self.current_target = None self.target_grasp_pose = None - self.grasp_stage = GraspStage.PRE_GRASP self.last_position_error = None self.last_target_reached = False if self.controller: @@ -166,15 +147,6 @@ def get_current_target(self) -> Optional[Detection3D]: """ return self.current_target - def set_grasp_stage(self, stage: GraspStage): - """ - Set the grasp stage. - - Args: - stage: The new grasp stage - """ - self.grasp_stage = stage - def set_grasp_pitch(self, pitch_degrees: float): """ Set the grasp pitch angle in degrees. @@ -190,7 +162,7 @@ def set_grasp_pitch(self, pitch_degrees: float): # Reset target grasp pose to recompute with new pitch self.target_grasp_pose = None - def is_target_reached(self, ee_pose: Pose) -> bool: + def is_target_reached(self, ee_pose: Pose, grasp_distance: float) -> bool: """ Check if the current target stage has been reached. @@ -209,18 +181,7 @@ def is_target_reached(self, ee_pose: Pose) -> bool: error_z = self.target_grasp_pose.position.z - ee_pose.position.z error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) - stage_reached = error_magnitude < self.target_tolerance - - # Handle stage transitions - if stage_reached and self.grasp_stage == GraspStage.PRE_GRASP: - return True # Signal that pre-grasp target was reached - elif stage_reached and self.grasp_stage == GraspStage.GRASP: - # Grasp reached, clear target - logger.info("Grasp position reached, clearing target") - self.clear_target() - return True - - return False + return error_magnitude < self.target_tolerance def update_target_tracking(self, new_detections: Detection3DArray) -> bool: """ @@ -272,12 +233,13 @@ def update_target_tracking(self, new_detections: Detection3DArray) -> bool: ) return False - def _update_target_grasp_pose(self, ee_pose: Pose): + def _update_target_grasp_pose(self, ee_pose: Pose, grasp_distance: float): """ Update target grasp pose based on current target and EE pose. Args: ee_pose: Current end-effector pose + grasp_distance: Distance to maintain from target (pregrasp or grasp distance) """ if ( not self.current_target @@ -304,12 +266,7 @@ def _update_target_grasp_pose(self, ee_pose: Pose): target_pose = Pose(target_pos, target_orientation) # Apply grasp distance - distance = ( - self.pregrasp_distance - if self.grasp_stage == GraspStage.PRE_GRASP - else self.grasp_distance - ) - self.target_grasp_pose = self._apply_grasp_distance(target_pose, distance) + self.target_grasp_pose = self._apply_grasp_distance(target_pose, grasp_distance) def _apply_grasp_distance(self, target_pose: Pose, distance: float) -> Pose: """ @@ -344,7 +301,10 @@ def _apply_grasp_distance(self, target_pose: Pose, distance: float) -> Pose: return Pose(offset_position, target_pose.orientation) def compute_control( - self, ee_pose: Pose, new_detections: Optional[Detection3DArray] = None + self, + ee_pose: Pose, + new_detections: Optional[Detection3DArray] = None, + grasp_distance: float = 0.15, ) -> Tuple[Optional[Vector3], Optional[Vector3], bool, bool, Optional[Pose]]: """ Compute PBVS control with position and orientation servoing. @@ -352,6 +312,7 @@ def compute_control( Args: ee_pose: Current end-effector pose new_detections: Optional new detections for target tracking + grasp_distance: Distance to maintain from target (meters) Returns: Tuple of (velocity_command, angular_velocity_command, target_reached, has_target, target_pose) @@ -381,8 +342,10 @@ def compute_control( # Update target grasp pose if not self.current_target: logger.info("No current target") + return None, None, False, False, None - self._update_target_grasp_pose(ee_pose) + # Update target grasp pose with provided distance + self._update_target_grasp_pose(ee_pose, grasp_distance) if self.target_grasp_pose is None: logger.warning("Failed to compute grasp pose") @@ -397,15 +360,7 @@ def compute_control( ) # Check if target reached using our separate function - target_reached = self.is_target_reached(ee_pose) - - # If stage transitioned, recompute target grasp pose - if ( - target_reached - and self.grasp_stage == GraspStage.GRASP - and self.target_grasp_pose is None - ): - self._update_target_grasp_pose(ee_pose) + target_reached = self.is_target_reached(ee_pose, grasp_distance) # Return appropriate values based on control mode if self.direct_ee_control: @@ -422,50 +377,31 @@ def compute_control( ) return velocity_cmd, angular_velocity_cmd, target_reached, target_tracked, None - def get_object_pose_camera_frame( - self, object_pos: Vector3, camera_pose: Pose - ) -> Tuple[Vector3, Quaternion]: - """ - Get object pose in camera frame coordinates with orientation. - - Args: - object_pos: Object position in camera frame - camera_pose: Current camera pose - - Returns: - Tuple of (position, rotation) in camera frame - """ - # Calculate orientation pointing at camera - yaw_to_camera = yaw_towards_point(object_pos) - - # Convert euler angles to quaternion using utility function - euler = Vector3(0.0, 0.0, yaw_to_camera) # Level grasp - orientation = euler_to_quaternion(euler) - - return object_pos, orientation - def create_status_overlay( self, image: np.ndarray, + grasp_stage=None, ) -> np.ndarray: """ Create PBVS status overlay on image. Args: image: Input image + grasp_stage: Current grasp stage (optional) Returns: Image with PBVS status overlay """ if self.direct_ee_control: # Use direct control overlay + stage_value = grasp_stage.value if grasp_stage else "idle" return create_pbvs_status_overlay( image, self.current_target, self.last_position_error, self.last_target_reached, self.target_grasp_pose, - self.grasp_stage.value, + stage_value, is_direct_control=True, ) else: diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 143b74a33c..5aa33bccce 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -47,6 +47,13 @@ def pose_to_matrix(pose: Pose) -> np.ndarray: # Create rotation matrix from quaternion using scipy quat = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + + # Check for zero norm quaternion and use identity if invalid + quat_norm = np.linalg.norm(quat) + if quat_norm == 0.0: + # Use identity quaternion [0, 0, 0, 1] if zero norm detected + quat = [0.0, 0.0, 0.0, 1.0] + rotation = R.from_quat(quat) Rot = rotation.as_matrix() diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index 33774ad030..6299b57185 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -23,27 +23,6 @@ import cv2 import sys -import time - - -from dimos.hardware.zed_camera import ZEDCamera -from dimos.hardware.piper_arm import PiperArm -from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor -from dimos.manipulation.visual_servoing.pbvs import PBVS, GraspStage -from dimos.perception.common.utils import ( - find_clicked_detection, - bbox2d_to_corners, -) -from dimos.manipulation.visual_servoing.utils import ( - match_detection_by_id, -) -from dimos.utils.transform_utils import ( - pose_to_matrix, - matrix_to_pose, - create_transform_from_6dof, - compose_transforms, -) -from dimos_lcm.geometry_msgs import Vector3 try: import pyzed.sl as sl @@ -51,6 +30,10 @@ print("Error: ZED SDK not installed.") sys.exit(1) +from dimos.hardware.zed_camera import ZEDCamera +from dimos.hardware.piper_arm import PiperArm +from dimos.manipulation.visual_servoing.manipulation import Manipulation + # Global for mouse events mouse_click = None @@ -62,50 +45,12 @@ def mouse_callback(event, x, y, _flags, _param): mouse_click = (x, y) -def execute_grasp(arm, target_object, target_pose, grasp_width_offset: float = 0.02) -> bool: - """ - Execute grasping by opening gripper to accommodate target object. - - Args: - arm: Robot arm interface with gripper control - target_object: Detection3D with size information - grasp_width_offset: Additional width to add to object size for gripper opening - - Returns: - True if grasp was executed, False if no target or no size data - """ - if not target_object: - print("❌ No target object provided for grasping") - return False - - if not target_object.bbox or not target_object.bbox.size: - print("❌ Target has no size information for grasping") - return False - - # Get object size from detection3d data (already in meters) - object_size = target_object.bbox.size - object_width = object_size.x - - # Calculate gripper opening with offset - gripper_opening = object_width + grasp_width_offset - - # Clamp gripper opening to reasonable limits (0.5cm to 10cm) - gripper_opening = max(0.005, min(gripper_opening, 0.1)) - - print(f"🤏 Executing grasp: opening gripper to {gripper_opening * 1000:.1f}mm") - - # Command gripper to open - arm.cmd_gripper_ctrl(gripper_opening) - arm.cmd_ee_pose(target_pose, line_mode=True) - - return True - - def main(): global mouse_click # Configuration DIRECT_EE_CONTROL = True # True: direct EE pose control, False: velocity control + INITIAL_GRASP_PITCH_DEGREES = 30 # 0° = level grasp, 90° = top-down grasp print("=== PBVS Eye-in-Hand Test ===") print("Using EE pose as odometry for camera pose") @@ -133,17 +78,9 @@ def main(): print("Initialized Piper arm") except Exception as e: print(f"Failed to initialize Piper arm: {e}") + zed.close() return - # Define EE to camera transform (adjust these values for your setup) - # Format: [x, y, z, rx, ry, rz] in meters and radians - ee_to_camera_6dof = [-0.06, 0.03, -0.05, 0.0, -1.57, 0.0] - - # Create transform matrices - pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) - rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) - T_ee_to_camera = create_transform_from_6dof(pos, rot) - # Get camera intrinsics cam_info = zed.get_camera_info() intrinsics = [ @@ -153,195 +90,57 @@ def main(): cam_info["left_cam"]["cy"], ] - # Initialize processors - detector = Detection3DProcessor(intrinsics) - pbvs = PBVS( - position_gain=0.3, - rotation_gain=0.2, - target_tolerance=0.05, - pregrasp_distance=0.25, - grasp_distance=0.01, - direct_ee_control=DIRECT_EE_CONTROL, - ) + # Initialize manipulation system + try: + manipulation = Manipulation( + camera=zed, + arm=arm, + camera_intrinsics=intrinsics, + direct_ee_control=DIRECT_EE_CONTROL, + ee_to_camera_6dof=[-0.06, 0.03, -0.05, 0.0, -1.57, 0.0], # Adjust for your setup + ) + except Exception as e: + print(f"Failed to initialize manipulation system: {e}") + zed.close() + arm.disable() + return - # Set custom grasp pitch (60 degrees - between level and top-down) - GRASP_PITCH_DEGREES = 0 # 0° = level grasp, 90° = top-down grasp - pbvs.set_grasp_pitch(GRASP_PITCH_DEGREES) + # Set initial grasp pitch + manipulation.set_grasp_pitch(INITIAL_GRASP_PITCH_DEGREES) # Setup window cv2.namedWindow("PBVS") cv2.setMouseCallback("PBVS", mouse_callback) - # Control state for direct EE mode - execute_target = False # Only move when space is pressed - last_valid_target = None - - # Rate limiting for pose execution - MIN_EXECUTION_PERIOD = 1.0 # Minimum seconds between pose executions - last_execution_time = 0 - try: while True: - # Capture - bgr, _, depth, _ = zed.capture_frame_with_pose() - if bgr is None or depth is None: + # Update manipulation system + if not manipulation.update(): continue - # Process - rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - - # Get EE pose from robot (this serves as our odometry) - ee_pose = arm.get_ee_pose() - - # Transform EE pose to camera pose - ee_transform = pose_to_matrix(ee_pose) - camera_transform = compose_transforms(ee_transform, T_ee_to_camera) - camera_pose = matrix_to_pose(camera_transform) - - # Process detections using camera transform - detection_3d_array, detection_2d_array = detector.process_frame( - rgb, depth, camera_transform - ) - - # Handle click + # Handle mouse click if mouse_click: - clicked_3d = find_clicked_detection( - mouse_click, detection_2d_array.detections, detection_3d_array.detections - ) - if clicked_3d: - pbvs.set_target(clicked_3d) + x, y = mouse_click + manipulation.pick_target(x, y) mouse_click = None - # Create visualization with position overlays - viz = detector.visualize_detections( - rgb, detection_3d_array.detections, detection_2d_array.detections - ) - - # PBVS control - vel_cmd, ang_vel_cmd, reached, target_tracked, target_pose = pbvs.compute_control( - ee_pose, detection_3d_array - ) - - # Apply commands to robot based on control mode - if DIRECT_EE_CONTROL and target_pose: - # Check if enough time has passed since last execution - current_time = time.time() - if current_time - last_execution_time >= MIN_EXECUTION_PERIOD: - # Direct EE pose control - print( - f"🎯 EXECUTING target pose: pos=({target_pose.position.x:.3f}, {target_pose.position.y:.3f}, {target_pose.position.z:.3f})" - ) - last_valid_target = pbvs.get_current_target() - if pbvs.grasp_stage == GraspStage.PRE_GRASP: - arm.cmd_ee_pose(target_pose) - last_execution_time = current_time - elif pbvs.grasp_stage == GraspStage.GRASP and execute_target: - execute_grasp(arm, last_valid_target, target_pose, grasp_width_offset=0.03) - last_execution_time = current_time - execute_target = False # Reset flag after execution - elif not DIRECT_EE_CONTROL and vel_cmd and ang_vel_cmd: - # Velocity control - arm.cmd_vel_ee( - vel_cmd.x, vel_cmd.y, vel_cmd.z, ang_vel_cmd.x, ang_vel_cmd.y, ang_vel_cmd.z - ) - - # Add PBVS status overlay - viz = pbvs.create_status_overlay(viz) - - # Highlight target - current_target = pbvs.get_current_target() - if target_tracked and current_target: - det_2d = match_detection_by_id( - current_target, detection_3d_array.detections, detection_2d_array.detections - ) - if det_2d and det_2d.bbox: - x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) - x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) - - cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) - cv2.putText( - viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 - ) - - # Convert back to BGR for OpenCV display - viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) - - # Add pose info - mode_text = "Direct EE" if DIRECT_EE_CONTROL else "Velocity" - cv2.putText( - viz_bgr, - f"Eye-in-Hand ({mode_text})", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (0, 255, 255), - 1, - ) - - camera_text = f"Camera: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" - cv2.putText( - viz_bgr, camera_text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 - ) - - ee_text = f"EE: ({ee_pose.position.x:.2f}, {ee_pose.position.y:.2f}, {ee_pose.position.z:.2f})m" - cv2.putText(viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) - - # Add control status - if DIRECT_EE_CONTROL: - status_text = ( - "Target Ready - Press SPACE to execute" if target_pose else "No target selected" - ) - status_color = (0, 255, 255) if target_pose else (100, 100, 100) - cv2.putText( - viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 - ) - cv2.putText( - viz_bgr, - "s=STOP | h=HOME | SPACE=EXECUTE | g=RELEASE", - (10, 110), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) - - # Display - cv2.imshow("PBVS", viz_bgr) + # Get and display visualization + viz = manipulation.get_visualization() + if viz is not None: + cv2.imshow("PBVS", viz) # Handle keyboard input key = cv2.waitKey(1) & 0xFF if key == ord("q"): break - elif key == ord("r"): - pbvs.clear_target() - elif key == ord("s"): - print("🛑 SOFT STOP - Emergency stopping robot!") - arm.softStop() - elif key == ord("h"): - print("🏠 GO HOME - Returning to safe position...") - arm.gotoZero() - elif key == ord(" ") and DIRECT_EE_CONTROL and target_pose: - execute_target = True - if pbvs.grasp_stage == GraspStage.PRE_GRASP: - pbvs.set_grasp_stage(GraspStage.GRASP) - print("⚡ Executing target pose") - elif key == 82: # Up arrow - increase pitch - new_pitch = min(90.0, pbvs.grasp_pitch_degrees + 15.0) - pbvs.set_grasp_pitch(new_pitch) - print(f"↑ Grasp pitch: {new_pitch:.0f}°") - elif key == 84: # Down arrow - decrease pitch - new_pitch = max(0.0, pbvs.grasp_pitch_degrees - 15.0) - pbvs.set_grasp_pitch(new_pitch) - print(f"↓ Grasp pitch: {new_pitch:.0f}°") - elif key == ord("g"): - print("🖐️ Opening gripper") - arm.release_gripper() + else: + manipulation.handle_keyboard_command(key) except KeyboardInterrupt: pass finally: cv2.destroyAllWindows() - detector.cleanup() + manipulation.cleanup() zed.close() arm.disable() From 61b7cb973b992429e12a7e710b9b75504192e188 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 21 Jul 2025 15:21:58 -0700 Subject: [PATCH 73/89] further cleanup --- dimos/hardware/piper_arm.py | 38 +- .../visual_servoing/detection3d.py | 2 +- .../visual_servoing/manipulation.py | 475 +++++++++--------- tests/test_ibvs.py | 12 +- 4 files changed, 262 insertions(+), 265 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 3ec4f216f7..02c26733b2 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -28,17 +28,19 @@ import select from scipy.spatial.transform import Rotation as R from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler +from dimos.utils.logging_config import setup_logger -import random import threading import pytest import dimos.core as core import dimos.protocol.service.lcmservice as lcmservice -from dimos.core import In, Module, Out, rpc +from dimos.core import In, Module, rpc from dimos_lcm.geometry_msgs import Pose, Vector3, Twist +logger = setup_logger("dimos.hardware.piper_arm") + class PiperArm: def __init__(self, arm_name: str = "arm"): @@ -69,7 +71,7 @@ def enable(self): while not self.arm.EnablePiper(): pass time.sleep(0.01) - print(f"[PiperArm] Enabled") + logger.info("Arm enabled") # self.arm.ModeCtrl( # ctrl_mode=0x01, # CAN command mode # move_mode=0x01, # “Move-J”, but ignored in MIT @@ -88,7 +90,7 @@ def gotoZero(self): RY = round(position[4] * factor) RZ = round(position[5] * factor) joint_6 = round(position[6] * factor) - print(X, Y, Z, RX, RY, RZ) + logger.debug(f"Going to zero position: X={X}, Y={Y}, Z={Z}, RX={RX}, RY={RY}, RZ={RZ}") self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) self.arm.GripperCtrl(0, 1000, 0x01, 0) @@ -157,32 +159,32 @@ def get_ee_pose(self): return Pose(position, orientation) - def cmd_gripper_ctrl(self, position): + def cmd_gripper_ctrl(self, position, effort=250): """Command end-effector gripper""" factor = 1000 position = position * factor * factor - self.arm.GripperCtrl(abs(round(position)), 250, 0x01, 0) - print(f"[PiperArm] Commanding gripper position: {position}") + self.arm.GripperCtrl(abs(round(position)), effort, 0x01, 0) + logger.debug(f"Commanding gripper position: {position}mm") def enable_gripper(self): """Enable the gripper using the initialization sequence""" - print("[PiperArm] Enabling gripper...") + logger.info("Enabling gripper...") while not self.arm.EnablePiper(): time.sleep(0.01) self.arm.GripperCtrl(0, 1000, 0x02, 0) self.arm.GripperCtrl(0, 1000, 0x01, 0) - print("[PiperArm] Gripper enabled") + logger.info("Gripper enabled") def release_gripper(self): """Release gripper by opening to 100mm (10cm)""" - print("[PiperArm] Releasing gripper (opening to 100mm)...") + logger.info("Releasing gripper (opening to 100mm)") self.cmd_gripper_ctrl(0.1) # 0.1m = 100mm = 10cm def resetArm(self): self.arm.MotionCtrl_1(0x02, 0, 0) self.arm.MotionCtrl_2(0, 0, 0, 0xAD) - print(f"[PiperArm] Resetting arm") + logger.info("Resetting arm") def init_vel_controller(self): self.chain = kp.build_serial_chain_from_urdf( @@ -272,7 +274,9 @@ def cmd_vel_ee(self, x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot): ) # Apply velocity increment - current_pose = current_pose + np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt + current_pose = ( + current_pose + np.array([x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot]) * self.dt + ) self.cmd_ee_pose_values( current_pose[0], @@ -311,7 +315,9 @@ def control_loop(): while True: # Check for timeout (1 second) if self.last_cmd_time and (time.time() - self.last_cmd_time) > 1.0: - print("No velocity command received for 1 second, stopping control loop") + logger.warning( + "No velocity command received for 1 second, stopping control loop" + ) break cmd_vel = self.latest_cmd @@ -390,7 +396,7 @@ def run_velocity_controller(): velocity_controller.start() - print("Velocity controller started") + logger.info("Velocity controller started") while True: time.sleep(1) @@ -437,7 +443,7 @@ def teleop_linear_vel(arm): elif key == "s": z_dot -= 0.01 elif key == "q": - print("Exiting teleop.") + logger.info("Exiting teleop") arm.disable() break @@ -448,7 +454,7 @@ def teleop_linear_vel(arm): # Only linear velocities, angular set to zero arm.cmd_vel_ee(x_dot, y_dot, z_dot, 0, 0, 0) - print( + logger.debug( f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s" ) diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index b54bb81fd3..01f51cf2b3 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -66,7 +66,7 @@ def __init__( min_confidence: float = 0.6, min_points: int = 30, max_depth: float = 1.0, - max_object_size: float = 0.2, + max_object_size: float = 0.15, ): """ Initialize the real-time 3D detection processor. diff --git a/dimos/manipulation/visual_servoing/manipulation.py b/dimos/manipulation/visual_servoing/manipulation.py index 9db79595f7..295f6424cb 100644 --- a/dimos/manipulation/visual_servoing/manipulation.py +++ b/dimos/manipulation/visual_servoing/manipulation.py @@ -72,7 +72,6 @@ def __init__( self, camera: Any, # Generic camera object with required interface arm: Any, # Generic arm object with required interface - camera_intrinsics: list, # [fx, fy, cx, cy] direct_ee_control: bool = True, ee_to_camera_6dof: Optional[list] = None, ): @@ -80,10 +79,9 @@ def __init__( Initialize manipulation system. Args: - camera: Camera object with capture_frame_with_pose() method + camera: Camera object with capture_frame_with_pose() and calculate_intrinsics() methods arm: Robot arm object with get_ee_pose(), cmd_ee_pose(), cmd_vel_ee(), cmd_gripper_ctrl(), release_gripper(), softStop(), gotoZero(), and disable() methods - camera_intrinsics: Camera intrinsics [fx, fy, cx, cy] direct_ee_control: If True, use direct EE pose control; if False, use velocity control ee_to_camera_6dof: EE to camera transform [x, y, z, rx, ry, rz] in meters and radians """ @@ -100,6 +98,15 @@ def __init__( rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) self.T_ee_to_camera = create_transform_from_6dof(pos, rot) + # Get camera intrinsics + cam_intrinsics = camera.calculate_intrinsics() + camera_intrinsics = [ + cam_intrinsics["focal_length_x"], + cam_intrinsics["focal_length_y"], + cam_intrinsics["principal_point_x"], + cam_intrinsics["principal_point_y"], + ] + # Initialize processors self.detector = Detection3DProcessor(camera_intrinsics) self.pbvs = PBVS( @@ -114,6 +121,8 @@ def __init__( self.waiting_for_reach = False # True when waiting for robot to reach commanded pose self.last_commanded_pose = None # Last pose sent to robot self.target_updated = False # True when target has been updated with fresh detections + self.waiting_start_time = None # Time when waiting for reach started + self.reach_pose_timeout = 10.0 # Timeout for reaching commanded pose (seconds) # Grasp parameters self.grasp_width_offset = 0.03 # Default grasp width offset @@ -122,14 +131,15 @@ def __init__( self.grasp_distance = 0.01 # Distance for final grasp approach (m) self.grasp_close_delay = 3.0 # Time to wait at grasp pose before closing (seconds) self.grasp_reached_time = None # Time when grasp pose was reached + self.gripper_max_opening = 0.07 # Maximum gripper opening (m) # Grasp stage tracking self.grasp_stage = GraspStage.IDLE # Pose stabilization tracking self.pose_history_size = 4 # Number of poses to check for stabilization - self.pose_stabilization_threshold = 0.005 # 1cm threshold for stabilization - self.stabilization_timeout = 10.0 # Timeout in seconds before giving up + self.pose_stabilization_threshold = 0.01 # 1cm threshold for stabilization + self.stabilization_timeout = 15.0 # Timeout in seconds before giving up self.stabilization_start_time = None # Time when stabilization started self.reached_poses = deque( maxlen=self.pose_history_size @@ -138,15 +148,11 @@ def __init__( # State for visualization self.current_visualization = None - self.last_rgb = None self.last_detection_3d_array = None self.last_detection_2d_array = None - self.last_camera_pose = None self.last_target_tracked = False - logger.info( - f"Initialized Manipulation system in {'Direct EE' if direct_ee_control else 'Velocity'} control mode" - ) + # Log initialization only if needed for debugging def set_grasp_stage(self, stage: GraspStage): """ @@ -156,7 +162,7 @@ def set_grasp_stage(self, stage: GraspStage): stage: The new grasp stage """ self.grasp_stage = stage - logger.info(f"Set grasp stage to: {stage.value}") + logger.info(f"Grasp stage: {stage.value}") def set_grasp_pitch(self, pitch_degrees: float): """ @@ -170,7 +176,22 @@ def set_grasp_pitch(self, pitch_degrees: float): pitch_degrees = max(0.0, min(90.0, pitch_degrees)) self.grasp_pitch_degrees = pitch_degrees self.pbvs.set_grasp_pitch(pitch_degrees) - logger.info(f"Set grasp pitch to: {pitch_degrees} degrees") + + def _check_reach_timeout(self) -> bool: + """ + Check if robot has exceeded timeout while reaching pose. + + Returns: + True if timeout exceeded, False otherwise + """ + if ( + self.waiting_start_time + and (time.time() - self.waiting_start_time) > self.reach_pose_timeout + ): + logger.warning(f"Robot failed to reach pose within {self.reach_pose_timeout}s timeout") + self.reset_to_idle() + return True + return False def reset_to_idle(self): """Reset the manipulation system to IDLE state.""" @@ -183,6 +204,17 @@ def reset_to_idle(self): self.target_updated = False self.stabilization_start_time = None self.grasp_reached_time = None + self.waiting_start_time = None + + def execute_idle(self) -> bool: + """ + Execute idle stage: just visualization, no control. + + Returns: + False (no target tracked in idle) + """ + # Nothing to do in idle + return False def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: """ @@ -197,11 +229,25 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: # Get EE pose ee_pose = self.arm.get_ee_pose() - # PBVS control with pre-grasp distance - vel_cmd, ang_vel_cmd, _, target_tracked, target_pose = self.pbvs.compute_control( - ee_pose, detection_3d_array, self.pregrasp_distance - ) + # Check if waiting for robot to reach commanded pose + if self.waiting_for_reach and self.last_commanded_pose: + # Check for timeout + if self._check_reach_timeout(): + return False + + reached = self.pbvs.is_target_reached(ee_pose, self.pregrasp_distance) + + if reached: + self.waiting_for_reach = False + self.waiting_start_time = None + self.reached_poses.append(self.last_commanded_pose) + self.target_updated = False # Reset flag so we wait for fresh update + time.sleep(0.3) + + # While waiting, don't process new commands + return self.last_target_tracked + # Check timeout if ( self.stabilization_start_time and (time.time() - self.stabilization_start_time) > self.stabilization_timeout @@ -214,6 +260,11 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: self.reset_to_idle() return False + # PBVS control with pre-grasp distance + vel_cmd, ang_vel_cmd, _, target_tracked, target_pose = self.pbvs.compute_control( + ee_pose, detection_3d_array, self.pregrasp_distance + ) + # Set target_updated flag if target was successfully tracked if target_tracked and target_pose: self.target_updated = True @@ -223,7 +274,7 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: if self.direct_ee_control and target_pose and target_tracked: # Check if we have enough reached poses and they're stable if self.check_target_stabilized(): - logger.info("Target stabilized, transitioning to GRASP stage") + logger.info("Target stabilized, transitioning to GRASP") self.grasp_stage = GraspStage.GRASP self.adjustment_count = 0 self.waiting_for_reach = False @@ -232,19 +283,11 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: self.arm.cmd_ee_pose(target_pose) self.last_commanded_pose = target_pose self.waiting_for_reach = True + self.waiting_start_time = time.time() # Start timeout timer self.target_updated = False # Reset flag after commanding self.adjustment_count += 1 - elapsed_time = ( - time.time() - self.stabilization_start_time - if self.stabilization_start_time - else 0 - ) - logger.info( - f"Commanded target pose: pos=({target_pose.position.x:.3f}, " - f"{target_pose.position.y:.3f}, {target_pose.position.z:.3f}), " - f"attempt {self.adjustment_count} (elapsed: {elapsed_time:.1f}s)" - ) + # Command sent to robot # Sleep for 200ms after commanding to avoid rapid commands time.sleep(0.2) @@ -267,12 +310,39 @@ def execute_grasp(self, detection_3d_array: Detection3DArray) -> bool: Returns: True if target is being tracked """ - if not self.waiting_for_reach and self.last_valid_target: - # Get EE pose - ee_pose = self.arm.get_ee_pose() + # Get EE pose + ee_pose = self.arm.get_ee_pose() + + # Check if waiting for robot to reach grasp pose + if self.waiting_for_reach: + # Check for timeout + if self._check_reach_timeout(): + return False + + reached = self.pbvs.is_target_reached(ee_pose, self.grasp_distance) + + if reached and not self.grasp_reached_time: + # First time reaching grasp pose + self.grasp_reached_time = time.time() + self.waiting_start_time = None # Reset timeout timer + # Robot reached grasp pose + + # Wait for delay then transition to CLOSE_AND_LIFT + if ( + self.grasp_reached_time + and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay + ): + logger.info("Grasp delay completed, closing gripper") + self.grasp_stage = GraspStage.CLOSE_AND_LIFT + self.waiting_for_reach = False + # While waiting, don't process new commands + return self.last_target_tracked + + # Only command grasp if not waiting and have valid target + if self.last_valid_target: # PBVS control with grasp distance - vel_cmd, ang_vel_cmd, _, target_tracked, target_pose = self.pbvs.compute_control( + _, _, _, target_tracked, target_pose = self.pbvs.compute_control( ee_pose, detection_3d_array, self.grasp_distance ) @@ -281,16 +351,15 @@ def execute_grasp(self, detection_3d_array: Detection3DArray) -> bool: object_size = self.last_valid_target.bbox.size object_width = object_size.x gripper_opening = object_width + self.grasp_width_offset - gripper_opening = max(0.005, min(gripper_opening, 0.1)) + gripper_opening = max(0.005, min(gripper_opening, self.gripper_max_opening)) - logger.info(f"Executing grasp: opening gripper to {gripper_opening * 1000:.1f}mm") - print(f"Executing grasp: opening gripper to {gripper_opening * 1000:.1f}mm") + logger.info(f"Executing grasp: gripper={gripper_opening * 1000:.1f}mm") # Command gripper to open and move to grasp pose self.arm.cmd_gripper_ctrl(gripper_opening) self.arm.cmd_ee_pose(target_pose, line_mode=True) self.waiting_for_reach = True - logger.info("Grasp pose commanded") + self.waiting_start_time = time.time() # Start timeout timer return target_tracked @@ -298,19 +367,15 @@ def execute_grasp(self, detection_3d_array: Detection3DArray) -> bool: def execute_close_and_lift(self): """Execute the close and lift sequence.""" - logger.info("Executing CLOSE_AND_LIFT sequence") - # Close gripper - logger.info("Closing gripper") self.arm.cmd_gripper_ctrl(0.0) # Close gripper completely time.sleep(0.5) # Wait for gripper to close # Return to home position - logger.info("Returning to home position") self.arm.gotoZero() # Reset to IDLE after completion - logger.info("Grasp sequence completed, returning to IDLE") + logger.info("Grasp sequence completed") self.reset_to_idle() def capture_and_process( @@ -330,18 +395,15 @@ def capture_and_process( if bgr is None or depth is None: return None, None, None, None - # Process rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - # Get EE pose from robot (this serves as our odometry) + # Get EE pose and camera transform ee_pose = self.arm.get_ee_pose() - - # Transform EE pose to camera pose ee_transform = pose_to_matrix(ee_pose) camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) camera_pose = matrix_to_pose(camera_transform) - # Process detections using camera transform + # Process detections detection_3d_array, detection_2d_array = self.detector.process_frame( rgb, depth, camera_transform ) @@ -368,133 +430,18 @@ def pick_target(self, x: int, y: int) -> bool: ) if clicked_3d: self.pbvs.set_target(clicked_3d) + logger.info( + f"Target selected: ID={clicked_3d.id}, pos=({clicked_3d.bbox.center.position.x:.3f}, {clicked_3d.bbox.center.position.y:.3f}, {clicked_3d.bbox.center.position.z:.3f})" + ) self.grasp_stage = GraspStage.PRE_GRASP # Transition from IDLE to PRE_GRASP self.reached_poses.clear() # Clear pose history self.adjustment_count = 0 # Reset adjustment counter self.waiting_for_reach = False # Ensure we're not stuck in waiting state self.last_commanded_pose = None self.stabilization_start_time = time.time() # Start the timeout timer - logger.info(f"Target selected at ({x}, {y})") return True return False - def create_visualization( - self, - rgb: np.ndarray, - detection_3d_array: Detection3DArray, - detection_2d_array: Detection2DArray, - camera_pose: Pose, - target_tracked: bool, - ) -> np.ndarray: - """ - Create visualization with detections and status overlays. - - Args: - rgb: RGB image - detection_3d_array: 3D detections - detection_2d_array: 2D detections - camera_pose: Current camera pose - target_tracked: Whether target is being tracked - - Returns: - BGR image with visualizations - """ - # Create visualization with position overlays - viz = self.detector.visualize_detections( - rgb, detection_3d_array.detections, detection_2d_array.detections - ) - - # Add PBVS status overlay - viz = self.pbvs.create_status_overlay(viz, self.grasp_stage) - - # Highlight target - current_target = self.pbvs.get_current_target() - if target_tracked and current_target: - det_2d = match_detection_by_id( - current_target, detection_3d_array.detections, detection_2d_array.detections - ) - if det_2d and det_2d.bbox: - x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) - x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) - - cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) - cv2.putText( - viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 - ) - - # Convert back to BGR for OpenCV display - viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) - - # Add pose info - mode_text = "Direct EE" if self.direct_ee_control else "Velocity" - cv2.putText( - viz_bgr, - f"Eye-in-Hand ({mode_text})", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (0, 255, 255), - 1, - ) - - # Get EE pose for display - ee_pose = self.arm.get_ee_pose() - - camera_text = f"Camera: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" - cv2.putText(viz_bgr, camera_text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1) - - ee_text = ( - f"EE: ({ee_pose.position.x:.2f}, {ee_pose.position.y:.2f}, {ee_pose.position.z:.2f})m" - ) - cv2.putText(viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) - - # Add control status for direct EE mode - if self.direct_ee_control: - if self.grasp_stage == GraspStage.IDLE: - status_text = "IDLE - Click object to select target" - status_color = (100, 100, 100) - elif self.grasp_stage == GraspStage.PRE_GRASP: - if self.waiting_for_reach: - status_text = "PRE-GRASP - Waiting for robot to reach target..." - status_color = (255, 255, 0) - else: - poses_text = f" ({len(self.reached_poses)}/{self.pose_history_size} poses)" - elapsed_time = ( - time.time() - self.stabilization_start_time - if self.stabilization_start_time - else 0 - ) - time_text = f" [{elapsed_time:.1f}s/{self.stabilization_timeout:.0f}s]" - status_text = f"PRE-GRASP - Collecting stable poses{poses_text}{time_text}" - status_color = (0, 255, 255) - elif self.grasp_stage == GraspStage.GRASP: - if self.grasp_reached_time: - time_remaining = self.grasp_close_delay - ( - time.time() - self.grasp_reached_time - ) - status_text = f"GRASP - Waiting to close ({time_remaining:.1f}s)" - else: - status_text = "GRASP - Moving to grasp pose" - status_color = (0, 255, 0) - else: # CLOSE_AND_LIFT - status_text = "CLOSE_AND_LIFT - Closing gripper and lifting" - status_color = (255, 0, 255) - - cv2.putText( - viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 - ) - cv2.putText( - viz_bgr, - "s=STOP | h=HOME | SPACE=FORCE GRASP | g=RELEASE", - (10, 110), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) - - return viz_bgr - def update(self) -> bool: """ Main update function that handles capture, processing, control, and visualization. @@ -502,97 +449,35 @@ def update(self) -> bool: Returns: True if update was successful, False if capture failed """ - # Always capture frame for visualization - bgr, _, depth, _ = self.camera.capture_frame_with_pose() - if bgr is None or depth is None: + # Capture and process frame + rgb, detection_3d_array, detection_2d_array, camera_pose = self.capture_and_process() + if rgb is None: return False - rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - - # If waiting for robot to reach target, check if reached - if self.waiting_for_reach and self.last_commanded_pose: - ee_pose = self.arm.get_ee_pose() - - if self.grasp_stage == GraspStage.GRASP: - # Check if grasp pose is reached - grasp_distance = self.grasp_distance - reached = self.pbvs.is_target_reached(ee_pose, grasp_distance) - - if reached and not self.grasp_reached_time: - # First time reaching grasp pose - self.grasp_reached_time = time.time() - logger.info( - f"Robot reached grasp pose, waiting {self.grasp_close_delay}s before closing gripper" - ) - - # Wait for delay then transition to CLOSE_AND_LIFT - if ( - self.grasp_reached_time - and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay - ): - logger.info( - f"Waited {self.grasp_close_delay}s at grasp pose, transitioning to CLOSE_AND_LIFT" - ) - self.grasp_stage = GraspStage.CLOSE_AND_LIFT - self.waiting_for_reach = False - else: - # For PRE_GRASP stage, check if reached - grasp_distance = ( - self.pregrasp_distance - if self.grasp_stage == GraspStage.PRE_GRASP - else self.grasp_distance - ) - reached = self.pbvs.is_target_reached(ee_pose, grasp_distance) - - if reached: - logger.info("Robot reached commanded pose") - self.waiting_for_reach = False - self.reached_poses.append(self.last_commanded_pose) - self.target_updated = False # Reset flag so we wait for fresh update - time.sleep(0.3) - - # Create basic visualization while waiting - self.current_visualization = self._create_waiting_visualization(rgb) - return True - - # Normal processing when not waiting - # Get EE pose and camera transform - ee_pose = self.arm.get_ee_pose() - ee_transform = pose_to_matrix(ee_pose) - camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) - camera_pose = matrix_to_pose(camera_transform) - - # Process detections - detection_3d_array, detection_2d_array = self.detector.process_frame( - rgb, depth, camera_transform - ) - # Store for target selection - self.last_rgb = rgb self.last_detection_3d_array = detection_3d_array self.last_detection_2d_array = detection_2d_array - self.last_camera_pose = camera_pose # Execute stage-specific logic target_tracked = False if self.grasp_stage == GraspStage.IDLE: - # Nothing to do in IDLE - pass + target_tracked = self.execute_idle() elif self.grasp_stage == GraspStage.PRE_GRASP: if detection_3d_array: target_tracked = self.execute_pre_grasp(detection_3d_array) - self.last_target_tracked = target_tracked elif self.grasp_stage == GraspStage.GRASP: if detection_3d_array: target_tracked = self.execute_grasp(detection_3d_array) - self.last_target_tracked = target_tracked elif self.grasp_stage == GraspStage.CLOSE_AND_LIFT: - # No visual servoing needed for close and lift self.execute_close_and_lift() - # Create full visualization - if detection_3d_array and detection_2d_array and camera_pose: + self.last_target_tracked = target_tracked + + # Create visualization + if self.waiting_for_reach: + self.current_visualization = self._create_waiting_visualization(rgb) + elif detection_3d_array and detection_2d_array and camera_pose: self.current_visualization = self.create_visualization( rgb, detection_3d_array, detection_2d_array, camera_pose, target_tracked ) @@ -635,7 +520,6 @@ def handle_keyboard_command(self, key: int) -> str: elif key == ord(" ") and self.direct_ee_control and self.pbvs.target_grasp_pose: # Manual override - immediately transition to GRASP if in PRE_GRASP if self.grasp_stage == GraspStage.PRE_GRASP: - logger.info("Manual grasp execution requested") self.set_grasp_stage(GraspStage.GRASP) print("Executing target pose") return "execute" @@ -656,6 +540,94 @@ def handle_keyboard_command(self, key: int) -> str: return "" + def create_visualization( + self, + rgb: np.ndarray, + detection_3d_array: Detection3DArray, + detection_2d_array: Detection2DArray, + camera_pose: Pose, + target_tracked: bool, + ) -> np.ndarray: + """ + Create visualization with detections and status overlays. + + Args: + rgb: RGB image + detection_3d_array: 3D detections + detection_2d_array: 2D detections + camera_pose: Current camera pose + target_tracked: Whether target is being tracked + + Returns: + BGR image with visualizations + """ + # Create visualization with position overlays + viz = self.detector.visualize_detections( + rgb, detection_3d_array.detections, detection_2d_array.detections + ) + + # Add PBVS status overlay + viz = self.pbvs.create_status_overlay(viz, self.grasp_stage) + + # Highlight target + current_target = self.pbvs.get_current_target() + if target_tracked and current_target: + det_2d = match_detection_by_id( + current_target, detection_3d_array.detections, detection_2d_array.detections + ) + if det_2d and det_2d.bbox: + x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + + cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) + cv2.putText( + viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 + ) + + # Convert back to BGR for OpenCV display + viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) + + # Add pose info + mode_text = "Direct EE" if self.direct_ee_control else "Velocity" + cv2.putText( + viz_bgr, + f"Eye-in-Hand ({mode_text})", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 255), + 1, + ) + + # Get EE pose for display + ee_pose = self.arm.get_ee_pose() + + camera_text = f"Camera: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" + cv2.putText(viz_bgr, camera_text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1) + + ee_text = ( + f"EE: ({ee_pose.position.x:.2f}, {ee_pose.position.y:.2f}, {ee_pose.position.z:.2f})m" + ) + cv2.putText(viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + # Add control status for direct EE mode + if self.direct_ee_control: + status_text, status_color = self._get_status_text_and_color() + cv2.putText( + viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 + ) + cv2.putText( + viz_bgr, + "s=STOP | h=HOME | SPACE=FORCE GRASP | g=RELEASE", + (10, 110), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) + + return viz_bgr + def _create_waiting_visualization(self, rgb: np.ndarray) -> np.ndarray: """ Create a simple visualization while waiting for robot to reach pose. @@ -715,6 +687,36 @@ def _create_waiting_visualization(self, rgb: np.ndarray) -> np.ndarray: return viz_bgr + def _get_status_text_and_color(self) -> Tuple[str, Tuple[int, int, int]]: + """ + Get status text and color based on current stage and state. + + Returns: + Tuple of (status_text, status_color) + """ + if self.grasp_stage == GraspStage.IDLE: + return "IDLE - Click object to select target", (100, 100, 100) + elif self.grasp_stage == GraspStage.PRE_GRASP: + if self.waiting_for_reach: + return "PRE-GRASP - Waiting for robot to reach target...", (255, 255, 0) + else: + poses_text = f" ({len(self.reached_poses)}/{self.pose_history_size} poses)" + elapsed_time = ( + time.time() - self.stabilization_start_time + if self.stabilization_start_time + else 0 + ) + time_text = f" [{elapsed_time:.1f}s/{self.stabilization_timeout:.0f}s]" + return f"PRE-GRASP - Collecting stable poses{poses_text}{time_text}", (0, 255, 255) + elif self.grasp_stage == GraspStage.GRASP: + if self.grasp_reached_time: + time_remaining = self.grasp_close_delay - (time.time() - self.grasp_reached_time) + return f"GRASP - Waiting to close ({time_remaining:.1f}s)", (0, 255, 0) + else: + return "GRASP - Moving to grasp pose", (0, 255, 0) + else: # CLOSE_AND_LIFT + return "CLOSE_AND_LIFT - Closing gripper and lifting", (255, 0, 255) + def check_target_stabilized(self) -> bool: """ Check if the commanded poses have stabilized. @@ -739,4 +741,3 @@ def check_target_stabilized(self) -> bool: def cleanup(self): """Clean up resources (detector only, hardware cleanup is caller's responsibility).""" self.detector.cleanup() - logger.info("Cleaned up manipulation system resources") diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index 6299b57185..b56b93eec6 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -50,7 +50,7 @@ def main(): # Configuration DIRECT_EE_CONTROL = True # True: direct EE pose control, False: velocity control - INITIAL_GRASP_PITCH_DEGREES = 30 # 0° = level grasp, 90° = top-down grasp + INITIAL_GRASP_PITCH_DEGREES = 45 # 0° = level grasp, 90° = top-down grasp print("=== PBVS Eye-in-Hand Test ===") print("Using EE pose as odometry for camera pose") @@ -81,21 +81,11 @@ def main(): zed.close() return - # Get camera intrinsics - cam_info = zed.get_camera_info() - intrinsics = [ - cam_info["left_cam"]["fx"], - cam_info["left_cam"]["fy"], - cam_info["left_cam"]["cx"], - cam_info["left_cam"]["cy"], - ] - # Initialize manipulation system try: manipulation = Manipulation( camera=zed, arm=arm, - camera_intrinsics=intrinsics, direct_ee_control=DIRECT_EE_CONTROL, ee_to_camera_6dof=[-0.06, 0.03, -0.05, 0.0, -1.57, 0.0], # Adjust for your setup ) From d51e58dbc9825caec1a311ba0bf4fe38f25123b7 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Mon, 21 Jul 2025 20:30:56 -0700 Subject: [PATCH 74/89] fixed bugs, refactoring --- dimos/hardware/piper_arm.py | 77 +++- .../visual_servoing/manipulation.py | 400 ++++++++++++------ dimos/manipulation/visual_servoing/pbvs.py | 45 +- tests/test_ibvs.py | 7 +- 4 files changed, 369 insertions(+), 160 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 02c26733b2..91e61a1f13 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -16,6 +16,7 @@ from typing import ( Optional, + Tuple, ) from piper_sdk import * # from the official Piper SDK import numpy as np @@ -82,7 +83,22 @@ def enable(self): def gotoZero(self): factor = 1000 - position = [57.0, 0.0, 250.0, 0, 97.0, 0, 0] + position = [57.0, 0.0, 215.0, 0, 90.0, 0, 0] + X = round(position[0] * factor) + Y = round(position[1] * factor) + Z = round(position[2] * factor) + RX = round(position[3] * factor) + RY = round(position[4] * factor) + RZ = round(position[5] * factor) + joint_6 = round(position[6] * factor) + logger.debug(f"Going to zero position: X={X}, Y={Y}, Z={Z}, RX={RX}, RY={RY}, RZ={RZ}") + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) + self.arm.GripperCtrl(0, 1000, 0x01, 0) + + def gotoObserve(self): + factor = 1000 + position = [57.0, 0.0, 280.0, 0, 120.0, 0, 0] X = round(position[0] * factor) Y = round(position[1] * factor) Z = round(position[2] * factor) @@ -159,12 +175,13 @@ def get_ee_pose(self): return Pose(position, orientation) - def cmd_gripper_ctrl(self, position, effort=250): + def cmd_gripper_ctrl(self, position, effort=0.25): """Command end-effector gripper""" factor = 1000 - position = position * factor * factor + position = position * factor * factor # meters + effort = effort * factor # N/m - self.arm.GripperCtrl(abs(round(position)), effort, 0x01, 0) + self.arm.GripperCtrl(abs(round(position)), abs(round(effort)), 0x01, 0) logger.debug(f"Commanding gripper position: {position}mm") def enable_gripper(self): @@ -181,6 +198,58 @@ def release_gripper(self): logger.info("Releasing gripper (opening to 100mm)") self.cmd_gripper_ctrl(0.1) # 0.1m = 100mm = 10cm + def get_gripper_feedback(self) -> Tuple[float, float]: + """ + Get current gripper feedback. + + Returns: + Tuple of (angle_degrees, effort) where: + - angle_degrees: Current gripper angle in degrees + - effort: Current gripper effort (0.0 to 1.0 range) + """ + gripper_msg = self.arm.GetArmGripperMsgs() + angle_degrees = ( + gripper_msg.gripper_state.grippers_angle / 1000.0 + ) # Convert from SDK units to degrees + effort = gripper_msg.gripper_state.grippers_effort / 1000.0 # Convert from SDK units to N/m + return angle_degrees, effort + + def close_gripper(self, commanded_effort: float = 0.25) -> Tuple[bool, bool]: + """ + Close the gripper and check if an object is grasped. + + Args: + commanded_effort: Effort to use when closing gripper (default 0.25 N/m) + + Returns: + Tuple of (gripper_closed, object_grasped) where: + - gripper_closed: True if gripper reached near-zero position + - object_grasped: True if effort > 80% of commanded effort (object detected) + """ + # Command gripper to close (0.0 position) + self.cmd_gripper_ctrl(0.0, effort=commanded_effort) + + # Wait for gripper to close + time.sleep(1.0) + + # Get gripper feedback + angle_degrees, actual_effort = self.get_gripper_feedback() + + # Check if gripper is closed (angle close to 0 within threshold) + angle_threshold = 0.02 # m + gripper_closed = abs(angle_degrees) < angle_threshold + + # Check if object is grasped (effort > 80% of commanded effort) + effort_threshold = 0.8 * commanded_effort + object_present = abs(actual_effort) > effort_threshold + + if object_present: + logger.info(f"Object detected in gripper (effort: {actual_effort:.3f} N/m)") + else: + logger.info(f"No object detected (effort: {actual_effort:.3f} N/m)") + + return gripper_closed, object_present + def resetArm(self): self.arm.MotionCtrl_1(0x02, 0, 0) self.arm.MotionCtrl_2(0, 0, 0, 0xAD) diff --git a/dimos/manipulation/visual_servoing/manipulation.py b/dimos/manipulation/visual_servoing/manipulation.py index 295f6424cb..e640a73cd8 100644 --- a/dimos/manipulation/visual_servoing/manipulation.py +++ b/dimos/manipulation/visual_servoing/manipulation.py @@ -53,7 +53,43 @@ class GraspStage(Enum): IDLE = "idle" # No target set PRE_GRASP = "pre_grasp" # Target set, moving to pre-grasp position GRASP = "grasp" # Executing final grasp - CLOSE_AND_LIFT = "close_and_lift" # Close gripper and lift + CLOSE_AND_RETRACT = "close_and_retract" # Close gripper and retract + + +class Feedback: + """ + Feedback data returned by the manipulation system update. + + Contains comprehensive state information about the manipulation process. + """ + + def __init__( + self, + grasp_stage: GraspStage, + target_tracked: bool, + last_commanded_pose: Optional[Pose] = None, + current_ee_pose: Optional[Pose] = None, + current_camera_pose: Optional[Pose] = None, + target_pose: Optional[Pose] = None, + waiting_for_reach: bool = False, + pose_count: int = 0, + max_poses: int = 0, + stabilization_time: float = 0.0, + grasp_successful: Optional[bool] = None, + adjustment_count: int = 0, + ): + self.grasp_stage = grasp_stage + self.target_tracked = target_tracked + self.last_commanded_pose = last_commanded_pose + self.current_ee_pose = current_ee_pose + self.current_camera_pose = current_camera_pose + self.target_pose = target_pose + self.waiting_for_reach = waiting_for_reach + self.pose_count = pose_count + self.max_poses = max_poses + self.stabilization_time = stabilization_time + self.grasp_successful = grasp_successful + self.adjustment_count = adjustment_count class Manipulation: @@ -72,7 +108,6 @@ def __init__( self, camera: Any, # Generic camera object with required interface arm: Any, # Generic arm object with required interface - direct_ee_control: bool = True, ee_to_camera_6dof: Optional[list] = None, ): """ @@ -80,18 +115,16 @@ def __init__( Args: camera: Camera object with capture_frame_with_pose() and calculate_intrinsics() methods - arm: Robot arm object with get_ee_pose(), cmd_ee_pose(), cmd_vel_ee(), - cmd_gripper_ctrl(), release_gripper(), softStop(), gotoZero(), and disable() methods - direct_ee_control: If True, use direct EE pose control; if False, use velocity control + arm: Robot arm object with get_ee_pose(), cmd_ee_pose(), + cmd_gripper_ctrl(), release_gripper(), softStop(), gotoZero(), gotoObserve(), and disable() methods ee_to_camera_6dof: EE to camera transform [x, y, z, rx, ry, rz] in meters and radians """ self.camera = camera self.arm = arm - self.direct_ee_control = direct_ee_control # Default EE to camera transform if not provided if ee_to_camera_6dof is None: - ee_to_camera_6dof = [-0.06, 0.03, -0.05, 0.0, -1.57, 0.0] + ee_to_camera_6dof = [-0.065, 0.03, -0.105, 0.0, -1.57, 0.0] # Create transform matrices pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) @@ -110,10 +143,7 @@ def __init__( # Initialize processors self.detector = Detection3DProcessor(camera_intrinsics) self.pbvs = PBVS( - position_gain=0.3, - rotation_gain=0.2, target_tolerance=0.05, - direct_ee_control=direct_ee_control, ) # Control state @@ -127,9 +157,9 @@ def __init__( # Grasp parameters self.grasp_width_offset = 0.03 # Default grasp width offset self.grasp_pitch_degrees = 30.0 # Default grasp pitch in degrees - self.pregrasp_distance = 0.3 # Distance to maintain before grasping (m) - self.grasp_distance = 0.01 # Distance for final grasp approach (m) - self.grasp_close_delay = 3.0 # Time to wait at grasp pose before closing (seconds) + self.pregrasp_distance = 0.25 # Distance to maintain before grasping (m) + self.grasp_distance_range = 0.03 # Range for grasp distance mapping (±5cm = ±0.05m) + self.grasp_close_delay = 2.0 # Time to wait at grasp pose before closing (seconds) self.grasp_reached_time = None # Time when grasp pose was reached self.gripper_max_opening = 0.07 # Maximum gripper opening (m) @@ -150,9 +180,13 @@ def __init__( self.current_visualization = None self.last_detection_3d_array = None self.last_detection_2d_array = None - self.last_target_tracked = False - # Log initialization only if needed for debugging + # Grasp result + self.last_grasp_successful = None # True if last grasp was successful + self.final_pregrasp_pose = None # Store the final pre-grasp pose for retraction + + # Go to observe position + self.arm.gotoObserve() def set_grasp_stage(self, stage: GraspStage): """ @@ -205,27 +239,18 @@ def reset_to_idle(self): self.stabilization_start_time = None self.grasp_reached_time = None self.waiting_start_time = None + self.last_grasp_successful = None + self.final_pregrasp_pose = None - def execute_idle(self) -> bool: - """ - Execute idle stage: just visualization, no control. + self.arm.gotoObserve() - Returns: - False (no target tracked in idle) - """ + def execute_idle(self): + """Execute idle stage: just visualization, no control.""" # Nothing to do in idle - return False - - def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: - """ - Execute pre-grasp stage: visual servoing to pre-grasp position. - - Args: - detection_3d_array: Current 3D detections + pass - Returns: - True if target is being tracked - """ + def execute_pre_grasp(self, detection_3d_array: Detection3DArray): + """Execute pre-grasp stage: visual servoing to pre-grasp position.""" # Get EE pose ee_pose = self.arm.get_ee_pose() @@ -233,9 +258,9 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: if self.waiting_for_reach and self.last_commanded_pose: # Check for timeout if self._check_reach_timeout(): - return False + return - reached = self.pbvs.is_target_reached(ee_pose, self.pregrasp_distance) + reached = self.pbvs.is_target_reached(ee_pose) if reached: self.waiting_for_reach = False @@ -245,7 +270,7 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: time.sleep(0.3) # While waiting, don't process new commands - return self.last_target_tracked + return # Check timeout if ( @@ -255,26 +280,30 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: logger.warning( f"Failed to get stable grasp after {self.stabilization_timeout} seconds, resetting" ) - self.arm.gotoZero() - time.sleep(1.0) + self.reset_to_idle() - return False + return + + # Update tracking with new detections + target_tracked = False + if detection_3d_array: + target_tracked = self.pbvs.update_tracking(detection_3d_array) + if target_tracked: + self.target_updated = True + self.last_valid_target = self.pbvs.get_current_target() # PBVS control with pre-grasp distance - vel_cmd, ang_vel_cmd, _, target_tracked, target_pose = self.pbvs.compute_control( - ee_pose, detection_3d_array, self.pregrasp_distance + _, _, _, has_target, target_pose = self.pbvs.compute_control( + ee_pose, self.pregrasp_distance ) - # Set target_updated flag if target was successfully tracked - if target_tracked and target_pose: - self.target_updated = True - self.last_valid_target = self.pbvs.get_current_target() - - # Handle direct EE control - if self.direct_ee_control and target_pose and target_tracked: + # Handle pose control + if target_pose and has_target: # Check if we have enough reached poses and they're stable if self.check_target_stabilized(): logger.info("Target stabilized, transitioning to GRASP") + # Store the final pre-grasp pose for retraction + self.final_pregrasp_pose = self.last_commanded_pose self.grasp_stage = GraspStage.GRASP self.adjustment_count = 0 self.waiting_for_reach = False @@ -287,29 +316,11 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray) -> bool: self.target_updated = False # Reset flag after commanding self.adjustment_count += 1 - # Command sent to robot - # Sleep for 200ms after commanding to avoid rapid commands time.sleep(0.2) - elif not self.direct_ee_control and vel_cmd and ang_vel_cmd: - # Velocity control - self.arm.cmd_vel_ee( - vel_cmd.x, vel_cmd.y, vel_cmd.z, ang_vel_cmd.x, ang_vel_cmd.y, ang_vel_cmd.z - ) - - return target_tracked - - def execute_grasp(self, detection_3d_array: Detection3DArray) -> bool: - """ - Execute grasp stage: move to final grasp position. - - Args: - detection_3d_array: Current 3D detections - - Returns: - True if target is being tracked - """ + def execute_grasp(self, detection_3d_array: Detection3DArray): + """Execute grasp stage: move to final grasp position.""" # Get EE pose ee_pose = self.arm.get_ee_pose() @@ -317,9 +328,9 @@ def execute_grasp(self, detection_3d_array: Detection3DArray) -> bool: if self.waiting_for_reach: # Check for timeout if self._check_reach_timeout(): - return False + return - reached = self.pbvs.is_target_reached(ee_pose, self.grasp_distance) + reached = self.pbvs.is_target_reached(ee_pose) if reached and not self.grasp_reached_time: # First time reaching grasp pose @@ -327,26 +338,37 @@ def execute_grasp(self, detection_3d_array: Detection3DArray) -> bool: self.waiting_start_time = None # Reset timeout timer # Robot reached grasp pose - # Wait for delay then transition to CLOSE_AND_LIFT + # Wait for delay then transition to CLOSE_AND_RETRACT if ( self.grasp_reached_time and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay ): logger.info("Grasp delay completed, closing gripper") - self.grasp_stage = GraspStage.CLOSE_AND_LIFT + self.grasp_stage = GraspStage.CLOSE_AND_RETRACT self.waiting_for_reach = False # While waiting, don't process new commands - return self.last_target_tracked + return # Only command grasp if not waiting and have valid target if self.last_valid_target: - # PBVS control with grasp distance - _, _, _, target_tracked, target_pose = self.pbvs.compute_control( - ee_pose, detection_3d_array, self.grasp_distance + # Update tracking with new detections + if detection_3d_array: + self.pbvs.update_tracking(detection_3d_array) + + # Calculate grasp distance based on pitch angle + # Maps from -grasp_distance_range to +grasp_distance_range based on pitch + # 0° pitch (level grasp) -> -5cm (move 5cm closer to object) + # 90° pitch (top-down) -> +5cm (stay 5cm farther from object) + normalized_pitch = self.grasp_pitch_degrees / 90.0 # 0.0 to 1.0 + grasp_distance = -self.grasp_distance_range + ( + 2 * self.grasp_distance_range * normalized_pitch ) - if self.direct_ee_control and target_pose and target_tracked: + # PBVS control with calculated grasp distance + _, _, _, has_target, target_pose = self.pbvs.compute_control(ee_pose, grasp_distance) + + if target_pose and has_target: # Get object size and calculate gripper opening object_size = self.last_valid_target.bbox.size object_width = object_size.x @@ -361,22 +383,47 @@ def execute_grasp(self, detection_3d_array: Detection3DArray) -> bool: self.waiting_for_reach = True self.waiting_start_time = time.time() # Start timeout timer - return target_tracked + def execute_close_and_retract(self): + """Execute the close and retract sequence.""" + # Close gripper and check if object is grasped (only once) + if self.last_grasp_successful is None: + _, object_present = self.arm.close_gripper() + self.last_grasp_successful = object_present - return False + if object_present: + logger.info("Object successfully grasped!") + else: + logger.warning("No object detected in gripper") - def execute_close_and_lift(self): - """Execute the close and lift sequence.""" - # Close gripper - self.arm.cmd_gripper_ctrl(0.0) # Close gripper completely - time.sleep(0.5) # Wait for gripper to close + # Get EE pose + ee_pose = self.arm.get_ee_pose() - # Return to home position - self.arm.gotoZero() + # Check if waiting for robot to reach retraction pose + if self.waiting_for_reach: + # Check for timeout + if self._check_reach_timeout(): + return - # Reset to IDLE after completion - logger.info("Grasp sequence completed") - self.reset_to_idle() + # Temporarily set PBVS target to check if reached + original_target = self.pbvs.target_grasp_pose + self.pbvs.target_grasp_pose = self.final_pregrasp_pose + reached = self.pbvs.is_target_reached(ee_pose) + self.pbvs.target_grasp_pose = original_target + + if reached: + logger.info("Reached pre-grasp retraction position") + self.waiting_for_reach = False + logger.info( + f"Grasp sequence completed (object_grasped={self.last_grasp_successful})" + ) + self.reset_to_idle() + + else: + # Command retraction to pre-grasp + logger.info("Retracting to pre-grasp position") + self.arm.cmd_ee_pose(self.final_pregrasp_pose, line_mode=True) + self.waiting_for_reach = True + self.waiting_start_time = time.time() def capture_and_process( self, @@ -442,37 +489,36 @@ def pick_target(self, x: int, y: int) -> bool: return True return False - def update(self) -> bool: + def update(self) -> Optional[Feedback]: """ Main update function that handles capture, processing, control, and visualization. Returns: - True if update was successful, False if capture failed + Feedback object with current state information, or None if capture failed """ # Capture and process frame rgb, detection_3d_array, detection_2d_array, camera_pose = self.capture_and_process() if rgb is None: - return False + return None # Store for target selection self.last_detection_3d_array = detection_3d_array self.last_detection_2d_array = detection_2d_array # Execute stage-specific logic - target_tracked = False - if self.grasp_stage == GraspStage.IDLE: - target_tracked = self.execute_idle() + self.execute_idle() elif self.grasp_stage == GraspStage.PRE_GRASP: if detection_3d_array: - target_tracked = self.execute_pre_grasp(detection_3d_array) + self.execute_pre_grasp(detection_3d_array) elif self.grasp_stage == GraspStage.GRASP: if detection_3d_array: - target_tracked = self.execute_grasp(detection_3d_array) - elif self.grasp_stage == GraspStage.CLOSE_AND_LIFT: - self.execute_close_and_lift() + self.execute_grasp(detection_3d_array) + elif self.grasp_stage == GraspStage.CLOSE_AND_RETRACT: + self.execute_close_and_retract() - self.last_target_tracked = target_tracked + # Update tracking status from PBVS + target_tracked = self.pbvs.get_current_target() is not None # Create visualization if self.waiting_for_reach: @@ -485,7 +531,34 @@ def update(self) -> bool: # Basic visualization with just the RGB image self.current_visualization = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) - return True + # Get current EE pose + ee_pose = self.arm.get_ee_pose() + + # Calculate stabilization time if applicable + stabilization_time = 0.0 + if self.stabilization_start_time: + stabilization_time = time.time() - self.stabilization_start_time + + # Get target pose from PBVS if available + target_pose = None + if self.pbvs.target_grasp_pose: + target_pose = self.pbvs.target_grasp_pose + + # Create and return feedback + return Feedback( + grasp_stage=self.grasp_stage, + target_tracked=target_tracked, + last_commanded_pose=self.last_commanded_pose, + current_ee_pose=ee_pose, + current_camera_pose=camera_pose, + target_pose=target_pose, + waiting_for_reach=self.waiting_for_reach, + pose_count=len(self.reached_poses), + max_poses=self.pose_history_size, + stabilization_time=stabilization_time, + grasp_successful=self.last_grasp_successful, + adjustment_count=self.adjustment_count, + ) def get_visualization(self) -> Optional[np.ndarray]: """ @@ -513,11 +586,7 @@ def handle_keyboard_command(self, key: int) -> str: print("SOFT STOP - Emergency stopping robot!") self.arm.softStop() return "stop" - elif key == ord("h"): - print("GO HOME - Returning to safe position...") - self.arm.gotoZero() - return "home" - elif key == ord(" ") and self.direct_ee_control and self.pbvs.target_grasp_pose: + elif key == ord(" ") and self.pbvs.target_grasp_pose: # Manual override - immediately transition to GRASP if in PRE_GRASP if self.grasp_stage == GraspStage.PRE_GRASP: self.set_grasp_stage(GraspStage.GRASP) @@ -588,10 +657,9 @@ def create_visualization( viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) # Add pose info - mode_text = "Direct EE" if self.direct_ee_control else "Velocity" cv2.putText( viz_bgr, - f"Eye-in-Hand ({mode_text})", + "Eye-in-Hand Visual Servoing", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, @@ -610,21 +678,18 @@ def create_visualization( ) cv2.putText(viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) - # Add control status for direct EE mode - if self.direct_ee_control: - status_text, status_color = self._get_status_text_and_color() - cv2.putText( - viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1 - ) - cv2.putText( - viz_bgr, - "s=STOP | h=HOME | SPACE=FORCE GRASP | g=RELEASE", - (10, 110), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) + # Add control status + status_text, status_color = self._get_status_text_and_color() + cv2.putText(viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1) + cv2.putText( + viz_bgr, + "s=STOP | r=RESET | SPACE=FORCE GRASP | g=RELEASE", + (10, 110), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) return viz_bgr @@ -714,8 +779,8 @@ def _get_status_text_and_color(self) -> Tuple[str, Tuple[int, int, int]]: return f"GRASP - Waiting to close ({time_remaining:.1f}s)", (0, 255, 0) else: return "GRASP - Moving to grasp pose", (0, 255, 0) - else: # CLOSE_AND_LIFT - return "CLOSE_AND_LIFT - Closing gripper and lifting", (255, 0, 255) + else: # CLOSE_AND_RETRACT + return "CLOSE_AND_RETRACT - Closing gripper and retracting", (255, 0, 255) def check_target_stabilized(self) -> bool: """ @@ -738,6 +803,93 @@ def check_target_stabilized(self) -> bool: # Check if all axes are below threshold return np.all(std_devs < self.pose_stabilization_threshold) + def pick_and_place( + self, object_point: Tuple[int, int], target_point: Optional[Tuple[int, int]] = None + ) -> bool: + """ + Execute a complete pick and place operation. + + Similar to navigate_path_local, this function handles the complete pick operation + autonomously, including object selection, grasping, and optional placement. + + Args: + object_point: (x, y) pixel coordinates of the object to pick + target_point: Optional (x, y) pixel coordinates for placement (not implemented yet) + + Returns: + True if object was successfully picked, False otherwise + """ + # Validate input + if not isinstance(object_point, tuple) or len(object_point) != 2: + logger.error(f"Invalid object_point: {object_point}. Expected (x, y) tuple.") + return False + + logger.info(f"Starting pick operation at pixel ({object_point[0]}, {object_point[1]})") + + # Reset to ensure clean state + self.reset_to_idle() + + # Configuration + max_operation_time = 60.0 # Maximum time for complete pick operation + perception_init_time = 2.0 # Time to allow perception to stabilize + + # Wait for perception to initialize + init_start = time.time() + perception_ready = False + + while (time.time() - init_start) < perception_init_time: + feedback = self.update() + if feedback is not None: + perception_ready = True + time.sleep(0.1) + + if not perception_ready: + logger.error("Perception system failed to initialize") + return False + + # Select the target object + x, y = object_point + try: + if not self.pick_target(x, y): + logger.error(f"No valid object detected at pixel ({x}, {y})") + return False + except Exception as e: + logger.error(f"Exception during target selection: {e}") + return False + + # Execute pick operation + operation_start = time.time() + + while (time.time() - operation_start) < max_operation_time: + try: + # Update the manipulation system + feedback = self.update() + if feedback is None: + logger.error("Lost perception during pick operation") + self.reset_to_idle() + return False + + # Check if grasp sequence completed + if feedback.grasp_successful is not None: + if feedback.grasp_successful: + logger.info("Object successfully grasped") + if target_point: + logger.info("Place operation not yet implemented") + return True + else: + logger.warning("Grasp attempt failed - no object detected in gripper") + return False + + except Exception as e: + logger.error(f"Unexpected error during pick operation: {e}") + self.reset_to_idle() + return False + + # Operation timeout + logger.error(f"Pick operation exceeded maximum time of {max_operation_time}s") + self.reset_to_idle() + return False + def cleanup(self): """Clean up resources (detector only, hardware cleanup is caller's responsibility).""" self.detector.cleanup() diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index ad49b49856..61afad667b 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -58,9 +58,9 @@ def __init__( max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s target_tolerance: float = 0.01, # 1cm - max_tracking_distance_threshold: float = 0.08, # Max distance for target tracking (m) + max_tracking_distance_threshold: float = 0.1, # Max distance for target tracking (m) min_size_similarity: float = 0.6, # Min size similarity threshold (0.0-1.0) - direct_ee_control: bool = False, # If True, output target poses instead of velocities + direct_ee_control: bool = True, # If True, output target poses instead of velocities ): """ Initialize PBVS system. @@ -162,7 +162,7 @@ def set_grasp_pitch(self, pitch_degrees: float): # Reset target grasp pose to recompute with new pitch self.target_grasp_pose = None - def is_target_reached(self, ee_pose: Pose, grasp_distance: float) -> bool: + def is_target_reached(self, ee_pose: Pose) -> bool: """ Check if the current target stage has been reached. @@ -183,17 +183,18 @@ def is_target_reached(self, ee_pose: Pose, grasp_distance: float) -> bool: error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) return error_magnitude < self.target_tolerance - def update_target_tracking(self, new_detections: Detection3DArray) -> bool: + def update_tracking(self, new_detections: Optional[Detection3DArray] = None) -> bool: """ - Update target by matching to closest object in new detections. + Update target tracking with new detections. If tracking is lost, keeps the old target pose. Args: - new_detections: List of newly detected objects + new_detections: Optional new detections for target tracking Returns: True if target was successfully tracked, False if lost (but target is kept) """ + # Check if we have a current target if ( not self.current_target or not self.current_target.bbox @@ -201,7 +202,9 @@ def update_target_tracking(self, new_detections: Detection3DArray) -> bool: ): return False - if not new_detections or new_detections.detections_length == 0: + # Try to update target tracking if new detections provided + # Continue with last known pose even if tracking is lost + if new_detections is None or new_detections.detections_length == 0: logger.debug("No detections for target tracking - using last known pose") return False @@ -257,7 +260,7 @@ def _update_target_grasp_pose(self, ee_pose: Pose, grasp_distance: float): # Create target pose with proper orientation # Convert grasp pitch from degrees to radians with mapping: # 0° (level) -> π/2 (1.57 rad), 90° (top-down) -> π (3.14 rad) - pitch_radians = 1.57 + (self.grasp_pitch_degrees * np.pi / 180.0 / 2.0) + pitch_radians = 1.57 + np.radians(self.grasp_pitch_degrees) # Convert euler angles to quaternion using utility function euler = Vector3(0.0, pitch_radians, yaw_to_ee) # roll=0, pitch=mapped, yaw=calculated @@ -303,7 +306,6 @@ def _apply_grasp_distance(self, target_pose: Pose, distance: float) -> Pose: def compute_control( self, ee_pose: Pose, - new_detections: Optional[Detection3DArray] = None, grasp_distance: float = 0.15, ) -> Tuple[Optional[Vector3], Optional[Vector3], bool, bool, Optional[Pose]]: """ @@ -311,7 +313,6 @@ def compute_control( Args: ee_pose: Current end-effector pose - new_detections: Optional new detections for target tracking grasp_distance: Distance to maintain from target (meters) Returns: @@ -330,20 +331,6 @@ def compute_control( ): return None, None, False, False, None - # Try to update target tracking if new detections provided - # Continue with last known pose even if tracking is lost - target_tracked = False - if new_detections is not None: - if self.update_target_tracking(new_detections): - target_tracked = True - else: - target_tracked = False - - # Update target grasp pose - if not self.current_target: - logger.info("No current target") - return None, None, False, False, None - # Update target grasp pose with provided distance self._update_target_grasp_pose(ee_pose, grasp_distance) @@ -360,22 +347,24 @@ def compute_control( ) # Check if target reached using our separate function - target_reached = self.is_target_reached(ee_pose, grasp_distance) + target_reached = self.is_target_reached(ee_pose) # Return appropriate values based on control mode if self.direct_ee_control: # Direct control mode if self.target_grasp_pose: self.last_target_reached = target_reached - return None, None, target_reached, target_tracked, self.target_grasp_pose + # Return has_target=True since we have a target + return None, None, target_reached, True, self.target_grasp_pose else: - return None, None, False, target_tracked, None + return None, None, False, True, None else: # Velocity control mode - use controller velocity_cmd, angular_velocity_cmd, controller_reached = ( self.controller.compute_control(ee_pose, self.target_grasp_pose) ) - return velocity_cmd, angular_velocity_cmd, target_reached, target_tracked, None + # Return has_target=True since we have a target, regardless of tracking status + return velocity_cmd, angular_velocity_cmd, target_reached, True, None def create_status_overlay( self, diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index b56b93eec6..4af038b0a0 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -50,7 +50,7 @@ def main(): # Configuration DIRECT_EE_CONTROL = True # True: direct EE pose control, False: velocity control - INITIAL_GRASP_PITCH_DEGREES = 45 # 0° = level grasp, 90° = top-down grasp + INITIAL_GRASP_PITCH_DEGREES = 30 # 0° = level grasp, 90° = top-down grasp print("=== PBVS Eye-in-Hand Test ===") print("Using EE pose as odometry for camera pose") @@ -86,8 +86,6 @@ def main(): manipulation = Manipulation( camera=zed, arm=arm, - direct_ee_control=DIRECT_EE_CONTROL, - ee_to_camera_6dof=[-0.06, 0.03, -0.05, 0.0, -1.57, 0.0], # Adjust for your setup ) except Exception as e: print(f"Failed to initialize manipulation system: {e}") @@ -105,7 +103,8 @@ def main(): try: while True: # Update manipulation system - if not manipulation.update(): + feedback = manipulation.update() + if feedback is None: continue # Handle mouse click From c43bd53d2d65f22431ad8d30c7815a3612d3c538 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 22 Jul 2025 12:14:22 -0700 Subject: [PATCH 75/89] more cleanup --- dimos/hardware/piper_arm.py | 27 ++- .../visual_servoing/manipulation.py | 184 ++++++++---------- 2 files changed, 93 insertions(+), 118 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 91e61a1f13..910083ed3e 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -214,31 +214,30 @@ def get_gripper_feedback(self) -> Tuple[float, float]: effort = gripper_msg.gripper_state.grippers_effort / 1000.0 # Convert from SDK units to N/m return angle_degrees, effort - def close_gripper(self, commanded_effort: float = 0.25) -> Tuple[bool, bool]: + def close_gripper(self, commanded_effort: float = 0.25) -> None: """ - Close the gripper and check if an object is grasped. + Close the gripper. Args: commanded_effort: Effort to use when closing gripper (default 0.25 N/m) - - Returns: - Tuple of (gripper_closed, object_grasped) where: - - gripper_closed: True if gripper reached near-zero position - - object_grasped: True if effort > 80% of commanded effort (object detected) """ # Command gripper to close (0.0 position) self.cmd_gripper_ctrl(0.0, effort=commanded_effort) + logger.info("Closing gripper") + + def gripper_object_detected(self, commanded_effort: float = 0.25) -> bool: + """ + Check if an object is detected in the gripper based on effort feedback. - # Wait for gripper to close - time.sleep(1.0) + Args: + commanded_effort: The effort that was used when closing gripper (default 0.25 N/m) + Returns: + True if object is detected in gripper, False otherwise + """ # Get gripper feedback angle_degrees, actual_effort = self.get_gripper_feedback() - # Check if gripper is closed (angle close to 0 within threshold) - angle_threshold = 0.02 # m - gripper_closed = abs(angle_degrees) < angle_threshold - # Check if object is grasped (effort > 80% of commanded effort) effort_threshold = 0.8 * commanded_effort object_present = abs(actual_effort) > effort_threshold @@ -248,7 +247,7 @@ def close_gripper(self, commanded_effort: float = 0.25) -> Tuple[bool, bool]: else: logger.info(f"No object detected (effort: {actual_effort:.3f} N/m)") - return gripper_closed, object_present + return object_present def resetArm(self): self.arm.MotionCtrl_1(0x02, 0, 0) diff --git a/dimos/manipulation/visual_servoing/manipulation.py b/dimos/manipulation/visual_servoing/manipulation.py index e640a73cd8..bf9713c905 100644 --- a/dimos/manipulation/visual_servoing/manipulation.py +++ b/dimos/manipulation/visual_servoing/manipulation.py @@ -182,7 +182,7 @@ def __init__( self.last_detection_2d_array = None # Grasp result - self.last_grasp_successful = None # True if last grasp was successful + self.pick_success = None # True if last grasp was successful self.final_pregrasp_pose = None # Store the final pre-grasp pose for retraction # Go to observe position @@ -227,6 +227,25 @@ def _check_reach_timeout(self) -> bool: return True return False + def _update_tracking(self, detection_3d_array: Optional[Detection3DArray]) -> bool: + """ + Update tracking with new detections in a compact way. + + Args: + detection_3d_array: Optional detection array + + Returns: + True if target is tracked + """ + if not detection_3d_array: + return False + + target_tracked = self.pbvs.update_tracking(detection_3d_array) + if target_tracked: + self.target_updated = True + self.last_valid_target = self.pbvs.get_current_target() + return target_tracked + def reset_to_idle(self): """Reset the manipulation system to IDLE state.""" self.pbvs.clear_target() @@ -239,7 +258,7 @@ def reset_to_idle(self): self.stabilization_start_time = None self.grasp_reached_time = None self.waiting_start_time = None - self.last_grasp_successful = None + self.pick_success = None self.final_pregrasp_pose = None self.arm.gotoObserve() @@ -249,9 +268,8 @@ def execute_idle(self): # Nothing to do in idle pass - def execute_pre_grasp(self, detection_3d_array: Detection3DArray): + def execute_pre_grasp(self): """Execute pre-grasp stage: visual servoing to pre-grasp position.""" - # Get EE pose ee_pose = self.arm.get_ee_pose() # Check if waiting for robot to reach commanded pose @@ -272,7 +290,7 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray): # While waiting, don't process new commands return - # Check timeout + # Check stabilization timeout if ( self.stabilization_start_time and (time.time() - self.stabilization_start_time) > self.stabilization_timeout @@ -280,18 +298,9 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray): logger.warning( f"Failed to get stable grasp after {self.stabilization_timeout} seconds, resetting" ) - self.reset_to_idle() return - # Update tracking with new detections - target_tracked = False - if detection_3d_array: - target_tracked = self.pbvs.update_tracking(detection_3d_array) - if target_tracked: - self.target_updated = True - self.last_valid_target = self.pbvs.get_current_target() - # PBVS control with pre-grasp distance _, _, _, has_target, target_pose = self.pbvs.compute_control( ee_pose, self.pregrasp_distance @@ -302,7 +311,6 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray): # Check if we have enough reached poses and they're stable if self.check_target_stabilized(): logger.info("Target stabilized, transitioning to GRASP") - # Store the final pre-grasp pose for retraction self.final_pregrasp_pose = self.last_commanded_pose self.grasp_stage = GraspStage.GRASP self.adjustment_count = 0 @@ -312,33 +320,25 @@ def execute_pre_grasp(self, detection_3d_array: Detection3DArray): self.arm.cmd_ee_pose(target_pose) self.last_commanded_pose = target_pose self.waiting_for_reach = True - self.waiting_start_time = time.time() # Start timeout timer - self.target_updated = False # Reset flag after commanding + self.waiting_start_time = time.time() + self.target_updated = False self.adjustment_count += 1 - - # Sleep for 200ms after commanding to avoid rapid commands time.sleep(0.2) - def execute_grasp(self, detection_3d_array: Detection3DArray): + def execute_grasp(self): """Execute grasp stage: move to final grasp position.""" - # Get EE pose ee_pose = self.arm.get_ee_pose() - # Check if waiting for robot to reach grasp pose + # Handle waiting with special grasp logic if self.waiting_for_reach: - # Check for timeout if self._check_reach_timeout(): return - reached = self.pbvs.is_target_reached(ee_pose) - - if reached and not self.grasp_reached_time: - # First time reaching grasp pose + if self.pbvs.is_target_reached(ee_pose) and not self.grasp_reached_time: self.grasp_reached_time = time.time() - self.waiting_start_time = None # Reset timeout timer - # Robot reached grasp pose + self.waiting_start_time = None - # Wait for delay then transition to CLOSE_AND_RETRACT + # Check if delay completed if ( self.grasp_reached_time and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay @@ -346,21 +346,12 @@ def execute_grasp(self, detection_3d_array: Detection3DArray): logger.info("Grasp delay completed, closing gripper") self.grasp_stage = GraspStage.CLOSE_AND_RETRACT self.waiting_for_reach = False - - # While waiting, don't process new commands return # Only command grasp if not waiting and have valid target if self.last_valid_target: - # Update tracking with new detections - if detection_3d_array: - self.pbvs.update_tracking(detection_3d_array) - - # Calculate grasp distance based on pitch angle - # Maps from -grasp_distance_range to +grasp_distance_range based on pitch - # 0° pitch (level grasp) -> -5cm (move 5cm closer to object) - # 90° pitch (top-down) -> +5cm (stay 5cm farther from object) - normalized_pitch = self.grasp_pitch_degrees / 90.0 # 0.0 to 1.0 + # Calculate grasp distance based on pitch angle (0° -> -5cm, 90° -> +5cm) + normalized_pitch = self.grasp_pitch_degrees / 90.0 grasp_distance = -self.grasp_distance_range + ( 2 * self.grasp_distance_range * normalized_pitch ) @@ -369,42 +360,29 @@ def execute_grasp(self, detection_3d_array: Detection3DArray): _, _, _, has_target, target_pose = self.pbvs.compute_control(ee_pose, grasp_distance) if target_pose and has_target: - # Get object size and calculate gripper opening - object_size = self.last_valid_target.bbox.size - object_width = object_size.x - gripper_opening = object_width + self.grasp_width_offset - gripper_opening = max(0.005, min(gripper_opening, self.gripper_max_opening)) + # Calculate gripper opening + object_width = self.last_valid_target.bbox.size.x + gripper_opening = max( + 0.005, min(object_width + self.grasp_width_offset, self.gripper_max_opening) + ) logger.info(f"Executing grasp: gripper={gripper_opening * 1000:.1f}mm") - # Command gripper to open and move to grasp pose + # Command gripper and pose self.arm.cmd_gripper_ctrl(gripper_opening) self.arm.cmd_ee_pose(target_pose, line_mode=True) self.waiting_for_reach = True - self.waiting_start_time = time.time() # Start timeout timer + self.waiting_start_time = time.time() def execute_close_and_retract(self): - """Execute the close and retract sequence.""" - # Close gripper and check if object is grasped (only once) - if self.last_grasp_successful is None: - _, object_present = self.arm.close_gripper() - self.last_grasp_successful = object_present - - if object_present: - logger.info("Object successfully grasped!") - else: - logger.warning("No object detected in gripper") - - # Get EE pose + """Execute the retraction sequence after gripper has been closed.""" ee_pose = self.arm.get_ee_pose() - # Check if waiting for robot to reach retraction pose if self.waiting_for_reach: - # Check for timeout if self._check_reach_timeout(): return - # Temporarily set PBVS target to check if reached + # Check if reached retraction pose original_target = self.pbvs.target_grasp_pose self.pbvs.target_grasp_pose = self.final_pregrasp_pose reached = self.pbvs.is_target_reached(ee_pose) @@ -413,15 +391,18 @@ def execute_close_and_retract(self): if reached: logger.info("Reached pre-grasp retraction position") self.waiting_for_reach = False - logger.info( - f"Grasp sequence completed (object_grasped={self.last_grasp_successful})" - ) + self.pick_success = self.arm.gripper_object_detected() + logger.info(f"Grasp sequence completed") + if self.pick_success: + logger.info("Object successfully grasped!") + else: + logger.warning("No object detected in gripper") self.reset_to_idle() - else: # Command retraction to pre-grasp logger.info("Retracting to pre-grasp position") self.arm.cmd_ee_pose(self.final_pregrasp_pose, line_mode=True) + self.arm.close_gripper() self.waiting_for_reach = True self.waiting_start_time = time.time() @@ -505,58 +486,53 @@ def update(self) -> Optional[Feedback]: self.last_detection_3d_array = detection_3d_array self.last_detection_2d_array = detection_2d_array - # Execute stage-specific logic - if self.grasp_stage == GraspStage.IDLE: - self.execute_idle() - elif self.grasp_stage == GraspStage.PRE_GRASP: - if detection_3d_array: - self.execute_pre_grasp(detection_3d_array) - elif self.grasp_stage == GraspStage.GRASP: - if detection_3d_array: - self.execute_grasp(detection_3d_array) - elif self.grasp_stage == GraspStage.CLOSE_AND_RETRACT: - self.execute_close_and_retract() + # Update tracking if we have detections and not in IDLE or CLOSE_AND_RETRACT + # Only update if not waiting for reach (to ensure fresh updates after reaching) + if ( + detection_3d_array + and self.grasp_stage in [GraspStage.PRE_GRASP, GraspStage.GRASP] + and not self.waiting_for_reach + ): + self._update_tracking(detection_3d_array) - # Update tracking status from PBVS + # Execute stage-specific logic + stage_handlers = { + GraspStage.IDLE: self.execute_idle, + GraspStage.PRE_GRASP: self.execute_pre_grasp, + GraspStage.GRASP: self.execute_grasp, + GraspStage.CLOSE_AND_RETRACT: self.execute_close_and_retract, + } + if self.grasp_stage in stage_handlers: + stage_handlers[self.grasp_stage]() + + # Get tracking status and create visualization target_tracked = self.pbvs.get_current_target() is not None - - # Create visualization - if self.waiting_for_reach: - self.current_visualization = self._create_waiting_visualization(rgb) - elif detection_3d_array and detection_2d_array and camera_pose: - self.current_visualization = self.create_visualization( + self.current_visualization = ( + self._create_waiting_visualization(rgb) + if self.waiting_for_reach + else self.create_visualization( rgb, detection_3d_array, detection_2d_array, camera_pose, target_tracked ) - else: - # Basic visualization with just the RGB image - self.current_visualization = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) - - # Get current EE pose - ee_pose = self.arm.get_ee_pose() - - # Calculate stabilization time if applicable - stabilization_time = 0.0 - if self.stabilization_start_time: - stabilization_time = time.time() - self.stabilization_start_time - - # Get target pose from PBVS if available - target_pose = None - if self.pbvs.target_grasp_pose: - target_pose = self.pbvs.target_grasp_pose + if detection_3d_array and detection_2d_array and camera_pose + else cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + ) # Create and return feedback + ee_pose = self.arm.get_ee_pose() return Feedback( grasp_stage=self.grasp_stage, target_tracked=target_tracked, last_commanded_pose=self.last_commanded_pose, current_ee_pose=ee_pose, current_camera_pose=camera_pose, - target_pose=target_pose, + target_pose=self.pbvs.target_grasp_pose, waiting_for_reach=self.waiting_for_reach, pose_count=len(self.reached_poses), max_poses=self.pose_history_size, - stabilization_time=stabilization_time, - grasp_successful=self.last_grasp_successful, + stabilization_time=time.time() - self.stabilization_start_time + if self.stabilization_start_time + else 0.0, + grasp_successful=self.pick_success, adjustment_count=self.adjustment_count, ) From b3484b685e521bda8b0537c929dfb7bcfe7a9a4e Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 22 Jul 2025 19:01:56 -0700 Subject: [PATCH 76/89] turned ZED into a module and simplified the visualization logic --- dimos/hardware/zed_camera.py | 354 +++++++++++- .../visual_servoing/manipulation.py | 303 +---------- dimos/manipulation/visual_servoing/pbvs.py | 39 +- dimos/manipulation/visual_servoing/utils.py | 508 ++++++++---------- tests/test_zed_module.py | 276 ++++++++++ 5 files changed, 873 insertions(+), 607 deletions(-) create mode 100644 tests/test_zed_module.py diff --git a/dimos/hardware/zed_camera.py b/dimos/hardware/zed_camera.py index a2ceeba54e..df7ea7bf3a 100644 --- a/dimos/hardware/zed_camera.py +++ b/dimos/hardware/zed_camera.py @@ -17,6 +17,10 @@ import open3d as o3d from typing import Optional, Tuple, Dict, Any import logging +import time +import threading +from reactivex import interval +from reactivex import operators as ops try: import pyzed.sl as sl @@ -25,8 +29,17 @@ logging.warning("ZED SDK not found. Please install pyzed to use ZED camera functionality.") from dimos.hardware.stereo_camera import StereoCamera +from dimos.core import Module, Out, rpc +from dimos.utils.logging_config import setup_logger -logger = logging.getLogger(__name__) +# Import LCM message types +from dimos_lcm.sensor_msgs import Image +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.geometry_msgs import PoseStamped +from dimos_lcm.std_msgs import Header, Time +from dimos_lcm.geometry_msgs import Pose, Point, Quaternion + +logger = setup_logger(__name__) class ZEDCamera(StereoCamera): @@ -512,3 +525,342 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.close() + + +class ZEDModule(Module): + """ + Dask module for ZED camera that publishes sensor data via LCM. + + Publishes: + - /zed/color_image: RGB camera images + - /zed/depth_image: Depth images + - /zed/camera_info: Camera calibration information + - /zed/pose: Camera pose (if tracking enabled) + """ + + # Define LCM outputs + color_image: Out[Image] = None + depth_image: Out[Image] = None + camera_info: Out[CameraInfo] = None + pose: Out[PoseStamped] = None + + def __init__( + self, + camera_id: int = 0, + resolution: str = "HD720", + depth_mode: str = "NEURAL", + fps: int = 30, + enable_tracking: bool = True, + enable_imu_fusion: bool = True, + set_floor_as_origin: bool = True, + publish_rate: float = 30.0, + frame_id: str = "zed_camera", + **kwargs, + ): + """ + Initialize ZED Module. + + Args: + camera_id: Camera ID (0 for first ZED) + resolution: Resolution string ("HD720", "HD1080", "HD2K", "VGA") + depth_mode: Depth mode string ("NEURAL", "ULTRA", "QUALITY", "PERFORMANCE") + fps: Camera frame rate + enable_tracking: Enable positional tracking + enable_imu_fusion: Enable IMU fusion for tracking + set_floor_as_origin: Set floor as origin for tracking + publish_rate: Rate to publish messages (Hz) + frame_id: TF frame ID for messages + """ + super().__init__(**kwargs) + + self.camera_id = camera_id + self.fps = fps + self.enable_tracking = enable_tracking + self.enable_imu_fusion = enable_imu_fusion + self.set_floor_as_origin = set_floor_as_origin + self.publish_rate = publish_rate + self.frame_id = frame_id + + # Convert string parameters to ZED enums + self.resolution = getattr(sl.RESOLUTION, resolution, sl.RESOLUTION.HD720) + self.depth_mode = getattr(sl.DEPTH_MODE, depth_mode, sl.DEPTH_MODE.NEURAL) + + # Internal state + self.zed_camera = None + self._running = False + self._subscription = None + self._sequence = 0 + + logger.info(f"ZEDModule initialized for camera {camera_id}") + + @rpc + def start(self): + """Start the ZED module and begin publishing data.""" + if self._running: + logger.warning("ZED module already running") + return + + try: + # Initialize ZED camera + self.zed_camera = ZEDCamera( + camera_id=self.camera_id, + resolution=self.resolution, + depth_mode=self.depth_mode, + fps=self.fps, + ) + + # Open camera + if not self.zed_camera.open(): + logger.error("Failed to open ZED camera") + return + + # Enable tracking if requested + if self.enable_tracking: + success = self.zed_camera.enable_positional_tracking( + enable_imu_fusion=self.enable_imu_fusion, + set_floor_as_origin=self.set_floor_as_origin, + enable_pose_smoothing=True, + enable_area_memory=True, + ) + if not success: + logger.warning("Failed to enable positional tracking") + self.enable_tracking = False + + # Publish camera info once at startup + self._publish_camera_info() + + # Start periodic frame capture and publishing + self._running = True + publish_interval = 1.0 / self.publish_rate + + self._subscription = interval(publish_interval).subscribe( + lambda _: self._capture_and_publish() + ) + + logger.info(f"ZED module started, publishing at {self.publish_rate} Hz") + + except Exception as e: + logger.error(f"Error starting ZED module: {e}") + self._running = False + + @rpc + def stop(self): + """Stop the ZED module.""" + if not self._running: + return + + self._running = False + + # Stop subscription + if self._subscription: + self._subscription.dispose() + self._subscription = None + + # Close camera + if self.zed_camera: + self.zed_camera.close() + self.zed_camera = None + + logger.info("ZED module stopped") + + def _capture_and_publish(self): + """Capture frame and publish all data.""" + if not self._running or not self.zed_camera: + return + + try: + # Capture frame with pose + left_img, _, depth, pose_data = self.zed_camera.capture_frame_with_pose() + + if left_img is None or depth is None: + return + + # Get timestamp + timestamp_ns = time.time_ns() + timestamp = Time(sec=timestamp_ns // 1_000_000_000, nsec=timestamp_ns % 1_000_000_000) + + # Create header + header = Header(seq=self._sequence, stamp=timestamp, frame_id=self.frame_id) + self._sequence += 1 + + # Publish color image + self._publish_color_image(left_img, header) + + # Publish depth image + self._publish_depth_image(depth, header) + + # Publish pose if tracking enabled and valid + if self.enable_tracking and pose_data and pose_data.get("valid", False): + self._publish_pose(pose_data, header) + + except Exception as e: + logger.error(f"Error in capture and publish: {e}") + + def _publish_color_image(self, image: np.ndarray, header: Header): + """Publish color image as LCM message.""" + try: + # Convert BGR to RGB if needed + if len(image.shape) == 3 and image.shape[2] == 3: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + else: + image_rgb = image + + # Create LCM Image message + height, width = image_rgb.shape[:2] + encoding = "rgb8" if len(image_rgb.shape) == 3 else "mono8" + step = width * (3 if len(image_rgb.shape) == 3 else 1) + data = image_rgb.tobytes() + + msg = Image( + data_length=len(data), + header=header, + height=height, + width=width, + encoding=encoding, + is_bigendian=0, + step=step, + data=data, + ) + + self.color_image.publish(msg) + + except Exception as e: + logger.error(f"Error publishing color image: {e}") + + def _publish_depth_image(self, depth: np.ndarray, header: Header): + """Publish depth image as LCM message.""" + try: + # Depth is float32 in meters + height, width = depth.shape[:2] + encoding = "32FC1" # 32-bit float, single channel + step = width * 4 # 4 bytes per float + data = depth.astype(np.float32).tobytes() + + msg = Image( + data_length=len(data), + header=header, + height=height, + width=width, + encoding=encoding, + is_bigendian=0, + step=step, + data=data, + ) + + self.depth_image.publish(msg) + + except Exception as e: + logger.error(f"Error publishing depth image: {e}") + + def _publish_camera_info(self): + """Publish camera calibration information.""" + try: + info = self.zed_camera.get_camera_info() + if not info: + return + + # Get calibration parameters + left_cam = info.get("left_cam", {}) + resolution = info.get("resolution", {}) + + # Create CameraInfo message + header = Header(seq=0, stamp=Time(sec=int(time.time()), nsec=0), frame_id=self.frame_id) + + # Create camera matrix K (3x3) + K = [ + left_cam.get("fx", 0), + 0, + left_cam.get("cx", 0), + 0, + left_cam.get("fy", 0), + left_cam.get("cy", 0), + 0, + 0, + 1, + ] + + # Distortion coefficients + D = [ + left_cam.get("k1", 0), + left_cam.get("k2", 0), + left_cam.get("p1", 0), + left_cam.get("p2", 0), + left_cam.get("k3", 0), + ] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [ + left_cam.get("fx", 0), + 0, + left_cam.get("cx", 0), + 0, + 0, + left_cam.get("fy", 0), + left_cam.get("cy", 0), + 0, + 0, + 0, + 1, + 0, + ] + + msg = CameraInfo( + D_length=len(D), + header=header, + height=resolution.get("height", 0), + width=resolution.get("width", 0), + distortion_model="plumb_bob", + D=D, + K=K, + R=R, + P=P, + binning_x=0, + binning_y=0, + ) + + self.camera_info.publish(msg) + logger.info("Published camera info") + + except Exception as e: + logger.error(f"Error publishing camera info: {e}") + + def _publish_pose(self, pose_data: Dict[str, Any], header: Header): + """Publish camera pose as PoseStamped message.""" + try: + position = pose_data.get("position", [0, 0, 0]) + rotation = pose_data.get("rotation", [0, 0, 0, 1]) # quaternion [x,y,z,w] + + # Create Pose message + pose = Pose( + position=Point(x=position[0], y=position[1], z=position[2]), + orientation=Quaternion(x=rotation[0], y=rotation[1], z=rotation[2], w=rotation[3]), + ) + + # Create PoseStamped message + msg = PoseStamped(header=header, pose=pose) + + self.pose.publish(msg) + + except Exception as e: + logger.error(f"Error publishing pose: {e}") + + @rpc + def get_camera_info(self) -> Dict[str, Any]: + """Get camera information and calibration parameters.""" + if self.zed_camera: + return self.zed_camera.get_camera_info() + return {} + + @rpc + def get_pose(self) -> Optional[Dict[str, Any]]: + """Get current camera pose if tracking is enabled.""" + if self.zed_camera and self.enable_tracking: + return self.zed_camera.get_pose() + return None + + def cleanup(self): + """Clean up resources on module destruction.""" + self.stop() diff --git a/dimos/manipulation/visual_servoing/manipulation.py b/dimos/manipulation/visual_servoing/manipulation.py index bf9713c905..38a57244ed 100644 --- a/dimos/manipulation/visual_servoing/manipulation.py +++ b/dimos/manipulation/visual_servoing/manipulation.py @@ -29,10 +29,9 @@ from dimos.manipulation.visual_servoing.pbvs import PBVS from dimos.perception.common.utils import ( find_clicked_detection, - bbox2d_to_corners, ) from dimos.manipulation.visual_servoing.utils import ( - match_detection_by_id, + create_manipulation_visualization, ) from dimos.utils.transform_utils import ( pose_to_matrix, @@ -72,11 +71,7 @@ def __init__( current_camera_pose: Optional[Pose] = None, target_pose: Optional[Pose] = None, waiting_for_reach: bool = False, - pose_count: int = 0, - max_poses: int = 0, - stabilization_time: float = 0.0, grasp_successful: Optional[bool] = None, - adjustment_count: int = 0, ): self.grasp_stage = grasp_stage self.target_tracked = target_tracked @@ -85,11 +80,7 @@ def __init__( self.current_camera_pose = current_camera_pose self.target_pose = target_pose self.waiting_for_reach = waiting_for_reach - self.pose_count = pose_count - self.max_poses = max_poses - self.stabilization_time = stabilization_time self.grasp_successful = grasp_successful - self.adjustment_count = adjustment_count class Manipulation: @@ -505,21 +496,12 @@ def update(self) -> Optional[Feedback]: if self.grasp_stage in stage_handlers: stage_handlers[self.grasp_stage]() - # Get tracking status and create visualization + # Get tracking status target_tracked = self.pbvs.get_current_target() is not None - self.current_visualization = ( - self._create_waiting_visualization(rgb) - if self.waiting_for_reach - else self.create_visualization( - rgb, detection_3d_array, detection_2d_array, camera_pose, target_tracked - ) - if detection_3d_array and detection_2d_array and camera_pose - else cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) - ) - # Create and return feedback + # Create feedback ee_pose = self.arm.get_ee_pose() - return Feedback( + feedback = Feedback( grasp_stage=self.grasp_stage, target_tracked=target_tracked, last_commanded_pose=self.last_commanded_pose, @@ -527,15 +509,16 @@ def update(self) -> Optional[Feedback]: current_camera_pose=camera_pose, target_pose=self.pbvs.target_grasp_pose, waiting_for_reach=self.waiting_for_reach, - pose_count=len(self.reached_poses), - max_poses=self.pose_history_size, - stabilization_time=time.time() - self.stabilization_start_time - if self.stabilization_start_time - else 0.0, grasp_successful=self.pick_success, - adjustment_count=self.adjustment_count, ) + # Create simple visualization using feedback + self.current_visualization = create_manipulation_visualization( + rgb, feedback, detection_3d_array, detection_2d_array + ) + + return feedback + def get_visualization(self) -> Optional[np.ndarray]: """ Get the current visualization image. @@ -585,179 +568,6 @@ def handle_keyboard_command(self, key: int) -> str: return "" - def create_visualization( - self, - rgb: np.ndarray, - detection_3d_array: Detection3DArray, - detection_2d_array: Detection2DArray, - camera_pose: Pose, - target_tracked: bool, - ) -> np.ndarray: - """ - Create visualization with detections and status overlays. - - Args: - rgb: RGB image - detection_3d_array: 3D detections - detection_2d_array: 2D detections - camera_pose: Current camera pose - target_tracked: Whether target is being tracked - - Returns: - BGR image with visualizations - """ - # Create visualization with position overlays - viz = self.detector.visualize_detections( - rgb, detection_3d_array.detections, detection_2d_array.detections - ) - - # Add PBVS status overlay - viz = self.pbvs.create_status_overlay(viz, self.grasp_stage) - - # Highlight target - current_target = self.pbvs.get_current_target() - if target_tracked and current_target: - det_2d = match_detection_by_id( - current_target, detection_3d_array.detections, detection_2d_array.detections - ) - if det_2d and det_2d.bbox: - x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) - x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) - - cv2.rectangle(viz, (x1, y1), (x2, y2), (0, 255, 0), 3) - cv2.putText( - viz, "TARGET", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 - ) - - # Convert back to BGR for OpenCV display - viz_bgr = cv2.cvtColor(viz, cv2.COLOR_RGB2BGR) - - # Add pose info - cv2.putText( - viz_bgr, - "Eye-in-Hand Visual Servoing", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (0, 255, 255), - 1, - ) - - # Get EE pose for display - ee_pose = self.arm.get_ee_pose() - - camera_text = f"Camera: ({camera_pose.position.x:.2f}, {camera_pose.position.y:.2f}, {camera_pose.position.z:.2f})m" - cv2.putText(viz_bgr, camera_text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1) - - ee_text = ( - f"EE: ({ee_pose.position.x:.2f}, {ee_pose.position.y:.2f}, {ee_pose.position.z:.2f})m" - ) - cv2.putText(viz_bgr, ee_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) - - # Add control status - status_text, status_color = self._get_status_text_and_color() - cv2.putText(viz_bgr, status_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.5, status_color, 1) - cv2.putText( - viz_bgr, - "s=STOP | r=RESET | SPACE=FORCE GRASP | g=RELEASE", - (10, 110), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 255), - 1, - ) - - return viz_bgr - - def _create_waiting_visualization(self, rgb: np.ndarray) -> np.ndarray: - """ - Create a simple visualization while waiting for robot to reach pose. - - Args: - rgb: RGB image - - Returns: - BGR image with waiting status - """ - viz_bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) - - # Add waiting status - cv2.putText( - viz_bgr, - "WAITING FOR ROBOT TO REACH TARGET...", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - (0, 255, 255), - 2, - ) - - # Add current stage info - stage_text = f"Stage: {self.grasp_stage.value.upper()}" - cv2.putText( - viz_bgr, - stage_text, - (10, 60), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 0), - 1, - ) - - # Add progress info based on stage - if self.grasp_stage == GraspStage.PRE_GRASP: - progress_text = f"Reached poses: {len(self.reached_poses)}/{self.pose_history_size}" - elif self.grasp_stage == GraspStage.GRASP and self.grasp_reached_time: - time_remaining = max( - 0, self.grasp_close_delay - (time.time() - self.grasp_reached_time) - ) - progress_text = f"Closing gripper in: {time_remaining:.1f}s" - else: - progress_text = "" - - if progress_text: - cv2.putText( - viz_bgr, - progress_text, - (10, 90), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (0, 255, 255), - 1, - ) - - return viz_bgr - - def _get_status_text_and_color(self) -> Tuple[str, Tuple[int, int, int]]: - """ - Get status text and color based on current stage and state. - - Returns: - Tuple of (status_text, status_color) - """ - if self.grasp_stage == GraspStage.IDLE: - return "IDLE - Click object to select target", (100, 100, 100) - elif self.grasp_stage == GraspStage.PRE_GRASP: - if self.waiting_for_reach: - return "PRE-GRASP - Waiting for robot to reach target...", (255, 255, 0) - else: - poses_text = f" ({len(self.reached_poses)}/{self.pose_history_size} poses)" - elapsed_time = ( - time.time() - self.stabilization_start_time - if self.stabilization_start_time - else 0 - ) - time_text = f" [{elapsed_time:.1f}s/{self.stabilization_timeout:.0f}s]" - return f"PRE-GRASP - Collecting stable poses{poses_text}{time_text}", (0, 255, 255) - elif self.grasp_stage == GraspStage.GRASP: - if self.grasp_reached_time: - time_remaining = self.grasp_close_delay - (time.time() - self.grasp_reached_time) - return f"GRASP - Waiting to close ({time_remaining:.1f}s)", (0, 255, 0) - else: - return "GRASP - Moving to grasp pose", (0, 255, 0) - else: # CLOSE_AND_RETRACT - return "CLOSE_AND_RETRACT - Closing gripper and retracting", (255, 0, 255) - def check_target_stabilized(self) -> bool: """ Check if the commanded poses have stabilized. @@ -778,94 +588,3 @@ def check_target_stabilized(self) -> bool: # Check if all axes are below threshold return np.all(std_devs < self.pose_stabilization_threshold) - - def pick_and_place( - self, object_point: Tuple[int, int], target_point: Optional[Tuple[int, int]] = None - ) -> bool: - """ - Execute a complete pick and place operation. - - Similar to navigate_path_local, this function handles the complete pick operation - autonomously, including object selection, grasping, and optional placement. - - Args: - object_point: (x, y) pixel coordinates of the object to pick - target_point: Optional (x, y) pixel coordinates for placement (not implemented yet) - - Returns: - True if object was successfully picked, False otherwise - """ - # Validate input - if not isinstance(object_point, tuple) or len(object_point) != 2: - logger.error(f"Invalid object_point: {object_point}. Expected (x, y) tuple.") - return False - - logger.info(f"Starting pick operation at pixel ({object_point[0]}, {object_point[1]})") - - # Reset to ensure clean state - self.reset_to_idle() - - # Configuration - max_operation_time = 60.0 # Maximum time for complete pick operation - perception_init_time = 2.0 # Time to allow perception to stabilize - - # Wait for perception to initialize - init_start = time.time() - perception_ready = False - - while (time.time() - init_start) < perception_init_time: - feedback = self.update() - if feedback is not None: - perception_ready = True - time.sleep(0.1) - - if not perception_ready: - logger.error("Perception system failed to initialize") - return False - - # Select the target object - x, y = object_point - try: - if not self.pick_target(x, y): - logger.error(f"No valid object detected at pixel ({x}, {y})") - return False - except Exception as e: - logger.error(f"Exception during target selection: {e}") - return False - - # Execute pick operation - operation_start = time.time() - - while (time.time() - operation_start) < max_operation_time: - try: - # Update the manipulation system - feedback = self.update() - if feedback is None: - logger.error("Lost perception during pick operation") - self.reset_to_idle() - return False - - # Check if grasp sequence completed - if feedback.grasp_successful is not None: - if feedback.grasp_successful: - logger.info("Object successfully grasped") - if target_point: - logger.info("Place operation not yet implemented") - return True - else: - logger.warning("Grasp attempt failed - no object detected in gripper") - return False - - except Exception as e: - logger.error(f"Unexpected error during pick operation: {e}") - self.reset_to_idle() - return False - - # Operation timeout - logger.error(f"Pick operation exceeded maximum time of {max_operation_time}s") - self.reset_to_idle() - return False - - def cleanup(self): - """Clean up resources (detector only, hardware cleanup is caller's responsibility).""" - self.detector.cleanup() diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index 61afad667b..da8c6c7dca 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -30,8 +30,7 @@ ) from dimos.manipulation.visual_servoing.utils import ( find_best_object_match, - create_pbvs_status_overlay, - create_pbvs_controller_overlay, + create_pbvs_visualization, ) logger = setup_logger("dimos.manipulation.pbvs") @@ -381,25 +380,14 @@ def create_status_overlay( Returns: Image with PBVS status overlay """ - if self.direct_ee_control: - # Use direct control overlay - stage_value = grasp_stage.value if grasp_stage else "idle" - return create_pbvs_status_overlay( - image, - self.current_target, - self.last_position_error, - self.last_target_reached, - self.target_grasp_pose, - stage_value, - is_direct_control=True, - ) - else: - # Use controller's overlay for velocity mode - return self.controller.create_status_overlay( - image, - self.current_target, - self.direct_ee_control, - ) + stage_value = grasp_stage.value if grasp_stage else "idle" + return create_pbvs_visualization( + image, + self.current_target, + self.last_position_error, + self.last_target_reached, + stage_value, + ) class PBVSController: @@ -574,7 +562,6 @@ def create_status_overlay( self, image: np.ndarray, current_target: Optional[Detection3D] = None, - direct_ee_control: bool = False, ) -> np.ndarray: """ Create PBVS status overlay on image. @@ -582,18 +569,14 @@ def create_status_overlay( Args: image: Input image current_target: Current target object Detection3D (for display) - direct_ee_control: Whether in direct EE control mode Returns: Image with PBVS status overlay """ - return create_pbvs_controller_overlay( + return create_pbvs_visualization( image, current_target, self.last_position_error, - self.last_rotation_error, - self.last_velocity_cmd, - self.last_angular_velocity_cmd, self.last_target_reached, - direct_ee_control, + "velocity_control", ) diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index 4e8a0a81b7..6d07183104 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -233,6 +233,228 @@ def estimate_object_depth( # ============= Visualization Functions ============= +def create_manipulation_visualization( + rgb_image: np.ndarray, + feedback, + detection_3d_array=None, + detection_2d_array=None, +) -> np.ndarray: + """ + Create simple visualization for manipulation class using feedback. + + Args: + rgb_image: RGB image array + feedback: Feedback object containing all state information + detection_3d_array: Optional 3D detections for object visualization + detection_2d_array: Optional 2D detections for object visualization + + Returns: + BGR image with visualization overlays + """ + # Convert to BGR for OpenCV + viz = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Draw detections if available + if detection_3d_array and detection_2d_array: + # Extract 2D bboxes + bboxes_2d = [] + for det_2d in detection_2d_array.detections: + if det_2d.bbox: + x1 = det_2d.bbox.center.position.x - det_2d.bbox.size_x / 2 + y1 = det_2d.bbox.center.position.y - det_2d.bbox.size_y / 2 + x2 = det_2d.bbox.center.position.x + det_2d.bbox.size_x / 2 + y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2 + bboxes_2d.append([x1, y1, x2, y2]) + + # Draw basic detections + rgb_with_detections = visualize_detections_3d( + rgb_image, detection_3d_array.detections, show_coordinates=True, bboxes_2d=bboxes_2d + ) + viz = cv2.cvtColor(rgb_with_detections, cv2.COLOR_RGB2BGR) + + # Add manipulation status overlay + status_y = 30 + cv2.putText( + viz, + "Eye-in-Hand Visual Servoing", + (10, status_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, + ) + + # Stage information + stage_text = f"Stage: {feedback.grasp_stage.value.upper()}" + stage_color = { + "idle": (100, 100, 100), + "pre_grasp": (0, 255, 255), + "grasp": (0, 255, 0), + "close_and_retract": (255, 0, 255), + }.get(feedback.grasp_stage.value, (255, 255, 255)) + + cv2.putText( + viz, + stage_text, + (10, status_y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + stage_color, + 1, + ) + + # Target tracking status + if feedback.target_tracked: + cv2.putText( + viz, + "Target: TRACKED", + (10, status_y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 0), + 1, + ) + elif feedback.grasp_stage.value != "idle": + cv2.putText( + viz, + "Target: LOST", + (10, status_y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 0, 255), + 1, + ) + + # Waiting status + if feedback.waiting_for_reach: + cv2.putText( + viz, + "Status: WAITING FOR ROBOT", + (10, status_y + 65), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 0), + 1, + ) + + # Grasp result + if feedback.grasp_successful is not None: + result_text = "Grasp: SUCCESS" if feedback.grasp_successful else "Grasp: FAILED" + result_color = (0, 255, 0) if feedback.grasp_successful else (0, 0, 255) + cv2.putText( + viz, + result_text, + (10, status_y + 85), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + result_color, + 2, + ) + + # Control hints (bottom of image) + hint_text = "Click object to grasp | s=STOP | r=RESET | g=RELEASE" + cv2.putText( + viz, + hint_text, + (10, viz.shape[0] - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + return viz + + +def create_pbvs_visualization( + image: np.ndarray, + current_target=None, + position_error=None, + target_reached=False, + grasp_stage="idle", +) -> np.ndarray: + """ + Create simple PBVS visualization overlay. + + Args: + image: Input image (RGB or BGR) + current_target: Current target Detection3D + position_error: Position error Vector3 + target_reached: Whether target is reached + grasp_stage: Current grasp stage string + + Returns: + Image with PBVS overlay + """ + viz = image.copy() + + # Only show PBVS info if we have a target + if current_target is None: + return viz + + # Create status panel at bottom + height, width = viz.shape[:2] + panel_height = 100 + panel_y = height - panel_height + + # Semi-transparent overlay + overlay = viz.copy() + cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) + viz = cv2.addWeighted(viz, 0.7, overlay, 0.3, 0) + + # PBVS Status + y_offset = panel_y + 20 + cv2.putText( + viz, + "PBVS Control", + (10, y_offset), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, + ) + + # Position error + if position_error: + error_mag = np.linalg.norm([position_error.x, position_error.y, position_error.z]) + error_text = f"Error: {error_mag * 100:.1f}cm" + error_color = (0, 255, 0) if target_reached else (0, 255, 255) + cv2.putText( + viz, + error_text, + (10, y_offset + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + error_color, + 1, + ) + + # Stage + cv2.putText( + viz, + f"Stage: {grasp_stage}", + (10, y_offset + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 150, 255), + 1, + ) + + # Target reached indicator + if target_reached: + cv2.putText( + viz, + "TARGET REACHED", + (width - 150, y_offset + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + return viz + + def visualize_detections_3d( rgb_image: np.ndarray, detections: List[Detection3D], @@ -313,292 +535,6 @@ def visualize_detections_3d( return viz -def create_pbvs_status_overlay( - image: np.ndarray, - current_target: Optional[Detection3D], - position_error: Optional[Vector3], - target_reached: bool, - target_grasp_pose: Optional[Pose], - grasp_stage: str, - is_direct_control: bool = False, -) -> np.ndarray: - """ - Create PBVS status overlay for direct control mode. - - Args: - image: Input image - current_target: Current target Detection3D - position_error: Position error vector - target_reached: Whether target is reached - target_grasp_pose: Target grasp pose - grasp_stage: Current grasp stage - is_direct_control: Whether in direct control mode - - Returns: - Image with status overlay - """ - viz_img = image.copy() - height, width = image.shape[:2] - - # Status panel - if current_target is not None: - panel_height = 175 # Adjusted panel for target, grasp pose, stage, and distance info - panel_y = height - panel_height - overlay = viz_img.copy() - cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) - viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) - - # Status text - y = panel_y + 20 - mode_text = "Direct EE" if is_direct_control else "Velocity" - cv2.putText( - viz_img, - f"PBVS Status ({mode_text})", - (10, y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 255), - 2, - ) - - # Add frame info - cv2.putText( - viz_img, "Frame: Camera", (250, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 - ) - - if position_error: - error_mag = np.linalg.norm( - [ - position_error.x, - position_error.y, - position_error.z, - ] - ) - color = (0, 255, 0) if target_reached else (0, 255, 255) - - cv2.putText( - viz_img, - f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", - (10, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - color, - 1, - ) - - cv2.putText( - viz_img, - f"XYZ: ({position_error.x:.3f}, {position_error.y:.3f}, {position_error.z:.3f})", - (10, y + 45), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (200, 200, 200), - 1, - ) - - # Show target and grasp poses - if current_target and current_target.bbox and current_target.bbox.center: - target_pos = current_target.bbox.center.position - cv2.putText( - viz_img, - f"Target: ({target_pos.x:.3f}, {target_pos.y:.3f}, {target_pos.z:.3f})", - (10, y + 65), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 255, 0), - 1, - ) - - if target_grasp_pose: - grasp_pos = target_grasp_pose.position - cv2.putText( - viz_img, - f"Grasp: ({grasp_pos.x:.3f}, {grasp_pos.y:.3f}, {grasp_pos.z:.3f})", - (10, y + 80), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (0, 255, 255), - 1, - ) - - # Show pregrasp distance if we have both poses - if current_target and current_target.bbox and current_target.bbox.center: - target_pos = current_target.bbox.center.position - distance = np.sqrt( - (grasp_pos.x - target_pos.x) ** 2 - + (grasp_pos.y - target_pos.y) ** 2 - + (grasp_pos.z - target_pos.z) ** 2 - ) - - # Show current stage and distance - stage_text = f"Stage: {grasp_stage}" - cv2.putText( - viz_img, - stage_text, - (10, y + 95), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 150, 255), - 1, - ) - - distance_text = f"Distance: {distance * 1000:.1f}mm" - cv2.putText( - viz_img, - distance_text, - (10, y + 110), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (255, 200, 0), - 1, - ) - - if target_reached: - cv2.putText( - viz_img, - "TARGET REACHED", - (width - 150, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 0), - 2, - ) - - return viz_img - - -def create_pbvs_controller_overlay( - image: np.ndarray, - current_target: Optional[Detection3D], - position_error: Optional[Vector3], - rotation_error: Optional[Vector3], - velocity_cmd: Optional[Vector3], - angular_velocity_cmd: Optional[Vector3], - target_reached: bool, - direct_ee_control: bool = False, -) -> np.ndarray: - """ - Create PBVS controller status overlay on image. - - Args: - image: Input image - current_target: Current target Detection3D (for display) - position_error: Position error vector - rotation_error: Rotation error vector - velocity_cmd: Linear velocity command - angular_velocity_cmd: Angular velocity command - target_reached: Whether target is reached - direct_ee_control: Whether in direct EE control mode - - Returns: - Image with PBVS status overlay - """ - viz_img = image.copy() - height, width = image.shape[:2] - - # Status panel - if current_target is not None: - panel_height = 160 # Adjusted panel height - panel_y = height - panel_height - overlay = viz_img.copy() - cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) - viz_img = cv2.addWeighted(viz_img, 0.7, overlay, 0.3, 0) - - # Status text - y = panel_y + 20 - mode_text = "Direct EE" if direct_ee_control else "Velocity" - cv2.putText( - viz_img, - f"PBVS Status ({mode_text})", - (10, y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 255), - 2, - ) - - # Add frame info - cv2.putText( - viz_img, "Frame: Camera", (250, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1 - ) - - if position_error: - error_mag = np.linalg.norm( - [ - position_error.x, - position_error.y, - position_error.z, - ] - ) - color = (0, 255, 0) if target_reached else (0, 255, 255) - - cv2.putText( - viz_img, - f"Pos Error: {error_mag:.3f}m ({error_mag * 100:.1f}cm)", - (10, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - color, - 1, - ) - - cv2.putText( - viz_img, - f"XYZ: ({position_error.x:.3f}, {position_error.y:.3f}, {position_error.z:.3f})", - (10, y + 45), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (200, 200, 200), - 1, - ) - - if velocity_cmd and not direct_ee_control: - cv2.putText( - viz_img, - f"Lin Vel: ({velocity_cmd.x:.2f}, {velocity_cmd.y:.2f}, {velocity_cmd.z:.2f})m/s", - (10, y + 65), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 200, 0), - 1, - ) - - if rotation_error: - cv2.putText( - viz_img, - f"Rot Error: ({rotation_error.x:.2f}, {rotation_error.y:.2f}, {rotation_error.z:.2f})rad", - (10, y + 85), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (200, 200, 200), - 1, - ) - - if angular_velocity_cmd and not direct_ee_control: - cv2.putText( - viz_img, - f"Ang Vel: ({angular_velocity_cmd.x:.2f}, {angular_velocity_cmd.y:.2f}, {angular_velocity_cmd.z:.2f})rad/s", - (10, y + 105), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 200, 0), - 1, - ) - - if target_reached: - cv2.putText( - viz_img, - "TARGET REACHED", - (width - 150, y + 25), - cv2.FONT_HERSHEY_SIMPLEX, - 0.6, - (0, 255, 0), - 2, - ) - - return viz_img - - def match_detection_by_id( detection_3d: Detection3D, detections_3d: List[Detection3D], detections_2d: List[Detection2D] ) -> Optional[Detection2D]: diff --git a/tests/test_zed_module.py b/tests/test_zed_module.py new file mode 100644 index 0000000000..fbc99a54a4 --- /dev/null +++ b/tests/test_zed_module.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for ZED Module with LCM visualization.""" + +import asyncio +import threading +import time +from typing import Optional +import numpy as np +import cv2 + +from dimos import core +from dimos.hardware.zed_camera import ZEDModule +from dimos.protocol import pubsub +from dimos.utils.logging_config import setup_logger +from dimos.perception.common.utils import colorize_depth + +# Import LCM message types +from dimos_lcm.sensor_msgs import Image as LCMImage +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.geometry_msgs import PoseStamped +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic + +logger = setup_logger("test_zed_module") + + +class ZEDVisualizationNode: + """Node that subscribes to ZED topics and visualizes the data.""" + + def __init__(self): + self.lcm = LCM() + self.latest_color = None + self.latest_depth = None + self.latest_pose = None + self.camera_info = None + self._running = False + + # Subscribe to topics + self.color_topic = Topic("/zed/color_image", LCMImage) + self.depth_topic = Topic("/zed/depth_image", LCMImage) + self.camera_info_topic = Topic("/zed/camera_info", CameraInfo) + self.pose_topic = Topic("/zed/pose", PoseStamped) + + def start(self): + """Start the visualization node.""" + self._running = True + self.lcm.start() + + # Subscribe to topics + self.lcm.subscribe(self.color_topic, self._on_color_image) + self.lcm.subscribe(self.depth_topic, self._on_depth_image) + self.lcm.subscribe(self.camera_info_topic, self._on_camera_info) + self.lcm.subscribe(self.pose_topic, self._on_pose) + + logger.info("Visualization node started, subscribed to ZED topics") + + def stop(self): + """Stop the visualization node.""" + self._running = False + cv2.destroyAllWindows() + + def _on_color_image(self, msg: LCMImage, topic: str): + """Handle color image messages.""" + try: + # Convert LCM message to numpy array + data = np.frombuffer(msg.data, dtype=np.uint8) + + if msg.encoding == "rgb8": + image = data.reshape((msg.height, msg.width, 3)) + # Convert RGB to BGR for OpenCV + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + elif msg.encoding == "mono8": + image = data.reshape((msg.height, msg.width)) + else: + logger.warning(f"Unsupported encoding: {msg.encoding}") + return + + self.latest_color = image + logger.debug(f"Received color image: {msg.width}x{msg.height}") + + except Exception as e: + logger.error(f"Error processing color image: {e}") + + def _on_depth_image(self, msg: LCMImage, topic: str): + """Handle depth image messages.""" + try: + # Convert LCM message to numpy array + if msg.encoding == "32FC1": + data = np.frombuffer(msg.data, dtype=np.float32) + depth = data.reshape((msg.height, msg.width)) + else: + logger.warning(f"Unsupported depth encoding: {msg.encoding}") + return + + self.latest_depth = depth + logger.debug(f"Received depth image: {msg.width}x{msg.height}") + + except Exception as e: + logger.error(f"Error processing depth image: {e}") + + def _on_camera_info(self, msg: CameraInfo, topic: str): + """Handle camera info messages.""" + self.camera_info = msg + logger.info( + f"Received camera info: {msg.width}x{msg.height}, distortion model: {msg.distortion_model}" + ) + + def _on_pose(self, msg: PoseStamped, topic: str): + """Handle pose messages.""" + self.latest_pose = msg + pos = msg.pose.position + ori = msg.pose.orientation + logger.debug( + f"Pose: pos=({pos.x:.2f}, {pos.y:.2f}, {pos.z:.2f}), " + + f"ori=({ori.x:.2f}, {ori.y:.2f}, {ori.z:.2f}, {ori.w:.2f})" + ) + + def visualize(self): + """Run visualization loop.""" + while self._running: + # Create visualization + vis_images = [] + + # Color image + if self.latest_color is not None: + color_vis = self.latest_color.copy() + + # Add pose text if available + if self.latest_pose is not None: + pos = self.latest_pose.pose.position + text = f"Pose: ({pos.x:.2f}, {pos.y:.2f}, {pos.z:.2f})" + cv2.putText( + color_vis, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2 + ) + + vis_images.append(("ZED Color", color_vis)) + + # Depth image + if self.latest_depth is not None: + depth_colorized = colorize_depth(self.latest_depth, max_depth=5.0) + if depth_colorized is not None: + # Convert RGB to BGR for OpenCV + depth_colorized = cv2.cvtColor(depth_colorized, cv2.COLOR_RGB2BGR) + + # Add depth stats + valid_mask = np.isfinite(self.latest_depth) & (self.latest_depth > 0) + if np.any(valid_mask): + min_depth = np.min(self.latest_depth[valid_mask]) + max_depth = np.max(self.latest_depth[valid_mask]) + mean_depth = np.mean(self.latest_depth[valid_mask]) + + text = f"Depth: min={min_depth:.2f}m, max={max_depth:.2f}m, mean={mean_depth:.2f}m" + cv2.putText( + depth_colorized, + text, + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 1, + ) + + vis_images.append(("ZED Depth", depth_colorized)) + + # Show windows + for name, image in vis_images: + cv2.imshow(name, image) + + # Handle key press + key = cv2.waitKey(1) & 0xFF + if key == ord("q"): + logger.info("Quit requested") + self._running = False + break + elif key == ord("s"): + # Save images + if self.latest_color is not None: + cv2.imwrite("zed_color.png", self.latest_color) + logger.info("Saved color image to zed_color.png") + if self.latest_depth is not None: + np.save("zed_depth.npy", self.latest_depth) + logger.info("Saved depth data to zed_depth.npy") + + time.sleep(0.03) # ~30 FPS + + +async def test_zed_module(): + """Test the ZED Module with visualization.""" + logger.info("Starting ZED Module test") + + # Start Dask + dimos = core.start(1) + + # Enable LCM auto-configuration + pubsub.lcm.autoconf() + + try: + # Deploy ZED module + logger.info("Deploying ZED module...") + zed = dimos.deploy( + ZEDModule, + camera_id=0, + resolution="HD720", + depth_mode="NEURAL", + fps=30, + enable_tracking=True, + publish_rate=10.0, # 10 Hz for testing + frame_id="zed_camera", + ) + + # Configure LCM transports + zed.color_image.transport = core.LCMTransport("/zed/color_image", LCMImage) + zed.depth_image.transport = core.LCMTransport("/zed/depth_image", LCMImage) + zed.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) + zed.pose.transport = core.LCMTransport("/zed/pose", PoseStamped) + + # Print module info + logger.info("ZED Module configured:") + print(zed.io().result()) + + # Start ZED module + logger.info("Starting ZED module...") + zed.start() + + # Give module time to initialize + await asyncio.sleep(2) + + # Create and start visualization node + viz_node = ZEDVisualizationNode() + viz_node.start() + + # Run visualization in separate thread + viz_thread = threading.Thread(target=viz_node.visualize, daemon=True) + viz_thread.start() + + logger.info("ZED Module running. Press 'q' in image window to quit, 's' to save images.") + + # Keep running until visualization stops + while viz_node._running: + await asyncio.sleep(0.1) + + # Stop ZED module + logger.info("Stopping ZED module...") + zed.stop() + + # Stop visualization + viz_node.stop() + + except Exception as e: + logger.error(f"Error in test: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + dimos.close() + logger.info("Test completed") + + +if __name__ == "__main__": + # Run the test + asyncio.run(test_zed_module()) From 42b634c2a7b8c57726a7282d19199793ae855cc0 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Tue, 22 Jul 2025 20:52:34 -0700 Subject: [PATCH 77/89] pick and place as module fully working --- dimos/hardware/zed_camera.py | 4 +- .../visual_servoing/manipulation_module.py | 730 ++++++++++++++++++ tests/test_ibvs.py | 1 - tests/test_pick_and_place_module.py | 289 +++++++ 4 files changed, 1022 insertions(+), 2 deletions(-) create mode 100644 dimos/manipulation/visual_servoing/manipulation_module.py create mode 100644 tests/test_pick_and_place_module.py diff --git a/dimos/hardware/zed_camera.py b/dimos/hardware/zed_camera.py index df7ea7bf3a..7ee2aed634 100644 --- a/dimos/hardware/zed_camera.py +++ b/dimos/hardware/zed_camera.py @@ -689,6 +689,9 @@ def _capture_and_publish(self): # Publish depth image self._publish_depth_image(depth, header) + # Publish camera info periodically + self._publish_camera_info() + # Publish pose if tracking enabled and valid if self.enable_tracking and pose_data and pose_data.get("valid", False): self._publish_pose(pose_data, header) @@ -822,7 +825,6 @@ def _publish_camera_info(self): ) self.camera_info.publish(msg) - logger.info("Published camera info") except Exception as e: logger.error(f"Error publishing camera info: {e}") diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py new file mode 100644 index 0000000000..b2967f9bd9 --- /dev/null +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -0,0 +1,730 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Manipulation module for robotic grasping with visual servoing. +Handles grasping logic, state machine, and hardware coordination as a Dimos module. +""" + +import cv2 +import time +import threading +from typing import Optional, Tuple, Any, Dict +from enum import Enum +from collections import deque + +import numpy as np + +from dimos.core import Module, In, Out, rpc +from dimos_lcm.sensor_msgs import Image, CameraInfo +from dimos_lcm.geometry_msgs import Vector3, Pose +from dimos_lcm.vision_msgs import Detection3DArray, Detection2DArray + +from dimos.hardware.piper_arm import PiperArm +from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor +from dimos.manipulation.visual_servoing.pbvs import PBVS +from dimos.perception.common.utils import find_clicked_detection +from dimos.manipulation.visual_servoing.utils import create_manipulation_visualization +from dimos.utils.transform_utils import ( + pose_to_matrix, + matrix_to_pose, + create_transform_from_6dof, + compose_transforms, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.manipulation.manipulation_module") + + +class GraspStage(Enum): + """Enum for different grasp stages.""" + + IDLE = "idle" # No target set + PRE_GRASP = "pre_grasp" # Target set, moving to pre-grasp position + GRASP = "grasp" # Executing final grasp + CLOSE_AND_RETRACT = "close_and_retract" # Close gripper and retract + + +class Feedback: + """ + Feedback data returned by the manipulation system update. + + Contains comprehensive state information about the manipulation process. + """ + + def __init__( + self, + grasp_stage: GraspStage, + target_tracked: bool, + last_commanded_pose: Optional[Pose] = None, + current_ee_pose: Optional[Pose] = None, + current_camera_pose: Optional[Pose] = None, + target_pose: Optional[Pose] = None, + waiting_for_reach: bool = False, + grasp_successful: Optional[bool] = None, + ): + self.grasp_stage = grasp_stage + self.target_tracked = target_tracked + self.last_commanded_pose = last_commanded_pose + self.current_ee_pose = current_ee_pose + self.current_camera_pose = current_camera_pose + self.target_pose = target_pose + self.waiting_for_reach = waiting_for_reach + self.grasp_successful = grasp_successful + + +class ManipulationModule(Module): + """ + Manipulation module for visual servoing and grasping. + + Subscribes to: + - ZED RGB images + - ZED depth images + - ZED camera info + + Publishes: + - Visualization images + + RPC methods: + - handle_keyboard_command: Process keyboard input + - pick_and_place: Execute pick and place task + """ + + # LCM inputs + rgb_image: In[Image] = None + depth_image: In[Image] = None + camera_info: In[CameraInfo] = None + + # LCM outputs + viz_image: Out[Image] = None + + def __init__( + self, + ee_to_camera_6dof: Optional[list] = None, + **kwargs, + ): + """ + Initialize manipulation module. + + Args: + ee_to_camera_6dof: EE to camera transform [x, y, z, rx, ry, rz] in meters and radians + """ + super().__init__(**kwargs) + + # Initialize arm directly + self.arm = PiperArm() + + # Default EE to camera transform if not provided + if ee_to_camera_6dof is None: + ee_to_camera_6dof = [-0.065, 0.03, -0.105, 0.0, -1.57, 0.0] + + # Create transform matrices + pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) + rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) + self.T_ee_to_camera = create_transform_from_6dof(pos, rot) + + # Camera intrinsics will be set when camera info is received + self.camera_intrinsics = None + self.detector = None + self.pbvs = None + + # Control state + self.last_valid_target = None + self.waiting_for_reach = False + self.last_commanded_pose = None + self.target_updated = False + self.waiting_start_time = None + self.reach_pose_timeout = 10.0 + + # Grasp parameters + self.grasp_width_offset = 0.03 + self.grasp_pitch_degrees = 30.0 + self.pregrasp_distance = 0.25 + self.grasp_distance_range = 0.03 + self.grasp_close_delay = 2.0 + self.grasp_reached_time = None + self.gripper_max_opening = 0.07 + + # Grasp stage tracking + self.grasp_stage = GraspStage.IDLE + + # Pose stabilization tracking + self.pose_history_size = 4 + self.pose_stabilization_threshold = 0.01 + self.stabilization_timeout = 15.0 + self.stabilization_start_time = None + self.reached_poses = deque(maxlen=self.pose_history_size) + self.adjustment_count = 0 + + # State for visualization + self.current_visualization = None + self.last_detection_3d_array = None + self.last_detection_2d_array = None + + # Grasp result and task tracking + self.pick_success = None + self.final_pregrasp_pose = None + self.task_failed = False # New variable for tracking task failure + + # Task control + self.task_running = False + self.task_thread = None + self.stop_event = threading.Event() + + # Latest sensor data + self.latest_rgb = None + self.latest_depth = None + self.latest_camera_info = None + + # Target selection + self.target_click = None + + # Move arm to observe position on init + self.arm.gotoObserve() + + @rpc + def start(self): + """Start the manipulation module.""" + # Subscribe to camera data + self.rgb_image.subscribe(self._on_rgb_image) + self.depth_image.subscribe(self._on_depth_image) + self.camera_info.subscribe(self._on_camera_info) + + logger.info("Manipulation module started") + + @rpc + def stop(self): + """Stop the manipulation module.""" + # Stop any running task + self.stop_event.set() + if self.task_thread and self.task_thread.is_alive(): + self.task_thread.join(timeout=5.0) + + # Disable arm + self.arm.disable() + logger.info("Manipulation module stopped") + + def _on_rgb_image(self, msg: Image): + """Handle RGB image messages.""" + try: + # Convert LCM message to numpy array + data = np.frombuffer(msg.data, dtype=np.uint8) + if msg.encoding == "rgb8": + self.latest_rgb = data.reshape((msg.height, msg.width, 3)) + else: + logger.warning(f"Unsupported RGB encoding: {msg.encoding}") + except Exception as e: + logger.error(f"Error processing RGB image: {e}") + + def _on_depth_image(self, msg: Image): + """Handle depth image messages.""" + try: + # Convert LCM message to numpy array + if msg.encoding == "32FC1": + data = np.frombuffer(msg.data, dtype=np.float32) + self.latest_depth = data.reshape((msg.height, msg.width)) + else: + logger.warning(f"Unsupported depth encoding: {msg.encoding}") + except Exception as e: + logger.error(f"Error processing depth image: {e}") + + def _on_camera_info(self, msg: CameraInfo): + """Handle camera info messages.""" + try: + # Extract camera intrinsics + self.camera_intrinsics = [ + msg.K[0], # fx + msg.K[4], # fy + msg.K[2], # cx + msg.K[5], # cy + ] + + # Initialize processors if not already done + if self.detector is None: + self.detector = Detection3DProcessor(self.camera_intrinsics) + self.pbvs = PBVS(target_tolerance=0.05) + logger.info("Initialized detection and PBVS processors") + + self.latest_camera_info = msg + except Exception as e: + logger.error(f"Error processing camera info: {e}") + + @rpc + def handle_keyboard_command(self, key: str) -> str: + """ + Handle keyboard commands for robot control. + + Args: + key: Keyboard key as string + + Returns: + Action taken as string, or empty string if no action + """ + key_code = ord(key) if len(key) == 1 else int(key) + + if key_code == ord("r"): + self.reset_to_idle() + return "reset" + elif key_code == ord("s"): + logger.info("SOFT STOP - Emergency stopping robot!") + self.arm.softStop() + return "stop" + elif key_code == ord(" ") and self.pbvs and self.pbvs.target_grasp_pose: + # Manual override - immediately transition to GRASP if in PRE_GRASP + if self.grasp_stage == GraspStage.PRE_GRASP: + self.set_grasp_stage(GraspStage.GRASP) + logger.info("Executing target pose") + return "execute" + elif key_code == 82: # Up arrow - increase pitch + new_pitch = min(90.0, self.grasp_pitch_degrees + 15.0) + self.set_grasp_pitch(new_pitch) + logger.info(f"Grasp pitch: {new_pitch:.0f} degrees") + return "pitch_up" + elif key_code == 84: # Down arrow - decrease pitch + new_pitch = max(0.0, self.grasp_pitch_degrees - 15.0) + self.set_grasp_pitch(new_pitch) + logger.info(f"Grasp pitch: {new_pitch:.0f} degrees") + return "pitch_down" + elif key_code == ord("g"): + logger.info("Opening gripper") + self.arm.release_gripper() + return "release" + + return "" + + @rpc + def pick_and_place(self, target_x: int = None, target_y: int = None) -> Dict[str, Any]: + """ + Start a pick and place task. + + Args: + target_x: Optional X coordinate of target object + target_y: Optional Y coordinate of target object + + Returns: + Dict with status and message + """ + if self.task_running: + return {"status": "error", "message": "Task already running"} + + if self.camera_intrinsics is None: + return {"status": "error", "message": "Camera not initialized"} + + # Set target if coordinates provided + if target_x is not None and target_y is not None: + self.target_click = (target_x, target_y) + + # Reset task state + self.task_failed = False + self.stop_event.clear() + + # Ensure any previous thread has finished + if self.task_thread and self.task_thread.is_alive(): + self.stop_event.set() + self.task_thread.join(timeout=1.0) + + # Start task in separate thread + self.task_thread = threading.Thread(target=self._run_pick_and_place, daemon=True) + self.task_thread.start() + + return {"status": "started", "message": "Pick and place task started"} + + def _run_pick_and_place(self): + """Run the pick and place task loop.""" + self.task_running = True + logger.info("Starting pick and place task") + + try: + while not self.stop_event.is_set(): + # Check for task failure + if self.task_failed: + logger.error("Task failed, terminating pick and place") + self.stop_event.set() + break + + # Update manipulation system + feedback = self.update() + if feedback is None: + time.sleep(0.01) + continue + + # Check if task is complete + if feedback.grasp_successful is not None: + if feedback.grasp_successful: + logger.info("Pick and place completed successfully!") + else: + logger.warning("Pick and place failed - no object detected") + # Reset to idle state and stop the event loop + self.reset_to_idle() + self.stop_event.set() + break + + # Small delay to prevent CPU overload + time.sleep(0.01) + + except Exception as e: + logger.error(f"Error in pick and place task: {e}") + self.task_failed = True + finally: + self.task_running = False + logger.info("Pick and place task ended") + + def set_grasp_stage(self, stage: GraspStage): + """Set the grasp stage.""" + self.grasp_stage = stage + logger.info(f"Grasp stage: {stage.value}") + + def set_grasp_pitch(self, pitch_degrees: float): + """Set the grasp pitch angle.""" + pitch_degrees = max(0.0, min(90.0, pitch_degrees)) + self.grasp_pitch_degrees = pitch_degrees + if self.pbvs: + self.pbvs.set_grasp_pitch(pitch_degrees) + + def _check_reach_timeout(self) -> bool: + """Check if robot has exceeded timeout while reaching pose.""" + if ( + self.waiting_start_time + and (time.time() - self.waiting_start_time) > self.reach_pose_timeout + ): + logger.warning(f"Robot failed to reach pose within {self.reach_pose_timeout}s timeout") + self.task_failed = True + self.reset_to_idle() + return True + return False + + def _update_tracking(self, detection_3d_array: Optional[Detection3DArray]) -> bool: + """Update tracking with new detections.""" + if not detection_3d_array or not self.pbvs: + return False + + target_tracked = self.pbvs.update_tracking(detection_3d_array) + if target_tracked: + self.target_updated = True + self.last_valid_target = self.pbvs.get_current_target() + return target_tracked + + def reset_to_idle(self): + """Reset the manipulation system to IDLE state.""" + if self.pbvs: + self.pbvs.clear_target() + self.grasp_stage = GraspStage.IDLE + self.reached_poses.clear() + self.adjustment_count = 0 + self.waiting_for_reach = False + self.last_commanded_pose = None + self.target_updated = False + self.stabilization_start_time = None + self.grasp_reached_time = None + self.waiting_start_time = None + self.pick_success = None + self.final_pregrasp_pose = None + + self.arm.gotoObserve() + + def execute_idle(self): + """Execute idle stage: just visualization, no control.""" + pass + + def execute_pre_grasp(self): + """Execute pre-grasp stage: visual servoing to pre-grasp position.""" + ee_pose = self.arm.get_ee_pose() + + # Check if waiting for robot to reach commanded pose + if self.waiting_for_reach and self.last_commanded_pose: + # Check for timeout + if self._check_reach_timeout(): + return + + reached = self.pbvs.is_target_reached(ee_pose) + + if reached: + self.waiting_for_reach = False + self.waiting_start_time = None + self.reached_poses.append(self.last_commanded_pose) + self.target_updated = False + time.sleep(0.3) + + return + + # Check stabilization timeout + if ( + self.stabilization_start_time + and (time.time() - self.stabilization_start_time) > self.stabilization_timeout + ): + logger.warning( + f"Failed to get stable grasp after {self.stabilization_timeout} seconds, resetting" + ) + self.task_failed = True + self.reset_to_idle() + return + + # PBVS control with pre-grasp distance + _, _, _, has_target, target_pose = self.pbvs.compute_control( + ee_pose, self.pregrasp_distance + ) + + # Handle pose control + if target_pose and has_target: + # Check if we have enough reached poses and they're stable + if self.check_target_stabilized(): + logger.info("Target stabilized, transitioning to GRASP") + self.final_pregrasp_pose = self.last_commanded_pose + self.grasp_stage = GraspStage.GRASP + self.adjustment_count = 0 + self.waiting_for_reach = False + elif not self.waiting_for_reach and self.target_updated: + # Command the pose only if target has been updated + self.arm.cmd_ee_pose(target_pose) + self.last_commanded_pose = target_pose + self.waiting_for_reach = True + self.waiting_start_time = time.time() + self.target_updated = False + self.adjustment_count += 1 + time.sleep(0.2) + + def execute_grasp(self): + """Execute grasp stage: move to final grasp position.""" + ee_pose = self.arm.get_ee_pose() + + # Handle waiting with special grasp logic + if self.waiting_for_reach: + if self._check_reach_timeout(): + return + + if self.pbvs.is_target_reached(ee_pose) and not self.grasp_reached_time: + self.grasp_reached_time = time.time() + self.waiting_start_time = None + + # Check if delay completed + if ( + self.grasp_reached_time + and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay + ): + logger.info("Grasp delay completed, closing gripper") + self.grasp_stage = GraspStage.CLOSE_AND_RETRACT + self.waiting_for_reach = False + return + + # Only command grasp if not waiting and have valid target + if self.last_valid_target: + # Calculate grasp distance based on pitch angle + normalized_pitch = self.grasp_pitch_degrees / 90.0 + grasp_distance = -self.grasp_distance_range + ( + 2 * self.grasp_distance_range * normalized_pitch + ) + + # PBVS control with calculated grasp distance + _, _, _, has_target, target_pose = self.pbvs.compute_control(ee_pose, grasp_distance) + + if target_pose and has_target: + # Calculate gripper opening + object_width = self.last_valid_target.bbox.size.x + gripper_opening = max( + 0.005, min(object_width + self.grasp_width_offset, self.gripper_max_opening) + ) + + logger.info(f"Executing grasp: gripper={gripper_opening * 1000:.1f}mm") + + # Command gripper and pose + self.arm.cmd_gripper_ctrl(gripper_opening) + self.arm.cmd_ee_pose(target_pose, line_mode=True) + self.waiting_for_reach = True + self.waiting_start_time = time.time() + + def execute_close_and_retract(self): + """Execute the retraction sequence after gripper has been closed.""" + ee_pose = self.arm.get_ee_pose() + + if self.waiting_for_reach: + if self._check_reach_timeout(): + return + + # Check if reached retraction pose + original_target = self.pbvs.target_grasp_pose + self.pbvs.target_grasp_pose = self.final_pregrasp_pose + reached = self.pbvs.is_target_reached(ee_pose) + self.pbvs.target_grasp_pose = original_target + + if reached: + logger.info("Reached pre-grasp retraction position") + self.waiting_for_reach = False + self.pick_success = self.arm.gripper_object_detected() + logger.info(f"Grasp sequence completed") + if self.pick_success: + logger.info("Object successfully grasped!") + else: + logger.warning("No object detected in gripper") + self.task_failed = True + # Don't reset to idle here - let the task loop handle it after detecting completion + else: + # Command retraction to pre-grasp + logger.info("Retracting to pre-grasp position") + self.arm.cmd_ee_pose(self.final_pregrasp_pose, line_mode=True) + self.arm.close_gripper() + self.waiting_for_reach = True + self.waiting_start_time = time.time() + + def capture_and_process( + self, + ) -> Tuple[ + Optional[np.ndarray], Optional[Detection3DArray], Optional[Detection2DArray], Optional[Pose] + ]: + """Capture frame from camera data and process detections.""" + # Check if we have all required data + if self.latest_rgb is None or self.latest_depth is None or self.detector is None: + return None, None, None, None + + # Get EE pose and camera transform + ee_pose = self.arm.get_ee_pose() + ee_transform = pose_to_matrix(ee_pose) + camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) + camera_pose = matrix_to_pose(camera_transform) + + # Process detections + detection_3d_array, detection_2d_array = self.detector.process_frame( + self.latest_rgb, self.latest_depth, camera_transform + ) + + return self.latest_rgb, detection_3d_array, detection_2d_array, camera_pose + + def pick_target(self, x: int, y: int) -> bool: + """Select a target object at the given pixel coordinates.""" + if not self.last_detection_2d_array or not self.last_detection_3d_array: + logger.warning("No detections available for target selection") + return False + + clicked_3d = find_clicked_detection( + (x, y), self.last_detection_2d_array.detections, self.last_detection_3d_array.detections + ) + if clicked_3d and self.pbvs: + self.pbvs.set_target(clicked_3d) + logger.info( + f"Target selected: ID={clicked_3d.id}, pos=({clicked_3d.bbox.center.position.x:.3f}, {clicked_3d.bbox.center.position.y:.3f}, {clicked_3d.bbox.center.position.z:.3f})" + ) + self.grasp_stage = GraspStage.PRE_GRASP + self.reached_poses.clear() + self.adjustment_count = 0 + self.waiting_for_reach = False + self.last_commanded_pose = None + self.stabilization_start_time = time.time() + return True + return False + + def update(self) -> Optional[Dict[str, Any]]: + """Main update function that handles capture, processing, control, and visualization.""" + # Capture and process frame + rgb, detection_3d_array, detection_2d_array, camera_pose = self.capture_and_process() + if rgb is None: + return None + + # Store for target selection + self.last_detection_3d_array = detection_3d_array + self.last_detection_2d_array = detection_2d_array + + # Handle target selection if click is pending + if self.target_click: + x, y = self.target_click + if self.pick_target(x, y): + self.target_click = None + + # Update tracking if we have detections and not in IDLE or CLOSE_AND_RETRACT + if ( + detection_3d_array + and self.grasp_stage in [GraspStage.PRE_GRASP, GraspStage.GRASP] + and not self.waiting_for_reach + ): + self._update_tracking(detection_3d_array) + + # Execute stage-specific logic + stage_handlers = { + GraspStage.IDLE: self.execute_idle, + GraspStage.PRE_GRASP: self.execute_pre_grasp, + GraspStage.GRASP: self.execute_grasp, + GraspStage.CLOSE_AND_RETRACT: self.execute_close_and_retract, + } + if self.grasp_stage in stage_handlers: + stage_handlers[self.grasp_stage]() + + # Get tracking status + target_tracked = self.pbvs.get_current_target() is not None if self.pbvs else False + + # Create feedback object + ee_pose = self.arm.get_ee_pose() + feedback = Feedback( + grasp_stage=self.grasp_stage, + target_tracked=target_tracked, + last_commanded_pose=self.last_commanded_pose, + current_ee_pose=ee_pose, + current_camera_pose=camera_pose, + target_pose=self.pbvs.target_grasp_pose if self.pbvs else None, + waiting_for_reach=self.waiting_for_reach, + grasp_successful=self.pick_success, + ) + + # Create visualization only if task is running + if self.task_running: + self.current_visualization = create_manipulation_visualization( + rgb, feedback, detection_3d_array, detection_2d_array + ) + + # Publish visualization + if self.current_visualization is not None: + self._publish_visualization(self.current_visualization) + + return feedback + + def _publish_visualization(self, viz_image: np.ndarray): + """Publish visualization image to LCM.""" + try: + # Convert BGR to RGB for publishing + viz_rgb = cv2.cvtColor(viz_image, cv2.COLOR_BGR2RGB) + + # Create LCM Image message + height, width = viz_rgb.shape[:2] + data = viz_rgb.tobytes() + + msg = Image( + data_length=len(data), + height=height, + width=width, + encoding="rgb8", + is_bigendian=0, + step=width * 3, + data=data, + ) + + self.viz_image.publish(msg) + except Exception as e: + logger.error(f"Error publishing visualization: {e}") + + def check_target_stabilized(self) -> bool: + """Check if the commanded poses have stabilized.""" + if len(self.reached_poses) < self.reached_poses.maxlen: + return False + + # Extract positions + positions = np.array( + [[p.position.x, p.position.y, p.position.z] for p in self.reached_poses] + ) + + # Calculate standard deviation for each axis + std_devs = np.std(positions, axis=0) + + # Check if all axes are below threshold + return np.all(std_devs < self.pose_stabilization_threshold) + + def cleanup(self): + """Clean up resources on module destruction.""" + self.stop() diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py index 4af038b0a0..0192b1aa56 100644 --- a/tests/test_ibvs.py +++ b/tests/test_ibvs.py @@ -129,7 +129,6 @@ def main(): pass finally: cv2.destroyAllWindows() - manipulation.cleanup() zed.close() arm.disable() diff --git a/tests/test_pick_and_place_module.py b/tests/test_pick_and_place_module.py new file mode 100644 index 0000000000..dd7ce174e2 --- /dev/null +++ b/tests/test_pick_and_place_module.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test script for pick and place manipulation module. +Subscribes to visualization images and handles mouse/keyboard input. +""" + +import cv2 +import sys +import asyncio +import threading +import time +import numpy as np +from typing import Optional + +try: + import pyzed.sl as sl +except ImportError: + print("Error: ZED SDK not installed.") + sys.exit(1) + +from dimos import core +from dimos.hardware.zed_camera import ZEDModule +from dimos.manipulation.visual_servoing.manipulation_module import ManipulationModule +from dimos.protocol import pubsub +from dimos.utils.logging_config import setup_logger + +# Import LCM message types +from dimos_lcm.sensor_msgs import Image as LCMImage +from dimos_lcm.sensor_msgs import CameraInfo +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic + +logger = setup_logger("test_pick_and_place_module") + +# Global for mouse events +mouse_click = None +camera_mouse_click = None +current_window = None + + +def mouse_callback(event, x, y, _flags, param): + global mouse_click, camera_mouse_click + window_name = param + if event == cv2.EVENT_LBUTTONDOWN: + if window_name == "Camera Feed": + camera_mouse_click = (x, y) + else: + mouse_click = (x, y) + + +class VisualizationNode: + """Node that subscribes to visualization images and handles user input.""" + + def __init__(self, manipulation_module): + self.lcm = LCM() + self.latest_viz = None + self.latest_camera = None + self._running = False + self.manipulation = manipulation_module + + # Subscribe to visualization topic + self.viz_topic = Topic("/manipulation/viz", LCMImage) + self.camera_topic = Topic("/zed/color_image", LCMImage) + + def start(self): + """Start the visualization node.""" + self._running = True + self.lcm.start() + + # Subscribe to visualization topic + self.lcm.subscribe(self.viz_topic, self._on_viz_image) + # Subscribe to camera topic for point selection + self.lcm.subscribe(self.camera_topic, self._on_camera_image) + + logger.info("Visualization node started") + + def stop(self): + """Stop the visualization node.""" + self._running = False + cv2.destroyAllWindows() + + def _on_viz_image(self, msg: LCMImage, topic: str): + """Handle visualization image messages.""" + try: + # Convert LCM message to numpy array + data = np.frombuffer(msg.data, dtype=np.uint8) + if msg.encoding == "rgb8": + image = data.reshape((msg.height, msg.width, 3)) + # Convert RGB to BGR for OpenCV + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + self.latest_viz = image + except Exception as e: + logger.error(f"Error processing viz image: {e}") + + def _on_camera_image(self, msg: LCMImage, topic: str): + """Handle camera image messages.""" + try: + # Convert LCM message to numpy array + data = np.frombuffer(msg.data, dtype=np.uint8) + if msg.encoding == "rgb8": + image = data.reshape((msg.height, msg.width, 3)) + # Convert RGB to BGR for OpenCV + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + self.latest_camera = image + except Exception as e: + logger.error(f"Error processing camera image: {e}") + + def run_visualization(self): + """Run the visualization loop with user interaction.""" + global mouse_click, camera_mouse_click + + # Setup windows + cv2.namedWindow("Pick and Place") + cv2.setMouseCallback("Pick and Place", mouse_callback, "Pick and Place") + + cv2.namedWindow("Camera Feed") + cv2.setMouseCallback("Camera Feed", mouse_callback, "Camera Feed") + + print("=== Pick and Place Module Test ===") + print("Control mode: Module-based with LCM communication") + print("Click objects to select targets | 'r' - reset | 'q' - quit") + print("SAFETY CONTROLS:") + print(" 's' - SOFT STOP (emergency stop)") + print(" 'g' - RELEASE GRIPPER (open gripper)") + print(" 'SPACE' - EXECUTE target pose (manual override)") + print("GRASP PITCH CONTROLS:") + print(" '↑' - Increase grasp pitch by 15° (towards top-down)") + print(" '↓' - Decrease grasp pitch by 15° (towards level)") + print(" 'p' - Start pick and place task") + print("\nNOTE: Click on objects in the Camera Feed window to select targets!") + + while self._running: + # Show camera feed (always available) + if self.latest_camera is not None: + cv2.imshow("Camera Feed", self.latest_camera) + + # Show visualization if available + if self.latest_viz is not None: + cv2.imshow("Pick and Place", self.latest_viz) + + # Handle keyboard input + key = cv2.waitKey(1) & 0xFF + if key != 255: # Key was pressed + if key == ord("q"): + logger.info("Quit requested") + self._running = False + break + elif key == ord("p"): + # Start pick and place task + if mouse_click: + x, y = mouse_click + result = self.manipulation.pick_and_place(x, y) + logger.info(f"Pick and place task: {result}") + mouse_click = None + else: + result = self.manipulation.pick_and_place() + logger.info(f"Pick and place task (no target): {result}") + else: + # Send keyboard command to manipulation module + if key in [82, 84]: # Arrow keys + action = self.manipulation.handle_keyboard_command(str(key)) + else: + action = self.manipulation.handle_keyboard_command(chr(key)) + if action: + logger.info(f"Action: {action}") + + # Handle mouse click from Camera Feed window + if camera_mouse_click: + # Start pick and place task with the clicked point + x, y = camera_mouse_click + result = self.manipulation.pick_and_place(x, y) + logger.info(f"Started pick and place at ({x}, {y}) from camera feed: {result}") + camera_mouse_click = None + + # Handle mouse click from Pick and Place window (if viz is running) + elif mouse_click and self.latest_viz is not None: + # If there's a pending click and we're not running a task, start one + x, y = mouse_click + result = self.manipulation.pick_and_place(x, y) + logger.info(f"Started pick and place at ({x}, {y}): {result}") + mouse_click = None + + time.sleep(0.03) # ~30 FPS + + +async def test_pick_and_place_module(): + """Test the pick and place manipulation module.""" + logger.info("Starting Pick and Place Module test") + + # Start Dask + dimos = core.start(2) # Need 2 workers for ZED and manipulation modules + + # Enable LCM auto-configuration + pubsub.lcm.autoconf() + + try: + # Deploy ZED module + logger.info("Deploying ZED module...") + zed = dimos.deploy( + ZEDModule, + camera_id=0, + resolution="HD720", + depth_mode="NEURAL", + fps=30, + enable_tracking=False, # We don't need tracking for manipulation + publish_rate=30.0, + frame_id="zed_camera", + ) + + # Configure ZED LCM transports + zed.color_image.transport = core.LCMTransport("/zed/color_image", LCMImage) + zed.depth_image.transport = core.LCMTransport("/zed/depth_image", LCMImage) + zed.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) + + # Deploy manipulation module + logger.info("Deploying manipulation module...") + manipulation = dimos.deploy(ManipulationModule) + + # Connect manipulation inputs to ZED outputs + manipulation.rgb_image.connect(zed.color_image) + manipulation.depth_image.connect(zed.depth_image) + manipulation.camera_info.connect(zed.camera_info) + + # Configure manipulation output + manipulation.viz_image.transport = core.LCMTransport("/manipulation/viz", LCMImage) + + # Print module info + logger.info("Modules configured:") + print("\nZED Module:") + print(zed.io().result()) + print("\nManipulation Module:") + print(manipulation.io().result()) + + # Start modules + logger.info("Starting modules...") + zed.start() + manipulation.start() + + # Give modules time to initialize + await asyncio.sleep(2) + + # Create and start visualization node + viz_node = VisualizationNode(manipulation) + viz_node.start() + + # Run visualization in separate thread + viz_thread = threading.Thread(target=viz_node.run_visualization, daemon=True) + viz_thread.start() + + # Keep running until visualization stops + while viz_node._running: + await asyncio.sleep(0.1) + + # Stop modules + logger.info("Stopping modules...") + manipulation.stop() + zed.stop() + + # Stop visualization + viz_node.stop() + + except Exception as e: + logger.error(f"Error in test: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + dimos.close() + logger.info("Test completed") + + +if __name__ == "__main__": + # Run the test + asyncio.run(test_pick_and_place_module()) From a58c6c5e7766f31ff843c515bb1361ed846d7a27 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 23 Jul 2025 02:39:43 -0700 Subject: [PATCH 78/89] fully implemented place in pick and place, cleaned up significantly --- dimos/hardware/piper_arm.py | 3 +- .../visual_servoing/detection3d.py | 50 +- .../visual_servoing/manipulation.py | 590 ------------------ .../visual_servoing/manipulation_module.py | 219 ++++++- dimos/manipulation/visual_servoing/pbvs.py | 102 +-- dimos/manipulation/visual_servoing/utils.py | 339 +++++++++- dimos/perception/common/utils.py | 99 ++- dimos/perception/grasp_generation/utils.py | 44 +- dimos/perception/pointcloud/utils.py | 43 +- tests/test_ibvs.py | 137 ---- .../test_manipulation_perception_pipeline.py | 2 +- tests/test_pick_and_place_module.py | 170 ++++- 12 files changed, 780 insertions(+), 1018 deletions(-) delete mode 100644 dimos/manipulation/visual_servoing/manipulation.py delete mode 100644 tests/test_ibvs.py diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 910083ed3e..9921c53c8a 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -109,7 +109,6 @@ def gotoObserve(self): logger.debug(f"Going to zero position: X={X}, Y={Y}, Z={Z}, RX={RX}, RY={RY}, RZ={RZ}") self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) - self.arm.GripperCtrl(0, 1000, 0x01, 0) def softStop(self): self.gotoZero() @@ -214,7 +213,7 @@ def get_gripper_feedback(self) -> Tuple[float, float]: effort = gripper_msg.gripper_state.grippers_effort / 1000.0 # Convert from SDK units to N/m return angle_degrees, effort - def close_gripper(self, commanded_effort: float = 0.25) -> None: + def close_gripper(self, commanded_effort: float = 0.5) -> None: """ Close the gripper. diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index 01f51cf2b3..9eaf48d774 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -40,13 +40,10 @@ Point2D, ) from dimos_lcm.std_msgs import Header -from dimos.manipulation.visual_servoing.utils import estimate_object_depth, visualize_detections_3d -from dimos.utils.transform_utils import ( - optical_to_robot_frame, - pose_to_matrix, - matrix_to_pose, - euler_to_quaternion, - compose_transforms, +from dimos.manipulation.visual_servoing.utils import ( + estimate_object_depth, + visualize_detections_3d, + transform_pose, ) logger = setup_logger("dimos.perception.detection3d") @@ -180,8 +177,8 @@ def process_frame( obj_cam_orientation = pose.get( "rotation", np.array([0.0, 0.0, 0.0]) ) # Default to no rotation - transformed_pose = self._transform_object_pose( - obj_cam_pos, obj_cam_orientation, transform + transformed_pose = transform_pose( + obj_cam_pos, obj_cam_orientation, transform, to_robot=True ) center_pose = transformed_pose else: @@ -240,41 +237,6 @@ def process_frame( ), ) - def _transform_object_pose( - self, obj_pos: np.ndarray, obj_orientation: np.ndarray, transform_matrix: np.ndarray - ) -> Pose: - """ - Transform object pose from optical frame to desired frame using transformation matrix. - - Args: - obj_pos: Object position in optical frame [x, y, z] - obj_orientation: Object orientation in optical frame [roll, pitch, yaw] in radians - transform_matrix: 4x4 transformation matrix from camera frame to desired frame - - Returns: - Object pose in desired frame as Pose - """ - # Create object pose in optical frame - # Convert euler angles to quaternion using utility function - euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2]) - obj_orientation_quat = euler_to_quaternion(euler_vector) - - obj_pose_optical = Pose(Point(obj_pos[0], obj_pos[1], obj_pos[2]), obj_orientation_quat) - - # Transform object pose from optical frame to robot frame convention first - obj_pose_robot_frame = optical_to_robot_frame(obj_pose_optical) - - # Create transformation matrix from object pose (relative to camera) - T_camera_object = pose_to_matrix(obj_pose_robot_frame) - - # Use compose_transforms to combine transformations - T_desired_object = compose_transforms(transform_matrix, T_camera_object) - - # Convert back to pose - desired_pose = matrix_to_pose(T_desired_object) - - return desired_pose - def visualize_detections( self, rgb_image: np.ndarray, diff --git a/dimos/manipulation/visual_servoing/manipulation.py b/dimos/manipulation/visual_servoing/manipulation.py deleted file mode 100644 index 38a57244ed..0000000000 --- a/dimos/manipulation/visual_servoing/manipulation.py +++ /dev/null @@ -1,590 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation system for robotic grasping with visual servoing. -Handles grasping logic, state machine, and hardware coordination. -""" - -import cv2 -import time -from typing import Optional, Tuple, Any -from enum import Enum -from collections import deque - -import numpy as np - -from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor -from dimos.manipulation.visual_servoing.pbvs import PBVS -from dimos.perception.common.utils import ( - find_clicked_detection, -) -from dimos.manipulation.visual_servoing.utils import ( - create_manipulation_visualization, -) -from dimos.utils.transform_utils import ( - pose_to_matrix, - matrix_to_pose, - create_transform_from_6dof, - compose_transforms, -) -from dimos.utils.logging_config import setup_logger -from dimos_lcm.geometry_msgs import Vector3, Pose -from dimos_lcm.vision_msgs import Detection3DArray, Detection2DArray - -logger = setup_logger("dimos.manipulation.manipulation") - - -class GraspStage(Enum): - """Enum for different grasp stages.""" - - IDLE = "idle" # No target set - PRE_GRASP = "pre_grasp" # Target set, moving to pre-grasp position - GRASP = "grasp" # Executing final grasp - CLOSE_AND_RETRACT = "close_and_retract" # Close gripper and retract - - -class Feedback: - """ - Feedback data returned by the manipulation system update. - - Contains comprehensive state information about the manipulation process. - """ - - def __init__( - self, - grasp_stage: GraspStage, - target_tracked: bool, - last_commanded_pose: Optional[Pose] = None, - current_ee_pose: Optional[Pose] = None, - current_camera_pose: Optional[Pose] = None, - target_pose: Optional[Pose] = None, - waiting_for_reach: bool = False, - grasp_successful: Optional[bool] = None, - ): - self.grasp_stage = grasp_stage - self.target_tracked = target_tracked - self.last_commanded_pose = last_commanded_pose - self.current_ee_pose = current_ee_pose - self.current_camera_pose = current_camera_pose - self.target_pose = target_pose - self.waiting_for_reach = waiting_for_reach - self.grasp_successful = grasp_successful - - -class Manipulation: - """ - High-level manipulation orchestrator for visual servoing and grasping. - - Handles: - - State machine for grasping sequences - - Grasp execution logic - - Coordination between perception and control - - This class is hardware-agnostic and accepts camera and arm objects. - """ - - def __init__( - self, - camera: Any, # Generic camera object with required interface - arm: Any, # Generic arm object with required interface - ee_to_camera_6dof: Optional[list] = None, - ): - """ - Initialize manipulation system. - - Args: - camera: Camera object with capture_frame_with_pose() and calculate_intrinsics() methods - arm: Robot arm object with get_ee_pose(), cmd_ee_pose(), - cmd_gripper_ctrl(), release_gripper(), softStop(), gotoZero(), gotoObserve(), and disable() methods - ee_to_camera_6dof: EE to camera transform [x, y, z, rx, ry, rz] in meters and radians - """ - self.camera = camera - self.arm = arm - - # Default EE to camera transform if not provided - if ee_to_camera_6dof is None: - ee_to_camera_6dof = [-0.065, 0.03, -0.105, 0.0, -1.57, 0.0] - - # Create transform matrices - pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) - rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) - self.T_ee_to_camera = create_transform_from_6dof(pos, rot) - - # Get camera intrinsics - cam_intrinsics = camera.calculate_intrinsics() - camera_intrinsics = [ - cam_intrinsics["focal_length_x"], - cam_intrinsics["focal_length_y"], - cam_intrinsics["principal_point_x"], - cam_intrinsics["principal_point_y"], - ] - - # Initialize processors - self.detector = Detection3DProcessor(camera_intrinsics) - self.pbvs = PBVS( - target_tolerance=0.05, - ) - - # Control state - self.last_valid_target = None - self.waiting_for_reach = False # True when waiting for robot to reach commanded pose - self.last_commanded_pose = None # Last pose sent to robot - self.target_updated = False # True when target has been updated with fresh detections - self.waiting_start_time = None # Time when waiting for reach started - self.reach_pose_timeout = 10.0 # Timeout for reaching commanded pose (seconds) - - # Grasp parameters - self.grasp_width_offset = 0.03 # Default grasp width offset - self.grasp_pitch_degrees = 30.0 # Default grasp pitch in degrees - self.pregrasp_distance = 0.25 # Distance to maintain before grasping (m) - self.grasp_distance_range = 0.03 # Range for grasp distance mapping (±5cm = ±0.05m) - self.grasp_close_delay = 2.0 # Time to wait at grasp pose before closing (seconds) - self.grasp_reached_time = None # Time when grasp pose was reached - self.gripper_max_opening = 0.07 # Maximum gripper opening (m) - - # Grasp stage tracking - self.grasp_stage = GraspStage.IDLE - - # Pose stabilization tracking - self.pose_history_size = 4 # Number of poses to check for stabilization - self.pose_stabilization_threshold = 0.01 # 1cm threshold for stabilization - self.stabilization_timeout = 15.0 # Timeout in seconds before giving up - self.stabilization_start_time = None # Time when stabilization started - self.reached_poses = deque( - maxlen=self.pose_history_size - ) # Only stores poses that were reached - self.adjustment_count = 0 - - # State for visualization - self.current_visualization = None - self.last_detection_3d_array = None - self.last_detection_2d_array = None - - # Grasp result - self.pick_success = None # True if last grasp was successful - self.final_pregrasp_pose = None # Store the final pre-grasp pose for retraction - - # Go to observe position - self.arm.gotoObserve() - - def set_grasp_stage(self, stage: GraspStage): - """ - Set the grasp stage. - - Args: - stage: The new grasp stage - """ - self.grasp_stage = stage - logger.info(f"Grasp stage: {stage.value}") - - def set_grasp_pitch(self, pitch_degrees: float): - """ - Set the grasp pitch angle. - - Args: - pitch_degrees: Grasp pitch angle in degrees (0-90) - 0 = level grasp, 90 = top-down grasp - """ - # Clamp to valid range - pitch_degrees = max(0.0, min(90.0, pitch_degrees)) - self.grasp_pitch_degrees = pitch_degrees - self.pbvs.set_grasp_pitch(pitch_degrees) - - def _check_reach_timeout(self) -> bool: - """ - Check if robot has exceeded timeout while reaching pose. - - Returns: - True if timeout exceeded, False otherwise - """ - if ( - self.waiting_start_time - and (time.time() - self.waiting_start_time) > self.reach_pose_timeout - ): - logger.warning(f"Robot failed to reach pose within {self.reach_pose_timeout}s timeout") - self.reset_to_idle() - return True - return False - - def _update_tracking(self, detection_3d_array: Optional[Detection3DArray]) -> bool: - """ - Update tracking with new detections in a compact way. - - Args: - detection_3d_array: Optional detection array - - Returns: - True if target is tracked - """ - if not detection_3d_array: - return False - - target_tracked = self.pbvs.update_tracking(detection_3d_array) - if target_tracked: - self.target_updated = True - self.last_valid_target = self.pbvs.get_current_target() - return target_tracked - - def reset_to_idle(self): - """Reset the manipulation system to IDLE state.""" - self.pbvs.clear_target() - self.grasp_stage = GraspStage.IDLE - self.reached_poses.clear() - self.adjustment_count = 0 - self.waiting_for_reach = False - self.last_commanded_pose = None - self.target_updated = False - self.stabilization_start_time = None - self.grasp_reached_time = None - self.waiting_start_time = None - self.pick_success = None - self.final_pregrasp_pose = None - - self.arm.gotoObserve() - - def execute_idle(self): - """Execute idle stage: just visualization, no control.""" - # Nothing to do in idle - pass - - def execute_pre_grasp(self): - """Execute pre-grasp stage: visual servoing to pre-grasp position.""" - ee_pose = self.arm.get_ee_pose() - - # Check if waiting for robot to reach commanded pose - if self.waiting_for_reach and self.last_commanded_pose: - # Check for timeout - if self._check_reach_timeout(): - return - - reached = self.pbvs.is_target_reached(ee_pose) - - if reached: - self.waiting_for_reach = False - self.waiting_start_time = None - self.reached_poses.append(self.last_commanded_pose) - self.target_updated = False # Reset flag so we wait for fresh update - time.sleep(0.3) - - # While waiting, don't process new commands - return - - # Check stabilization timeout - if ( - self.stabilization_start_time - and (time.time() - self.stabilization_start_time) > self.stabilization_timeout - ): - logger.warning( - f"Failed to get stable grasp after {self.stabilization_timeout} seconds, resetting" - ) - self.reset_to_idle() - return - - # PBVS control with pre-grasp distance - _, _, _, has_target, target_pose = self.pbvs.compute_control( - ee_pose, self.pregrasp_distance - ) - - # Handle pose control - if target_pose and has_target: - # Check if we have enough reached poses and they're stable - if self.check_target_stabilized(): - logger.info("Target stabilized, transitioning to GRASP") - self.final_pregrasp_pose = self.last_commanded_pose - self.grasp_stage = GraspStage.GRASP - self.adjustment_count = 0 - self.waiting_for_reach = False - elif not self.waiting_for_reach and self.target_updated: - # Command the pose only if target has been updated - self.arm.cmd_ee_pose(target_pose) - self.last_commanded_pose = target_pose - self.waiting_for_reach = True - self.waiting_start_time = time.time() - self.target_updated = False - self.adjustment_count += 1 - time.sleep(0.2) - - def execute_grasp(self): - """Execute grasp stage: move to final grasp position.""" - ee_pose = self.arm.get_ee_pose() - - # Handle waiting with special grasp logic - if self.waiting_for_reach: - if self._check_reach_timeout(): - return - - if self.pbvs.is_target_reached(ee_pose) and not self.grasp_reached_time: - self.grasp_reached_time = time.time() - self.waiting_start_time = None - - # Check if delay completed - if ( - self.grasp_reached_time - and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay - ): - logger.info("Grasp delay completed, closing gripper") - self.grasp_stage = GraspStage.CLOSE_AND_RETRACT - self.waiting_for_reach = False - return - - # Only command grasp if not waiting and have valid target - if self.last_valid_target: - # Calculate grasp distance based on pitch angle (0° -> -5cm, 90° -> +5cm) - normalized_pitch = self.grasp_pitch_degrees / 90.0 - grasp_distance = -self.grasp_distance_range + ( - 2 * self.grasp_distance_range * normalized_pitch - ) - - # PBVS control with calculated grasp distance - _, _, _, has_target, target_pose = self.pbvs.compute_control(ee_pose, grasp_distance) - - if target_pose and has_target: - # Calculate gripper opening - object_width = self.last_valid_target.bbox.size.x - gripper_opening = max( - 0.005, min(object_width + self.grasp_width_offset, self.gripper_max_opening) - ) - - logger.info(f"Executing grasp: gripper={gripper_opening * 1000:.1f}mm") - - # Command gripper and pose - self.arm.cmd_gripper_ctrl(gripper_opening) - self.arm.cmd_ee_pose(target_pose, line_mode=True) - self.waiting_for_reach = True - self.waiting_start_time = time.time() - - def execute_close_and_retract(self): - """Execute the retraction sequence after gripper has been closed.""" - ee_pose = self.arm.get_ee_pose() - - if self.waiting_for_reach: - if self._check_reach_timeout(): - return - - # Check if reached retraction pose - original_target = self.pbvs.target_grasp_pose - self.pbvs.target_grasp_pose = self.final_pregrasp_pose - reached = self.pbvs.is_target_reached(ee_pose) - self.pbvs.target_grasp_pose = original_target - - if reached: - logger.info("Reached pre-grasp retraction position") - self.waiting_for_reach = False - self.pick_success = self.arm.gripper_object_detected() - logger.info(f"Grasp sequence completed") - if self.pick_success: - logger.info("Object successfully grasped!") - else: - logger.warning("No object detected in gripper") - self.reset_to_idle() - else: - # Command retraction to pre-grasp - logger.info("Retracting to pre-grasp position") - self.arm.cmd_ee_pose(self.final_pregrasp_pose, line_mode=True) - self.arm.close_gripper() - self.waiting_for_reach = True - self.waiting_start_time = time.time() - - def capture_and_process( - self, - ) -> Tuple[ - Optional[np.ndarray], Optional[Detection3DArray], Optional[Detection2DArray], Optional[Pose] - ]: - """ - Capture frame from camera and process detections. - - Returns: - Tuple of (rgb_image, detection_3d_array, detection_2d_array, camera_pose) - Returns None values if capture fails - """ - # Capture frame - bgr, _, depth, _ = self.camera.capture_frame_with_pose() - if bgr is None or depth is None: - return None, None, None, None - - rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - - # Get EE pose and camera transform - ee_pose = self.arm.get_ee_pose() - ee_transform = pose_to_matrix(ee_pose) - camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) - camera_pose = matrix_to_pose(camera_transform) - - # Process detections - detection_3d_array, detection_2d_array = self.detector.process_frame( - rgb, depth, camera_transform - ) - - return rgb, detection_3d_array, detection_2d_array, camera_pose - - def pick_target(self, x: int, y: int) -> bool: - """ - Select a target object at the given pixel coordinates. - - Args: - x: X coordinate in image - y: Y coordinate in image - - Returns: - True if a target was successfully selected - """ - if not self.last_detection_2d_array or not self.last_detection_3d_array: - logger.warning("No detections available for target selection") - return False - - clicked_3d = find_clicked_detection( - (x, y), self.last_detection_2d_array.detections, self.last_detection_3d_array.detections - ) - if clicked_3d: - self.pbvs.set_target(clicked_3d) - logger.info( - f"Target selected: ID={clicked_3d.id}, pos=({clicked_3d.bbox.center.position.x:.3f}, {clicked_3d.bbox.center.position.y:.3f}, {clicked_3d.bbox.center.position.z:.3f})" - ) - self.grasp_stage = GraspStage.PRE_GRASP # Transition from IDLE to PRE_GRASP - self.reached_poses.clear() # Clear pose history - self.adjustment_count = 0 # Reset adjustment counter - self.waiting_for_reach = False # Ensure we're not stuck in waiting state - self.last_commanded_pose = None - self.stabilization_start_time = time.time() # Start the timeout timer - return True - return False - - def update(self) -> Optional[Feedback]: - """ - Main update function that handles capture, processing, control, and visualization. - - Returns: - Feedback object with current state information, or None if capture failed - """ - # Capture and process frame - rgb, detection_3d_array, detection_2d_array, camera_pose = self.capture_and_process() - if rgb is None: - return None - - # Store for target selection - self.last_detection_3d_array = detection_3d_array - self.last_detection_2d_array = detection_2d_array - - # Update tracking if we have detections and not in IDLE or CLOSE_AND_RETRACT - # Only update if not waiting for reach (to ensure fresh updates after reaching) - if ( - detection_3d_array - and self.grasp_stage in [GraspStage.PRE_GRASP, GraspStage.GRASP] - and not self.waiting_for_reach - ): - self._update_tracking(detection_3d_array) - - # Execute stage-specific logic - stage_handlers = { - GraspStage.IDLE: self.execute_idle, - GraspStage.PRE_GRASP: self.execute_pre_grasp, - GraspStage.GRASP: self.execute_grasp, - GraspStage.CLOSE_AND_RETRACT: self.execute_close_and_retract, - } - if self.grasp_stage in stage_handlers: - stage_handlers[self.grasp_stage]() - - # Get tracking status - target_tracked = self.pbvs.get_current_target() is not None - - # Create feedback - ee_pose = self.arm.get_ee_pose() - feedback = Feedback( - grasp_stage=self.grasp_stage, - target_tracked=target_tracked, - last_commanded_pose=self.last_commanded_pose, - current_ee_pose=ee_pose, - current_camera_pose=camera_pose, - target_pose=self.pbvs.target_grasp_pose, - waiting_for_reach=self.waiting_for_reach, - grasp_successful=self.pick_success, - ) - - # Create simple visualization using feedback - self.current_visualization = create_manipulation_visualization( - rgb, feedback, detection_3d_array, detection_2d_array - ) - - return feedback - - def get_visualization(self) -> Optional[np.ndarray]: - """ - Get the current visualization image. - - Returns: - BGR image with visualizations, or None if no visualization available - """ - return self.current_visualization - - def handle_keyboard_command(self, key: int) -> str: - """ - Handle keyboard commands for robot control. - - Args: - key: Keyboard key code - - Returns: - Action taken as string, or empty string if no action - """ - if key == ord("r"): - self.reset_to_idle() - return "reset" - elif key == ord("s"): - print("SOFT STOP - Emergency stopping robot!") - self.arm.softStop() - return "stop" - elif key == ord(" ") and self.pbvs.target_grasp_pose: - # Manual override - immediately transition to GRASP if in PRE_GRASP - if self.grasp_stage == GraspStage.PRE_GRASP: - self.set_grasp_stage(GraspStage.GRASP) - print("Executing target pose") - return "execute" - elif key == 82: # Up arrow - increase pitch - new_pitch = min(90.0, self.grasp_pitch_degrees + 15.0) - self.set_grasp_pitch(new_pitch) - print(f"Grasp pitch: {new_pitch:.0f} degrees") - return "pitch_up" - elif key == 84: # Down arrow - decrease pitch - new_pitch = max(0.0, self.grasp_pitch_degrees - 15.0) - self.set_grasp_pitch(new_pitch) - print(f"Grasp pitch: {new_pitch:.0f} degrees") - return "pitch_down" - elif key == ord("g"): - print("Opening gripper") - self.arm.release_gripper() - return "release" - - return "" - - def check_target_stabilized(self) -> bool: - """ - Check if the commanded poses have stabilized. - - Returns: - True if poses are stable, False otherwise - """ - if len(self.reached_poses) < self.reached_poses.maxlen: - return False # Not enough poses yet - - # Extract positions - positions = np.array( - [[p.position.x, p.position.y, p.position.z] for p in self.reached_poses] - ) - - # Calculate standard deviation for each axis - std_devs = np.std(positions, axis=0) - - # Check if all axes are below threshold - return np.all(std_devs < self.pose_stabilization_threshold) diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index b2967f9bd9..8c345db6a3 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -28,14 +28,21 @@ from dimos.core import Module, In, Out, rpc from dimos_lcm.sensor_msgs import Image, CameraInfo -from dimos_lcm.geometry_msgs import Vector3, Pose +from dimos_lcm.geometry_msgs import Vector3, Pose, Point, Quaternion from dimos_lcm.vision_msgs import Detection3DArray, Detection2DArray from dimos.hardware.piper_arm import PiperArm from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor from dimos.manipulation.visual_servoing.pbvs import PBVS from dimos.perception.common.utils import find_clicked_detection -from dimos.manipulation.visual_servoing.utils import create_manipulation_visualization +from dimos.manipulation.visual_servoing.utils import ( + create_manipulation_visualization, + select_points_from_depth, + transform_points_3d, + update_target_grasp_pose, + apply_grasp_distance, + is_target_reached, +) from dimos.utils.transform_utils import ( pose_to_matrix, matrix_to_pose, @@ -54,6 +61,8 @@ class GraspStage(Enum): PRE_GRASP = "pre_grasp" # Target set, moving to pre-grasp position GRASP = "grasp" # Executing final grasp CLOSE_AND_RETRACT = "close_and_retract" # Close gripper and retract + PLACE = "place" # Move to place position and release object + RETRACT = "retract" # Retract from place position class Feedback: @@ -72,7 +81,7 @@ def __init__( current_camera_pose: Optional[Pose] = None, target_pose: Optional[Pose] = None, waiting_for_reach: bool = False, - grasp_successful: Optional[bool] = None, + success: Optional[bool] = None, ): self.grasp_stage = grasp_stage self.target_tracked = target_tracked @@ -81,7 +90,7 @@ def __init__( self.current_camera_pose = current_camera_pose self.target_pose = target_pose self.waiting_for_reach = waiting_for_reach - self.grasp_successful = grasp_successful + self.success = success class ManipulationModule(Module): @@ -176,6 +185,7 @@ def __init__( self.pick_success = None self.final_pregrasp_pose = None self.task_failed = False # New variable for tracking task failure + self.overall_success = None # Track overall pick and place success # Task control self.task_running = False @@ -190,6 +200,11 @@ def __init__( # Target selection self.target_click = None + # Place target position and object info + self.place_target_position = None + self.target_object_height = None + self.place_pose = None # Store the calculated place pose for retraction + # Move arm to observe position on init self.arm.gotoObserve() @@ -274,11 +289,15 @@ def handle_keyboard_command(self, key: str) -> str: key_code = ord(key) if len(key) == 1 else int(key) if key_code == ord("r"): + self.stop_event.set() + self.task_running = False self.reset_to_idle() return "reset" elif key_code == ord("s"): logger.info("SOFT STOP - Emergency stopping robot!") self.arm.softStop() + self.stop_event.set() + self.task_running = False return "stop" elif key_code == ord(" ") and self.pbvs and self.pbvs.target_grasp_pose: # Manual override - immediately transition to GRASP if in PRE_GRASP @@ -304,13 +323,17 @@ def handle_keyboard_command(self, key: str) -> str: return "" @rpc - def pick_and_place(self, target_x: int = None, target_y: int = None) -> Dict[str, Any]: + def pick_and_place( + self, target_x: int = None, target_y: int = None, place_x: int = None, place_y: int = None + ) -> Dict[str, Any]: """ Start a pick and place task. Args: target_x: Optional X coordinate of target object target_y: Optional Y coordinate of target object + place_x: Optional X coordinate of place location + place_y: Optional Y coordinate of place location Returns: Dict with status and message @@ -325,6 +348,45 @@ def pick_and_place(self, target_x: int = None, target_y: int = None) -> Dict[str if target_x is not None and target_y is not None: self.target_click = (target_x, target_y) + # Process place location if provided + if place_x is not None and self.latest_depth is not None: + # Select points around the place location from depth image + points_3d_camera = select_points_from_depth( + self.latest_depth, + (place_x, place_y), + self.camera_intrinsics, + radius=10, # 10 pixel radius around place point + ) + + if points_3d_camera.size > 0: + # Get current camera transform to transform points to world frame + ee_pose = self.arm.get_ee_pose() + ee_transform = pose_to_matrix(ee_pose) + camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) + + # Transform points from camera frame to world frame + points_3d_world = transform_points_3d( + points_3d_camera, + camera_transform, + to_robot=True, # Convert from optical to robot frame + ) + + # Average the 3D points to get place position + place_position = np.mean(points_3d_world, axis=0) + + # Create place target pose with same orientation as current EE + # For now, just store the position - full implementation will come later + self.place_target_position = place_position + logger.info( + f"Place target set at position: ({place_position[0]:.3f}, {place_position[1]:.3f}, {place_position[2]:.3f})" + ) + logger.info("Note: Z-offset will be applied once target object is detected") + else: + logger.warning("No valid depth points found at place location") + self.place_target_position = None + else: + self.place_target_position = None + # Reset task state self.task_failed = False self.stop_event.clear() @@ -360,8 +422,8 @@ def _run_pick_and_place(self): continue # Check if task is complete - if feedback.grasp_successful is not None: - if feedback.grasp_successful: + if feedback.success is not None: + if feedback.success: logger.info("Pick and place completed successfully!") else: logger.warning("Pick and place failed - no object detected") @@ -430,6 +492,8 @@ def reset_to_idle(self): self.waiting_start_time = None self.pick_success = None self.final_pregrasp_pose = None + self.overall_success = None + self.place_pose = None self.arm.gotoObserve() @@ -447,7 +511,9 @@ def execute_pre_grasp(self): if self._check_reach_timeout(): return - reached = self.pbvs.is_target_reached(ee_pose) + reached = is_target_reached( + self.last_commanded_pose, ee_pose, self.pbvs.target_tolerance + ) if reached: self.waiting_for_reach = False @@ -503,7 +569,10 @@ def execute_grasp(self): if self._check_reach_timeout(): return - if self.pbvs.is_target_reached(ee_pose) and not self.grasp_reached_time: + if ( + is_target_reached(self.pbvs.target_grasp_pose, ee_pose, self.pbvs.target_tolerance) + and not self.grasp_reached_time + ): self.grasp_reached_time = time.time() self.waiting_start_time = None @@ -552,10 +621,9 @@ def execute_close_and_retract(self): return # Check if reached retraction pose - original_target = self.pbvs.target_grasp_pose - self.pbvs.target_grasp_pose = self.final_pregrasp_pose - reached = self.pbvs.is_target_reached(ee_pose) - self.pbvs.target_grasp_pose = original_target + reached = is_target_reached( + self.final_pregrasp_pose, ee_pose, self.pbvs.target_tolerance + ) if reached: logger.info("Reached pre-grasp retraction position") @@ -564,10 +632,17 @@ def execute_close_and_retract(self): logger.info(f"Grasp sequence completed") if self.pick_success: logger.info("Object successfully grasped!") + # Transition to PLACE stage if place position is available + if self.place_target_position is not None: + logger.info("Transitioning to PLACE stage") + self.grasp_stage = GraspStage.PLACE + else: + # No place position, just mark as overall success + self.overall_success = True else: logger.warning("No object detected in gripper") self.task_failed = True - # Don't reset to idle here - let the task loop handle it after detecting completion + self.overall_success = False else: # Command retraction to pre-grasp logger.info("Retracting to pre-grasp position") @@ -576,6 +651,80 @@ def execute_close_and_retract(self): self.waiting_for_reach = True self.waiting_start_time = time.time() + def execute_place(self): + """Execute place stage: move to place position and release object.""" + ee_pose = self.arm.get_ee_pose() + + if self.waiting_for_reach: + if self._check_reach_timeout(): + return + + # Check if reached place pose + place_pose = self.get_place_target_pose() + if place_pose: + reached = is_target_reached(place_pose, ee_pose, self.pbvs.target_tolerance) + + if reached: + logger.info("Reached place position, releasing gripper") + self.arm.release_gripper() + time.sleep(1.0) # Give time for gripper to open + + # Store the place pose for retraction + self.place_pose = place_pose + + # Transition to RETRACT stage + logger.info("Transitioning to RETRACT stage") + self.grasp_stage = GraspStage.RETRACT + self.waiting_for_reach = False + else: + # Get place pose and command movement + place_pose = self.get_place_target_pose() + if place_pose: + logger.info("Moving to place position") + self.arm.cmd_ee_pose(place_pose, line_mode=True) + self.waiting_for_reach = True + self.waiting_start_time = time.time() + else: + logger.error("Failed to get place target pose") + self.task_failed = True + self.overall_success = False + + def execute_retract(self): + """Execute retract stage: retract from place position.""" + ee_pose = self.arm.get_ee_pose() + + if self.waiting_for_reach: + if self._check_reach_timeout(): + return + + # Check if reached retract pose + if self.place_pose: + reached = is_target_reached(self.retract_pose, ee_pose, self.pbvs.target_tolerance) + + if reached: + logger.info("Reached retract position") + # Return to observe position + logger.info("Returning to observe position") + self.arm.gotoObserve() + self.arm.close_gripper() + + # Mark overall success + self.overall_success = True + logger.info("Pick and place completed successfully!") + self.waiting_for_reach = False + else: + # Calculate and command retract pose + if self.place_pose: + self.retract_pose = apply_grasp_distance(self.place_pose, self.pregrasp_distance) + logger.info("Retracting from place position") + self.arm.cmd_ee_pose(self.retract_pose, line_mode=True) + self.waiting_for_reach = True + self.waiting_start_time = time.time() + else: + logger.error("No place pose stored for retraction") + self.task_failed = True + self.overall_success = False + def capture_and_process( self, ) -> Tuple[ @@ -610,6 +759,12 @@ def pick_target(self, x: int, y: int) -> bool: ) if clicked_3d and self.pbvs: self.pbvs.set_target(clicked_3d) + + # Store target object height (z dimension) + if clicked_3d.bbox and clicked_3d.bbox.size: + self.target_object_height = clicked_3d.bbox.size.z + logger.info(f"Target object height: {self.target_object_height:.3f}m") + logger.info( f"Target selected: ID={clicked_3d.id}, pos=({clicked_3d.bbox.center.position.x:.3f}, {clicked_3d.bbox.center.position.y:.3f}, {clicked_3d.bbox.center.position.z:.3f})" ) @@ -653,6 +808,8 @@ def update(self) -> Optional[Dict[str, Any]]: GraspStage.PRE_GRASP: self.execute_pre_grasp, GraspStage.GRASP: self.execute_grasp, GraspStage.CLOSE_AND_RETRACT: self.execute_close_and_retract, + GraspStage.PLACE: self.execute_place, + GraspStage.RETRACT: self.execute_retract, } if self.grasp_stage in stage_handlers: stage_handlers[self.grasp_stage]() @@ -670,7 +827,7 @@ def update(self) -> Optional[Dict[str, Any]]: current_camera_pose=camera_pose, target_pose=self.pbvs.target_grasp_pose if self.pbvs else None, waiting_for_reach=self.waiting_for_reach, - grasp_successful=self.pick_success, + success=self.overall_success, ) # Create visualization only if task is running @@ -725,6 +882,38 @@ def check_target_stabilized(self) -> bool: # Check if all axes are below threshold return np.all(std_devs < self.pose_stabilization_threshold) + def get_place_target_pose(self) -> Optional[Pose]: + """Get the place target pose with z-offset applied based on object height.""" + if self.place_target_position is None: + return None + + # Create a copy of the place position + place_pos = self.place_target_position.copy() + + # Apply z-offset if target object height is known + if self.target_object_height is not None: + z_offset = self.target_object_height / 2.0 + place_pos[2] += z_offset + 0.05 + logger.info(f"Applied z-offset of {z_offset:.3f}m to place position") + + # Create place pose + place_center_pose = Pose( + Point(place_pos[0], place_pos[1], place_pos[2]), Quaternion(0.0, 0.0, 0.0, 1.0) + ) + + # Get current EE pose + ee_pose = self.arm.get_ee_pose() + + # Use update_target_grasp_pose with no grasp distance and current pitch angle + place_pose = update_target_grasp_pose( + place_center_pose, + ee_pose, + grasp_distance=0.0, # No grasp distance for placing + grasp_pitch_degrees=self.grasp_pitch_degrees, # Use current grasp pitch + ) + + return place_pose + def cleanup(self): """Clean up resources on module destruction.""" self.stop() diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index da8c6c7dca..c207b0e49c 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -29,8 +29,10 @@ euler_to_quaternion, ) from dimos.manipulation.visual_servoing.utils import ( + update_target_grasp_pose, find_best_object_match, create_pbvs_visualization, + is_target_reached, ) logger = setup_logger("dimos.manipulation.pbvs") @@ -161,27 +163,6 @@ def set_grasp_pitch(self, pitch_degrees: float): # Reset target grasp pose to recompute with new pitch self.target_grasp_pose = None - def is_target_reached(self, ee_pose: Pose) -> bool: - """ - Check if the current target stage has been reached. - - Args: - ee_pose: Current end-effector pose - - Returns: - True if current stage target is reached, False otherwise - """ - if not self.target_grasp_pose: - return False - - # Calculate position error - error_x = self.target_grasp_pose.position.x - ee_pose.position.x - error_y = self.target_grasp_pose.position.y - ee_pose.position.y - error_z = self.target_grasp_pose.position.z - ee_pose.position.z - - error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) - return error_magnitude < self.target_tolerance - def update_tracking(self, new_detections: Optional[Detection3DArray] = None) -> bool: """ Update target tracking with new detections. @@ -235,73 +216,6 @@ def update_tracking(self, new_detections: Optional[Detection3DArray] = None) -> ) return False - def _update_target_grasp_pose(self, ee_pose: Pose, grasp_distance: float): - """ - Update target grasp pose based on current target and EE pose. - - Args: - ee_pose: Current end-effector pose - grasp_distance: Distance to maintain from target (pregrasp or grasp distance) - """ - if ( - not self.current_target - or not self.current_target.bbox - or not self.current_target.bbox.center - ): - return - - # Get target position - target_pos = self.current_target.bbox.center.position - - # Calculate orientation pointing from target towards EE - yaw_to_ee = yaw_towards_point(target_pos, ee_pose.position) - - # Create target pose with proper orientation - # Convert grasp pitch from degrees to radians with mapping: - # 0° (level) -> π/2 (1.57 rad), 90° (top-down) -> π (3.14 rad) - pitch_radians = 1.57 + np.radians(self.grasp_pitch_degrees) - - # Convert euler angles to quaternion using utility function - euler = Vector3(0.0, pitch_radians, yaw_to_ee) # roll=0, pitch=mapped, yaw=calculated - target_orientation = euler_to_quaternion(euler) - - target_pose = Pose(target_pos, target_orientation) - - # Apply grasp distance - self.target_grasp_pose = self._apply_grasp_distance(target_pose, grasp_distance) - - def _apply_grasp_distance(self, target_pose: Pose, distance: float) -> Pose: - """ - Apply grasp distance offset to target pose along its approach direction. - - Args: - target_pose: Target grasp pose - distance: Distance to offset along the approach direction (meters) - - Returns: - Target pose offset by the specified distance along its approach direction - """ - # Convert pose to transformation matrix to extract rotation - T_target = pose_to_matrix(target_pose) - rotation_matrix = T_target[:3, :3] - - # Define the approach vector based on the target pose orientation - # Assuming the gripper approaches along its local -z axis (common for downward grasps) - # You can change this to [1, 0, 0] for x-axis or [0, 1, 0] for y-axis based on your gripper - approach_vector_local = np.array([0, 0, -1]) - - # Transform approach vector to world coordinates - approach_vector_world = rotation_matrix @ approach_vector_local - - # Apply offset along the approach direction - offset_position = Point( - target_pose.position.x + distance * approach_vector_world[0], - target_pose.position.y + distance * approach_vector_world[1], - target_pose.position.z + distance * approach_vector_world[2], - ) - - return Pose(offset_position, target_pose.orientation) - def compute_control( self, ee_pose: Pose, @@ -323,15 +237,13 @@ def compute_control( - target_pose: Target EE pose (only in direct_ee_control mode, otherwise None) """ # Check if we have a target - if ( - not self.current_target - or not self.current_target.bbox - or not self.current_target.bbox.center - ): + if not self.current_target: return None, None, False, False, None # Update target grasp pose with provided distance - self._update_target_grasp_pose(ee_pose, grasp_distance) + self.target_grasp_pose = update_target_grasp_pose( + self.current_target.bbox.center, ee_pose, grasp_distance, self.grasp_pitch_degrees + ) if self.target_grasp_pose is None: logger.warning("Failed to compute grasp pose") @@ -346,7 +258,7 @@ def compute_control( ) # Check if target reached using our separate function - target_reached = self.is_target_reached(ee_pose) + target_reached = is_target_reached(self.target_grasp_pose, ee_pose, self.target_tolerance) # Return appropriate values based on control mode if self.direct_ee_control: diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index 6d07183104..6b00964775 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -13,13 +13,320 @@ # limitations under the License. import numpy as np -from typing import Dict, Any, Optional, List, Tuple +from typing import Dict, Any, Optional, List, Tuple, Union from dataclasses import dataclass from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point from dimos_lcm.vision_msgs import Detection3D, Detection2D import cv2 from dimos.perception.detection2d.utils import plot_results +from dimos.perception.common.utils import project_2d_points_to_3d +from dimos.utils.transform_utils import ( + optical_to_robot_frame, + robot_to_optical_frame, + pose_to_matrix, + matrix_to_pose, + euler_to_quaternion, + compose_transforms, + yaw_towards_point, +) + + +def match_detection_by_id( + detection_3d: Detection3D, detections_3d: List[Detection3D], detections_2d: List[Detection2D] +) -> Optional[Detection2D]: + """ + Find the corresponding Detection2D for a given Detection3D. + + Args: + detection_3d: The Detection3D to match + detections_3d: List of all Detection3D objects + detections_2d: List of all Detection2D objects (must be 1:1 correspondence) + + Returns: + Corresponding Detection2D if found, None otherwise + """ + for i, det_3d in enumerate(detections_3d): + if det_3d.id == detection_3d.id and i < len(detections_2d): + return detections_2d[i] + return None + + +def transform_pose( + obj_pos: np.ndarray, + obj_orientation: np.ndarray, + transform_matrix: np.ndarray, + to_optical: bool = False, + to_robot: bool = False, +) -> Pose: + """ + Transform object pose with optional frame convention conversion. + + Args: + obj_pos: Object position [x, y, z] + obj_orientation: Object orientation [roll, pitch, yaw] in radians + transform_matrix: 4x4 transformation matrix from camera frame to desired frame + to_optical: If True, input is in robot frame → convert result to optical frame + to_robot: If True, input is in optical frame → convert to robot frame first + + Returns: + Object pose in desired frame as Pose + """ + # Create object pose from input + # Convert euler angles to quaternion using utility function + euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2]) + obj_orientation_quat = euler_to_quaternion(euler_vector) + + input_pose = Pose(Point(obj_pos[0], obj_pos[1], obj_pos[2]), obj_orientation_quat) + + # Apply input frame conversion based on flags + if to_robot: + # Input is in optical frame → convert to robot frame first + pose_for_transform = optical_to_robot_frame(input_pose) + else: + # Default or to_optical: use input pose as-is + pose_for_transform = input_pose + + # Create transformation matrix from pose (relative to camera) + T_camera_object = pose_to_matrix(pose_for_transform) + + # Use compose_transforms to combine transformations + T_desired_object = compose_transforms(transform_matrix, T_camera_object) + + # Convert back to pose + result_pose = matrix_to_pose(T_desired_object) + + # Apply output frame conversion based on flags + if to_optical: + # Input was robot frame → convert result to optical frame + desired_pose = robot_to_optical_frame(result_pose) + else: + # Default or to_robot: use result as-is + desired_pose = result_pose + + return desired_pose + + +def transform_points_3d( + points_3d: np.ndarray, + transform_matrix: np.ndarray, + to_optical: bool = False, + to_robot: bool = False, +) -> np.ndarray: + """ + Transform 3D points with optional frame convention conversion. + Applies the same transformation pipeline as transform_pose but for multiple points. + + Args: + points_3d: Nx3 array of 3D points [x, y, z] + transform_matrix: 4x4 transformation matrix from camera frame to desired frame + to_optical: If True, input is in robot frame → convert result to optical frame + to_robot: If True, input is in optical frame → convert to robot frame first + + Returns: + Nx3 array of transformed 3D points in desired frame + """ + if points_3d.size == 0: + return np.zeros((0, 3), dtype=np.float32) + + # Ensure points_3d is the right shape + points_3d = np.asarray(points_3d) + if points_3d.ndim == 1: + points_3d = points_3d.reshape(1, -1) + + transformed_points = [] + + for point in points_3d: + # Create pose with identity orientation for each point + input_point_pose = Pose( + Point(point[0], point[1], point[2]), + Quaternion(0.0, 0.0, 0.0, 1.0), # Identity quaternion + ) + + # Apply input frame conversion based on flags + if to_robot: + # Input is in optical frame → convert to robot frame first + pose_for_transform = optical_to_robot_frame(input_point_pose) + else: + # Default or to_optical: use input pose as-is + pose_for_transform = input_point_pose + + # Create transformation matrix from point pose (relative to camera) + T_camera_point = pose_to_matrix(pose_for_transform) + + # Use compose_transforms to combine transformations + T_desired_point = compose_transforms(transform_matrix, T_camera_point) + + # Convert back to pose + result_pose = matrix_to_pose(T_desired_point) + + # Apply output frame conversion based on flags + if to_optical: + # Input was robot frame → convert result to optical frame + desired_pose = robot_to_optical_frame(result_pose) + else: + # Default or to_robot: use result as-is + desired_pose = result_pose + + transformed_point = [ + desired_pose.position.x, + desired_pose.position.y, + desired_pose.position.z, + ] + transformed_points.append(transformed_point) + + return np.array(transformed_points, dtype=np.float32) + + +def select_points_from_depth( + depth_image: np.ndarray, + target_point: Tuple[int, int], + camera_intrinsics: Union[List[float], np.ndarray], + radius: int = 5, +) -> np.ndarray: + """ + Select points around a target point within a bounding box and project them to 3D. + + Args: + depth_image: Depth image in meters (H, W) + target_point: (x, y) target point coordinates + radius: Half-width of the bounding box (so bbox size is radius*2 x radius*2) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + + Returns: + Nx3 array of 3D points (X, Y, Z) in camera frame + """ + x_target, y_target = target_point + height, width = depth_image.shape + + # Define bounding box around target point + x_min = max(0, x_target - radius) + x_max = min(width, x_target + radius) + y_min = max(0, y_target - radius) + y_max = min(height, y_target + radius) + + # Create coordinate grids for the bounding box (vectorized) + y_coords, x_coords = np.meshgrid(range(y_min, y_max), range(x_min, x_max), indexing="ij") + + # Flatten to get all coordinate pairs + x_flat = x_coords.flatten() + y_flat = y_coords.flatten() + + # Extract corresponding depth values using advanced indexing + depth_flat = depth_image[y_flat, x_flat] + + # Create mask for valid depth values + valid_mask = (depth_flat > 0) & np.isfinite(depth_flat) + + # Early exit if no valid points + if not np.any(valid_mask): + return np.zeros((0, 3), dtype=np.float32) + + # Filter to get valid points and depths + points_2d = np.column_stack([x_flat[valid_mask], y_flat[valid_mask]]).astype(np.float32) + depth_values = depth_flat[valid_mask].astype(np.float32) + + # Use the common utility function for 3D projection + points_3d = project_2d_points_to_3d(points_2d, depth_values, camera_intrinsics) + + return points_3d + + +def update_target_grasp_pose( + target_pose: Pose, ee_pose: Pose, grasp_distance: float = 0.0, grasp_pitch_degrees: float = 45.0 +) -> Optional[Pose]: + """ + Update target grasp pose based on current target pose and EE pose. + + Args: + target_pose: Target pose to grasp + ee_pose: Current end-effector pose + grasp_distance: Distance to maintain from target (pregrasp or grasp distance) + grasp_pitch_degrees: Grasp pitch angle in degrees (default 90° for top-down) + + Returns: + Target grasp pose or None if target is invalid + """ + + # Get target position + target_pos = target_pose.position + + # Calculate orientation pointing from target towards EE + yaw_to_ee = yaw_towards_point(target_pos, ee_pose.position) + + # Create target pose with proper orientation + # Convert grasp pitch from degrees to radians with mapping: + # 0° (level) -> π/2 (1.57 rad), 90° (top-down) -> π (3.14 rad) + pitch_radians = 1.57 + np.radians(grasp_pitch_degrees) + + # Convert euler angles to quaternion using utility function + euler = Vector3(0.0, pitch_radians, yaw_to_ee) # roll=0, pitch=mapped, yaw=calculated + target_orientation = euler_to_quaternion(euler) + + updated_pose = Pose(target_pos, target_orientation) + + if grasp_distance > 0.0: + # Apply grasp distance + return apply_grasp_distance(updated_pose, grasp_distance) + else: + return updated_pose + + +def apply_grasp_distance(target_pose: Pose, distance: float) -> Pose: + """ + Apply grasp distance offset to target pose along its approach direction. + + Args: + target_pose: Target grasp pose + distance: Distance to offset along the approach direction (meters) + + Returns: + Target pose offset by the specified distance along its approach direction + """ + # Convert pose to transformation matrix to extract rotation + T_target = pose_to_matrix(target_pose) + rotation_matrix = T_target[:3, :3] + + # Define the approach vector based on the target pose orientation + # Assuming the gripper approaches along its local -z axis (common for downward grasps) + # You can change this to [1, 0, 0] for x-axis or [0, 1, 0] for y-axis based on your gripper + approach_vector_local = np.array([0, 0, -1]) + + # Transform approach vector to world coordinates + approach_vector_world = rotation_matrix @ approach_vector_local + + # Apply offset along the approach direction + offset_position = Point( + target_pose.position.x + distance * approach_vector_world[0], + target_pose.position.y + distance * approach_vector_world[1], + target_pose.position.z + distance * approach_vector_world[2], + ) + + return Pose(offset_position, target_pose.orientation) + + +def is_target_reached(target_pose: Pose, current_pose: Pose, tolerance: float = 0.01) -> bool: + """ + Check if the target pose has been reached within tolerance. + + Args: + target_pose: Target pose to reach + current_pose: Current pose (e.g., end-effector pose) + tolerance: Distance threshold for considering target reached (meters, default 0.01 = 1cm) + + Returns: + True if target is reached within tolerance, False otherwise + """ + if not target_pose: + return False + + # Calculate position error + error_x = target_pose.position.x - current_pose.position.x + error_y = target_pose.position.y - current_pose.position.y + error_z = target_pose.position.z - current_pose.position.z + + error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) + return error_magnitude < tolerance @dataclass @@ -291,6 +598,8 @@ def create_manipulation_visualization( "pre_grasp": (0, 255, 255), "grasp": (0, 255, 0), "close_and_retract": (255, 0, 255), + "place": (0, 150, 255), + "retract": (255, 150, 0), }.get(feedback.grasp_stage.value, (255, 255, 255)) cv2.putText( @@ -337,10 +646,10 @@ def create_manipulation_visualization( 1, ) - # Grasp result - if feedback.grasp_successful is not None: - result_text = "Grasp: SUCCESS" if feedback.grasp_successful else "Grasp: FAILED" - result_color = (0, 255, 0) if feedback.grasp_successful else (0, 0, 255) + # Overall result + if feedback.success is not None: + result_text = "Pick & Place: SUCCESS" if feedback.success else "Pick & Place: FAILED" + result_color = (0, 255, 0) if feedback.success else (0, 0, 255) cv2.putText( viz, result_text, @@ -533,23 +842,3 @@ def visualize_detections_3d( ) return viz - - -def match_detection_by_id( - detection_3d: Detection3D, detections_3d: List[Detection3D], detections_2d: List[Detection2D] -) -> Optional[Detection2D]: - """ - Find the corresponding Detection2D for a given Detection3D. - - Args: - detection_3d: The Detection3D to match - detections_3d: List of all Detection3D objects - detections_2d: List of all Detection2D objects (must be 1:1 correspondence) - - Returns: - Corresponding Detection2D if found, None otherwise - """ - for i, det_3d in enumerate(detections_3d): - if det_3d.id == detection_3d.id and i < len(detections_2d): - return detections_2d[i] - return None diff --git a/dimos/perception/common/utils.py b/dimos/perception/common/utils.py index ce2a358661..10d05d9b4d 100644 --- a/dimos/perception/common/utils.py +++ b/dimos/perception/common/utils.py @@ -14,7 +14,7 @@ import cv2 import numpy as np -from typing import List, Tuple, Optional, Any +from typing import List, Tuple, Optional, Any, Union from dimos.types.manipulation import ObjectData from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger @@ -24,6 +24,103 @@ logger = setup_logger("dimos.perception.common.utils") +def project_3d_points_to_2d( + points_3d: np.ndarray, camera_intrinsics: Union[List[float], np.ndarray] +) -> np.ndarray: + """ + Project 3D points to 2D image coordinates using camera intrinsics. + + Args: + points_3d: Nx3 array of 3D points (X, Y, Z) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + + Returns: + Nx2 array of 2D image coordinates (u, v) + """ + if len(points_3d) == 0: + return np.zeros((0, 2), dtype=np.int32) + + # Filter out points with zero or negative depth + valid_mask = points_3d[:, 2] > 0 + if not np.any(valid_mask): + return np.zeros((0, 2), dtype=np.int32) + + valid_points = points_3d[valid_mask] + + # Extract camera parameters + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + else: + camera_matrix = np.array(camera_intrinsics) + fx = camera_matrix[0, 0] + fy = camera_matrix[1, 1] + cx = camera_matrix[0, 2] + cy = camera_matrix[1, 2] + + # Project to image coordinates + u = (valid_points[:, 0] * fx / valid_points[:, 2]) + cx + v = (valid_points[:, 1] * fy / valid_points[:, 2]) + cy + + # Round to integer pixel coordinates + points_2d = np.column_stack([u, v]).astype(np.int32) + + return points_2d + + +def project_2d_points_to_3d( + points_2d: np.ndarray, + depth_values: np.ndarray, + camera_intrinsics: Union[List[float], np.ndarray], +) -> np.ndarray: + """ + Project 2D image points to 3D coordinates using depth values and camera intrinsics. + + Args: + points_2d: Nx2 array of 2D image coordinates (u, v) + depth_values: N-length array of depth values (Z coordinates) for each point + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + + Returns: + Nx3 array of 3D points (X, Y, Z) + """ + if len(points_2d) == 0: + return np.zeros((0, 3), dtype=np.float32) + + # Ensure depth_values is a numpy array + depth_values = np.asarray(depth_values) + + # Filter out points with zero or negative depth + valid_mask = depth_values > 0 + if not np.any(valid_mask): + return np.zeros((0, 3), dtype=np.float32) + + valid_points_2d = points_2d[valid_mask] + valid_depths = depth_values[valid_mask] + + # Extract camera parameters + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + else: + camera_matrix = np.array(camera_intrinsics) + fx = camera_matrix[0, 0] + fy = camera_matrix[1, 1] + cx = camera_matrix[0, 2] + cy = camera_matrix[1, 2] + + # Back-project to 3D coordinates + # X = (u - cx) * Z / fx + # Y = (v - cy) * Z / fy + # Z = depth + X = (valid_points_2d[:, 0] - cx) * valid_depths / fx + Y = (valid_points_2d[:, 1] - cy) * valid_depths / fy + Z = valid_depths + + # Stack into 3D points + points_3d = np.column_stack([X, Y, Z]).astype(np.float32) + + return points_3d + + def colorize_depth(depth_img: np.ndarray, max_depth: float = 5.0) -> Optional[np.ndarray]: """ Normalize and colorize depth image using COLORMAP_JET. diff --git a/dimos/perception/grasp_generation/utils.py b/dimos/perception/grasp_generation/utils.py index 94377363f2..ab0cfd0d15 100644 --- a/dimos/perception/grasp_generation/utils.py +++ b/dimos/perception/grasp_generation/utils.py @@ -18,49 +18,7 @@ import open3d as o3d import cv2 from typing import List, Dict, Tuple, Optional, Union - - -def project_3d_points_to_2d( - points_3d: np.ndarray, camera_intrinsics: Union[List[float], np.ndarray] -) -> np.ndarray: - """ - Project 3D points to 2D image coordinates using camera intrinsics. - - Args: - points_3d: Nx3 array of 3D points (X, Y, Z) - camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix - - Returns: - Nx2 array of 2D image coordinates (u, v) - """ - if len(points_3d) == 0: - return np.zeros((0, 2), dtype=np.int32) - - # Filter out points with zero or negative depth - valid_mask = points_3d[:, 2] > 0 - if not np.any(valid_mask): - return np.zeros((0, 2), dtype=np.int32) - - valid_points = points_3d[valid_mask] - - # Extract camera parameters - if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: - fx, fy, cx, cy = camera_intrinsics - else: - camera_matrix = np.array(camera_intrinsics) - fx = camera_matrix[0, 0] - fy = camera_matrix[1, 1] - cx = camera_matrix[0, 2] - cy = camera_matrix[1, 2] - - # Project to image coordinates - u = (valid_points[:, 0] * fx / valid_points[:, 2]) + cx - v = (valid_points[:, 1] * fy / valid_points[:, 2]) + cy - - # Round to integer pixel coordinates - points_2d = np.column_stack([u, v]).astype(np.int32) - - return points_2d +from dimos.perception.common.utils import project_3d_points_to_2d, project_2d_points_to_3d def create_gripper_geometry( diff --git a/dimos/perception/pointcloud/utils.py b/dimos/perception/pointcloud/utils.py index be65635393..b3c395bfa3 100644 --- a/dimos/perception/pointcloud/utils.py +++ b/dimos/perception/pointcloud/utils.py @@ -26,6 +26,7 @@ import open3d as o3d from typing import List, Optional, Tuple, Union, Dict, Any from scipy.spatial import cKDTree +from dimos.perception.common.utils import project_3d_points_to_2d def load_camera_matrix_from_yaml( @@ -304,48 +305,6 @@ def filter_point_cloud_radius( return pcd.remove_radius_outlier(nb_points=nb_points, radius=radius) -def project_3d_points_to_2d( - points_3d: np.ndarray, camera_intrinsics: Union[List[float], np.ndarray] -) -> np.ndarray: - """ - Project 3D points to 2D image coordinates using camera intrinsics. - - Args: - points_3d: Nx3 array of 3D points (X, Y, Z) - camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix - - Returns: - Nx2 array of 2D image coordinates (u, v) - """ - if len(points_3d) == 0: - return np.zeros((0, 2), dtype=np.int32) - - # Filter out points with zero or negative depth - valid_mask = points_3d[:, 2] > 0 - if not np.any(valid_mask): - return np.zeros((0, 2), dtype=np.int32) - - valid_points = points_3d[valid_mask] - - # Extract camera parameters - if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: - fx, fy, cx, cy = camera_intrinsics - else: - fx = camera_intrinsics[0, 0] - fy = camera_intrinsics[1, 1] - cx = camera_intrinsics[0, 2] - cy = camera_intrinsics[1, 2] - - # Project to image coordinates - u = (valid_points[:, 0] * fx / valid_points[:, 2]) + cx - v = (valid_points[:, 1] * fy / valid_points[:, 2]) + cy - - # Round to integer pixel coordinates - points_2d = np.column_stack([u, v]).astype(np.int32) - - return points_2d - - def overlay_point_clouds_on_image( base_image: np.ndarray, point_clouds: List[o3d.geometry.PointCloud], diff --git a/tests/test_ibvs.py b/tests/test_ibvs.py deleted file mode 100644 index 0192b1aa56..0000000000 --- a/tests/test_ibvs.py +++ /dev/null @@ -1,137 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2025 Dimensional Inc. - -""" -Test script for PBVS with eye-in-hand configuration. -Uses EE pose as odometry for camera pose estimation. -Click on objects to select targets. -""" - -import cv2 -import sys - -try: - import pyzed.sl as sl -except ImportError: - print("Error: ZED SDK not installed.") - sys.exit(1) - -from dimos.hardware.zed_camera import ZEDCamera -from dimos.hardware.piper_arm import PiperArm -from dimos.manipulation.visual_servoing.manipulation import Manipulation - - -# Global for mouse events -mouse_click = None - - -def mouse_callback(event, x, y, _flags, _param): - global mouse_click - if event == cv2.EVENT_LBUTTONDOWN: - mouse_click = (x, y) - - -def main(): - global mouse_click - - # Configuration - DIRECT_EE_CONTROL = True # True: direct EE pose control, False: velocity control - INITIAL_GRASP_PITCH_DEGREES = 30 # 0° = level grasp, 90° = top-down grasp - - print("=== PBVS Eye-in-Hand Test ===") - print("Using EE pose as odometry for camera pose") - print(f"Control mode: {'Direct EE Pose' if DIRECT_EE_CONTROL else 'Velocity Commands'}") - print("Click objects to select targets | 'r' - reset | 'q' - quit") - if DIRECT_EE_CONTROL: - print("SAFETY CONTROLS:") - print(" 's' - SOFT STOP (emergency stop)") - print(" 'h' - GO HOME (return to safe position)") - print(" 'SPACE' - EXECUTE target pose (only moves when pressed)") - print(" 'g' - RELEASE GRIPPER (open gripper to 100mm)") - print("GRASP PITCH CONTROLS:") - print(" '↑' - Increase grasp pitch by 15° (towards top-down)") - print(" '↓' - Decrease grasp pitch by 15° (towards level)") - - # Initialize hardware - zed = ZEDCamera(resolution=sl.RESOLUTION.HD720, depth_mode=sl.DEPTH_MODE.NEURAL) - if not zed.open(): - print("Camera initialization failed!") - return - - # Initialize robot arm - try: - arm = PiperArm() - print("Initialized Piper arm") - except Exception as e: - print(f"Failed to initialize Piper arm: {e}") - zed.close() - return - - # Initialize manipulation system - try: - manipulation = Manipulation( - camera=zed, - arm=arm, - ) - except Exception as e: - print(f"Failed to initialize manipulation system: {e}") - zed.close() - arm.disable() - return - - # Set initial grasp pitch - manipulation.set_grasp_pitch(INITIAL_GRASP_PITCH_DEGREES) - - # Setup window - cv2.namedWindow("PBVS") - cv2.setMouseCallback("PBVS", mouse_callback) - - try: - while True: - # Update manipulation system - feedback = manipulation.update() - if feedback is None: - continue - - # Handle mouse click - if mouse_click: - x, y = mouse_click - manipulation.pick_target(x, y) - mouse_click = None - - # Get and display visualization - viz = manipulation.get_visualization() - if viz is not None: - cv2.imshow("PBVS", viz) - - # Handle keyboard input - key = cv2.waitKey(1) & 0xFF - if key == ord("q"): - break - else: - manipulation.handle_keyboard_command(key) - - except KeyboardInterrupt: - pass - finally: - cv2.destroyAllWindows() - zed.close() - arm.disable() - - -if __name__ == "__main__": - main() diff --git a/tests/test_manipulation_perception_pipeline.py b/tests/test_manipulation_perception_pipeline.py index 8b333ec310..227f991650 100644 --- a/tests/test_manipulation_perception_pipeline.py +++ b/tests/test_manipulation_perception_pipeline.py @@ -36,7 +36,7 @@ from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.logging_config import logger -from dimos.perception.manip_aio_pipeline import ManipulationPipeline +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline def monitor_grasps(pipeline): diff --git a/tests/test_pick_and_place_module.py b/tests/test_pick_and_place_module.py index dd7ce174e2..27924481af 100644 --- a/tests/test_pick_and_place_module.py +++ b/tests/test_pick_and_place_module.py @@ -49,6 +49,9 @@ mouse_click = None camera_mouse_click = None current_window = None +pick_location = None # Store pick location +place_location = None # Store place location +place_mode = False # Track if we're in place selection mode def mouse_callback(event, x, y, _flags, param): @@ -120,7 +123,7 @@ def _on_camera_image(self, msg: LCMImage, topic: str): def run_visualization(self): """Run the visualization loop with user interaction.""" - global mouse_click, camera_mouse_click + global mouse_click, camera_mouse_click, pick_location, place_location, place_mode # Setup windows cv2.namedWindow("Pick and Place") @@ -131,21 +134,89 @@ def run_visualization(self): print("=== Pick and Place Module Test ===") print("Control mode: Module-based with LCM communication") - print("Click objects to select targets | 'r' - reset | 'q' - quit") - print("SAFETY CONTROLS:") + print("\nPICK AND PLACE WORKFLOW:") + print("1. Click on an object to select PICK location") + print("2. Click again to select PLACE location (auto pick & place)") + print("3. OR press 'p' after first click for pick-only task") + print("\nCONTROLS:") + print(" 'p' - Execute pick-only task (after selecting pick location)") + print(" 'r' - Reset everything") + print(" 'q' - Quit") print(" 's' - SOFT STOP (emergency stop)") print(" 'g' - RELEASE GRIPPER (open gripper)") print(" 'SPACE' - EXECUTE target pose (manual override)") - print("GRASP PITCH CONTROLS:") + print("\nGRASP PITCH CONTROLS:") print(" '↑' - Increase grasp pitch by 15° (towards top-down)") print(" '↓' - Decrease grasp pitch by 15° (towards level)") - print(" 'p' - Start pick and place task") - print("\nNOTE: Click on objects in the Camera Feed window to select targets!") + print("\nNOTE: Click on objects in the Camera Feed window!") while self._running: - # Show camera feed (always available) + # Show camera feed with status overlay if self.latest_camera is not None: - cv2.imshow("Camera Feed", self.latest_camera) + display_image = self.latest_camera.copy() + + # Add status text + status_text = "" + if pick_location is None: + status_text = "Click to select PICK location" + color = (0, 255, 0) + elif place_location is None: + status_text = "Click to select PLACE location (or press 'p' for pick-only)" + color = (0, 255, 255) + else: + status_text = "Executing pick and place..." + color = (255, 0, 255) + + cv2.putText( + display_image, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2 + ) + + # Draw pick location marker if set + if pick_location is not None: + # Simple circle marker + cv2.circle(display_image, pick_location, 10, (0, 255, 0), 2) + cv2.circle(display_image, pick_location, 2, (0, 255, 0), -1) + + # Simple label + cv2.putText( + display_image, + "PICK", + (pick_location[0] + 15, pick_location[1] + 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + # Draw place location marker if set + if place_location is not None: + # Simple circle marker + cv2.circle(display_image, place_location, 10, (0, 255, 255), 2) + cv2.circle(display_image, place_location, 2, (0, 255, 255), -1) + + # Simple label + cv2.putText( + display_image, + "PLACE", + (place_location[0] + 15, place_location[1] + 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, + ) + + # Draw simple arrow between pick and place + if pick_location is not None: + cv2.arrowedLine( + display_image, + pick_location, + place_location, + (255, 255, 0), + 2, + tipLength=0.05, + ) + + cv2.imshow("Camera Feed", display_image) # Show visualization if available if self.latest_viz is not None: @@ -158,16 +229,32 @@ def run_visualization(self): logger.info("Quit requested") self._running = False break + elif key == ord("r"): + # Reset everything + pick_location = None + place_location = None + place_mode = False + logger.info("Reset pick and place selections") + # Also send reset to manipulation module + action = self.manipulation.handle_keyboard_command("r") + if action: + logger.info(f"Action: {action}") elif key == ord("p"): - # Start pick and place task - if mouse_click: - x, y = mouse_click - result = self.manipulation.pick_and_place(x, y) - logger.info(f"Pick and place task: {result}") - mouse_click = None + # Execute pick-only task if pick location is set + if pick_location is not None: + logger.info(f"Executing pick-only task at {pick_location}") + result = self.manipulation.pick_and_place( + pick_location[0], + pick_location[1], + None, # No place location + None, + ) + logger.info(f"Pick task started: {result}") + # Clear selection after sending + pick_location = None + place_location = None else: - result = self.manipulation.pick_and_place() - logger.info(f"Pick and place task (no target): {result}") + logger.warning("Please select a pick location first!") else: # Send keyboard command to manipulation module if key in [82, 84]: # Arrow keys @@ -177,20 +264,57 @@ def run_visualization(self): if action: logger.info(f"Action: {action}") - # Handle mouse click from Camera Feed window + # Handle mouse clicks if camera_mouse_click: - # Start pick and place task with the clicked point x, y = camera_mouse_click - result = self.manipulation.pick_and_place(x, y) - logger.info(f"Started pick and place at ({x}, {y}) from camera feed: {result}") + + if pick_location is None: + # First click - set pick location + pick_location = (x, y) + logger.info(f"Pick location set at ({x}, {y})") + elif place_location is None: + # Second click - set place location and execute + place_location = (x, y) + logger.info(f"Place location set at ({x}, {y})") + logger.info(f"Executing pick at {pick_location} and place at ({x}, {y})") + + # Start pick and place task with both locations + result = self.manipulation.pick_and_place( + pick_location[0], pick_location[1], x, y + ) + logger.info(f"Pick and place task started: {result}") + + # Clear all points after sending mission + pick_location = None + place_location = None + camera_mouse_click = None # Handle mouse click from Pick and Place window (if viz is running) elif mouse_click and self.latest_viz is not None: - # If there's a pending click and we're not running a task, start one + # Similar logic for viz window clicks x, y = mouse_click - result = self.manipulation.pick_and_place(x, y) - logger.info(f"Started pick and place at ({x}, {y}): {result}") + + if pick_location is None: + # First click - set pick location + pick_location = (x, y) + logger.info(f"Pick location set at ({x}, {y}) from viz window") + elif place_location is None: + # Second click - set place location and execute + place_location = (x, y) + logger.info(f"Place location set at ({x}, {y}) from viz window") + logger.info(f"Executing pick at {pick_location} and place at ({x}, {y})") + + # Start pick and place task with both locations + result = self.manipulation.pick_and_place( + pick_location[0], pick_location[1], x, y + ) + logger.info(f"Pick and place task started: {result}") + + # Clear all points after sending mission + pick_location = None + place_location = None + mouse_click = None time.sleep(0.03) # ~30 FPS From 6da0956f0205bcde641a72deda12a1f37dc30c36 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 23 Jul 2025 20:01:26 -0700 Subject: [PATCH 79/89] major refactor, added dynamic grasp angle --- .../visual_servoing/manipulation_module.py | 514 ++++++++++-------- dimos/manipulation/visual_servoing/pbvs.py | 67 ++- dimos/manipulation/visual_servoing/utils.py | 12 +- dimos/utils/transform_utils.py | 18 + tests/test_pick_and_place_module.py | 3 - 5 files changed, 330 insertions(+), 284 deletions(-) diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index 8c345db6a3..b8f3dbb9f0 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -57,26 +57,22 @@ class GraspStage(Enum): """Enum for different grasp stages.""" - IDLE = "idle" # No target set - PRE_GRASP = "pre_grasp" # Target set, moving to pre-grasp position - GRASP = "grasp" # Executing final grasp - CLOSE_AND_RETRACT = "close_and_retract" # Close gripper and retract - PLACE = "place" # Move to place position and release object - RETRACT = "retract" # Retract from place position + IDLE = "idle" + PRE_GRASP = "pre_grasp" + GRASP = "grasp" + CLOSE_AND_RETRACT = "close_and_retract" + PLACE = "place" + RETRACT = "retract" class Feedback: - """ - Feedback data returned by the manipulation system update. - - Contains comprehensive state information about the manipulation process. - """ + """Feedback data containing state information about the manipulation process.""" def __init__( self, grasp_stage: GraspStage, target_tracked: bool, - last_commanded_pose: Optional[Pose] = None, + current_executed_pose: Optional[Pose] = None, current_ee_pose: Optional[Pose] = None, current_camera_pose: Optional[Pose] = None, target_pose: Optional[Pose] = None, @@ -85,7 +81,7 @@ def __init__( ): self.grasp_stage = grasp_stage self.target_tracked = target_tracked - self.last_commanded_pose = last_commanded_pose + self.current_executed_pose = current_executed_pose self.current_ee_pose = current_ee_pose self.current_camera_pose = current_camera_pose self.target_pose = target_pose @@ -128,22 +124,21 @@ def __init__( Args: ee_to_camera_6dof: EE to camera transform [x, y, z, rx, ry, rz] in meters and radians + workspace_min_radius: Minimum workspace radius in meters + workspace_max_radius: Maximum workspace radius in meters + min_grasp_pitch_degrees: Minimum grasp pitch angle (at max radius) + max_grasp_pitch_degrees: Maximum grasp pitch angle (at min radius) """ super().__init__(**kwargs) - # Initialize arm directly self.arm = PiperArm() - # Default EE to camera transform if not provided if ee_to_camera_6dof is None: ee_to_camera_6dof = [-0.065, 0.03, -0.105, 0.0, -1.57, 0.0] - - # Create transform matrices pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) self.T_ee_to_camera = create_transform_from_6dof(pos, rot) - # Camera intrinsics will be set when camera info is received self.camera_intrinsics = None self.detector = None self.pbvs = None @@ -151,20 +146,25 @@ def __init__( # Control state self.last_valid_target = None self.waiting_for_reach = False - self.last_commanded_pose = None + self.current_executed_pose = None # Track the actual pose sent to arm self.target_updated = False self.waiting_start_time = None self.reach_pose_timeout = 10.0 # Grasp parameters self.grasp_width_offset = 0.03 - self.grasp_pitch_degrees = 30.0 self.pregrasp_distance = 0.25 self.grasp_distance_range = 0.03 self.grasp_close_delay = 2.0 self.grasp_reached_time = None self.gripper_max_opening = 0.07 + # Workspace limits and dynamic pitch parameters + self.workspace_min_radius = 0.2 + self.workspace_max_radius = 0.75 + self.min_grasp_pitch_degrees = 5.0 + self.max_grasp_pitch_degrees = 75.0 + # Grasp stage tracking self.grasp_stage = GraspStage.IDLE @@ -176,6 +176,13 @@ def __init__( self.reached_poses = deque(maxlen=self.pose_history_size) self.adjustment_count = 0 + # Pose reachability tracking + self.ee_pose_history = deque(maxlen=20) # Keep history of EE poses + self.stuck_pose_threshold = 0.001 # 1mm movement threshold + self.stuck_pose_adjustment_degrees = 5.0 + self.stuck_count = 0 + self.max_stuck_reattempts = 7 + # State for visualization self.current_visualization = None self.last_detection_3d_array = None @@ -184,8 +191,8 @@ def __init__( # Grasp result and task tracking self.pick_success = None self.final_pregrasp_pose = None - self.task_failed = False # New variable for tracking task failure - self.overall_success = None # Track overall pick and place success + self.task_failed = False + self.overall_success = None # Task control self.task_running = False @@ -201,11 +208,12 @@ def __init__( self.target_click = None # Place target position and object info + self.home_pose = Pose(Point(0.0, 0.0, 0.0), Quaternion(0.0, 0.0, 0.0, 1.0)) self.place_target_position = None self.target_object_height = None - self.place_pose = None # Store the calculated place pose for retraction - - # Move arm to observe position on init + self.retract_distance = 0.12 + self.place_pose = None + self.retract_pose = None self.arm.gotoObserve() @rpc @@ -233,7 +241,6 @@ def stop(self): def _on_rgb_image(self, msg: Image): """Handle RGB image messages.""" try: - # Convert LCM message to numpy array data = np.frombuffer(msg.data, dtype=np.uint8) if msg.encoding == "rgb8": self.latest_rgb = data.reshape((msg.height, msg.width, 3)) @@ -245,7 +252,6 @@ def _on_rgb_image(self, msg: Image): def _on_depth_image(self, msg: Image): """Handle depth image messages.""" try: - # Convert LCM message to numpy array if msg.encoding == "32FC1": data = np.frombuffer(msg.data, dtype=np.float32) self.latest_depth = data.reshape((msg.height, msg.width)) @@ -257,15 +263,8 @@ def _on_depth_image(self, msg: Image): def _on_camera_info(self, msg: CameraInfo): """Handle camera info messages.""" try: - # Extract camera intrinsics - self.camera_intrinsics = [ - msg.K[0], # fx - msg.K[4], # fy - msg.K[2], # cx - msg.K[5], # cy - ] - - # Initialize processors if not already done + self.camera_intrinsics = [msg.K[0], msg.K[4], msg.K[2], msg.K[5]] + if self.detector is None: self.detector = Detection3DProcessor(self.camera_intrinsics) self.pbvs = PBVS(target_tolerance=0.05) @@ -300,21 +299,10 @@ def handle_keyboard_command(self, key: str) -> str: self.task_running = False return "stop" elif key_code == ord(" ") and self.pbvs and self.pbvs.target_grasp_pose: - # Manual override - immediately transition to GRASP if in PRE_GRASP if self.grasp_stage == GraspStage.PRE_GRASP: self.set_grasp_stage(GraspStage.GRASP) logger.info("Executing target pose") return "execute" - elif key_code == 82: # Up arrow - increase pitch - new_pitch = min(90.0, self.grasp_pitch_degrees + 15.0) - self.set_grasp_pitch(new_pitch) - logger.info(f"Grasp pitch: {new_pitch:.0f} degrees") - return "pitch_up" - elif key_code == 84: # Down arrow - decrease pitch - new_pitch = max(0.0, self.grasp_pitch_degrees - 15.0) - self.set_grasp_pitch(new_pitch) - logger.info(f"Grasp pitch: {new_pitch:.0f} degrees") - return "pitch_down" elif key_code == ord("g"): logger.info("Opening gripper") self.arm.release_gripper() @@ -344,59 +332,44 @@ def pick_and_place( if self.camera_intrinsics is None: return {"status": "error", "message": "Camera not initialized"} - # Set target if coordinates provided if target_x is not None and target_y is not None: self.target_click = (target_x, target_y) - - # Process place location if provided if place_x is not None and self.latest_depth is not None: - # Select points around the place location from depth image points_3d_camera = select_points_from_depth( self.latest_depth, (place_x, place_y), self.camera_intrinsics, - radius=10, # 10 pixel radius around place point + radius=10, ) if points_3d_camera.size > 0: - # Get current camera transform to transform points to world frame ee_pose = self.arm.get_ee_pose() ee_transform = pose_to_matrix(ee_pose) camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) - # Transform points from camera frame to world frame points_3d_world = transform_points_3d( points_3d_camera, camera_transform, - to_robot=True, # Convert from optical to robot frame + to_robot=True, ) - # Average the 3D points to get place position place_position = np.mean(points_3d_world, axis=0) - - # Create place target pose with same orientation as current EE - # For now, just store the position - full implementation will come later self.place_target_position = place_position logger.info( f"Place target set at position: ({place_position[0]:.3f}, {place_position[1]:.3f}, {place_position[2]:.3f})" ) - logger.info("Note: Z-offset will be applied once target object is detected") else: logger.warning("No valid depth points found at place location") self.place_target_position = None else: self.place_target_position = None - # Reset task state self.task_failed = False self.stop_event.clear() - # Ensure any previous thread has finished if self.task_thread and self.task_thread.is_alive(): self.stop_event.set() self.task_thread.join(timeout=1.0) - - # Start task in separate thread self.task_thread = threading.Thread(target=self._run_pick_and_place, daemon=True) self.task_thread.start() @@ -409,30 +382,25 @@ def _run_pick_and_place(self): try: while not self.stop_event.is_set(): - # Check for task failure if self.task_failed: logger.error("Task failed, terminating pick and place") self.stop_event.set() break - # Update manipulation system feedback = self.update() if feedback is None: time.sleep(0.01) continue - # Check if task is complete if feedback.success is not None: if feedback.success: logger.info("Pick and place completed successfully!") else: - logger.warning("Pick and place failed - no object detected") - # Reset to idle state and stop the event loop + logger.warning("Pick and place failed") self.reset_to_idle() self.stop_event.set() break - # Small delay to prevent CPU overload time.sleep(0.01) except Exception as e: @@ -447,22 +415,166 @@ def set_grasp_stage(self, stage: GraspStage): self.grasp_stage = stage logger.info(f"Grasp stage: {stage.value}") - def set_grasp_pitch(self, pitch_degrees: float): - """Set the grasp pitch angle.""" - pitch_degrees = max(0.0, min(90.0, pitch_degrees)) - self.grasp_pitch_degrees = pitch_degrees - if self.pbvs: - self.pbvs.set_grasp_pitch(pitch_degrees) + def calculate_dynamic_grasp_pitch(self, target_pose: Pose) -> float: + """ + Calculate grasp pitch dynamically based on distance from robot base. + Maps workspace radius to grasp pitch angle. - def _check_reach_timeout(self) -> bool: - """Check if robot has exceeded timeout while reaching pose.""" - if ( - self.waiting_start_time - and (time.time() - self.waiting_start_time) > self.reach_pose_timeout - ): - logger.warning(f"Robot failed to reach pose within {self.reach_pose_timeout}s timeout") + Args: + target_pose: Target pose + + Returns: + Grasp pitch angle in degrees + """ + # Calculate 3D distance from robot base (assumes robot at origin) + position = target_pose.position + distance = np.sqrt(position.x**2 + position.y**2 + position.z**2) + + # Clamp distance to workspace limits + distance = np.clip(distance, self.workspace_min_radius, self.workspace_max_radius) + + # Linear interpolation: min_radius -> max_pitch, max_radius -> min_pitch + # Normalized distance (0 to 1) + normalized_dist = (distance - self.workspace_min_radius) / ( + self.workspace_max_radius - self.workspace_min_radius + ) + + # Inverse mapping: closer objects need higher pitch + pitch_degrees = self.max_grasp_pitch_degrees - ( + normalized_dist * (self.max_grasp_pitch_degrees - self.min_grasp_pitch_degrees) + ) + + return pitch_degrees + + def check_within_workspace(self, target_pose: Pose) -> bool: + """ + Check if pose is within workspace limits and log error if not. + + Args: + target_pose: Target pose to validate + + Returns: + True if within workspace, False otherwise + """ + # Calculate 3D distance from robot base + position = target_pose.position + distance = np.sqrt(position.x**2 + position.y**2 + position.z**2) + + if not (self.workspace_min_radius <= distance <= self.workspace_max_radius): + logger.error( + f"Target outside workspace limits: distance {distance:.3f}m not in [{self.workspace_min_radius:.2f}, {self.workspace_max_radius:.2f}]" + ) + return False + + return True + + def _check_reach_timeout(self) -> Tuple[bool, float]: + """Check if robot has exceeded timeout while reaching pose. + + Returns: + Tuple of (timed_out, time_elapsed) + """ + if self.waiting_start_time: + time_elapsed = time.time() - self.waiting_start_time + if time_elapsed > self.reach_pose_timeout: + logger.warning( + f"Robot failed to reach pose within {self.reach_pose_timeout}s timeout" + ) + self.task_failed = True + self.reset_to_idle() + return True, time_elapsed + return False, time_elapsed + return False, 0.0 + + def _check_if_stuck(self) -> bool: + """ + Check if robot is stuck by analyzing pose history. + + Returns: + Tuple of (is_stuck, max_std_dev_mm) + """ + if len(self.ee_pose_history) < self.ee_pose_history.maxlen: + return False + + # Extract positions from pose history + positions = np.array( + [[p.position.x, p.position.y, p.position.z] for p in self.ee_pose_history] + ) + + # Calculate standard deviation of positions + std_devs = np.std(positions, axis=0) + # Check if all standard deviations are below stuck threshold + is_stuck = np.all(std_devs < self.stuck_pose_threshold) + + return is_stuck + + def check_reach_and_adjust(self, tolerance: Optional[float] = None) -> bool: + """ + Check if robot has reached the current executed pose while waiting. + Handles timeout internally by failing the task. + Also detects if the robot is stuck (not moving towards target). + + Args: + tolerance: Optional tolerance override (uses PBVS tolerance if not provided) + + Returns: + True if reached, False if still waiting or not in waiting state + """ + if not self.waiting_for_reach or not self.current_executed_pose: + return False + + # Get current end-effector pose + ee_pose = self.arm.get_ee_pose() + target_pose = self.current_executed_pose + + # Check for timeout - this will fail task and reset if timeout occurred + timed_out, time_elapsed = self._check_reach_timeout() + if timed_out: + return False + + # Use provided tolerance or default to PBVS tolerance + if tolerance is None: + tolerance = self.pbvs.target_tolerance if self.pbvs else 0.01 + + # Add current pose to history + self.ee_pose_history.append(ee_pose) + + # Check if robot is stuck + is_stuck = self._check_if_stuck() + if is_stuck: + if self.grasp_stage == GraspStage.RETRACT: + return True + self.stuck_count += 1 + pitch_degrees = self.calculate_dynamic_grasp_pitch(target_pose) + if self.stuck_count % 2 == 0: + pitch_degrees += self.stuck_pose_adjustment_degrees * (1 + self.stuck_count // 2) + else: + pitch_degrees -= self.stuck_pose_adjustment_degrees * (1 + self.stuck_count // 2) + + pitch_degrees = max( + self.min_grasp_pitch_degrees, min(self.max_grasp_pitch_degrees, pitch_degrees) + ) + updated_target_pose = update_target_grasp_pose(target_pose, ee_pose, 0.0, pitch_degrees) + logger.info( + f"updated_target_pose: {updated_target_pose.position.x}, {updated_target_pose.position.y}, {updated_target_pose.position.z}" + ) + self.arm.cmd_ee_pose(updated_target_pose, line_mode=True) + self.current_executed_pose = updated_target_pose + self.ee_pose_history.clear() + self.waiting_for_reach = True + self.waiting_start_time = time.time() + return False + + if self.stuck_count >= self.max_stuck_reattempts: self.task_failed = True self.reset_to_idle() + return False + + if is_target_reached(target_pose, ee_pose, tolerance): + self.waiting_for_reach = False + self.waiting_start_time = None + self.stuck_count = 0 + self.ee_pose_history.clear() return True return False @@ -483,9 +595,10 @@ def reset_to_idle(self): self.pbvs.clear_target() self.grasp_stage = GraspStage.IDLE self.reached_poses.clear() + self.ee_pose_history.clear() self.adjustment_count = 0 self.waiting_for_reach = False - self.last_commanded_pose = None + self.current_executed_pose = None self.target_updated = False self.stabilization_start_time = None self.grasp_reached_time = None @@ -494,37 +607,23 @@ def reset_to_idle(self): self.final_pregrasp_pose = None self.overall_success = None self.place_pose = None + self.retract_pose = None + self.stuck_count = 0 self.arm.gotoObserve() def execute_idle(self): - """Execute idle stage: just visualization, no control.""" + """Execute idle stage.""" pass def execute_pre_grasp(self): """Execute pre-grasp stage: visual servoing to pre-grasp position.""" - ee_pose = self.arm.get_ee_pose() - - # Check if waiting for robot to reach commanded pose - if self.waiting_for_reach and self.last_commanded_pose: - # Check for timeout - if self._check_reach_timeout(): - return - - reached = is_target_reached( - self.last_commanded_pose, ee_pose, self.pbvs.target_tolerance - ) - - if reached: - self.waiting_for_reach = False - self.waiting_start_time = None - self.reached_poses.append(self.last_commanded_pose) + if self.waiting_for_reach: + if self.check_reach_and_adjust(): + self.reached_poses.append(self.current_executed_pose) self.target_updated = False time.sleep(0.3) - return - - # Check stabilization timeout if ( self.stabilization_start_time and (time.time() - self.stabilization_start_time) > self.stabilization_timeout @@ -536,24 +635,28 @@ def execute_pre_grasp(self): self.reset_to_idle() return - # PBVS control with pre-grasp distance + ee_pose = self.arm.get_ee_pose() + dynamic_pitch = self.calculate_dynamic_grasp_pitch(self.pbvs.current_target.bbox.center) + _, _, _, has_target, target_pose = self.pbvs.compute_control( - ee_pose, self.pregrasp_distance + ee_pose, self.pregrasp_distance, dynamic_pitch ) - - # Handle pose control if target_pose and has_target: - # Check if we have enough reached poses and they're stable + # Validate target pose is within workspace + if not self.check_within_workspace(target_pose): + self.task_failed = True + self.reset_to_idle() + return + if self.check_target_stabilized(): logger.info("Target stabilized, transitioning to GRASP") - self.final_pregrasp_pose = self.last_commanded_pose + self.final_pregrasp_pose = self.current_executed_pose self.grasp_stage = GraspStage.GRASP self.adjustment_count = 0 self.waiting_for_reach = False elif not self.waiting_for_reach and self.target_updated: - # Command the pose only if target has been updated self.arm.cmd_ee_pose(target_pose) - self.last_commanded_pose = target_pose + self.current_executed_pose = target_pose self.waiting_for_reach = True self.waiting_start_time = time.time() self.target_updated = False @@ -562,126 +665,94 @@ def execute_pre_grasp(self): def execute_grasp(self): """Execute grasp stage: move to final grasp position.""" - ee_pose = self.arm.get_ee_pose() - - # Handle waiting with special grasp logic - if self.waiting_for_reach: - if self._check_reach_timeout(): - return - - if ( - is_target_reached(self.pbvs.target_grasp_pose, ee_pose, self.pbvs.target_tolerance) - and not self.grasp_reached_time - ): + if self.waiting_for_reach and self.pbvs and self.pbvs.target_grasp_pose: + if self.check_reach_and_adjust() and not self.grasp_reached_time: self.grasp_reached_time = time.time() - self.waiting_start_time = None - # Check if delay completed if ( self.grasp_reached_time and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay ): logger.info("Grasp delay completed, closing gripper") self.grasp_stage = GraspStage.CLOSE_AND_RETRACT - self.waiting_for_reach = False return - - # Only command grasp if not waiting and have valid target if self.last_valid_target: - # Calculate grasp distance based on pitch angle - normalized_pitch = self.grasp_pitch_degrees / 90.0 + # Calculate dynamic pitch for current target + dynamic_pitch = self.calculate_dynamic_grasp_pitch(self.last_valid_target.bbox.center) + normalized_pitch = dynamic_pitch / 90.0 grasp_distance = -self.grasp_distance_range + ( 2 * self.grasp_distance_range * normalized_pitch ) - # PBVS control with calculated grasp distance - _, _, _, has_target, target_pose = self.pbvs.compute_control(ee_pose, grasp_distance) + ee_pose = self.arm.get_ee_pose() + _, _, _, has_target, target_pose = self.pbvs.compute_control( + ee_pose, grasp_distance, dynamic_pitch + ) if target_pose and has_target: - # Calculate gripper opening + # Validate grasp pose is within workspace + if not self.check_within_workspace(target_pose): + self.task_failed = True + self.reset_to_idle() + return + object_width = self.last_valid_target.bbox.size.x gripper_opening = max( 0.005, min(object_width + self.grasp_width_offset, self.gripper_max_opening) ) logger.info(f"Executing grasp: gripper={gripper_opening * 1000:.1f}mm") - - # Command gripper and pose self.arm.cmd_gripper_ctrl(gripper_opening) self.arm.cmd_ee_pose(target_pose, line_mode=True) + self.current_executed_pose = target_pose self.waiting_for_reach = True self.waiting_start_time = time.time() def execute_close_and_retract(self): """Execute the retraction sequence after gripper has been closed.""" - ee_pose = self.arm.get_ee_pose() - - if self.waiting_for_reach: - if self._check_reach_timeout(): - return - - # Check if reached retraction pose - reached = is_target_reached( - self.final_pregrasp_pose, ee_pose, self.pbvs.target_tolerance - ) - - if reached: + if self.waiting_for_reach and self.final_pregrasp_pose: + if self.check_reach_and_adjust(): logger.info("Reached pre-grasp retraction position") - self.waiting_for_reach = False self.pick_success = self.arm.gripper_object_detected() - logger.info(f"Grasp sequence completed") if self.pick_success: logger.info("Object successfully grasped!") - # Transition to PLACE stage if place position is available if self.place_target_position is not None: logger.info("Transitioning to PLACE stage") self.grasp_stage = GraspStage.PLACE else: - # No place position, just mark as overall success self.overall_success = True else: logger.warning("No object detected in gripper") self.task_failed = True self.overall_success = False - else: - # Command retraction to pre-grasp + return + if not self.waiting_for_reach: logger.info("Retracting to pre-grasp position") self.arm.cmd_ee_pose(self.final_pregrasp_pose, line_mode=True) + self.current_executed_pose = self.final_pregrasp_pose self.arm.close_gripper() self.waiting_for_reach = True self.waiting_start_time = time.time() def execute_place(self): """Execute place stage: move to place position and release object.""" - ee_pose = self.arm.get_ee_pose() - if self.waiting_for_reach: - if self._check_reach_timeout(): - return - - # Check if reached place pose - place_pose = self.get_place_target_pose() - if place_pose: - reached = is_target_reached(place_pose, ee_pose, self.pbvs.target_tolerance) - - if reached: - logger.info("Reached place position, releasing gripper") - self.arm.release_gripper() - time.sleep(1.0) # Give time for gripper to open - - # Store the place pose for retraction - self.place_pose = place_pose + # Use the already executed pose instead of recalculating + if self.check_reach_and_adjust(): + logger.info("Reached place position, releasing gripper") + self.arm.release_gripper() + time.sleep(1.0) + self.place_pose = self.current_executed_pose + logger.info("Transitioning to RETRACT stage") + self.grasp_stage = GraspStage.RETRACT + return - # Transition to RETRACT stage - logger.info("Transitioning to RETRACT stage") - self.grasp_stage = GraspStage.RETRACT - self.waiting_for_reach = False - else: - # Get place pose and command movement + if not self.waiting_for_reach: place_pose = self.get_place_target_pose() if place_pose: logger.info("Moving to place position") self.arm.cmd_ee_pose(place_pose, line_mode=True) + self.current_executed_pose = place_pose self.waiting_for_reach = True self.waiting_start_time = time.time() else: @@ -691,33 +762,25 @@ def execute_place(self): def execute_retract(self): """Execute retract stage: retract from place position.""" - ee_pose = self.arm.get_ee_pose() - - if self.waiting_for_reach: - if self._check_reach_timeout(): - return + if self.waiting_for_reach and self.retract_pose: + if self.check_reach_and_adjust(): + logger.info("Reached retract position") + logger.info("Returning to observe position") + self.arm.gotoObserve() + self.arm.close_gripper() + self.overall_success = True + logger.info("Pick and place completed successfully!") + return - # Check if reached retract pose + if not self.waiting_for_reach: if self.place_pose: - reached = is_target_reached(self.retract_pose, ee_pose, self.pbvs.target_tolerance) - - if reached: - logger.info("Reached retract position") - # Return to observe position - logger.info("Returning to observe position") - self.arm.gotoObserve() - self.arm.close_gripper() - - # Mark overall success - self.overall_success = True - logger.info("Pick and place completed successfully!") - self.waiting_for_reach = False - else: - # Calculate and command retract pose - if self.place_pose: - self.retract_pose = apply_grasp_distance(self.place_pose, self.pregrasp_distance) + pose_pitch = self.calculate_dynamic_grasp_pitch(self.place_pose) + self.retract_pose = update_target_grasp_pose( + self.place_pose, self.home_pose, self.retract_distance, pose_pitch + ) logger.info("Retracting from place position") self.arm.cmd_ee_pose(self.retract_pose, line_mode=True) + self.current_executed_pose = self.retract_pose self.waiting_for_reach = True self.waiting_start_time = time.time() else: @@ -731,17 +794,13 @@ def capture_and_process( Optional[np.ndarray], Optional[Detection3DArray], Optional[Detection2DArray], Optional[Pose] ]: """Capture frame from camera data and process detections.""" - # Check if we have all required data if self.latest_rgb is None or self.latest_depth is None or self.detector is None: return None, None, None, None - # Get EE pose and camera transform ee_pose = self.arm.get_ee_pose() ee_transform = pose_to_matrix(ee_pose) camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) camera_pose = matrix_to_pose(camera_transform) - - # Process detections detection_3d_array, detection_2d_array = self.detector.process_frame( self.latest_rgb, self.latest_depth, camera_transform ) @@ -758,51 +817,49 @@ def pick_target(self, x: int, y: int) -> bool: (x, y), self.last_detection_2d_array.detections, self.last_detection_3d_array.detections ) if clicked_3d and self.pbvs: + # Validate workspace + if not self.check_within_workspace(clicked_3d.bbox.center): + self.task_failed = True + return False + self.pbvs.set_target(clicked_3d) - # Store target object height (z dimension) if clicked_3d.bbox and clicked_3d.bbox.size: self.target_object_height = clicked_3d.bbox.size.z logger.info(f"Target object height: {self.target_object_height:.3f}m") + position = clicked_3d.bbox.center.position logger.info( - f"Target selected: ID={clicked_3d.id}, pos=({clicked_3d.bbox.center.position.x:.3f}, {clicked_3d.bbox.center.position.y:.3f}, {clicked_3d.bbox.center.position.z:.3f})" + f"Target selected: ID={clicked_3d.id}, pos=({position.x:.3f}, {position.y:.3f}, {position.z:.3f})" ) self.grasp_stage = GraspStage.PRE_GRASP self.reached_poses.clear() self.adjustment_count = 0 self.waiting_for_reach = False - self.last_commanded_pose = None + self.current_executed_pose = None self.stabilization_start_time = time.time() return True return False def update(self) -> Optional[Dict[str, Any]]: """Main update function that handles capture, processing, control, and visualization.""" - # Capture and process frame rgb, detection_3d_array, detection_2d_array, camera_pose = self.capture_and_process() if rgb is None: return None - # Store for target selection self.last_detection_3d_array = detection_3d_array self.last_detection_2d_array = detection_2d_array - - # Handle target selection if click is pending if self.target_click: x, y = self.target_click if self.pick_target(x, y): self.target_click = None - # Update tracking if we have detections and not in IDLE or CLOSE_AND_RETRACT if ( detection_3d_array and self.grasp_stage in [GraspStage.PRE_GRASP, GraspStage.GRASP] and not self.waiting_for_reach ): self._update_tracking(detection_3d_array) - - # Execute stage-specific logic stage_handlers = { GraspStage.IDLE: self.execute_idle, GraspStage.PRE_GRASP: self.execute_pre_grasp, @@ -814,15 +871,12 @@ def update(self) -> Optional[Dict[str, Any]]: if self.grasp_stage in stage_handlers: stage_handlers[self.grasp_stage]() - # Get tracking status target_tracked = self.pbvs.get_current_target() is not None if self.pbvs else False - - # Create feedback object ee_pose = self.arm.get_ee_pose() feedback = Feedback( grasp_stage=self.grasp_stage, target_tracked=target_tracked, - last_commanded_pose=self.last_commanded_pose, + current_executed_pose=self.current_executed_pose, current_ee_pose=ee_pose, current_camera_pose=camera_pose, target_pose=self.pbvs.target_grasp_pose if self.pbvs else None, @@ -830,13 +884,11 @@ def update(self) -> Optional[Dict[str, Any]]: success=self.overall_success, ) - # Create visualization only if task is running if self.task_running: self.current_visualization = create_manipulation_visualization( rgb, feedback, detection_3d_array, detection_2d_array ) - # Publish visualization if self.current_visualization is not None: self._publish_visualization(self.current_visualization) @@ -845,10 +897,7 @@ def update(self) -> Optional[Dict[str, Any]]: def _publish_visualization(self, viz_image: np.ndarray): """Publish visualization image to LCM.""" try: - # Convert BGR to RGB for publishing viz_rgb = cv2.cvtColor(viz_image, cv2.COLOR_BGR2RGB) - - # Create LCM Image message height, width = viz_rgb.shape[:2] data = viz_rgb.tobytes() @@ -871,15 +920,10 @@ def check_target_stabilized(self) -> bool: if len(self.reached_poses) < self.reached_poses.maxlen: return False - # Extract positions positions = np.array( [[p.position.x, p.position.y, p.position.z] for p in self.reached_poses] ) - - # Calculate standard deviation for each axis std_devs = np.std(positions, axis=0) - - # Check if all axes are below threshold return np.all(std_devs < self.pose_stabilization_threshold) def get_place_target_pose(self) -> Optional[Pose]: @@ -887,29 +931,25 @@ def get_place_target_pose(self) -> Optional[Pose]: if self.place_target_position is None: return None - # Create a copy of the place position place_pos = self.place_target_position.copy() - - # Apply z-offset if target object height is known if self.target_object_height is not None: z_offset = self.target_object_height / 2.0 - place_pos[2] += z_offset + 0.05 - logger.info(f"Applied z-offset of {z_offset:.3f}m to place position") + place_pos[2] += z_offset + 0.1 - # Create place pose place_center_pose = Pose( Point(place_pos[0], place_pos[1], place_pos[2]), Quaternion(0.0, 0.0, 0.0, 1.0) ) - # Get current EE pose ee_pose = self.arm.get_ee_pose() - # Use update_target_grasp_pose with no grasp distance and current pitch angle + # Calculate dynamic pitch for place position + dynamic_pitch = self.calculate_dynamic_grasp_pitch(place_center_pose) + place_pose = update_target_grasp_pose( place_center_pose, ee_pose, - grasp_distance=0.0, # No grasp distance for placing - grasp_pitch_degrees=self.grasp_pitch_degrees, # Use current grasp pitch + grasp_distance=0.0, + grasp_pitch_degrees=dynamic_pitch, ) return place_pose diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index c207b0e49c..aef95066e5 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -18,16 +18,12 @@ """ import numpy as np -from typing import Optional, Tuple +from typing import Optional, Tuple, List +from collections import deque from scipy.spatial.transform import Rotation as R from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point from dimos_lcm.vision_msgs import Detection3D, Detection3DArray from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import ( - yaw_towards_point, - pose_to_matrix, - euler_to_quaternion, -) from dimos.manipulation.visual_servoing.utils import ( update_target_grasp_pose, find_best_object_match, @@ -59,7 +55,7 @@ def __init__( max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s target_tolerance: float = 0.01, # 1cm - max_tracking_distance_threshold: float = 0.1, # Max distance for target tracking (m) + max_tracking_distance_threshold: float = 0.15, # Max distance for target tracking (m) min_size_similarity: float = 0.6, # Min size similarity threshold (0.0-1.0) direct_ee_control: bool = True, # If True, output target poses instead of velocities ): @@ -95,14 +91,15 @@ def __init__( self.max_tracking_distance_threshold = max_tracking_distance_threshold self.min_size_similarity = min_size_similarity self.direct_ee_control = direct_ee_control - self.grasp_pitch_degrees = ( - 45.0 # Default grasp pitch in degrees (45° between level and top-down) - ) # Target state self.current_target = None self.target_grasp_pose = None + # Detection history for robust tracking + self.detection_history_size = 3 + self.detection_history = deque(maxlen=self.detection_history_size) + # For direct control mode visualization self.last_position_error = None self.last_target_reached = False @@ -135,6 +132,7 @@ def clear_target(self): self.target_grasp_pose = None self.last_position_error = None self.last_target_reached = False + self.detection_history.clear() if self.controller: self.controller.clear_state() logger.info("Target cleared") @@ -148,24 +146,9 @@ def get_current_target(self) -> Optional[Detection3D]: """ return self.current_target - def set_grasp_pitch(self, pitch_degrees: float): - """ - Set the grasp pitch angle in degrees. - - Args: - pitch_degrees: Grasp pitch angle in degrees (0-90) - 0° = level grasp (horizontal) - 90° = top-down grasp (vertical) - """ - # Clamp to valid range - pitch_degrees = max(0.0, min(90.0, pitch_degrees)) - self.grasp_pitch_degrees = pitch_degrees - # Reset target grasp pose to recompute with new pitch - self.target_grasp_pose = None - def update_tracking(self, new_detections: Optional[Detection3DArray] = None) -> bool: """ - Update target tracking with new detections. + Update target tracking with new detections using a rolling window. If tracking is lost, keeps the old target pose. Args: @@ -182,19 +165,31 @@ def update_tracking(self, new_detections: Optional[Detection3DArray] = None) -> ): return False - # Try to update target tracking if new detections provided - # Continue with last known pose even if tracking is lost - if new_detections is None or new_detections.detections_length == 0: - logger.debug("No detections for target tracking - using last known pose") + # Add new detections to history if provided + if new_detections is not None and new_detections.detections_length > 0: + self.detection_history.append(new_detections) + + # If no detection history, can't track + if not self.detection_history: + logger.debug("No detection history for target tracking - using last known pose") + return False + + # Collect all candidates from detection history + all_candidates = [] + for detection_array in self.detection_history: + all_candidates.extend(detection_array.detections) + + if not all_candidates: + logger.debug("No candidates in detection history") return False # Use stage-dependent distance threshold max_distance = self.max_tracking_distance_threshold - # Find best match using standardized utility function + # Find best match across all recent detections match_result = find_best_object_match( target_obj=self.current_target, - candidates=new_detections.detections, + candidates=all_candidates, max_distance=max_distance, min_size_similarity=self.min_size_similarity, ) @@ -210,7 +205,8 @@ def update_tracking(self, new_detections: Optional[Detection3DArray] = None) -> return True logger.debug( - f"Target tracking lost: distance={match_result.distance:.3f}m, " + f"Target tracking lost across {len(self.detection_history)} frames: " + f"distance={match_result.distance:.3f}m, " f"size_similarity={match_result.size_similarity:.2f}, " f"thresholds: distance={max_distance:.3f}m, size={self.min_size_similarity:.2f}" ) @@ -220,6 +216,7 @@ def compute_control( self, ee_pose: Pose, grasp_distance: float = 0.15, + grasp_pitch_degrees: float = 45.0, ) -> Tuple[Optional[Vector3], Optional[Vector3], bool, bool, Optional[Pose]]: """ Compute PBVS control with position and orientation servoing. @@ -240,9 +237,9 @@ def compute_control( if not self.current_target: return None, None, False, False, None - # Update target grasp pose with provided distance + # Update target grasp pose with provided distance and pitch self.target_grasp_pose = update_target_grasp_pose( - self.current_target.bbox.center, ee_pose, grasp_distance, self.grasp_pitch_degrees + self.current_target.bbox.center, ee_pose, grasp_distance, grasp_pitch_degrees ) if self.target_grasp_pose is None: diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index 6b00964775..f475cd70c8 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -29,6 +29,7 @@ euler_to_quaternion, compose_transforms, yaw_towards_point, + get_distance, ) @@ -317,15 +318,8 @@ def is_target_reached(target_pose: Pose, current_pose: Pose, tolerance: float = Returns: True if target is reached within tolerance, False otherwise """ - if not target_pose: - return False - - # Calculate position error - error_x = target_pose.position.x - current_pose.position.x - error_y = target_pose.position.y - current_pose.position.y - error_z = target_pose.position.z - current_pose.position.z - - error_magnitude = np.sqrt(error_x**2 + error_y**2 + error_z**2) + # Calculate position error using distance utility + error_magnitude = get_distance(target_pose, current_pose) return error_magnitude < tolerance diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 5aa33bccce..eaedbcecf3 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -378,3 +378,21 @@ def quaternion_to_euler(quaternion: Quaternion, degrees: bool = False) -> Vector ) else: return Vector3(euler[0], euler[1], euler[2]) + + +def get_distance(pose1: Pose, pose2: Pose) -> float: + """ + Calculate Euclidean distance between two poses. + + Args: + pose1: First pose + pose2: Second pose + + Returns: + Euclidean distance between the two poses in meters + """ + dx = pose1.position.x - pose2.position.x + dy = pose1.position.y - pose2.position.y + dz = pose1.position.z - pose2.position.z + + return np.linalg.norm(np.array([dx, dy, dz])) diff --git a/tests/test_pick_and_place_module.py b/tests/test_pick_and_place_module.py index 27924481af..53f39be74e 100644 --- a/tests/test_pick_and_place_module.py +++ b/tests/test_pick_and_place_module.py @@ -145,9 +145,6 @@ def run_visualization(self): print(" 's' - SOFT STOP (emergency stop)") print(" 'g' - RELEASE GRIPPER (open gripper)") print(" 'SPACE' - EXECUTE target pose (manual override)") - print("\nGRASP PITCH CONTROLS:") - print(" '↑' - Increase grasp pitch by 15° (towards top-down)") - print(" '↓' - Decrease grasp pitch by 15° (towards level)") print("\nNOTE: Click on objects in the Camera Feed window!") while self._running: From 9adb80a7dae079691f93813ce62558cdb696f069 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Wed, 23 Jul 2025 22:57:30 -0700 Subject: [PATCH 80/89] hacky bug fix --- .../visual_servoing/manipulation_module.py | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index b8f3dbb9f0..1a7e966826 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -134,7 +134,7 @@ def __init__( self.arm = PiperArm() if ee_to_camera_6dof is None: - ee_to_camera_6dof = [-0.065, 0.03, -0.105, 0.0, -1.57, 0.0] + ee_to_camera_6dof = [-0.065, 0.03, -0.095, 0.0, -1.57, 0.0] pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) self.T_ee_to_camera = create_transform_from_6dof(pos, rot) @@ -149,7 +149,7 @@ def __init__( self.current_executed_pose = None # Track the actual pose sent to arm self.target_updated = False self.waiting_start_time = None - self.reach_pose_timeout = 10.0 + self.reach_pose_timeout = 20.0 # Grasp parameters self.grasp_width_offset = 0.03 @@ -163,7 +163,7 @@ def __init__( self.workspace_min_radius = 0.2 self.workspace_max_radius = 0.75 self.min_grasp_pitch_degrees = 5.0 - self.max_grasp_pitch_degrees = 75.0 + self.max_grasp_pitch_degrees = 60.0 # Grasp stage tracking self.grasp_stage = GraspStage.IDLE @@ -267,7 +267,7 @@ def _on_camera_info(self, msg: CameraInfo): if self.detector is None: self.detector = Detection3DProcessor(self.camera_intrinsics) - self.pbvs = PBVS(target_tolerance=0.05) + self.pbvs = PBVS() logger.info("Initialized detection and PBVS processors") self.latest_camera_info = msg @@ -508,15 +508,12 @@ def _check_if_stuck(self) -> bool: return is_stuck - def check_reach_and_adjust(self, tolerance: Optional[float] = None) -> bool: + def check_reach_and_adjust(self) -> bool: """ Check if robot has reached the current executed pose while waiting. Handles timeout internally by failing the task. Also detects if the robot is stuck (not moving towards target). - Args: - tolerance: Optional tolerance override (uses PBVS tolerance if not provided) - Returns: True if reached, False if still waiting or not in waiting state """ @@ -532,17 +529,17 @@ def check_reach_and_adjust(self, tolerance: Optional[float] = None) -> bool: if timed_out: return False - # Use provided tolerance or default to PBVS tolerance - if tolerance is None: - tolerance = self.pbvs.target_tolerance if self.pbvs else 0.01 - # Add current pose to history self.ee_pose_history.append(ee_pose) # Check if robot is stuck is_stuck = self._check_if_stuck() if is_stuck: - if self.grasp_stage == GraspStage.RETRACT: + if self.grasp_stage == GraspStage.RETRACT or self.grasp_stage == GraspStage.PLACE: + self.waiting_for_reach = False + self.waiting_start_time = None + self.stuck_count = 0 + self.ee_pose_history.clear() return True self.stuck_count += 1 pitch_degrees = self.calculate_dynamic_grasp_pitch(target_pose) @@ -555,10 +552,7 @@ def check_reach_and_adjust(self, tolerance: Optional[float] = None) -> bool: self.min_grasp_pitch_degrees, min(self.max_grasp_pitch_degrees, pitch_degrees) ) updated_target_pose = update_target_grasp_pose(target_pose, ee_pose, 0.0, pitch_degrees) - logger.info( - f"updated_target_pose: {updated_target_pose.position.x}, {updated_target_pose.position.y}, {updated_target_pose.position.z}" - ) - self.arm.cmd_ee_pose(updated_target_pose, line_mode=True) + self.arm.cmd_ee_pose(updated_target_pose) self.current_executed_pose = updated_target_pose self.ee_pose_history.clear() self.waiting_for_reach = True @@ -570,7 +564,7 @@ def check_reach_and_adjust(self, tolerance: Optional[float] = None) -> bool: self.reset_to_idle() return False - if is_target_reached(target_pose, ee_pose, tolerance): + if is_target_reached(target_pose, ee_pose, self.pbvs.target_tolerance): self.waiting_for_reach = False self.waiting_start_time = None self.stuck_count = 0 @@ -622,7 +616,7 @@ def execute_pre_grasp(self): if self.check_reach_and_adjust(): self.reached_poses.append(self.current_executed_pose) self.target_updated = False - time.sleep(0.3) + time.sleep(0.2) return if ( self.stabilization_start_time @@ -665,17 +659,17 @@ def execute_pre_grasp(self): def execute_grasp(self): """Execute grasp stage: move to final grasp position.""" - if self.waiting_for_reach and self.pbvs and self.pbvs.target_grasp_pose: + if self.waiting_for_reach: if self.check_reach_and_adjust() and not self.grasp_reached_time: self.grasp_reached_time = time.time() + return - if ( - self.grasp_reached_time - and (time.time() - self.grasp_reached_time) >= self.grasp_close_delay - ): + if self.grasp_reached_time: + if (time.time() - self.grasp_reached_time) >= self.grasp_close_delay: logger.info("Grasp delay completed, closing gripper") self.grasp_stage = GraspStage.CLOSE_AND_RETRACT return + if self.last_valid_target: # Calculate dynamic pitch for current target dynamic_pitch = self.calculate_dynamic_grasp_pitch(self.last_valid_target.bbox.center) From fe3c7859825eeaa95e8ecfb1322fdcb695e80dea Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Thu, 24 Jul 2025 22:50:04 -0700 Subject: [PATCH 81/89] Feat: fully working piper arm manipulation robot --- .../visual_servoing/manipulation_module.py | 16 +- dimos/robot/agilex/piper_arm.py | 201 ++++++++ dimos/robot/agilex/run.py | 196 ++++++++ dimos/skills/manipulation/pick_and_place.py | 439 ++++++++++++++++++ tests/test_pick_and_place_module.py | 123 ++--- tests/test_pick_and_place_skill.py | 154 ++++++ 6 files changed, 1037 insertions(+), 92 deletions(-) create mode 100644 dimos/robot/agilex/piper_arm.py create mode 100644 dimos/robot/agilex/run.py create mode 100644 dimos/skills/manipulation/pick_and_place.py create mode 100644 tests/test_pick_and_place_skill.py diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index 1a7e966826..65c440d24b 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -234,8 +234,8 @@ def stop(self): if self.task_thread and self.task_thread.is_alive(): self.task_thread.join(timeout=5.0) - # Disable arm - self.arm.disable() + # Reset to idle + self.reset_to_idle() logger.info("Manipulation module stopped") def _on_rgb_image(self, msg: Image): @@ -274,6 +274,13 @@ def _on_camera_info(self, msg: CameraInfo): except Exception as e: logger.error(f"Error processing camera info: {e}") + @rpc + def get_single_rgb_frame(self) -> Optional[np.ndarray]: + """ + get the latest rgb frame from the camera + """ + return self.latest_rgb + @rpc def handle_keyboard_command(self, key: str) -> str: """ @@ -948,6 +955,9 @@ def get_place_target_pose(self) -> Optional[Pose]: return place_pose + @rpc def cleanup(self): """Clean up resources on module destruction.""" - self.stop() + if self.detector and hasattr(self.detector, "cleanup"): + self.detector.cleanup() + self.arm.disable() diff --git a/dimos/robot/agilex/piper_arm.py b/dimos/robot/agilex/piper_arm.py new file mode 100644 index 0000000000..1bb43a76d8 --- /dev/null +++ b/dimos/robot/agilex/piper_arm.py @@ -0,0 +1,201 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +from typing import Optional, List + +from dimos import core +from dimos.hardware.zed_camera import ZEDModule +from dimos.manipulation.visual_servoing.manipulation_module import ManipulationModule +from dimos_lcm.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.skills.skills import SkillLibrary +from dimos.types.robot_capabilities import RobotCapability +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.utils.logging_config import setup_logger + +# Import LCM message types +from dimos_lcm.sensor_msgs import CameraInfo + +logger = setup_logger("dimos.robot.agilex.piper_arm", level=logging.INFO) + + +class PiperArmRobot: + """Piper Arm robot with ZED camera and manipulation capabilities.""" + + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + self.dimos = None + self.stereo_camera = None + self.manipulation_interface = None + self.skill_library = SkillLibrary() + + # Initialize capabilities + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + """Start the robot modules.""" + # Start Dimos + self.dimos = core.start(2) # Need 2 workers for ZED and manipulation modules + self.foxglove_bridge = FoxgloveBridge() + + # Enable LCM auto-configuration + pubsub.lcm.autoconf() + + # Deploy ZED module + logger.info("Deploying ZED module...") + self.stereo_camera = self.dimos.deploy( + ZEDModule, + camera_id=0, + resolution="HD720", + depth_mode="NEURAL", + fps=30, + enable_tracking=False, # We don't need tracking for manipulation + publish_rate=30.0, + frame_id="zed_camera", + ) + + # Configure ZED LCM transports + self.stereo_camera.color_image.transport = core.LCMTransport("/zed/color_image", Image) + self.stereo_camera.depth_image.transport = core.LCMTransport("/zed/depth_image", Image) + self.stereo_camera.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) + + # Deploy manipulation module + logger.info("Deploying manipulation module...") + self.manipulation_interface = self.dimos.deploy(ManipulationModule) + + # Connect manipulation inputs to ZED outputs + self.manipulation_interface.rgb_image.connect(self.stereo_camera.color_image) + self.manipulation_interface.depth_image.connect(self.stereo_camera.depth_image) + self.manipulation_interface.camera_info.connect(self.stereo_camera.camera_info) + + # Configure manipulation output + self.manipulation_interface.viz_image.transport = core.LCMTransport( + "/manipulation/viz", Image + ) + + # Print module info + logger.info("Modules configured:") + print("\nZED Module:") + print(self.stereo_camera.io().result()) + print("\nManipulation Module:") + print(self.manipulation_interface.io().result()) + + # Start modules + logger.info("Starting modules...") + self.foxglove_bridge.start() + self.stereo_camera.start() + self.manipulation_interface.start() + + # Give modules time to initialize + await asyncio.sleep(2) + + logger.info("PiperArmRobot initialized and started") + + def get_skills(self): + """Get the robot's skill library. + + Returns: + The robot's skill library for adding/managing skills + """ + return self.skill_library + + def pick_and_place( + self, pick_x: int, pick_y: int, place_x: Optional[int] = None, place_y: Optional[int] = None + ): + """Execute pick and place task. + + Args: + pick_x: X coordinate for pick location + pick_y: Y coordinate for pick location + place_x: X coordinate for place location (optional) + place_y: Y coordinate for place location (optional) + + Returns: + Result of the pick and place operation + """ + if self.manipulation_interface: + return self.manipulation_interface.pick_and_place(pick_x, pick_y, place_x, place_y) + else: + logger.error("Manipulation module not initialized") + return False + + def handle_keyboard_command(self, key: str): + """Pass keyboard commands to manipulation module. + + Args: + key: Keyboard key pressed + + Returns: + Action taken or None + """ + if self.manipulation_interface: + return self.manipulation_interface.handle_keyboard_command(key) + else: + logger.error("Manipulation module not initialized") + return None + + def has_capability(self, capability: RobotCapability) -> bool: + """Check if the robot has a specific capability. + + Args: + capability: The capability to check for + + Returns: + bool: True if the robot has the capability + """ + return capability in self.capabilities + + def stop(self): + """Stop all modules and clean up.""" + logger.info("Stopping PiperArmRobot...") + + try: + if self.manipulation_interface: + self.manipulation_interface.stop() + self.manipulation_interface.cleanup() + + if self.stereo_camera: + self.stereo_camera.stop() + except Exception as e: + logger.warning(f"Error stopping modules: {e}") + + # Close dimos last to ensure workers are available for cleanup + if self.dimos: + self.dimos.close() + + logger.info("PiperArmRobot stopped") + + +async def run_piper_arm(): + """Run the Piper Arm robot.""" + robot = PiperArmRobot() + + await robot.start() + + # Keep the robot running + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + finally: + await robot.stop() + + +if __name__ == "__main__": + asyncio.run(run_piper_arm()) diff --git a/dimos/robot/agilex/run.py b/dimos/robot/agilex/run.py new file mode 100644 index 0000000000..c9e5a036d8 --- /dev/null +++ b/dimos/robot/agilex/run.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Piper Arm robot with Claude agent integration. +Provides manipulation capabilities with natural language interface. +""" + +import asyncio +import os +import sys +import time +from dotenv import load_dotenv + +import reactivex as rx +import reactivex.operators as ops + +from dimos.robot.agilex.piper_arm import PiperArmRobot +from dimos.agents.claude_agent import ClaudeAgent +from dimos.skills.manipulation.pick_and_place import PickAndPlace +from dimos.skills.kill_skill import KillSkill +from dimos.skills.observe import Observe +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.stream.audio.pipelines import stt, tts +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.agilex.run") + +# Load environment variables +load_dotenv() + +# System prompt for the Piper Arm manipulation agent +SYSTEM_PROMPT = """You are an intelligent robotic assistant controlling a Piper Arm robot with advanced manipulation capabilities. Your primary role is to help users with pick and place tasks using natural language understanding. + +## Your Capabilities: +1. **Visual Perception**: You have access to a ZED stereo camera that provides RGB and depth information +2. **Object Manipulation**: You can pick up and place objects using a 6-DOF robotic arm with a gripper +3. **Language Understanding**: You use the Qwen vision-language model to identify objects and locations from natural language descriptions + +## Available Skills: +- **PickAndPlace**: Execute pick and place operations based on object and location descriptions + - Pick only: "Pick up the red mug" + - Pick and place: "Move the book to the shelf" +- **Observe**: Capture and analyze the current camera view +- **KillSkill**: Stop any currently running skill + +## Guidelines: +1. **Safety First**: Always ensure safe operation. If unsure about an object's graspability or a placement location's stability, ask for clarification +2. **Clear Communication**: Explain what you're doing and ask for confirmation when needed +3. **Error Handling**: If a task fails, explain why and suggest alternatives +4. **Precision**: When users give specific object descriptions, use them exactly as provided to the vision model + +## Interaction Examples: +- User: "Pick up the coffee mug" + You: "I'll pick up the coffee mug for you." [Execute PickAndPlace with object_query="coffee mug"] + +- User: "Put the toy on the table" + You: "I'll place the toy on the table." [Execute PickAndPlace with object_query="toy", target_query="on the table"] + +- User: "What do you see?" + You: "Let me take a look at the current view." [Execute Observe] + +Remember: You're here to assist with manipulation tasks. Be helpful, precise, and always prioritize safe operation of the robot.""" + + +def main(): + """Main entry point.""" + print("\n" + "=" * 60) + print("Piper Arm Robot with Claude Agent") + print("=" * 60) + print("\nThis system integrates:") + print(" - Piper Arm 6-DOF robot") + print(" - ZED stereo camera") + print(" - Claude AI for natural language understanding") + print(" - Qwen VLM for visual object detection") + print(" - Web interface with text and voice input") + print(" - Foxglove visualization via LCM") + print("\nStarting system...\n") + + # Check for API key + if not os.getenv("ANTHROPIC_API_KEY"): + print("WARNING: ANTHROPIC_API_KEY not found in environment") + print("Please set your API key in .env file or environment") + sys.exit(1) + + logger.info("Starting Piper Arm Robot with Agent") + + # Create robot instance + robot = PiperArmRobot() + + try: + # Start the robot (this is async, so we need asyncio.run) + logger.info("Initializing robot...") + asyncio.run(robot.start()) + logger.info("Robot initialized successfully") + + # Set up skill library + skills = robot.get_skills() + skills.add(PickAndPlace) + skills.add(Observe) + skills.add(KillSkill) + + # Create skill instances + skills.create_instance("PickAndPlace", robot=robot) + skills.create_instance("Observe", robot=robot) + skills.create_instance("KillSkill", robot=robot, skill_library=skills) + + logger.info(f"Skills registered: {[skill.__name__ for skill in skills.get_class_skills()]}") + + # Set up streams for agent and web interface + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + audio_subject = rx.subject.Subject() + + # Set up streams for web interface + streams = {} + + text_streams = { + "agent_responses": agent_response_stream, + } + + # Create web interface first (needed for agent) + try: + web_interface = RobotWebInterface( + port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams + ) + logger.info("Web interface created successfully") + except Exception as e: + logger.error(f"Failed to create web interface: {e}") + raise + + # Set up speech-to-text + stt_node = stt() + stt_node.consume_audio(audio_subject.pipe(ops.share())) + + # Create Claude agent + agent = ClaudeAgent( + dev_name="piper_arm_agent", + input_query_stream=web_interface.query_stream, # Use text input from web interface + # input_query_stream=stt_node.emit_text(), # Uncomment to use voice input + skills=skills, + system_query=SYSTEM_PROMPT, + model_name="claude-3-5-haiku-latest", + thinking_budget_tokens=0, + max_output_tokens_per_request=4096, + ) + + # Subscribe to agent responses + agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + + # Set up text-to-speech for agent responses + tts_node = tts() + tts_node.consume_text(agent.get_response_observable()) + + logger.info("=" * 60) + logger.info("Piper Arm Agent Ready!") + logger.info(f"Web interface available at: http://localhost:5555") + logger.info("Foxglove visualization available at: ws://localhost:8765") + logger.info("You can:") + logger.info(" - Type commands in the web interface") + logger.info(" - Use voice commands") + logger.info(" - Ask the robot to pick up objects") + logger.info(" - Ask the robot to move objects to locations") + logger.info("=" * 60) + + # Run web interface (this blocks) + web_interface.run() + + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + except Exception as e: + logger.error(f"Error running robot: {e}") + import traceback + + traceback.print_exc() + finally: + logger.info("Shutting down...") + # Stop the robot (this is also async) + robot.stop() + logger.info("Robot stopped") + + +if __name__ == "__main__": + main() diff --git a/dimos/skills/manipulation/pick_and_place.py b/dimos/skills/manipulation/pick_and_place.py new file mode 100644 index 0000000000..51bed5240e --- /dev/null +++ b/dimos/skills/manipulation/pick_and_place.py @@ -0,0 +1,439 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Pick and place skill for Piper Arm robot. + +This module provides a skill that uses Qwen VLM to identify pick and place +locations based on natural language queries, then executes the manipulation. +""" + +import json +import cv2 +import os +from typing import Optional, Tuple, Dict, Any +import numpy as np +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.models.qwen.video_query import query_single_frame +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.skills.manipulation.pick_and_place") + + +def parse_qwen_points_response(response: str) -> Optional[Tuple[Tuple[int, int], Tuple[int, int]]]: + """ + Parse Qwen's response containing two points. + + Args: + response: Qwen's response containing JSON with two points + + Returns: + Tuple of (pick_point, place_point) where each point is (x, y), or None if parsing fails + """ + try: + # Try to extract JSON from the response + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 + + if start_idx >= 0 and end_idx > start_idx: + json_str = response[start_idx:end_idx] + result = json.loads(json_str) + + # Extract pick and place points + if "pick_point" in result and "place_point" in result: + pick = result["pick_point"] + place = result["place_point"] + + # Validate points have x,y coordinates + if ( + isinstance(pick, (list, tuple)) + and len(pick) >= 2 + and isinstance(place, (list, tuple)) + and len(place) >= 2 + ): + return (int(pick[0]), int(pick[1])), (int(place[0]), int(place[1])) + + except Exception as e: + logger.error(f"Error parsing Qwen points response: {e}") + logger.debug(f"Raw response: {response}") + + return None + + +def save_debug_image_with_points( + image: np.ndarray, + pick_point: Optional[Tuple[int, int]] = None, + place_point: Optional[Tuple[int, int]] = None, + filename_prefix: str = "qwen_debug", +) -> str: + """ + Save debug image with crosshairs marking pick and/or place points. + + Args: + image: RGB image array + pick_point: (x, y) coordinates for pick location + place_point: (x, y) coordinates for place location + filename_prefix: Prefix for the saved filename + + Returns: + Path to the saved image + """ + # Create a copy to avoid modifying original + debug_image = image.copy() + + # Convert RGB to BGR for OpenCV if needed + if len(debug_image.shape) == 3 and debug_image.shape[2] == 3: + debug_image = cv2.cvtColor(debug_image, cv2.COLOR_RGB2BGR) + + # Draw pick point crosshair (green) + if pick_point: + x, y = pick_point + # Draw crosshair + cv2.drawMarker(debug_image, (x, y), (0, 255, 0), cv2.MARKER_CROSS, 30, 2) + # Draw circle + cv2.circle(debug_image, (x, y), 5, (0, 255, 0), -1) + # Add label + cv2.putText( + debug_image, "PICK", (x + 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 + ) + + # Draw place point crosshair (cyan) + if place_point: + x, y = place_point + # Draw crosshair + cv2.drawMarker(debug_image, (x, y), (255, 255, 0), cv2.MARKER_CROSS, 30, 2) + # Draw circle + cv2.circle(debug_image, (x, y), 5, (255, 255, 0), -1) + # Add label + cv2.putText( + debug_image, "PLACE", (x + 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2 + ) + + # Draw arrow from pick to place if both exist + if pick_point and place_point: + cv2.arrowedLine(debug_image, pick_point, place_point, (255, 0, 255), 2, tipLength=0.03) + + # Generate filename with timestamp + filename = f"{filename_prefix}.png" + filepath = os.path.join(os.getcwd(), filename) + + # Save image + cv2.imwrite(filepath, debug_image) + logger.info(f"Debug image saved to: {filepath}") + + return filepath + + +def parse_qwen_single_point_response(response: str) -> Optional[Tuple[int, int]]: + """ + Parse Qwen's response containing a single point. + + Args: + response: Qwen's response containing JSON with a point + + Returns: + Tuple of (x, y) or None if parsing fails + """ + try: + # Try to extract JSON from the response + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 + + if start_idx >= 0 and end_idx > start_idx: + json_str = response[start_idx:end_idx] + result = json.loads(json_str) + + # Try different possible keys + point = None + for key in ["point", "location", "position", "coordinates"]: + if key in result: + point = result[key] + break + + # Validate point has x,y coordinates + if point and isinstance(point, (list, tuple)) and len(point) >= 2: + return int(point[0]), int(point[1]) + + except Exception as e: + logger.error(f"Error parsing Qwen single point response: {e}") + logger.debug(f"Raw response: {response}") + + return None + + +class PickAndPlace(AbstractManipulationSkill): + """ + A skill that performs pick and place operations using vision-language guidance. + + This skill uses Qwen VLM to identify objects and locations based on natural + language queries, then executes pick and place operations using the robot's + manipulation interface. + + Example usage: + # Just pick the object + skill = PickAndPlace(robot=robot, object_query="red mug") + + # Pick and place the object + skill = PickAndPlace(robot=robot, object_query="red mug", target_query="on the coaster") + + The skill uses the robot's stereo camera to capture RGB images and its manipulation + interface to execute the pick and place operation. It automatically handles coordinate + transformation from 2D pixel coordinates to 3D world coordinates. + """ + + object_query: str = Field( + "mug", + description="Natural language description of the object to pick (e.g., 'red mug', 'small box')", + ) + + target_query: Optional[str] = Field( + None, + description="Natural language description of where to place the object (e.g., 'on the table', 'in the basket'). If not provided, only pick operation will be performed.", + ) + + model_name: str = Field( + "qwen2.5-vl-72b-instruct", description="Qwen model to use for visual queries" + ) + + def __init__(self, robot=None, **data): + """ + Initialize the PickAndPlace skill. + + Args: + robot: The PiperArmRobot instance + **data: Additional configuration data + """ + super().__init__(robot=robot, **data) + + def _get_camera_frame(self) -> Optional[np.ndarray]: + """ + Get a single RGB frame from the robot's camera. + + Returns: + RGB image as numpy array or None if capture fails + """ + if not self._robot or not self._robot.manipulation_interface: + logger.error("Robot or stereo camera not available") + return None + + try: + # Use the RPC call to get a single RGB frame + rgb_frame = self._robot.manipulation_interface.get_single_rgb_frame() + if rgb_frame is None: + logger.error("Failed to capture RGB frame from camera") + return rgb_frame + except Exception as e: + logger.error(f"Error getting camera frame: {e}") + return None + + def _query_pick_and_place_points( + self, frame: np.ndarray + ) -> Optional[Tuple[Tuple[int, int], Tuple[int, int]]]: + """ + Query Qwen to get both pick and place points in a single query. + + Args: + frame: RGB image array + + Returns: + Tuple of (pick_point, place_point) or None if query fails + """ + # This method is only called when both object and target are specified + prompt = ( + f"Look at this image carefully. I need you to identify two specific locations:\n" + f"1. Find the {self.object_query} - this is the object I want to pick up\n" + f"2. Identify where to place it {self.target_query}\n\n" + "Instructions:\n" + "- The pick_point should be at the center or graspable part of the object\n" + "- The place_point should be a stable, flat surface at the target location\n" + "- Consider the object's size when choosing the placement point\n\n" + "Return ONLY a JSON object with this exact format:\n" + "{'pick_point': [x, y], 'place_point': [x, y]}\n" + "where [x, y] are pixel coordinates in the image." + ) + + try: + response = query_single_frame(frame, prompt, model_name=self.model_name) + return parse_qwen_points_response(response) + except Exception as e: + logger.error(f"Error querying Qwen for pick and place points: {e}") + return None + + def _query_single_point( + self, frame: np.ndarray, query: str, point_type: str + ) -> Optional[Tuple[int, int]]: + """ + Query Qwen to get a single point location. + + Args: + frame: RGB image array + query: Natural language description of what to find + point_type: Type of point ('pick' or 'place') for context + + Returns: + Tuple of (x, y) pixel coordinates or None if query fails + """ + if point_type == "pick": + prompt = ( + f"Look at this image carefully and find the {query}.\n\n" + "Instructions:\n" + "- Identify the exact object matching the description\n" + "- Choose the center point or the most graspable location on the object\n" + "- If multiple matching objects exist, choose the most prominent or accessible one\n" + "- Consider the object's shape and material when selecting the grasp point\n\n" + "Return ONLY a JSON object with this exact format:\n" + "{'point': [x, y]}\n" + "where [x, y] are the pixel coordinates of the optimal grasping point on the object." + ) + else: # place + prompt = ( + f"Look at this image and identify where to place an object {query}.\n\n" + "Instructions:\n" + "- Find a stable, flat surface at the specified location\n" + "- Ensure the placement spot is clear of obstacles\n" + "- Consider the size of the object being placed\n" + "- If the query specifies a container or specific spot, center the placement there\n" + "- Otherwise, find the most appropriate nearby surface\n\n" + "Return ONLY a JSON object with this exact format:\n" + "{'point': [x, y]}\n" + "where [x, y] are the pixel coordinates of the optimal placement location." + ) + + try: + response = query_single_frame(frame, prompt, model_name=self.model_name) + return parse_qwen_single_point_response(response) + except Exception as e: + logger.error(f"Error querying Qwen for {point_type} point: {e}") + return None + + def __call__(self) -> Dict[str, Any]: + """ + Execute the pick and place operation. + + Returns: + Dictionary with operation results + """ + super().__call__() + + if not self._robot: + error_msg = "No robot instance provided to PickAndPlace skill" + logger.error(error_msg) + return {"success": False, "error": error_msg} + + # Register skill as running + skill_library = self._robot.get_skills() + self.register_as_running("PickAndPlace", skill_library) + + # Get camera frame + frame = self._get_camera_frame() + if frame is None: + return {"success": False, "error": "Failed to capture camera frame"} + + # Get pick and place points from Qwen + pick_point = None + place_point = None + + # Determine mode based on whether target_query is provided + if self.target_query is None: + # Pick only mode + logger.info("Pick-only mode (no target specified)") + + # Query for pick point + pick_point = self._query_single_point(frame, self.object_query, "pick") + if not pick_point: + return {"success": False, "error": f"Failed to find {self.object_query}"} + + # No place point needed for pick-only + place_point = None + else: + # Pick and place mode - can use either single or dual query + logger.info("Pick and place mode (target specified)") + + # Try single query first for efficiency + points = self._query_pick_and_place_points(frame) + pick_point, place_point = points + + logger.info(f"Pick point: {pick_point}, Place point: {place_point}") + + # Save debug image with marked points + if pick_point or place_point: + save_debug_image_with_points(frame, pick_point, place_point) + + # Execute pick (and optionally place) using the robot's interface + try: + if place_point: + # Pick and place + result = self._robot.pick_and_place( + pick_x=pick_point[0], + pick_y=pick_point[1], + place_x=place_point[0], + place_y=place_point[1], + ) + else: + # Pick only + result = self._robot.pick_and_place( + pick_x=pick_point[0], pick_y=pick_point[1], place_x=None, place_y=None + ) + + if result: + if self.target_query: + message = ( + f"Successfully picked {self.object_query} and placed it {self.target_query}" + ) + else: + message = f"Successfully picked {self.object_query}" + + return { + "success": True, + "pick_point": pick_point, + "place_point": place_point, + "object": self.object_query, + "target": self.target_query, + "message": message, + } + else: + operation = "Pick and place" if self.target_query else "Pick" + return { + "success": False, + "pick_point": pick_point, + "place_point": place_point, + "error": f"{operation} operation failed", + } + + except Exception as e: + logger.error(f"Error executing pick and place: {e}") + return { + "success": False, + "error": f"Execution error: {str(e)}", + "pick_point": pick_point, + "place_point": place_point, + } + finally: + # Always unregister skill when done + self.stop() + + def stop(self) -> None: + """ + Stop the pick and place operation and perform cleanup. + """ + logger.info("Stopping PickAndPlace skill") + + # Unregister skill from skill library + if self._robot: + skill_library = self._robot.get_skills() + self.unregister_as_running("PickAndPlace", skill_library) + + logger.info("PickAndPlace skill stopped successfully") diff --git a/tests/test_pick_and_place_module.py b/tests/test_pick_and_place_module.py index 53f39be74e..6a8470863e 100644 --- a/tests/test_pick_and_place_module.py +++ b/tests/test_pick_and_place_module.py @@ -14,7 +14,7 @@ # limitations under the License. """ -Test script for pick and place manipulation module. +Run script for Piper Arm robot with pick and place functionality. Subscribes to visualization images and handles mouse/keyboard input. """ @@ -32,18 +32,14 @@ print("Error: ZED SDK not installed.") sys.exit(1) -from dimos import core -from dimos.hardware.zed_camera import ZEDModule -from dimos.manipulation.visual_servoing.manipulation_module import ManipulationModule -from dimos.protocol import pubsub +from dimos.robot.agilex.piper_arm import PiperArmRobot from dimos.utils.logging_config import setup_logger # Import LCM message types -from dimos_lcm.sensor_msgs import Image as LCMImage -from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.sensor_msgs import Image from dimos.protocol.pubsub.lcmpubsub import LCM, Topic -logger = setup_logger("test_pick_and_place_module") +logger = setup_logger("dimos.tests.test_pick_and_place_module") # Global for mouse events mouse_click = None @@ -67,16 +63,16 @@ def mouse_callback(event, x, y, _flags, param): class VisualizationNode: """Node that subscribes to visualization images and handles user input.""" - def __init__(self, manipulation_module): + def __init__(self, robot: PiperArmRobot): self.lcm = LCM() self.latest_viz = None self.latest_camera = None self._running = False - self.manipulation = manipulation_module + self.robot = robot # Subscribe to visualization topic - self.viz_topic = Topic("/manipulation/viz", LCMImage) - self.camera_topic = Topic("/zed/color_image", LCMImage) + self.viz_topic = Topic("/manipulation/viz", Image) + self.camera_topic = Topic("/zed/color_image", Image) def start(self): """Start the visualization node.""" @@ -95,7 +91,7 @@ def stop(self): self._running = False cv2.destroyAllWindows() - def _on_viz_image(self, msg: LCMImage, topic: str): + def _on_viz_image(self, msg: Image, topic: str): """Handle visualization image messages.""" try: # Convert LCM message to numpy array @@ -108,7 +104,7 @@ def _on_viz_image(self, msg: LCMImage, topic: str): except Exception as e: logger.error(f"Error processing viz image: {e}") - def _on_camera_image(self, msg: LCMImage, topic: str): + def _on_camera_image(self, msg: Image, topic: str): """Handle camera image messages.""" try: # Convert LCM message to numpy array @@ -132,7 +128,7 @@ def run_visualization(self): cv2.namedWindow("Camera Feed") cv2.setMouseCallback("Camera Feed", mouse_callback, "Camera Feed") - print("=== Pick and Place Module Test ===") + print("=== Piper Arm Robot - Pick and Place ===") print("Control mode: Module-based with LCM communication") print("\nPICK AND PLACE WORKFLOW:") print("1. Click on an object to select PICK location") @@ -232,15 +228,15 @@ def run_visualization(self): place_location = None place_mode = False logger.info("Reset pick and place selections") - # Also send reset to manipulation module - action = self.manipulation.handle_keyboard_command("r") + # Also send reset to robot + action = self.robot.handle_keyboard_command("r") if action: logger.info(f"Action: {action}") elif key == ord("p"): # Execute pick-only task if pick location is set if pick_location is not None: logger.info(f"Executing pick-only task at {pick_location}") - result = self.manipulation.pick_and_place( + result = self.robot.pick_and_place( pick_location[0], pick_location[1], None, # No place location @@ -253,11 +249,11 @@ def run_visualization(self): else: logger.warning("Please select a pick location first!") else: - # Send keyboard command to manipulation module + # Send keyboard command to robot if key in [82, 84]: # Arrow keys - action = self.manipulation.handle_keyboard_command(str(key)) + action = self.robot.handle_keyboard_command(str(key)) else: - action = self.manipulation.handle_keyboard_command(chr(key)) + action = self.robot.handle_keyboard_command(chr(key)) if action: logger.info(f"Action: {action}") @@ -276,9 +272,7 @@ def run_visualization(self): logger.info(f"Executing pick at {pick_location} and place at ({x}, {y})") # Start pick and place task with both locations - result = self.manipulation.pick_and_place( - pick_location[0], pick_location[1], x, y - ) + result = self.robot.pick_and_place(pick_location[0], pick_location[1], x, y) logger.info(f"Pick and place task started: {result}") # Clear all points after sending mission @@ -303,9 +297,7 @@ def run_visualization(self): logger.info(f"Executing pick at {pick_location} and place at ({x}, {y})") # Start pick and place task with both locations - result = self.manipulation.pick_and_place( - pick_location[0], pick_location[1], x, y - ) + result = self.robot.pick_and_place(pick_location[0], pick_location[1], x, y) logger.info(f"Pick and place task started: {result}") # Clear all points after sending mission @@ -317,64 +309,22 @@ def run_visualization(self): time.sleep(0.03) # ~30 FPS -async def test_pick_and_place_module(): - """Test the pick and place manipulation module.""" - logger.info("Starting Pick and Place Module test") - - # Start Dask - dimos = core.start(2) # Need 2 workers for ZED and manipulation modules +async def run_piper_arm_with_viz(): + """Run the Piper Arm robot with visualization.""" + logger.info("Starting Piper Arm Robot") - # Enable LCM auto-configuration - pubsub.lcm.autoconf() + # Create robot instance + robot = PiperArmRobot() try: - # Deploy ZED module - logger.info("Deploying ZED module...") - zed = dimos.deploy( - ZEDModule, - camera_id=0, - resolution="HD720", - depth_mode="NEURAL", - fps=30, - enable_tracking=False, # We don't need tracking for manipulation - publish_rate=30.0, - frame_id="zed_camera", - ) - - # Configure ZED LCM transports - zed.color_image.transport = core.LCMTransport("/zed/color_image", LCMImage) - zed.depth_image.transport = core.LCMTransport("/zed/depth_image", LCMImage) - zed.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) - - # Deploy manipulation module - logger.info("Deploying manipulation module...") - manipulation = dimos.deploy(ManipulationModule) - - # Connect manipulation inputs to ZED outputs - manipulation.rgb_image.connect(zed.color_image) - manipulation.depth_image.connect(zed.depth_image) - manipulation.camera_info.connect(zed.camera_info) - - # Configure manipulation output - manipulation.viz_image.transport = core.LCMTransport("/manipulation/viz", LCMImage) - - # Print module info - logger.info("Modules configured:") - print("\nZED Module:") - print(zed.io().result()) - print("\nManipulation Module:") - print(manipulation.io().result()) - - # Start modules - logger.info("Starting modules...") - zed.start() - manipulation.start() - - # Give modules time to initialize + # Start the robot + await robot.start() + + # Give modules time to fully initialize await asyncio.sleep(2) # Create and start visualization node - viz_node = VisualizationNode(manipulation) + viz_node = VisualizationNode(robot) viz_node.start() # Run visualization in separate thread @@ -385,26 +335,21 @@ async def test_pick_and_place_module(): while viz_node._running: await asyncio.sleep(0.1) - # Stop modules - logger.info("Stopping modules...") - manipulation.stop() - zed.stop() - # Stop visualization viz_node.stop() except Exception as e: - logger.error(f"Error in test: {e}") + logger.error(f"Error running robot: {e}") import traceback traceback.print_exc() finally: # Clean up - dimos.close() - logger.info("Test completed") + robot.stop() + logger.info("Robot stopped") if __name__ == "__main__": - # Run the test - asyncio.run(test_pick_and_place_module()) + # Run the robot + asyncio.run(run_piper_arm_with_viz()) diff --git a/tests/test_pick_and_place_skill.py b/tests/test_pick_and_place_skill.py new file mode 100644 index 0000000000..40cf2c23b0 --- /dev/null +++ b/tests/test_pick_and_place_skill.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Piper Arm robot with pick and place functionality. +Uses hardcoded points and the PickAndPlace skill. +""" + +import sys +import asyncio + +try: + import pyzed.sl as sl # Required for ZED camera +except ImportError: + print("Error: ZED SDK not installed.") + sys.exit(1) + +from dimos.robot.agilex.piper_arm import PiperArmRobot +from dimos.skills.manipulation.pick_and_place import PickAndPlace +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.agilex.run_robot") + + +async def run_piper_arm(): + """Run the Piper Arm robot with pick and place skill.""" + logger.info("Starting Piper Arm Robot") + + # Create robot instance + robot = PiperArmRobot() + + try: + # Start the robot + await robot.start() + + # Give modules time to fully initialize + await asyncio.sleep(3) + + # Add the PickAndPlace skill to the robot's skill library + robot.skill_library.add(PickAndPlace) + + logger.info("Robot initialized successfully") + print("\n=== Piper Arm Robot - Pick and Place Demo ===") + print("This demo uses hardcoded pick and place points.") + print("\nCommands:") + print(" 1. Run pick and place with hardcoded points") + print(" 2. Run pick-only with hardcoded point") + print(" r. Reset robot to idle") + print(" q. Quit") + print("") + + running = True + while running: + try: + # Get user input + command = input("\nEnter command: ").strip().lower() + + if command == "q": + logger.info("Quit requested") + running = False + break + + elif command == "r" or command == "s": + logger.info("Resetting robot") + robot.handle_keyboard_command(command) + + elif command == "1": + # Hardcoded pick and place points + # These should be adjusted based on your camera view + print("\nExecuting pick and place with hardcoded points...") + + # Create and execute the skill + skill = PickAndPlace( + robot=robot, + object_query="labubu doll", # Will use visual detection + target_query="on the keyboard", # Will use visual detection + ) + + result = skill() + + if result["success"]: + print(f"✓ {result['message']}") + else: + print(f"✗ Failed: {result.get('error', 'Unknown error')}") + + elif command == "2": + # Pick-only with hardcoded point + print("\nExecuting pick-only with hardcoded point...") + + # Create and execute the skill for pick-only + skill = PickAndPlace( + robot=robot, + object_query="labubu doll", # Will use visual detection + target_query=None, # No place target - pick only + ) + + result = skill() + + if result["success"]: + print(f"✓ {result['message']}") + else: + print(f"✗ Failed: {result.get('error', 'Unknown error')}") + + else: + print("Invalid command. Please try again.") + + # Small delay to prevent CPU spinning + await asyncio.sleep(0.1) + + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + running = False + break + except Exception as e: + logger.error(f"Error in command loop: {e}") + print(f"Error: {e}") + + except Exception as e: + logger.error(f"Error running robot: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + logger.info("Shutting down robot...") + await robot.stop() + logger.info("Robot stopped") + + +def main(): + """Main entry point.""" + print("Starting Piper Arm Robot...") + print("Note: The robot will use Qwen VLM to identify objects and locations") + print("based on the queries specified in the code.") + + # Run the robot + asyncio.run(run_piper_arm()) + + +if __name__ == "__main__": + main() From 31860e1a744e2bf489f355442c1d9ff616142f7e Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Fri, 25 Jul 2025 12:17:42 -0700 Subject: [PATCH 82/89] bug fix, increased failure tolerance time --- dimos/manipulation/visual_servoing/manipulation_module.py | 2 +- dimos/manipulation/visual_servoing/pbvs.py | 2 +- dimos/skills/manipulation/pick_and_place.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index 65c440d24b..244b9bbe77 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -171,7 +171,7 @@ def __init__( # Pose stabilization tracking self.pose_history_size = 4 self.pose_stabilization_threshold = 0.01 - self.stabilization_timeout = 15.0 + self.stabilization_timeout = 25.0 self.stabilization_start_time = None self.reached_poses = deque(maxlen=self.pose_history_size) self.adjustment_count = 0 diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index aef95066e5..b0f87e5c73 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -55,7 +55,7 @@ def __init__( max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s target_tolerance: float = 0.01, # 1cm - max_tracking_distance_threshold: float = 0.15, # Max distance for target tracking (m) + max_tracking_distance_threshold: float = 0.12, # Max distance for target tracking (m) min_size_similarity: float = 0.6, # Min size similarity threshold (0.0-1.0) direct_ee_control: bool = True, # If True, output target poses instead of velocities ): diff --git a/dimos/skills/manipulation/pick_and_place.py b/dimos/skills/manipulation/pick_and_place.py index 51bed5240e..4306975d8d 100644 --- a/dimos/skills/manipulation/pick_and_place.py +++ b/dimos/skills/manipulation/pick_and_place.py @@ -94,10 +94,6 @@ def save_debug_image_with_points( # Create a copy to avoid modifying original debug_image = image.copy() - # Convert RGB to BGR for OpenCV if needed - if len(debug_image.shape) == 3 and debug_image.shape[2] == 3: - debug_image = cv2.cvtColor(debug_image, cv2.COLOR_RGB2BGR) - # Draw pick point crosshair (green) if pick_point: x, y = pick_point @@ -342,6 +338,10 @@ def __call__(self) -> Dict[str, Any]: if frame is None: return {"success": False, "error": "Failed to capture camera frame"} + # Convert RGB to BGR for OpenCV if needed + if len(frame.shape) == 3 and frame.shape[2] == 3: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + # Get pick and place points from Qwen pick_point = None place_point = None From 9a31c8f1d14e582e6a2533a34df3c3cf160a1262 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Fri, 25 Jul 2025 14:56:03 -0700 Subject: [PATCH 83/89] fixed most of Stash's comments --- dimos/hardware/piper_arm.py | 15 ++------------- dimos/manipulation/visual_servoing/detection3d.py | 7 +------ .../visual_servoing/manipulation_module.py | 4 +--- dimos/manipulation/visual_servoing/pbvs.py | 1 - dimos/manipulation/visual_servoing/utils.py | 10 ---------- dimos/perception/segmentation/sam_2d_seg.py | 2 -- dimos/robot/agilex/piper_arm.py | 2 +- 7 files changed, 5 insertions(+), 36 deletions(-) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 9921c53c8a..774f70b1c6 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -45,9 +45,9 @@ class PiperArm: def __init__(self, arm_name: str = "arm"): - self.init_can() self.arm = C_PiperInterface_V2() self.arm.ConnectPort() + self.resetArm() time.sleep(0.5) self.resetArm() time.sleep(0.5) @@ -57,17 +57,6 @@ def __init__(self, arm_name: str = "arm"): time.sleep(1) self.init_vel_controller() - def init_can(self): - result = subprocess.run( - [ - "bash", - "dimos/hardware/can_activate.sh", - ], # pass the script path directly if it has a shebang and execute perms - stdout=subprocess.PIPE, # capture stdout - stderr=subprocess.PIPE, # capture stderr - text=True, # return strings instead of bytes - ) - def enable(self): while not self.arm.EnablePiper(): pass @@ -250,7 +239,7 @@ def gripper_object_detected(self, commanded_effort: float = 0.25) -> bool: def resetArm(self): self.arm.MotionCtrl_1(0x02, 0, 0) - self.arm.MotionCtrl_2(0, 0, 0, 0xAD) + self.arm.MotionCtrl_2(0, 0, 0, 0x00) logger.info("Resetting arm") def init_vel_controller(self): diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index 9eaf48d774..887fd023ab 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -46,7 +46,7 @@ transform_pose, ) -logger = setup_logger("dimos.perception.detection3d") +logger = setup_logger("dimos.manipulation.visual_servoing.detection3d") class Detection3DProcessor: @@ -84,10 +84,8 @@ def __init__( use_tracker=False, use_analyzer=False, use_filtering=True, - device="cuda" if cv2.cuda.getCudaEnabledDeviceCount() > 0 else "cpu", ) - # Store confidence threshold for filtering self.min_confidence = min_confidence logger.info( @@ -116,7 +114,6 @@ def process_frame( # Run Sam segmentation with tracking masks, bboxes, track_ids, probs, names = self.detector.process_image(bgr_image) - # Early exit if no detections if not masks or len(masks) == 0: return Detection3DArray( detections_length=0, header=Header(), detections=[] @@ -138,13 +135,11 @@ def process_frame( camera_intrinsics=self.camera_intrinsics, ) - # Build detection results detections_3d = [] detections_2d = [] pose_dict = {p["mask_idx"]: p for p in poses if p["centroid"][2] < self.max_depth} for i, (bbox, name, prob, track_id) in enumerate(zip(bboxes, names, probs, track_ids)): - # Skip if no 3D pose data if i not in pose_dict: continue diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index 244b9bbe77..b2724fb59a 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -51,7 +51,7 @@ ) from dimos.utils.logging_config import setup_logger -logger = setup_logger("dimos.manipulation.manipulation_module") +logger = setup_logger("dimos.manipulation.visual_servoing.manipulation_module") class GraspStage(Enum): @@ -234,7 +234,6 @@ def stop(self): if self.task_thread and self.task_thread.is_alive(): self.task_thread.join(timeout=5.0) - # Reset to idle self.reset_to_idle() logger.info("Manipulation module stopped") @@ -536,7 +535,6 @@ def check_reach_and_adjust(self) -> bool: if timed_out: return False - # Add current pose to history self.ee_pose_history.append(ee_pose) # Check if robot is stuck diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index b0f87e5c73..e34ec94557 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -334,7 +334,6 @@ def __init__( self.max_angular_velocity = max_angular_velocity self.target_tolerance = target_tolerance - # State variables for visualization self.last_position_error = None self.last_rotation_error = None self.last_velocity_cmd = None diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index f475cd70c8..4546326ef6 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -73,7 +73,6 @@ def transform_pose( Returns: Object pose in desired frame as Pose """ - # Create object pose from input # Convert euler angles to quaternion using utility function euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2]) obj_orientation_quat = euler_to_quaternion(euler_vector) @@ -130,7 +129,6 @@ def transform_points_3d( if points_3d.size == 0: return np.zeros((0, 3), dtype=np.float32) - # Ensure points_3d is the right shape points_3d = np.asarray(points_3d) if points_3d.ndim == 1: points_3d = points_3d.reshape(1, -1) @@ -138,7 +136,6 @@ def transform_points_3d( transformed_points = [] for point in points_3d: - # Create pose with identity orientation for each point input_point_pose = Pose( Point(point[0], point[1], point[2]), Quaternion(0.0, 0.0, 0.0, 1.0), # Identity quaternion @@ -200,7 +197,6 @@ def select_points_from_depth( x_target, y_target = target_point height, width = depth_image.shape - # Define bounding box around target point x_min = max(0, x_target - radius) x_max = min(width, x_target + radius) y_min = max(0, y_target - radius) @@ -216,18 +212,14 @@ def select_points_from_depth( # Extract corresponding depth values using advanced indexing depth_flat = depth_image[y_flat, x_flat] - # Create mask for valid depth values valid_mask = (depth_flat > 0) & np.isfinite(depth_flat) - # Early exit if no valid points if not np.any(valid_mask): return np.zeros((0, 3), dtype=np.float32) - # Filter to get valid points and depths points_2d = np.column_stack([x_flat[valid_mask], y_flat[valid_mask]]).astype(np.float32) depth_values = depth_flat[valid_mask].astype(np.float32) - # Use the common utility function for 3D projection points_3d = project_2d_points_to_3d(points_2d, depth_values, camera_intrinsics) return points_3d @@ -249,7 +241,6 @@ def update_target_grasp_pose( Target grasp pose or None if target is invalid """ - # Get target position target_pos = target_pose.position # Calculate orientation pointing from target towards EE @@ -267,7 +258,6 @@ def update_target_grasp_pose( updated_pose = Pose(target_pos, target_orientation) if grasp_distance > 0.0: - # Apply grasp distance return apply_grasp_distance(updated_pose, grasp_distance) else: return updated_pose diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index 462342872b..cb2acaf076 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -42,14 +42,12 @@ def __init__( self, model_path="models_fastsam", model_name="FastSAM-s.onnx", - device="cpu", min_analysis_interval=5.0, use_tracker=True, use_analyzer=True, use_rich_labeling=False, use_filtering=True, ): - self.device = device if is_cuda_available(): logger.info("Using CUDA for SAM 2d segmenter") if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18 diff --git a/dimos/robot/agilex/piper_arm.py b/dimos/robot/agilex/piper_arm.py index 1bb43a76d8..63dc419a78 100644 --- a/dimos/robot/agilex/piper_arm.py +++ b/dimos/robot/agilex/piper_arm.py @@ -29,7 +29,7 @@ # Import LCM message types from dimos_lcm.sensor_msgs import CameraInfo -logger = setup_logger("dimos.robot.agilex.piper_arm", level=logging.INFO) +logger = setup_logger("dimos.robot.agilex.piper_arm") class PiperArmRobot: From e5abd8b5429c098584a5c1af9e36b8bcac48639a Mon Sep 17 00:00:00 2001 From: stash Date: Fri, 25 Jul 2025 16:59:50 -0700 Subject: [PATCH 84/89] Added gpu pytest tag to addopts --- dimos-lcm | 1 + 1 file changed, 1 insertion(+) create mode 160000 dimos-lcm diff --git a/dimos-lcm b/dimos-lcm new file mode 160000 index 0000000000..61e0b1893c --- /dev/null +++ b/dimos-lcm @@ -0,0 +1 @@ +Subproject commit 61e0b1893c14074893aad7dc07790948b2e6b3b3 From 691b64e9c63bd583913f6ab74970bf15ab9801d5 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Sat, 26 Jul 2025 02:13:15 -0700 Subject: [PATCH 85/89] added chinese readme --- dimos/robot/agilex/README.md | 371 +++++++++++++++++++++++++ dimos/robot/agilex/README_CN.md | 465 ++++++++++++++++++++++++++++++++ docs/modules_CN.md | 188 +++++++++++++ 3 files changed, 1024 insertions(+) create mode 100644 dimos/robot/agilex/README.md create mode 100644 dimos/robot/agilex/README_CN.md create mode 100644 docs/modules_CN.md diff --git a/dimos/robot/agilex/README.md b/dimos/robot/agilex/README.md new file mode 100644 index 0000000000..1e678cae65 --- /dev/null +++ b/dimos/robot/agilex/README.md @@ -0,0 +1,371 @@ +# DIMOS Manipulator Robot Development Guide + +This guide explains how to create robot classes, integrate agents, and use the DIMOS module system with LCM transport. + +## Table of Contents +1. [Robot Class Architecture](#robot-class-architecture) +2. [Module System & LCM Transport](#module-system--lcm-transport) +3. [Agent Integration](#agent-integration) +4. [Complete Example](#complete-example) + +## Robot Class Architecture + +### Basic Robot Class Structure + +A DIMOS robot class should follow this pattern: + +```python +from typing import Optional, List +from dimos import core +from dimos.types.robot_capabilities import RobotCapability + +class YourRobot: + """Your robot implementation.""" + + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + # Core components + self.dimos = None + self.modules = {} + self.skill_library = SkillLibrary() + + # Define capabilities + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + """Start the robot modules.""" + # Initialize DIMOS with worker count + self.dimos = core.start(2) # Number of workers needed + + # Deploy modules + # ... (see Module System section) + + def stop(self): + """Stop all modules and clean up.""" + # Stop modules + # Close DIMOS + if self.dimos: + self.dimos.close() +``` + +### Key Components Explained + +1. **Initialization**: Store references to modules, skills, and capabilities +2. **Async Start**: Modules must be deployed asynchronously +3. **Proper Cleanup**: Always stop modules before closing DIMOS + +## Module System & LCM Transport + +### Understanding DIMOS Modules + +Modules are the building blocks of DIMOS robots. They: +- Process data streams (inputs) +- Produce outputs +- Can be connected together +- Communicate via LCM (Lightweight Communications and Marshalling) + +### Deploying a Module + +```python +# Deploy a camera module +self.camera = self.dimos.deploy( + ZEDModule, # Module class + camera_id=0, # Module parameters + resolution="HD720", + depth_mode="NEURAL", + fps=30, + publish_rate=30.0, + frame_id="camera_frame" +) +``` + +### Setting Up LCM Transport + +LCM transport enables inter-module communication: + +```python +# Enable LCM auto-configuration +from dimos.protocol import pubsub +pubsub.lcm.autoconf() + +# Configure output transport +self.camera.color_image.transport = core.LCMTransport( + "/camera/color_image", # Topic name + Image # Message type +) +self.camera.depth_image.transport = core.LCMTransport( + "/camera/depth_image", + Image +) +``` + +### Connecting Modules + +Connect module outputs to inputs: + +```python +# Connect manipulation module to camera outputs +self.manipulation.rgb_image.connect(self.camera.color_image) +self.manipulation.depth_image.connect(self.camera.depth_image) +self.manipulation.camera_info.connect(self.camera.camera_info) +``` + +### Module Communication Pattern + +``` +┌──────────────┐ LCM ┌────────────────┐ LCM ┌──────────────┐ +│ Camera │────────▶│ Manipulation │────────▶│ Visualization│ +│ Module │ Messages│ Module │ Messages│ Output │ +└──────────────┘ └────────────────┘ └──────────────┘ + ▲ ▲ + │ │ + └──────────────────────────┘ + Direct Connection via RPC call +``` + +## Agent Integration + +### Setting Up Agent with Robot + +The run file pattern for agent integration: + +```python +#!/usr/bin/env python3 +import asyncio +import reactivex as rx +from dimos.agents.claude_agent import ClaudeAgent +from dimos.web.robot_web_interface import RobotWebInterface + +def main(): + # 1. Create and start robot + robot = YourRobot() + asyncio.run(robot.start()) + + # 2. Set up skills + skills = robot.get_skills() + skills.add(YourSkill) + skills.create_instance("YourSkill", robot=robot) + + # 3. Set up reactive streams + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + # 4. Create web interface + web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": agent_response_stream}, + audio_subject=rx.subject.Subject() + ) + + # 5. Create agent + agent = ClaudeAgent( + dev_name="your_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query="Your system prompt here", + model_name="claude-3-5-haiku-latest" + ) + + # 6. Connect agent responses + agent.get_response_observable().subscribe( + lambda x: agent_response_subject.on_next(x) + ) + + # 7. Run interface + web_interface.run() +``` + +### Key Integration Points + +1. **Reactive Streams**: Use RxPy for event-driven communication +2. **Web Interface**: Provides user input/output +3. **Agent**: Processes natural language and executes skills +4. **Skills**: Define robot capabilities as executable actions + +## Complete Example + +### Step 1: Create Robot Class (`my_robot.py`) + +```python +import asyncio +from typing import Optional, List +from dimos import core +from dimos.hardware.camera import CameraModule +from dimos.manipulation.module import ManipulationModule +from dimos.skills.skills import SkillLibrary +from dimos.types.robot_capabilities import RobotCapability +from dimos_lcm.sensor_msgs import Image, CameraInfo +from dimos.protocol import pubsub + +class MyRobot: + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + self.dimos = None + self.camera = None + self.manipulation = None + self.skill_library = SkillLibrary() + + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + # Start DIMOS + self.dimos = core.start(2) + + # Enable LCM + pubsub.lcm.autoconf() + + # Deploy camera + self.camera = self.dimos.deploy( + CameraModule, + camera_id=0, + fps=30 + ) + + # Configure camera LCM + self.camera.color_image.transport = core.LCMTransport("/camera/rgb", Image) + self.camera.depth_image.transport = core.LCMTransport("/camera/depth", Image) + self.camera.camera_info.transport = core.LCMTransport("/camera/info", CameraInfo) + + # Deploy manipulation + self.manipulation = self.dimos.deploy(ManipulationModule) + + # Connect modules + self.manipulation.rgb_image.connect(self.camera.color_image) + self.manipulation.depth_image.connect(self.camera.depth_image) + self.manipulation.camera_info.connect(self.camera.camera_info) + + # Configure manipulation output + self.manipulation.viz_image.transport = core.LCMTransport("/viz/output", Image) + + # Start modules + self.camera.start() + self.manipulation.start() + + await asyncio.sleep(2) # Allow initialization + + def get_skills(self): + return self.skill_library + + def stop(self): + if self.manipulation: + self.manipulation.stop() + if self.camera: + self.camera.stop() + if self.dimos: + self.dimos.close() +``` + +### Step 2: Create Run Script (`run.py`) + +```python +#!/usr/bin/env python3 +import asyncio +import os +from my_robot import MyRobot +from dimos.agents.claude_agent import ClaudeAgent +from dimos.skills.basic import BasicSkill +from dimos.web.robot_web_interface import RobotWebInterface +import reactivex as rx +import reactivex.operators as ops + +SYSTEM_PROMPT = """You are a helpful robot assistant.""" + +def main(): + # Check API key + if not os.getenv("ANTHROPIC_API_KEY"): + print("Please set ANTHROPIC_API_KEY") + return + + # Create robot + robot = MyRobot() + + try: + # Start robot + asyncio.run(robot.start()) + + # Set up skills + skills = robot.get_skills() + skills.add(BasicSkill) + skills.create_instance("BasicSkill", robot=robot) + + # Set up streams + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + # Create web interface + web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": agent_response_stream} + ) + + # Create agent + agent = ClaudeAgent( + dev_name="my_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query=SYSTEM_PROMPT + ) + + # Connect responses + agent.get_response_observable().subscribe( + lambda x: agent_response_subject.on_next(x) + ) + + print("Robot ready at http://localhost:5555") + + # Run + web_interface.run() + + finally: + robot.stop() + +if __name__ == "__main__": + main() +``` + +### Step 3: Define Skills (`skills.py`) + +```python +from dimos.skills import Skill, skill + +@skill( + description="Perform a basic action", + parameters={ + "action": "The action to perform" + } +) +class BasicSkill(Skill): + def __init__(self, robot): + self.robot = robot + + def run(self, action: str): + # Implement skill logic + return f"Performed: {action}" +``` + +## Best Practices + +1. **Module Lifecycle**: Always start DIMOS before deploying modules +2. **LCM Topics**: Use descriptive topic names with namespaces +3. **Error Handling**: Wrap module operations in try-except blocks +4. **Resource Cleanup**: Ensure proper cleanup in stop() methods +5. **Async Operations**: Use asyncio for non-blocking operations +6. **Stream Management**: Use RxPy for reactive programming patterns + +## Debugging Tips + +1. **Check Module Status**: Print module.io().result() to see connections +2. **Monitor LCM**: Use Foxglove to visualize LCM messages +3. **Log Everything**: Use dimos.utils.logging_config.setup_logger() +4. **Test Modules Independently**: Deploy and test one module at a time + +## Common Issues + +1. **"Module not started"**: Ensure start() is called after deployment +2. **"No data received"**: Check LCM transport configuration +3. **"Connection failed"**: Verify input/output types match +4. **"Cleanup errors"**: Stop modules before closing DIMOS \ No newline at end of file diff --git a/dimos/robot/agilex/README_CN.md b/dimos/robot/agilex/README_CN.md new file mode 100644 index 0000000000..482a09dd6d --- /dev/null +++ b/dimos/robot/agilex/README_CN.md @@ -0,0 +1,465 @@ +# DIMOS 机械臂机器人开发指南 + +本指南介绍如何创建机器人类、集成智能体(Agent)以及使用 DIMOS 模块系统和 LCM 传输。 + +## 目录 +1. [机器人类架构](#机器人类架构) +2. [模块系统与 LCM 传输](#模块系统与-lcm-传输) +3. [智能体集成](#智能体集成) +4. [完整示例](#完整示例) + +## 机器人类架构 + +### 基本机器人类结构 + +DIMOS 机器人类应遵循以下模式: + +```python +from typing import Optional, List +from dimos import core +from dimos.types.robot_capabilities import RobotCapability + +class YourRobot: + """您的机器人实现。""" + + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + # 核心组件 + self.dimos = None + self.modules = {} + self.skill_library = SkillLibrary() + + # 定义能力 + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + """启动机器人模块。""" + # 初始化 DIMOS,指定工作线程数 + self.dimos = core.start(2) # 需要的工作线程数 + + # 部署模块 + # ... (参见模块系统章节) + + def stop(self): + """停止所有模块并清理资源。""" + # 停止模块 + # 关闭 DIMOS + if self.dimos: + self.dimos.close() +``` + +### 关键组件说明 + +1. **初始化**:存储模块、技能和能力的引用 +2. **异步启动**:模块必须异步部署 +3. **正确清理**:在关闭 DIMOS 之前始终停止模块 + +## 模块系统与 LCM 传输 + +### 理解 DIMOS 模块 + +模块是 DIMOS 机器人的构建块。它们: +- 处理数据流(输入) +- 产生输出 +- 可以相互连接 +- 通过 LCM(轻量级通信和编组)进行通信 + +### 部署模块 + +```python +# 部署相机模块 +self.camera = self.dimos.deploy( + ZEDModule, # 模块类 + camera_id=0, # 模块参数 + resolution="HD720", + depth_mode="NEURAL", + fps=30, + publish_rate=30.0, + frame_id="camera_frame" +) +``` + +### 设置 LCM 传输 + +LCM 传输实现模块间通信: + +```python +# 启用 LCM 自动配置 +from dimos.protocol import pubsub +pubsub.lcm.autoconf() + +# 配置输出传输 +self.camera.color_image.transport = core.LCMTransport( + "/camera/color_image", # 主题名称 + Image # 消息类型 +) +self.camera.depth_image.transport = core.LCMTransport( + "/camera/depth_image", + Image +) +``` + +### 连接模块 + +将模块输出连接到输入: + +```python +# 将操作模块连接到相机输出 +self.manipulation.rgb_image.connect(self.camera.color_image) # ROS set_callback +self.manipulation.depth_image.connect(self.camera.depth_image) +self.manipulation.camera_info.connect(self.camera.camera_info) +``` + +### 模块通信模式 + +``` +┌──────────────┐ LCM ┌────────────────┐ LCM ┌──────────────┐ +│ 相机模块 │────────▶│ 操作模块 │────────▶│ 可视化输出 │ +│ │ 消息 │ │ 消息 │ │ +└──────────────┘ └────────────────┘ └──────────────┘ + ▲ ▲ + │ │ + └──────────────────────────┘ + 直接连接(RPC指令) +``` + +## 智能体集成 + +### 设置智能体与机器人 + +运行文件的智能体集成模式: + +```python +#!/usr/bin/env python3 +import asyncio +import reactivex as rx +from dimos.agents.claude_agent import ClaudeAgent +from dimos.web.robot_web_interface import RobotWebInterface + +def main(): + # 1. 创建并启动机器人 + robot = YourRobot() + asyncio.run(robot.start()) + + # 2. 设置技能 + skills = robot.get_skills() + skills.add(YourSkill) + skills.create_instance("YourSkill", robot=robot) + + # 3. 设置响应式流 + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + # 4. 创建 Web 界面 + web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": agent_response_stream}, + audio_subject=rx.subject.Subject() + ) + + # 5. 创建智能体 + agent = ClaudeAgent( + dev_name="your_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query="您的系统提示词", + model_name="claude-3-5-haiku-latest" + ) + + # 6. 连接智能体响应 + agent.get_response_observable().subscribe( + lambda x: agent_response_subject.on_next(x) + ) + + # 7. 运行界面 + web_interface.run() +``` + +### 关键集成点 + +1. **响应式流**:使用 RxPy 进行事件驱动通信 +2. **Web 界面**:提供用户输入/输出 +3. **智能体**:处理自然语言并执行技能 +4. **技能**:将机器人能力定义为可执行动作 + +## 完整示例 + +### 步骤 1:创建机器人类(`my_robot.py`) + +```python +import asyncio +from typing import Optional, List +from dimos import core +from dimos.hardware.camera import CameraModule +from dimos.manipulation.module import ManipulationModule +from dimos.skills.skills import SkillLibrary +from dimos.types.robot_capabilities import RobotCapability +from dimos_lcm.sensor_msgs import Image, CameraInfo +from dimos.protocol import pubsub + +class MyRobot: + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + self.dimos = None + self.camera = None + self.manipulation = None + self.skill_library = SkillLibrary() + + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + # 启动 DIMOS + self.dimos = core.start(2) + + # 启用 LCM + pubsub.lcm.autoconf() + + # 部署相机 + self.camera = self.dimos.deploy( + CameraModule, + camera_id=0, + fps=30 + ) + + # 配置相机 LCM + self.camera.color_image.transport = core.LCMTransport("/camera/rgb", Image) + self.camera.depth_image.transport = core.LCMTransport("/camera/depth", Image) + self.camera.camera_info.transport = core.LCMTransport("/camera/info", CameraInfo) + + # 部署操作模块 + self.manipulation = self.dimos.deploy(ManipulationModule) + + # 连接模块 + self.manipulation.rgb_image.connect(self.camera.color_image) + self.manipulation.depth_image.connect(self.camera.depth_image) + self.manipulation.camera_info.connect(self.camera.camera_info) + + # 配置操作输出 + self.manipulation.viz_image.transport = core.LCMTransport("/viz/output", Image) + + # 启动模块 + self.camera.start() + self.manipulation.start() + + await asyncio.sleep(2) # 允许初始化 + + def get_skills(self): + return self.skill_library + + def stop(self): + if self.manipulation: + self.manipulation.stop() + if self.camera: + self.camera.stop() + if self.dimos: + self.dimos.close() +``` + +### 步骤 2:创建运行脚本(`run.py`) + +```python +#!/usr/bin/env python3 +import asyncio +import os +from my_robot import MyRobot +from dimos.agents.claude_agent import ClaudeAgent +from dimos.skills.basic import BasicSkill +from dimos.web.robot_web_interface import RobotWebInterface +import reactivex as rx +import reactivex.operators as ops + +SYSTEM_PROMPT = """您是一个有用的机器人助手。""" + +def main(): + # 检查 API 密钥 + if not os.getenv("ANTHROPIC_API_KEY"): + print("请设置 ANTHROPIC_API_KEY") + return + + # 创建机器人 + robot = MyRobot() + + try: + # 启动机器人 + asyncio.run(robot.start()) + + # 设置技能 + skills = robot.get_skills() + skills.add(BasicSkill) + skills.create_instance("BasicSkill", robot=robot) + + # 设置流 + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + # 创建 Web 界面 + web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": agent_response_stream} + ) + + # 创建智能体 + agent = ClaudeAgent( + dev_name="my_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query=SYSTEM_PROMPT + ) + + # 连接响应 + agent.get_response_observable().subscribe( + lambda x: agent_response_subject.on_next(x) + ) + + print("机器人就绪,访问 http://localhost:5555") + + # 运行 + web_interface.run() + + finally: + robot.stop() + +if __name__ == "__main__": + main() +``` + +### 步骤 3:定义技能(`skills.py`) + +```python +from dimos.skills import Skill, skill + +@skill( + description="执行一个基本动作", + parameters={ + "action": "要执行的动作" + } +) +class BasicSkill(Skill): + def __init__(self, robot): + self.robot = robot + + def run(self, action: str): + # 实现技能逻辑 + return f"已执行:{action}" +``` + +## 最佳实践 + +1. **模块生命周期**:在部署模块之前始终先启动 DIMOS +2. **LCM 主题**:使用带命名空间的描述性主题名称 +3. **错误处理**:用 try-except 块包装模块操作 +4. **资源清理**:确保在 stop() 方法中正确清理 +5. **异步操作**:使用 asyncio 进行非阻塞操作 +6. **流管理**:使用 RxPy 进行响应式编程模式 + +## 调试技巧 + +1. **检查模块状态**:打印 module.io().result() 查看连接 +2. **监控 LCM**:使用 Foxglove 可视化 LCM 消息 +3. **记录一切**:使用 dimos.utils.logging_config.setup_logger() +4. **独立测试模块**:一次部署和测试一个模块 + +## 常见问题 + +1. **"模块未启动"**:确保在部署后调用 start() +2. **"未收到数据"**:检查 LCM 传输配置 +3. **"连接失败"**:验证输入/输出类型是否匹配 +4. **"清理错误"**:在关闭 DIMOS 之前停止模块 + +## 高级主题 + +### 自定义模块开发 + +创建自定义模块的基本结构: + +```python +from dimos.core import Module, In, Out, rpc + +class CustomModule(Module): + # 定义输入 + input_data: In[DataType] = None + + # 定义输出 + output_data: Out[DataType] = None + + def __init__(self, param1, param2, **kwargs): + super().__init__(**kwargs) + self.param1 = param1 + self.param2 = param2 + + @rpc + def start(self): + """启动模块处理。""" + self.input_data.subscribe(self._process_data) + + def _process_data(self, data): + """处理输入数据。""" + # 处理逻辑 + result = self.process(data) + # 发布输出 + self.output_data.publish(result) + + @rpc + def stop(self): + """停止模块。""" + # 清理资源 + pass +``` + +### 技能开发指南 + +技能是机器人可执行的高级动作: + +```python +from dimos.skills import Skill, skill +from typing import Optional + +@skill( + description="复杂操作技能", + parameters={ + "target": "目标对象", + "location": "目标位置" + } +) +class ComplexSkill(Skill): + def __init__(self, robot, **kwargs): + super().__init__(**kwargs) + self.robot = robot + + def run(self, target: str, location: Optional[str] = None): + """执行技能逻辑。""" + try: + # 1. 感知阶段 + object_info = self.robot.detect_object(target) + + # 2. 规划阶段 + if location: + plan = self.robot.plan_movement(object_info, location) + + # 3. 执行阶段 + result = self.robot.execute_plan(plan) + + return { + "success": True, + "message": f"成功移动 {target} 到 {location}" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } +``` + +### 性能优化 + +1. **并行处理**:使用多个工作线程处理不同模块 +2. **数据缓冲**:为高频数据流实现缓冲机制 +3. **延迟加载**:仅在需要时初始化重型模块 +4. **资源池化**:重用昂贵的资源(如神经网络模型) + +希望本指南能帮助您快速上手 DIMOS 机器人开发! \ No newline at end of file diff --git a/docs/modules_CN.md b/docs/modules_CN.md new file mode 100644 index 0000000000..d8f088ef59 --- /dev/null +++ b/docs/modules_CN.md @@ -0,0 +1,188 @@ +# Dimensional 模块系统 + +DimOS 模块系统使用 Dask 进行计算分布和 LCM(轻量级通信和编组)进行高性能进程间通信,实现分布式、多进程的机器人应用。 + +## 核心概念 + +### 1. 模块定义 +模块是继承自 `dimos.core.Module` 的 Python 类,定义输入、输出和 RPC 方法: + +```python +from dimos.core import Module, In, Out, rpc +from dimos.msgs.geometry_msgs import Vector3 + +class MyModule(Module): # ROS Node + # 将输入/输出声明为初始化为 None 的类属性 + data_in: In[Vector3] = None # ROS Subscriber + data_out: Out[Vector3] = None # ROS Publisher + + def __init__(): + # 调用父类 Module 初始化 + super().__init__() + + @rpc + def remote_method(self, param): + """使用 @rpc 装饰的方法可以远程调用""" + return param * 2 +``` + +### 2. 模块部署 +使用 `dimos.deploy()` 方法在 Dask 工作进程中部署模块: + +```python +from dimos import core + +# 启动具有 N 个工作进程的 Dask 集群 +dimos = core.start(4) + +# 部署模块时可以传递初始化参数 +# 在这种情况下,param1 和 param2 被传递到模块初始化中 +module = dimos.deploy(Module, param1="value1", param2=123) +``` + +### 3. 流连接 +模块通过使用 LCM 传输的响应式流进行通信: + +```python +# 为输出配置 LCM 传输 +module1.data_out.transport = core.LCMTransport("/topic_name", MessageType) + +# 将模块输入连接到输出 +module2.data_in.connect(module1.data_out) + +# 访问底层的 Observable 流 +stream = module1.data_out.observable() +stream.subscribe(lambda msg: print(f"接收到: {msg}")) +``` + +### 4. 模块生命周期 +```python +# 启动模块以开始处理 +module.start() # 如果定义了 @rpc start() 方法,则调用它 + +# 检查模块 I/O 配置 +print(module.io().result()) # 显示输入、输出和 RPC 方法 + +# 优雅关闭 +dimos.shutdown() +``` + +## 实际示例:机器人控制系统 + +```python +# 连接模块封装机器人硬件/仿真 +connection = dimos.deploy(ConnectionModule, ip=robot_ip) +connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) +connection.video.transport = core.LCMTransport("/video", Image) + +# 感知模块处理传感器数据 +perception = dimos.deploy(PersonTrackingStream, camera_intrinsics=[...]) +perception.video.connect(connection.video) +perception.tracking_data.transport = core.pLCMTransport("/person_tracking") + +# 开始处理 +connection.start() +perception.start() + +# 通过 RPC 启用跟踪 +perception.enable_tracking() + +# 获取最新的跟踪数据 +data = perception.get_tracking_data() +``` + +## LCM 传输配置 + +```python +# 用于简单类型(如激光雷达)的标准 LCM 传输 +connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + +# 用于复杂 Python 对象/字典的基于 pickle 的传输 +connection.tracking_data.transport = core.pLCMTransport("/person_tracking") + +# 自动配置 LCM 系统缓冲区(在容器中必需) +from dimos.protocol import pubsub +pubsub.lcm.autoconf() +``` + +这种架构使得能够将复杂的机器人系统构建为可组合的分布式模块,这些模块通过流和 RPC 高效通信,从单机扩展到集群。 + +# Dimensional 安装指南 +## Python 安装(Ubuntu 22.04) + +```bash +sudo apt install python3-venv + +# 克隆仓库(dev 分支,无子模块) +git clone -b dev https://github.com/dimensionalOS/dimos.git +cd dimos + +# 创建并激活虚拟环境 +python3 -m venv venv +source venv/bin/activate + +sudo apt install portaudio19-dev python3-pyaudio + +# 如果尚未安装,请安装 torch 和 torchvision +# 示例 CUDA 11.7,Pytorch 2.0.1(如果需要不同的 pytorch 版本,请替换) +pip install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +``` + +### 安装依赖 +```bash +# 仅 CPU(建议首先尝试) +pip install .[cpu,dev] + +# CUDA 安装 +pip install .[cuda,dev] + +# 复制并配置环境变量 +cp default.env .env +``` + +### 测试安装 +```bash +# 运行标准测试 +pytest -s dimos/ + +# 测试模块功能 +pytest -s -m module dimos/ + +# 测试 LCM 通信 +pytest -s -m lcm dimos/ +``` + +# Unitree Go2 快速开始 + +要快速测试模块系统,您可以直接运行 Unitree Go2 多进程示例: + +```bash +# 确保设置了所需的环境变量 +export ROBOT_IP= + +# 运行多进程 Unitree Go2 示例 +python dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py +``` + +## 模块系统的高级特性 + +### 分布式计算 +DimOS 模块系统建立在 Dask 之上,提供了强大的分布式计算能力: + +- **自动负载均衡**:模块自动分布在可用的工作进程中 +- **容错性**:如果工作进程失败,模块可以在其他工作进程上重新启动 +- **可扩展性**:从单机到集群的无缝扩展 + +### 响应式编程模型 +使用 RxPY 实现的响应式流提供了: + +- **异步处理**:非阻塞的数据流处理 +- **背压处理**:自动管理快速生产者和慢速消费者 +- **操作符链**:使用 map、filter、merge 等操作符进行流转换 + +### 性能优化 +LCM 传输针对机器人应用进行了优化: + +- **零拷贝**:大型消息的高效内存使用 +- **低延迟**:微秒级的消息传递 +- **多播支持**:一对多的高效通信 \ No newline at end of file From e11a96dba3038845072346bc5823f0fc6e299cfe Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Sat, 26 Jul 2025 15:11:02 -0700 Subject: [PATCH 86/89] removed submodule --- dimos-lcm | 1 - 1 file changed, 1 deletion(-) delete mode 160000 dimos-lcm diff --git a/dimos-lcm b/dimos-lcm deleted file mode 160000 index 61e0b1893c..0000000000 --- a/dimos-lcm +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 61e0b1893c14074893aad7dc07790948b2e6b3b3 From 43a32e6909e05327b38add5c582263978cbb5130 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Sat, 26 Jul 2025 15:40:13 -0700 Subject: [PATCH 87/89] added piper-sdk to manipulation requirements --- .gitignore | 4 +++- pyproject.toml | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f5c93cf65f..ea773eb114 100644 --- a/.gitignore +++ b/.gitignore @@ -41,4 +41,6 @@ dist/ data/* !data/.lfs/ FastSAM-x.pt -yolo11n.pt +yolo11n.pt + +dimos-lcm/ diff --git a/pyproject.toml b/pyproject.toml index 7d51aa91d4..436b9a5750 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,9 @@ manipulation = [ "pyyaml>=6.0", "contact-graspnet-pytorch @ git+https://github.com/dimensionalOS/contact_graspnet_pytorch.git", + # piper arm + "piper-sdk", + # Visualization (Optional) "kaleido>=0.2.1", "plotly>=5.9.0", From 09377c067368f98aa5bf0316d15587dfaaab3ce5 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Sat, 26 Jul 2025 23:27:19 -0700 Subject: [PATCH 88/89] undo addition to gitignore, also finally fixed bgr video --- .gitignore | 4 +--- dimos/robot/frontier_exploration/__init__.py | 1 - dimos/robot/unitree_webrtc/connection.py | 2 +- dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py | 4 +++- 4 files changed, 5 insertions(+), 6 deletions(-) delete mode 100644 dimos/robot/frontier_exploration/__init__.py diff --git a/.gitignore b/.gitignore index ea773eb114..f5c93cf65f 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,4 @@ dist/ data/* !data/.lfs/ FastSAM-x.pt -yolo11n.pt - -dimos-lcm/ +yolo11n.pt diff --git a/dimos/robot/frontier_exploration/__init__.py b/dimos/robot/frontier_exploration/__init__.py deleted file mode 100644 index 2b69011a9f..0000000000 --- a/dimos/robot/frontier_exploration/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from utils import * diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index 9bc1874cbe..6119cba860 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -239,7 +239,7 @@ async def accept_track(track: MediaStreamTrack) -> VideoMessage: if stop_event.is_set(): return frame = await track.recv() - subject.on_next(Image.from_numpy(frame.to_ndarray(format="bgr24"))) + subject.on_next(Image.from_numpy(frame.to_ndarray(format="rgb24"))) self.conn.video.add_track_callback(accept_track) diff --git a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py index f2b701fc63..0532be8320 100644 --- a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py @@ -94,7 +94,9 @@ def odom_stream(self): @functools.cache def video_stream(self): print("video stream start") - video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) + video_store = TimedSensorReplay( + "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() + ) return video_store.stream() def move(self, vector: Vector): From a8e205398aa533324b895552d81c3e789cbcb585 Mon Sep 17 00:00:00 2001 From: alexlin2 Date: Sun, 27 Jul 2025 00:10:25 -0700 Subject: [PATCH 89/89] put back init.py, relaxed frontier exploration filter conditions --- dimos/robot/frontier_exploration/__init__.py | 1 + .../test_wavefront_frontier_goal_selector.py | 9 +-------- .../wavefront_frontier_goal_selector.py | 19 ++----------------- 3 files changed, 4 insertions(+), 25 deletions(-) create mode 100644 dimos/robot/frontier_exploration/__init__.py diff --git a/dimos/robot/frontier_exploration/__init__.py b/dimos/robot/frontier_exploration/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/dimos/robot/frontier_exploration/__init__.py @@ -0,0 +1 @@ + diff --git a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py index c9b75b28d8..cd344dd0b4 100644 --- a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -25,7 +25,7 @@ ) from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.map import Map -from dimos.types.vector import Vector +from dimos.msgs.geometry_msgs import Vector3 as Vector from dimos.utils.testing import SensorReplay @@ -130,13 +130,6 @@ def test_exploration_goal_selection(): assert isinstance(goal, Vector), "Goal should be a Vector" print(f"Selected exploration goal: ({goal.x:.2f}, {goal.y:.2f})") - # Verify goal is at reasonable distance from robot - distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) - print(f"Goal distance from robot: {distance:.2f}m") - assert distance >= explorer.min_distance_from_robot, ( - "Goal should respect minimum distance from robot" - ) - # Test that goal gets marked as explored assert len(explorer.explored_goals) == 1, "Goal should be marked as explored" assert explorer.explored_goals[0] == goal, "Explored goal should match selected goal" diff --git a/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py index 1aca32fc93..454a70e803 100644 --- a/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/robot/frontier_exploration/wavefront_frontier_goal_selector.py @@ -82,11 +82,9 @@ class WavefrontFrontierExplorer: def __init__( self, - min_frontier_size: int = 10, + min_frontier_size: int = 8, occupancy_threshold: int = 65, - subsample_resolution: int = 2, - min_distance_from_robot: float = 0.5, - explored_area_buffer: float = 0.5, + subsample_resolution: int = 3, min_distance_from_obstacles: float = 0.6, info_gain_threshold: float = 0.03, num_no_gain_attempts: int = 4, @@ -101,8 +99,6 @@ def __init__( min_frontier_size: Minimum number of points to consider a valid frontier occupancy_threshold: Cost threshold above which a cell is considered occupied (0-255) subsample_resolution: Factor by which to subsample the costmap for faster processing (1=no subsampling, 2=half resolution, 4=quarter resolution) - min_distance_from_robot: Minimum distance frontier must be from robot (meters) - explored_area_buffer: Buffer distance around free areas to consider as explored (meters) min_distance_from_obstacles: Minimum distance frontier must be from obstacles (meters) info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) num_no_gain_attempts: Maximum number of consecutive attempts with no information gain @@ -113,8 +109,6 @@ def __init__( self.min_frontier_size = min_frontier_size self.occupancy_threshold = occupancy_threshold self.subsample_resolution = subsample_resolution - self.min_distance_from_robot = min_distance_from_robot - self.explored_area_buffer = explored_area_buffer self.min_distance_from_obstacles = min_distance_from_obstacles self.info_gain_threshold = info_gain_threshold self.num_no_gain_attempts = num_no_gain_attempts @@ -513,15 +507,6 @@ def _rank_frontiers( valid_frontiers = [] for i, frontier in enumerate(frontier_centroids): - robot_distance = np.sqrt( - (frontier.x - robot_pose.x) ** 2 + (frontier.y - robot_pose.y) ** 2 - ) - - # Filter 1: Skip frontiers too close to robot - if robot_distance < self.min_distance_from_robot: - continue - - # Filter 2: Skip frontiers too close to obstacles obstacle_distance = self._compute_distance_to_obstacles(frontier, costmap) if obstacle_distance < self.min_distance_from_obstacles: continue