diff --git a/assets/agent/prompt.txt b/assets/agent/prompt.txt index b3d159a7df..f38c13eb13 100644 --- a/assets/agent/prompt.txt +++ b/assets/agent/prompt.txt @@ -29,35 +29,12 @@ PERCEPTION & TEMPORAL AWARENESS: - You can recognize and track humans and objects in your field of view NAVIGATION & MOVEMENT: -- You can navigate to semantically described locations using Navigate (e.g., "go to the kitchen") -- You can navigate to visually identified objects using NavigateToObject (e.g., "go to the red chair") +- You can navigate to semantically described locations using NavigateWithText (e.g., "go to the kitchen") +- You can navigate to visually identified objects using NavigateWithText (e.g., "go to the red chair") - You can follow humans through complex environments using FollowHuman -- You can execute precise movement to specific coordinates using NavigateToGoal like if you're navigating to a GetPose waypoint - You can perform various body movements and gestures (sit, stand, dance, etc.) -- When navigating to a location like Kitchen or Bathroom or couch, use the generic Navigate skill to query spatial memory and navigate - You can stop any navigation process that is currently running using KillSkill -- Appended to every query you will find current objects detection and Saved Locations like this: - -Current objects detected: -[DETECTED OBJECTS] -Object 1: refrigerator - ID: 1 - Confidence: 0.88 - Position: x=9.44m, y=5.87m, z=-0.13m - Rotation: yaw=0.11 rad - Size: width=1.00m, height=1.46m - Depth: 4.92m - Bounding box: [606, 212, 773, 456] ----------------------------------- -Object 2: box - ID: 2 - Confidence: 0.84 - Position: x=11.30m, y=5.10m, z=-0.19m - Rotation: yaw=-0.03 rad - Size: width=0.91m, height=0.37m - Depth: 6.60m - Bounding box: [753, 149, 867, 195] ----------------------------------- + Saved Robot Locations: - LOCATION_NAME: Position (X, Y, Z), Rotation (X, Y, Z) @@ -70,8 +47,6 @@ Saved Robot Locations: ***When navigating to an object not in current object detected, run NavigateWithText, DO NOT EXPLORE with raw move commands!!!*** -***The object detection list is not a comprehensive source of information, when given a visual query like "go to the person wearing a hat" or "Do you see a dog", always Prioritize running observe skill and NavigateWithText*** - PLANNING & REASONING: - You can develop both short-term and long-term plans to achieve complex goals - You can reason about spatial relationships and plan efficient navigation paths diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 81b1ad4cee..dfae0fb3fb 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -54,7 +54,7 @@ def __getattr__(self, name: str): if name in self.rpcs: return lambda *args, **kwargs: self.rpc.call_sync( - f"{self.remote_name}/{name}", (args, kwargs) + f"{self.remote_name}/{name}", (args, kwargs), timeout=2.0 ) # return super().__getattr__(name) diff --git a/dimos/hardware/zed_camera.py b/dimos/hardware/zed_camera.py index 7ee2aed634..7f24eb8ec8 100644 --- a/dimos/hardware/zed_camera.py +++ b/dimos/hardware/zed_camera.py @@ -31,13 +31,14 @@ from dimos.hardware.stereo_camera import StereoCamera from dimos.core import Module, Out, rpc from dimos.utils.logging_config import setup_logger +from dimos.protocol.tf import TF +from dimos.msgs.geometry_msgs import Transform, Vector3, Quaternion # Import LCM message types -from dimos_lcm.sensor_msgs import Image +from dimos.msgs.sensor_msgs import Image, ImageFormat from dimos_lcm.sensor_msgs import CameraInfo -from dimos_lcm.geometry_msgs import PoseStamped -from dimos_lcm.std_msgs import Header, Time -from dimos_lcm.geometry_msgs import Pose, Point, Quaternion +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.std_msgs import Header logger = setup_logger(__name__) @@ -591,6 +592,9 @@ def __init__( self._subscription = None self._sequence = 0 + # Initialize TF publisher + self.tf = TF() + logger.info(f"ZEDModule initialized for camera {camera_id}") @rpc @@ -675,12 +679,8 @@ def _capture_and_publish(self): if left_img is None or depth is None: return - # Get timestamp - timestamp_ns = time.time_ns() - timestamp = Time(sec=timestamp_ns // 1_000_000_000, nsec=timestamp_ns % 1_000_000_000) - # Create header - header = Header(seq=self._sequence, stamp=timestamp, frame_id=self.frame_id) + header = Header(self.frame_id) self._sequence += 1 # Publish color image @@ -709,20 +709,11 @@ def _publish_color_image(self, image: np.ndarray, header: Header): image_rgb = image # Create LCM Image message - height, width = image_rgb.shape[:2] - encoding = "rgb8" if len(image_rgb.shape) == 3 else "mono8" - step = width * (3 if len(image_rgb.shape) == 3 else 1) - data = image_rgb.tobytes() - msg = Image( - data_length=len(data), - header=header, - height=height, - width=width, - encoding=encoding, - is_bigendian=0, - step=step, - data=data, + data=image_rgb, + format=ImageFormat.RGB, + frame_id=header.frame_id, + ts=header.ts, ) self.color_image.publish(msg) @@ -734,22 +725,12 @@ def _publish_depth_image(self, depth: np.ndarray, header: Header): """Publish depth image as LCM message.""" try: # Depth is float32 in meters - height, width = depth.shape[:2] - encoding = "32FC1" # 32-bit float, single channel - step = width * 4 # 4 bytes per float - data = depth.astype(np.float32).tobytes() - msg = Image( - data_length=len(data), - header=header, - height=height, - width=width, - encoding=encoding, - is_bigendian=0, - step=step, - data=data, + data=depth, + format=ImageFormat.DEPTH, + frame_id=header.frame_id, + ts=header.ts, ) - self.depth_image.publish(msg) except Exception as e: @@ -767,7 +748,7 @@ def _publish_camera_info(self): resolution = info.get("resolution", {}) # Create CameraInfo message - header = Header(seq=0, stamp=Time(sec=int(time.time()), nsec=0), frame_id=self.frame_id) + header = Header(self.frame_id) # Create camera matrix K (3x3) K = [ @@ -830,22 +811,25 @@ def _publish_camera_info(self): logger.error(f"Error publishing camera info: {e}") def _publish_pose(self, pose_data: Dict[str, Any], header: Header): - """Publish camera pose as PoseStamped message.""" + """Publish camera pose as PoseStamped message and TF transform.""" try: position = pose_data.get("position", [0, 0, 0]) rotation = pose_data.get("rotation", [0, 0, 0, 1]) # quaternion [x,y,z,w] - # Create Pose message - pose = Pose( - position=Point(x=position[0], y=position[1], z=position[2]), - orientation=Quaternion(x=rotation[0], y=rotation[1], z=rotation[2], w=rotation[3]), - ) - # Create PoseStamped message - msg = PoseStamped(header=header, pose=pose) - + msg = PoseStamped(ts=header.ts, position=position, orientation=rotation) self.pose.publish(msg) + # Publish TF transform + camera_tf = Transform( + translation=Vector3(position), + rotation=Quaternion(rotation), + frame_id="zed_world", + child_frame_id="zed_camera_link", + ts=header.ts, + ) + self.tf.publish(camera_tf) + except Exception as e: logger.error(f"Error publishing pose: {e}") diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index 887fd023ab..5fcc1451b6 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -26,7 +26,8 @@ from dimos.perception.detection2d.utils import calculate_object_size_from_bbox from dimos.perception.common.utils import bbox2d_to_corners -from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion +from dimos.msgs.std_msgs import Header from dimos_lcm.vision_msgs import ( Detection3D, Detection3DArray, @@ -39,7 +40,6 @@ Pose2D, Point2D, ) -from dimos_lcm.std_msgs import Header from dimos.manipulation.visual_servoing.utils import ( estimate_object_depth, visualize_detections_3d, @@ -179,8 +179,8 @@ def process_frame( else: # If no transform, use camera coordinates center_pose = Pose( - Point(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]), - Quaternion(0.0, 0.0, 0.0, 1.0), # Default orientation + position=Vector3(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), # Default orientation ) # Create Detection3D object diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index dfcb1dbcb0..eda3daa557 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -27,8 +27,9 @@ import numpy as np from dimos.core import Module, In, Out, rpc -from dimos_lcm.sensor_msgs import Image, CameraInfo -from dimos_lcm.geometry_msgs import Vector3, Pose, Point, Quaternion +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs import Vector3, Pose, Quaternion +from dimos_lcm.sensor_msgs import CameraInfo from dimos_lcm.vision_msgs import Detection3DArray, Detection2DArray from dimos.hardware.piper_arm import PiperArm @@ -207,7 +208,9 @@ def __init__( self.target_click = None # Place target position and object info - self.home_pose = Pose(Point(0.0, 0.0, 0.0), Quaternion(0.0, 0.0, 0.0, 1.0)) + self.home_pose = Pose( + position=Vector3(0.0, 0.0, 0.0), orientation=Quaternion(0.0, 0.0, 0.0, 1.0) + ) self.place_target_position = None self.target_object_height = None self.retract_distance = 0.12 @@ -239,22 +242,14 @@ def stop(self): def _on_rgb_image(self, msg: Image): """Handle RGB image messages.""" try: - data = np.frombuffer(msg.data, dtype=np.uint8) - if msg.encoding == "rgb8": - self.latest_rgb = data.reshape((msg.height, msg.width, 3)) - else: - logger.warning(f"Unsupported RGB encoding: {msg.encoding}") + self.latest_rgb = msg.data except Exception as e: logger.error(f"Error processing RGB image: {e}") def _on_depth_image(self, msg: Image): """Handle depth image messages.""" try: - if msg.encoding == "32FC1": - data = np.frombuffer(msg.data, dtype=np.float32) - self.latest_depth = data.reshape((msg.height, msg.width)) - else: - logger.warning(f"Unsupported depth encoding: {msg.encoding}") + self.latest_depth = msg.data except Exception as e: logger.error(f"Error processing depth image: {e}") @@ -896,19 +891,7 @@ def _publish_visualization(self, viz_image: np.ndarray): """Publish visualization image to LCM.""" try: viz_rgb = cv2.cvtColor(viz_image, cv2.COLOR_BGR2RGB) - height, width = viz_rgb.shape[:2] - data = viz_rgb.tobytes() - - msg = Image( - data_length=len(data), - height=height, - width=width, - encoding="rgb8", - is_bigendian=0, - step=width * 3, - data=data, - ) - + msg = Image.from_numpy(viz_rgb) self.viz_image.publish(msg) except Exception as e: logger.error(f"Error publishing visualization: {e}") @@ -935,7 +918,8 @@ def get_place_target_pose(self) -> Optional[Pose]: place_pos[2] += z_offset + 0.1 place_center_pose = Pose( - Point(place_pos[0], place_pos[1], place_pos[2]), Quaternion(0.0, 0.0, 0.0, 1.0) + position=Vector3(place_pos[0], place_pos[1], place_pos[2]), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), ) ee_pose = self.arm.get_ee_pose() diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index e34ec94557..77b4103104 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -21,7 +21,7 @@ from typing import Optional, Tuple, List from collections import deque from scipy.spatial.transform import Rotation as R -from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion from dimos_lcm.vision_msgs import Detection3D, Detection3DArray from dimos.utils.logging_config import setup_logger from dimos.manipulation.visual_servoing.utils import ( @@ -158,11 +158,7 @@ def update_tracking(self, new_detections: Optional[Detection3DArray] = None) -> True if target was successfully tracked, False if lost (but target is kept) """ # Check if we have a current target - if ( - not self.current_target - or not self.current_target.bbox - or not self.current_target.bbox.center - ): + if not self.current_target: return False # Add new detections to history if provided diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index 4546326ef6..df78d85327 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -16,7 +16,7 @@ from typing import Dict, Any, Optional, List, Tuple, Union from dataclasses import dataclass -from dimos_lcm.geometry_msgs import Pose, Vector3, Quaternion, Point +from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion from dimos_lcm.vision_msgs import Detection3D, Detection2D import cv2 from dimos.perception.detection2d.utils import plot_results @@ -30,6 +30,7 @@ compose_transforms, yaw_towards_point, get_distance, + offset_distance, ) @@ -77,7 +78,9 @@ def transform_pose( euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2]) obj_orientation_quat = euler_to_quaternion(euler_vector) - input_pose = Pose(Point(obj_pos[0], obj_pos[1], obj_pos[2]), obj_orientation_quat) + input_pose = Pose( + position=Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), orientation=obj_orientation_quat + ) # Apply input frame conversion based on flags if to_robot: @@ -137,8 +140,8 @@ def transform_points_3d( for point in points_3d: input_point_pose = Pose( - Point(point[0], point[1], point[2]), - Quaternion(0.0, 0.0, 0.0, 1.0), # Identity quaternion + position=Vector3(point[0], point[1], point[2]), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity quaternion ) # Apply input frame conversion based on flags @@ -258,44 +261,11 @@ def update_target_grasp_pose( updated_pose = Pose(target_pos, target_orientation) if grasp_distance > 0.0: - return apply_grasp_distance(updated_pose, grasp_distance) + return offset_distance(updated_pose, grasp_distance) else: return updated_pose -def apply_grasp_distance(target_pose: Pose, distance: float) -> Pose: - """ - Apply grasp distance offset to target pose along its approach direction. - - Args: - target_pose: Target grasp pose - distance: Distance to offset along the approach direction (meters) - - Returns: - Target pose offset by the specified distance along its approach direction - """ - # Convert pose to transformation matrix to extract rotation - T_target = pose_to_matrix(target_pose) - rotation_matrix = T_target[:3, :3] - - # Define the approach vector based on the target pose orientation - # Assuming the gripper approaches along its local -z axis (common for downward grasps) - # You can change this to [1, 0, 0] for x-axis or [0, 1, 0] for y-axis based on your gripper - approach_vector_local = np.array([0, 0, -1]) - - # Transform approach vector to world coordinates - approach_vector_world = rotation_matrix @ approach_vector_local - - # Apply offset along the approach direction - offset_position = Point( - target_pose.position.x + distance * approach_vector_world[0], - target_pose.position.y + distance * approach_vector_world[1], - target_pose.position.z + distance * approach_vector_world[2], - ) - - return Pose(offset_position, target_pose.orientation) - - def is_target_reached(target_pose: Pose, current_pose: Pose, tolerance: float = 0.01) -> bool: """ Check if the target pose has been reached within tolerance. @@ -461,11 +431,11 @@ def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: # Extract position position = zed_pose_data.get("position", [0, 0, 0]) - pos_vector = Point(position[0], position[1], position[2]) + pos_vector = Vector3(position[0], position[1], position[2]) quat = zed_pose_data["rotation"] orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) - return Pose(pos_vector, orientation) + return Pose(position=pos_vector, orientation=orientation) def estimate_object_depth( diff --git a/dimos/models/depth/metric3d.py b/dimos/models/depth/metric3d.py index 58cb63f640..b4f00718bc 100644 --- a/dimos/models/depth/metric3d.py +++ b/dimos/models/depth/metric3d.py @@ -24,7 +24,7 @@ class Metric3D: - def __init__(self, gt_depth_scale=256.0): + def __init__(self, camera_intrinsics=None, gt_depth_scale=256.0): # self.conf = get_config("zoedepth", "infer") # self.depth_model = build_model(self.conf) self.depth_model = torch.hub.load( @@ -35,7 +35,7 @@ def __init__(self, gt_depth_scale=256.0): # self.depth_model = torch.nn.DataParallel(self.depth_model) self.depth_model.eval() - self.intrinsic = [707.0493, 707.0493, 604.0814, 180.5066] + self.intrinsic = camera_intrinsics self.intrinsic_scaled = None self.gt_depth_scale = gt_depth_scale # And this self.pad_info = None @@ -76,12 +76,8 @@ def infer_depth(self, img, debug=False): # Convert to PIL format depth_image = self.unpad_transform_depth(pred_depth) - out_16bit_numpy = (depth_image.squeeze().cpu().numpy() * self.gt_depth_scale).astype( - np.uint16 - ) - depth_map_pil = Image.fromarray(out_16bit_numpy) - return depth_map_pil + return depth_image.cpu().numpy() def save_depth(self, pred_depth): # Save the depth map to a file diff --git a/dimos/models/qwen/video_query.py b/dimos/models/qwen/video_query.py index c37ca953c2..3416a1cf51 100644 --- a/dimos/models/qwen/video_query.py +++ b/dimos/models/qwen/video_query.py @@ -197,7 +197,7 @@ def get_bbox_from_qwen( return None -def get_bbox_from_qwen_frame(frame, object_name: Optional[str] = None) -> Optional[tuple]: +def get_bbox_from_qwen_frame(frame, object_name: Optional[str] = None) -> Optional[list]: """Get bounding box coordinates from Qwen for a specific object or any object using a single frame. Args: @@ -205,16 +205,15 @@ def get_bbox_from_qwen_frame(frame, object_name: Optional[str] = None) -> Option object_name: Optional name of object to detect Returns: - tuple: (bbox, size) where bbox is [x1, y1, x2, y2] or None if no detection - and size is the estimated height in meters + list: bbox as [x1, y1, x2, y2] or None if no detection """ # Ensure frame is numpy array if not isinstance(frame, np.ndarray): raise ValueError("Frame must be a numpy array") prompt = ( - f"Look at this image and find the {object_name if object_name else 'most prominent object'}. Estimate the approximate height of the subject." - "Return ONLY a JSON object with format: {'name': 'object_name', 'bbox': [x1, y1, x2, y2], 'size': height_in_meters} " + f"Look at this image and find the {object_name if object_name else 'most prominent object'}. " + "Return ONLY a JSON object with format: {'name': 'object_name', 'bbox': [x1, y1, x2, y2]} " "where x1,y1 is the top-left and x2,y2 is the bottom-right corner of the bounding box. If not found, return None." ) @@ -230,7 +229,7 @@ def get_bbox_from_qwen_frame(frame, object_name: Optional[str] = None) -> Option # Extract and validate bbox if "bbox" in result and len(result["bbox"]) == 4: - return result["bbox"], result["size"] + return result["bbox"] except Exception as e: print(f"Error parsing Qwen response: {e}") print(f"Raw response: {response}") diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 008cd93546..90bd851222 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -28,14 +28,15 @@ class ImageFormat(Enum): - """Supported image formats.""" + """Supported image formats for internal representation.""" - BGR = "bgr8" - RGB = "rgb8" - RGBA = "rgba8" - BGRA = "bgra8" - GRAY = "mono8" - GRAY16 = "mono16" + BGR = "BGR" # 8-bit Blue-Green-Red color + RGB = "RGB" # 8-bit Red-Green-Blue color + RGBA = "RGBA" # 8-bit RGB with Alpha + BGRA = "BGRA" # 8-bit BGR with Alpha + GRAY = "GRAY" # 8-bit Grayscale + GRAY16 = "GRAY16" # 16-bit Grayscale + DEPTH = "DEPTH" # 32-bit Float Depth @dataclass @@ -124,9 +125,12 @@ def from_file(cls, filepath: str, format: ImageFormat = ImageFormat.BGR) -> "Ima if cv_image is None: raise ValueError(f"Could not load image from {filepath}") - # Detect format based on channels + # Detect format based on channels and data type if len(cv_image.shape) == 2: - detected_format = ImageFormat.GRAY + if cv_image.dtype == np.uint16: + detected_format = ImageFormat.GRAY16 + else: + detected_format = ImageFormat.GRAY elif cv_image.shape[2] == 3: detected_format = ImageFormat.BGR # OpenCV default elif cv_image.shape[2] == 4: @@ -136,6 +140,19 @@ def from_file(cls, filepath: str, format: ImageFormat = ImageFormat.BGR) -> "Ima return cls(data=cv_image, format=detected_format) + @classmethod + def from_depth(cls, depth_data: np.ndarray, frame_id: str = "", ts: float = None) -> "Image": + """Create Image from depth data (float32 array).""" + if depth_data.dtype != np.float32: + depth_data = depth_data.astype(np.float32) + + return cls( + data=depth_data, + format=ImageFormat.DEPTH, + frame_id=frame_id, + ts=ts if ts is not None else time.time(), + ) + def to_opencv(self) -> np.ndarray: """Convert to OpenCV-compatible array (BGR format).""" if self.format == ImageFormat.BGR: @@ -150,6 +167,8 @@ def to_opencv(self) -> np.ndarray: return self.data elif self.format == ImageFormat.GRAY16: return self.data + elif self.format == ImageFormat.DEPTH: + return self.data # Depth images are already in the correct format else: raise ValueError(f"Unsupported format conversion: {self.format}") @@ -284,7 +303,7 @@ def lcm_encode(self, frame_id: Optional[str] = None) -> LCMImage: # Image properties msg.height = self.height msg.width = self.width - msg.encoding = self.format.value + msg.encoding = self._get_lcm_encoding() # Convert format to LCM encoding msg.is_bigendian = False # Use little endian msg.step = self._get_row_step() @@ -331,9 +350,38 @@ def _get_bytes_per_pixel(self) -> int: bytes_per_element = self.data.dtype.itemsize return self.channels * bytes_per_element + def _get_lcm_encoding(self) -> str: + """Get LCM encoding string from internal format and data type.""" + # Map internal format to LCM encoding based on format and dtype + if self.format == ImageFormat.GRAY: + if self.dtype == np.uint8: + return "mono8" + elif self.dtype == np.uint16: + return "mono16" + elif self.format == ImageFormat.GRAY16: + return "mono16" + elif self.format == ImageFormat.RGB: + return "rgb8" + elif self.format == ImageFormat.RGBA: + return "rgba8" + elif self.format == ImageFormat.BGR: + return "bgr8" + elif self.format == ImageFormat.BGRA: + return "bgra8" + elif self.format == ImageFormat.DEPTH: + if self.dtype == np.float32: + return "32FC1" + elif self.dtype == np.float64: + return "64FC1" + + raise ValueError( + f"Cannot determine LCM encoding for format={self.format}, dtype={self.dtype}" + ) + @staticmethod def _parse_encoding(encoding: str) -> dict: """Parse LCM image encoding string to determine format and data type.""" + # Standard encodings encoding_map = { "mono8": {"format": ImageFormat.GRAY, "dtype": np.uint8, "channels": 1}, "mono16": {"format": ImageFormat.GRAY16, "dtype": np.uint16, "channels": 1}, @@ -341,6 +389,10 @@ def _parse_encoding(encoding: str) -> dict: "rgba8": {"format": ImageFormat.RGBA, "dtype": np.uint8, "channels": 4}, "bgr8": {"format": ImageFormat.BGR, "dtype": np.uint8, "channels": 3}, "bgra8": {"format": ImageFormat.BGRA, "dtype": np.uint8, "channels": 4}, + # Depth/float encodings + "32FC1": {"format": ImageFormat.DEPTH, "dtype": np.float32, "channels": 1}, + "32FC3": {"format": ImageFormat.RGB, "dtype": np.float32, "channels": 3}, + "64FC1": {"format": ImageFormat.DEPTH, "dtype": np.float64, "channels": 1}, } if encoding not in encoding_map: diff --git a/dimos/msgs/sensor_msgs/__init__.py b/dimos/msgs/sensor_msgs/__init__.py index 170587e286..434ec75afb 100644 --- a/dimos/msgs/sensor_msgs/__init__.py +++ b/dimos/msgs/sensor_msgs/__init__.py @@ -1,2 +1,2 @@ -from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 diff --git a/dimos/msgs/std_msgs/Header.py b/dimos/msgs/std_msgs/Header.py index fd708be2f3..baefae3afa 100644 --- a/dimos/msgs/std_msgs/Header.py +++ b/dimos/msgs/std_msgs/Header.py @@ -31,18 +31,22 @@ class Header(LCMHeader): msg_name = "std_msgs.Header" + ts: float @dispatch def __init__(self) -> None: """Initialize a Header with current time and empty frame_id.""" - super().__init__(seq=0, stamp=LCMTime(), frame_id="") + self.ts = time.time() + sec = int(self.ts) + nsec = int((self.ts - sec) * 1_000_000_000) + super().__init__(seq=0, stamp=LCMTime(sec=sec, nsec=nsec), frame_id="") @dispatch def __init__(self, frame_id: str) -> None: """Initialize a Header with current time and specified frame_id.""" - ts = time.time() - sec = int(ts) - nsec = int((ts - sec) * 1_000_000_000) + self.ts = time.time() + sec = int(self.ts) + nsec = int((self.ts - sec) * 1_000_000_000) super().__init__(seq=1, stamp=LCMTime(sec=sec, nsec=nsec), frame_id=frame_id) @dispatch @@ -55,9 +59,9 @@ def __init__(self, timestamp: float, frame_id: str = "", seq: int = 1) -> None: @dispatch def __init__(self, timestamp: datetime, frame_id: str = "") -> None: """Initialize a Header with datetime object and frame_id.""" - ts = timestamp.timestamp() - sec = int(ts) - nsec = int((ts - sec) * 1_000_000_000) + self.ts = timestamp.timestamp() + sec = int(self.ts) + nsec = int((self.ts - sec) * 1_000_000_000) super().__init__(seq=1, stamp=LCMTime(sec=sec, nsec=nsec), frame_id=frame_id) @dispatch diff --git a/dimos/navigation/bt_navigator/goal_validator.py b/dimos/navigation/bt_navigator/goal_validator.py index 871c351db0..f43a45969c 100644 --- a/dimos/navigation/bt_navigator/goal_validator.py +++ b/dimos/navigation/bt_navigator/goal_validator.py @@ -400,6 +400,10 @@ def _is_position_safe( True if position is safe, False otherwise """ + # Check bounds first + if not (0 <= x < costmap.width and 0 <= y < costmap.height): + return False + # Check if position itself is free if costmap.grid[y, x] >= cost_threshold or costmap.grid[y, x] == CostValues.UNKNOWN: return False diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py index 3ca4587cb8..8a81af0356 100644 --- a/dimos/navigation/bt_navigator/navigator.py +++ b/dimos/navigation/bt_navigator/navigator.py @@ -21,13 +21,15 @@ import threading import time from enum import Enum -from typing import Optional +from typing import Callable, Optional from dimos.core import Module, In, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.nav_msgs import OccupancyGrid +from dimos_lcm.std_msgs import String from dimos.navigation.local_planner.local_planner import BaseLocalPlanner from dimos.navigation.bt_navigator.goal_validator import find_safe_goal +from dimos.navigation.bt_navigator.recovery_server import RecoveryServer from dimos.protocol.tf import TF from dimos.utils.logging_config import setup_logger from dimos_lcm.std_msgs import Bool @@ -66,17 +68,20 @@ class BehaviorTreeNavigator(Module): # LCM outputs goal: Out[PoseStamped] = None goal_reached: Out[Bool] = None + navigation_state: Out[String] = None def __init__( self, - local_planner: BaseLocalPlanner, publishing_frequency: float = 1.0, + reset_local_planner: Callable[[], None] = None, + check_goal_reached: Callable[[], bool] = None, **kwargs, ): """Initialize the Navigator. Args: publishing_frequency: Frequency to publish goals to global planner (Hz) + goal_tolerance: Distance threshold to consider goal reached (meters) """ super().__init__(**kwargs) @@ -90,11 +95,11 @@ def __init__( # Current goal self.current_goal: Optional[PoseStamped] = None + self.original_goal: Optional[PoseStamped] = None self.goal_lock = threading.Lock() # Goal reached state self._goal_reached = False - self._goal_reached_lock = threading.Lock() # Latest data self.latest_odom: Optional[PoseStamped] = None @@ -104,11 +109,17 @@ def __init__( self.control_thread: Optional[threading.Thread] = None self.stop_event = threading.Event() - self.local_planner = local_planner # TF listener self.tf = TF() - logger.info("Navigator initialized") + # Local planner + self.reset_local_planner = reset_local_planner + self.check_goal_reached = check_goal_reached + + # Recovery server for stuck detection + self.recovery_server = RecoveryServer(stuck_duration=5.0) + + logger.info("Navigator initialized with stuck detection") @rpc def start(self): @@ -150,7 +161,7 @@ def cleanup(self): logger.info("Navigator cleanup complete") @rpc - def set_goal(self, goal: PoseStamped, blocking: bool = False) -> bool: + def set_goal(self, goal: PoseStamped) -> bool: """ Set a new navigation goal. @@ -168,21 +179,13 @@ def set_goal(self, goal: PoseStamped, blocking: bool = False) -> bool: with self.goal_lock: self.current_goal = transformed_goal + self.original_goal = transformed_goal - with self._goal_reached_lock: - self._goal_reached = False + self._goal_reached = False with self.state_lock: self.state = NavigatorState.FOLLOWING_PATH - if blocking: - while not self.is_goal_reached(): - if self.state == NavigatorState.IDLE: - logger.info("Navigation was cancelled") - return False - - time.sleep(self.publishing_period) - return True @rpc @@ -194,6 +197,9 @@ def _on_odom(self, msg: PoseStamped): """Handle incoming odometry messages.""" self.latest_odom = msg + if self.state == NavigatorState.FOLLOWING_PATH: + self.recovery_server.update_odom(msg) + def _on_goal_request(self, msg: PoseStamped): """Handle incoming goal requests.""" self.set_goal(msg) @@ -241,19 +247,29 @@ def _control_loop(self): while not self.stop_event.is_set(): with self.state_lock: current_state = self.state + self.navigation_state.publish(String(data=current_state.value)) if current_state == NavigatorState.FOLLOWING_PATH: with self.goal_lock: goal = self.current_goal + original_goal = self.original_goal if goal is not None and self.latest_costmap is not None: + # Check if robot is stuck + if self.recovery_server.check_stuck(): + logger.warning("Robot is stuck! Cancelling goal and resetting.") + self.cancel_goal() + continue + + costmap = self.latest_costmap.inflate(0.1).gradient(max_distance=1.0) + # Find safe goal position safe_goal_pos = find_safe_goal( - self.latest_costmap, - goal.position, + costmap, + original_goal.position, algorithm="bfs", - cost_threshold=80, - min_clearance=0.1, + cost_threshold=60, + min_clearance=0.25, max_search_distance=5.0, ) @@ -266,21 +282,19 @@ def _control_loop(self): ts=goal.ts, ) self.goal.publish(safe_goal) + self.current_goal = safe_goal else: + logger.warning("Could not find safe goal position, cancelling goal") self.cancel_goal() - if self.local_planner.is_goal_reached(): - with self._goal_reached_lock: - self._goal_reached = True - logger.info("Goal reached!") + # Check if goal is reached + if self.check_goal_reached(): reached_msg = Bool() reached_msg.data = True self.goal_reached.publish(reached_msg) - self.local_planner.reset() - with self.goal_lock: - self.current_goal = None - with self.state_lock: - self.state = NavigatorState.IDLE + self.stop() + self._goal_reached = True + logger.info("Goal reached, resetting local planner") elif current_state == NavigatorState.RECOVERY: with self.state_lock: @@ -290,21 +304,24 @@ def _control_loop(self): @rpc def is_goal_reached(self) -> bool: - """Check if the current goal has been reached.""" - with self._goal_reached_lock: - return self._goal_reached + """Check if the current goal has been reached. + + Returns: + True if goal was reached, False otherwise + """ + return self._goal_reached def stop(self): """Stop navigation and return to IDLE state.""" with self.goal_lock: self.current_goal = None - with self._goal_reached_lock: - self._goal_reached = False + self._goal_reached = False with self.state_lock: self.state = NavigatorState.IDLE - self.local_planner.reset() + self.reset_local_planner() + self.recovery_server.reset() # Reset recovery server when stopping logger.info("Navigator stopped") diff --git a/dimos/navigation/bt_navigator/recovery_server.py b/dimos/navigation/bt_navigator/recovery_server.py new file mode 100644 index 0000000000..a5afa3b090 --- /dev/null +++ b/dimos/navigation/bt_navigator/recovery_server.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Recovery server for handling stuck detection and recovery behaviors. +""" + +from collections import deque + +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import get_distance + +logger = setup_logger("dimos.navigation.bt_navigator.recovery_server") + + +class RecoveryServer: + """ + Recovery server for detecting stuck situations and executing recovery behaviors. + + Currently implements stuck detection based on time without significant movement. + Will be extended with actual recovery behaviors in the future. + """ + + def __init__( + self, + position_threshold: float = 0.2, + stuck_duration: float = 3.0, + ): + """Initialize the recovery server. + + Args: + position_threshold: Minimum distance to travel to reset stuck timer (meters) + stuck_duration: Time duration without significant movement to consider stuck (seconds) + """ + self.position_threshold = position_threshold + self.stuck_duration = stuck_duration + + # Store last position that exceeded threshold + self.last_moved_pose = None + self.last_moved_time = None + self.current_odom = None + + logger.info( + f"RecoveryServer initialized with position_threshold={position_threshold}, " + f"stuck_duration={stuck_duration}" + ) + + def update_odom(self, odom: PoseStamped) -> None: + """Update the odometry data for stuck detection. + + Args: + odom: Current robot odometry with timestamp + """ + if odom is None: + return + + # Store current odom for checking stuck + self.current_odom = odom + + # Initialize on first update + if self.last_moved_pose is None: + self.last_moved_pose = odom + self.last_moved_time = odom.ts + return + + # Calculate distance from the reference position (last significant movement) + distance = get_distance(odom, self.last_moved_pose) + + # If robot has moved significantly from the reference, update reference + if distance > self.position_threshold: + self.last_moved_pose = odom + self.last_moved_time = odom.ts + + def check_stuck(self) -> bool: + """Check if the robot is stuck based on time without movement. + + Returns: + True if robot appears to be stuck, False otherwise + """ + if self.last_moved_time is None: + return False + + # Need current odom to check + if self.current_odom is None: + return False + + # Calculate time since last significant movement + current_time = self.current_odom.ts + time_since_movement = current_time - self.last_moved_time + + # Check if stuck based on duration without movement + is_stuck = time_since_movement > self.stuck_duration + + if is_stuck: + logger.warning( + f"Robot appears stuck! No movement for {time_since_movement:.1f} seconds" + ) + + return is_stuck + + def reset(self) -> None: + """Reset the recovery server state.""" + self.last_moved_pose = None + self.last_moved_time = None + self.current_odom = None + logger.debug("RecoveryServer reset") diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index 14d792ca65..f305c4ff14 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -188,7 +188,7 @@ def test_frontier_ranking(): # Initialize explorer with custom parameters explorer = WavefrontFrontierExplorer( - min_frontier_size=5, min_distance_from_obstacles=0.5, info_gain_threshold=0.02 + min_frontier_perimeter=0.5, safe_distance=0.5, info_gain_threshold=0.02 ) robot_pose = first_lidar.origin @@ -218,12 +218,13 @@ def test_frontier_ranking(): # Test distance to obstacles obstacle_dist = explorer._compute_distance_to_obstacles(goal1, costmap) - assert obstacle_dist >= explorer.min_distance_from_obstacles, ( - f"Goal should be at least {explorer.min_distance_from_obstacles}m from obstacles" + # Note: Goals might be closer than safe_distance if that's the best available frontier + # The safe_distance is used for scoring, not as a hard constraint + print( + f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.safe_distance}m)" ) print(f"Frontier ranking test passed - selected goal at ({goal1.x:.2f}, {goal1.y:.2f})") - print(f"Distance to obstacles: {obstacle_dist:.2f}m") print(f"Total frontiers detected: {len(frontiers1)}") else: print("No frontiers found for ranking test") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index dd26f6f79c..292b8e162d 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -32,6 +32,7 @@ from dimos.msgs.nav_msgs import OccupancyGrid, CostValues from dimos.utils.logging_config import setup_logger from dimos_lcm.std_msgs import Bool +from dimos.utils.transform_utils import get_distance logger = setup_logger("dimos.robot.unitree.frontier_exploration") @@ -98,28 +99,32 @@ class WavefrontFrontierExplorer(Module): def __init__( self, - min_frontier_size: int = 5, + min_frontier_perimeter: float = 0.5, occupancy_threshold: int = 99, - min_distance_from_obstacles: float = 0.2, + safe_distance: float = 3.0, + lookahead_distance: float = 5.0, + max_explored_distance: float = 10.0, info_gain_threshold: float = 0.03, num_no_gain_attempts: int = 4, - goal_timeout: float = 30.0, + goal_timeout: float = 15.0, **kwargs, ): """ Initialize the frontier explorer. Args: - min_frontier_size: Minimum number of points to consider a valid frontier + min_frontier_perimeter: Minimum perimeter in meters to consider a valid frontier occupancy_threshold: Cost threshold above which a cell is considered occupied (0-255) - min_distance_from_obstacles: Minimum distance frontier must be from obstacles (meters) + safe_distance: Safe distance from obstacles for scoring (meters) info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) num_no_gain_attempts: Maximum number of consecutive attempts with no information gain """ super().__init__(**kwargs) - self.min_frontier_size = min_frontier_size + self.min_frontier_perimeter = min_frontier_perimeter self.occupancy_threshold = occupancy_threshold - self.min_distance_from_obstacles = min_distance_from_obstacles + self.safe_distance = safe_distance + self.max_explored_distance = max_explored_distance + self.lookahead_distance = lookahead_distance self.info_gain_threshold = info_gain_threshold self.num_no_gain_attempts = num_no_gain_attempts self._cache = FrontierCache() @@ -347,7 +352,9 @@ def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> List[ frontier_point.classification |= PointClassification.FrontierClosed # Check if we found a large enough frontier - if len(new_frontier) >= self.min_frontier_size: + # Convert minimum perimeter to minimum number of cells based on resolution + min_cells = int(self.min_frontier_perimeter / costmap.resolution) + if len(new_frontier) >= min_cells: world_points = [] for point in new_frontier: world_pos = costmap.grid_to_world( @@ -459,7 +466,7 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr min_distance = float("inf") search_radius = ( - int(self.min_distance_from_obstacles / costmap.resolution) + 5 + int(self.safe_distance / costmap.resolution) + 5 ) # Search a bit beyond minimum # Search in a square around the frontier point @@ -483,7 +490,9 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr distance = np.sqrt(dx**2 + dy**2) * costmap.resolution min_distance = min(min_distance, distance) - return min_distance if min_distance != float("inf") else float("inf") + # If no obstacles found within search radius, return the safe distance + # This indicates the frontier is safely away from obstacles + return min_distance if min_distance != float("inf") else self.safe_distance def _compute_comprehensive_frontier_score( self, frontier: Vector3, frontier_size: int, robot_pose: Vector3, costmap: OccupancyGrid @@ -491,29 +500,38 @@ def _compute_comprehensive_frontier_score( """Compute comprehensive score considering multiple criteria.""" # 1. Distance from robot (preference for moderate distances) - robot_distance = np.sqrt( - (frontier.x - robot_pose.x) ** 2 + (frontier.y - robot_pose.y) ** 2 - ) + robot_distance = get_distance(frontier, robot_pose) # Distance score: prefer moderate distances (not too close, not too far) - optimal_distance = 4.0 # meters - distance_score = 1.0 / (1.0 + abs(robot_distance - optimal_distance)) + # Normalized to 0-1 range + distance_score = 1.0 / (1.0 + abs(robot_distance - self.lookahead_distance)) # 2. Information gain (frontier size) - info_gain_score = frontier_size + # Normalize by a reasonable max frontier size + max_expected_frontier_size = self.min_frontier_perimeter / costmap.resolution * 10 + info_gain_score = min(frontier_size / max_expected_frontier_size, 1.0) # 3. Distance to explored goals (bonus for being far from explored areas) + # Normalize by a reasonable max distance (e.g., 10 meters) explored_goals_distance = self._compute_distance_to_explored_goals(frontier) - explored_goals_score = explored_goals_distance + explored_goals_score = min(explored_goals_distance / self.max_explored_distance, 1.0) - # 4. Distance to obstacles (penalty for being too close) + # 4. Distance to obstacles (score based on safety) + # 0 = too close to obstacles, 1 = at or beyond safe distance obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) - obstacles_score = obstacles_distance + if obstacles_distance >= self.safe_distance: + obstacles_score = 1.0 # Fully safe + else: + obstacles_score = obstacles_distance / self.safe_distance # Linear penalty - # 5. Direction momentum (if we have a current direction) + # 5. Direction momentum (already in 0-1 range from dot product) momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) - # Combine scores with consistent scaling (no arbitrary multipliers) + logger.info( + f"Distance score: {distance_score:.2f}, Info gain: {info_gain_score:.2f}, Explored goals: {explored_goals_score:.2f}, Obstacles: {obstacles_score:.2f}, Momentum: {momentum_score:.2f}" + ) + + # Combine scores with consistent scaling total_score = ( 0.3 * info_gain_score # 30% information gain + 0.3 * explored_goals_score # 30% distance from explored goals @@ -549,10 +567,6 @@ def _rank_frontiers( valid_frontiers = [] for i, frontier in enumerate(frontier_centroids): - obstacle_distance = self._compute_distance_to_obstacles(frontier, costmap) - if obstacle_distance < self.min_distance_from_obstacles: - continue - # Compute comprehensive score frontier_size = frontier_sizes[i] if i < len(frontier_sizes) else 1 score = self._compute_comprehensive_frontier_score( @@ -717,7 +731,8 @@ def _exploration_loop(self): ) # Get exploration goal - goal = self.get_exploration_goal(robot_pose, self.latest_costmap) + costmap = self.latest_costmap.inflate(0.25) + goal = self.get_exploration_goal(robot_pose, costmap) if goal: # Publish goal to navigator diff --git a/dimos/navigation/global_planner/planner.py b/dimos/navigation/global_planner/planner.py index 186163cffb..47622f9cce 100644 --- a/dimos/navigation/global_planner/planner.py +++ b/dimos/navigation/global_planner/planner.py @@ -21,9 +21,57 @@ from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.navigation.global_planner.algo import astar from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import euler_to_quaternion logger = setup_logger("dimos.robot.unitree.global_planner") +import math +from dimos.msgs.geometry_msgs import Quaternion, Vector3 + + +def add_orientations_to_path(path: Path, goal_orientation: Quaternion = None) -> Path: + """Add orientations to path poses based on direction of movement. + + Args: + path: Path with poses to add orientations to + goal_orientation: Desired orientation for the final pose + + Returns: + Path with orientations added to all poses + """ + if not path.poses or len(path.poses) < 2: + return path + + # Calculate orientations for all poses except the last one + for i in range(len(path.poses) - 1): + current_pose = path.poses[i] + next_pose = path.poses[i + 1] + + # Calculate direction to next point + dx = next_pose.position.x - current_pose.position.x + dy = next_pose.position.y - current_pose.position.y + + # Calculate yaw angle + yaw = math.atan2(dy, dx) + + # Convert to quaternion (roll=0, pitch=0, yaw) + orientation = euler_to_quaternion(Vector3(0, 0, yaw)) + current_pose.orientation = orientation + + # Set last pose orientation + identity_quat = Quaternion(0, 0, 0, 1) + if goal_orientation is not None and goal_orientation != identity_quat: + # Use the provided goal orientation if it's not the identity + path.poses[-1].orientation = goal_orientation + elif len(path.poses) > 1: + # Use the previous pose's orientation + path.poses[-1].orientation = path.poses[-2].orientation + else: + # Single pose with identity goal orientation + path.poses[-1].orientation = identity_quat + + return path + def resample_path(path: Path, spacing: float) -> Path: """Resample a path to have approximately uniform spacing between poses. @@ -142,6 +190,8 @@ def _on_target(self, msg: PoseStamped): path = self.plan(msg) if path: + # Add orientations to the path, using the goal's orientation for the final pose + path = add_orientations_to_path(path, msg.orientation) self.path.publish(path) def plan(self, goal: Pose) -> Optional[Path]: @@ -154,9 +204,10 @@ def plan(self, goal: Pose) -> Optional[Path]: # Get current position from odometry robot_pos = self.latest_odom.position + costmap = self.latest_costmap.inflate(0.1).gradient(max_distance=1.0) # Run A* planning - path = astar(self.latest_costmap, goal.position, robot_pos) + path = astar(costmap, goal.position, robot_pos) if path: path = resample_path(path, 0.1) diff --git a/dimos/navigation/local_planner/holonomic_local_planner.py b/dimos/navigation/local_planner/holonomic_local_planner.py index 3a8c73d3e2..28a220cb41 100644 --- a/dimos/navigation/local_planner/holonomic_local_planner.py +++ b/dimos/navigation/local_planner/holonomic_local_planner.py @@ -24,7 +24,7 @@ from dimos.msgs.geometry_msgs import Vector3 from dimos.navigation.local_planner import BaseLocalPlanner -from dimos.utils.transform_utils import quaternion_to_euler, normalize_angle +from dimos.utils.transform_utils import quaternion_to_euler, normalize_angle, get_distance class HolonomicLocalPlanner(BaseLocalPlanner): @@ -47,15 +47,20 @@ def __init__( self, lookahead_dist: float = 1.0, k_rep: float = 0.5, + k_angular: float = 0.75, alpha: float = 0.5, v_max: float = 0.8, goal_tolerance: float = 0.5, + orientation_tolerance: float = 0.2, control_frequency: float = 10.0, **kwargs, ): """Initialize the GLAP planner with specified parameters.""" super().__init__( - goal_tolerance=goal_tolerance, control_frequency=control_frequency, **kwargs + goal_tolerance=goal_tolerance, + orientation_tolerance=orientation_tolerance, + control_frequency=control_frequency, + **kwargs, ) # Algorithm parameters @@ -63,6 +68,7 @@ def __init__( self.k_rep = k_rep self.alpha = alpha self.v_max = v_max + self.k_angular = k_angular # Previous velocity for filtering (vx, vy, vtheta) self.v_prev = np.array([0.0, 0.0, 0.0]) @@ -106,20 +112,39 @@ def compute_velocity(self) -> Optional[Vector3]: v_robot_x = cos_yaw * v_odom[0] + sin_yaw * v_odom[1] v_robot_y = -sin_yaw * v_odom[0] + cos_yaw * v_odom[1] - # Compute angular velocity to align with path direction + # Compute angular velocity closest_idx, _ = self._find_closest_point_on_path(pose, path) - lookahead_point = self._find_lookahead_point(path, closest_idx) - dx = lookahead_point[0] - pose[0] - dy = lookahead_point[1] - pose[1] - desired_yaw = np.arctan2(dy, dx) + # Check if we're near the final goal + goal_pose = self.latest_path.poses[-1] + distance_to_goal = get_distance(self.latest_odom, goal_pose) + + if distance_to_goal < self.goal_tolerance: + # Near goal - rotate to match final goal orientation + goal_euler = quaternion_to_euler(goal_pose.orientation) + desired_yaw = goal_euler.z + else: + # Not near goal - align with path direction + lookahead_point = self._find_lookahead_point(path, closest_idx) + dx = lookahead_point[0] - pose[0] + dy = lookahead_point[1] - pose[1] + desired_yaw = np.arctan2(dy, dx) yaw_error = normalize_angle(desired_yaw - robot_yaw) - k_angular = 2.0 # Angular gain + k_angular = self.k_angular v_theta = k_angular * yaw_error - v_robot_x = np.clip(v_robot_x, -self.v_max, self.v_max) - v_robot_y = np.clip(v_robot_y, -self.v_max, self.v_max) + # Slow down linear velocity when turning + # Scale linear velocity based on angular velocity magnitude + angular_speed = abs(v_theta) + max_angular_speed = self.v_max + + # Calculate speed reduction factor (1.0 when not turning, 0.2 when at max turn rate) + turn_slowdown = 1.0 - 0.8 * min(angular_speed / max_angular_speed, 1.0) + + # Apply speed reduction to linear velocities + v_robot_x = np.clip(v_robot_x * turn_slowdown, -self.v_max, self.v_max) + v_robot_y = np.clip(v_robot_y * turn_slowdown, -self.v_max, self.v_max) v_theta = np.clip(v_theta, -self.v_max, self.v_max) v_raw = np.array([v_robot_x, v_robot_y, v_theta]) diff --git a/dimos/navigation/local_planner/local_planner.py b/dimos/navigation/local_planner/local_planner.py index 2fa8fc6f37..7e503f7425 100644 --- a/dimos/navigation/local_planner/local_planner.py +++ b/dimos/navigation/local_planner/local_planner.py @@ -28,7 +28,7 @@ from dimos.msgs.geometry_msgs import Vector3, PoseStamped from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import get_distance +from dimos.utils.transform_utils import get_distance, quaternion_to_euler, normalize_angle logger = setup_logger("dimos.robot.local_planner") @@ -54,17 +54,25 @@ class BaseLocalPlanner(Module): # LCM outputs cmd_vel: Out[Vector3] = None - def __init__(self, goal_tolerance: float = 0.5, control_frequency: float = 10.0, **kwargs): + def __init__( + self, + goal_tolerance: float = 0.5, + orientation_tolerance: float = 0.2, + control_frequency: float = 10.0, + **kwargs, + ): """Initialize the local planner module. Args: goal_tolerance: Distance threshold to consider goal reached (meters) + orientation_tolerance: Orientation threshold to consider goal reached (radians) control_frequency: Frequency for control loop (Hz) """ super().__init__(**kwargs) # Parameters self.goal_tolerance = goal_tolerance + self.orientation_tolerance = orientation_tolerance self.control_frequency = control_frequency self.control_period = 1.0 / control_frequency @@ -113,7 +121,6 @@ def _follow_path_loop(self): """Main planning loop that runs in a separate thread.""" while not self.stop_planning.is_set(): if self.is_goal_reached(): - logger.info("Goal reached, stopping planning thread") self.stop_planning.set() stop_cmd = Vector3(0, 0, 0) self.cmd_vel.publish(stop_cmd) @@ -145,7 +152,7 @@ def compute_velocity(self) -> Optional[Vector3]: @rpc def is_goal_reached(self) -> bool: """ - Check if the robot has reached the goal position. + Check if the robot has reached the goal position and orientation. Returns: True if goal is reached within tolerance, False otherwise @@ -159,12 +166,18 @@ def is_goal_reached(self) -> bool: goal_pose = self.latest_path.poses[-1] distance = get_distance(self.latest_odom, goal_pose) - goal_reached = distance < self.goal_tolerance + # Check distance tolerance + if distance >= self.goal_tolerance: + return False + + # Check orientation tolerance + current_euler = quaternion_to_euler(self.latest_odom.orientation) + goal_euler = quaternion_to_euler(goal_pose.orientation) - if goal_reached: - logger.info(f"Goal reached! Distance: {distance:.3f}m < {self.goal_tolerance}m") + # Calculate yaw difference and normalize to [-pi, pi] + yaw_error = normalize_angle(goal_euler.z - current_euler.z) - return goal_reached + return abs(yaw_error) < self.orientation_tolerance @rpc def reset(self): diff --git a/dimos/navigation/local_planner/test_base_local_planner.py b/dimos/navigation/local_planner/test_base_local_planner.py index 93ec26882b..ee68586c9e 100644 --- a/dimos/navigation/local_planner/test_base_local_planner.py +++ b/dimos/navigation/local_planner/test_base_local_planner.py @@ -302,8 +302,12 @@ def test_curved_path_following(self, planner, empty_costmap): # Should have both X and Y components for curved motion assert vel is not None - assert vel.x > 0.3 # Moving forward - assert vel.y > 0.1 # Turning left (positive Y) + # Test general behavior: should be moving (not exact values) + assert vel.x > 0 # Moving forward (any positive value) + assert vel.y > 0 # Turning left (any positive value) + # Ensure we have meaningful movement, not just noise + total_linear = np.sqrt(vel.x**2 + vel.y**2) + assert total_linear > 0.1 # Some reasonable movement def test_robot_frame_transformation(self, empty_costmap): """Test that velocities are correctly transformed to robot frame.""" @@ -340,9 +344,13 @@ def test_robot_frame_transformation(self, empty_costmap): # Robot is facing +Y, path is along +X # So in robot frame: forward is +Y direction, path is to the right assert vel is not None - assert abs(vel.x) < 0.1 # No forward velocity in robot frame - assert vel.y < -0.5 # Should move right (negative Y in robot frame) - assert vel.z < -0.5 # Should turn right (negative angular velocity) + # Test relative magnitudes and signs rather than exact values + # Path is to the right, so Y velocity should be negative + assert vel.y < 0 # Should move right (negative Y in robot frame) + # Should turn to align with path + assert vel.z < 0 # Should turn right (negative angular velocity) + # X velocity should be relatively small compared to Y + assert abs(vel.x) < abs(vel.y) # Lateral movement dominates def test_angular_velocity_computation(self, empty_costmap): """Test that angular velocity is computed to align with path.""" @@ -377,6 +385,12 @@ def test_angular_velocity_computation(self, empty_costmap): # Path is at 45 degrees, robot facing 0 degrees # Should have positive angular velocity to turn left assert vel is not None - assert vel.x > 0.5 # Moving forward - assert vel.y > 0.5 # Also moving left (diagonal path) - assert vel.z > 0.5 # Positive angular velocity to turn towards path + # Test general behavior without exact thresholds + assert vel.x > 0 # Moving forward (any positive value) + assert vel.y > 0 # Moving left (holonomic, any positive value) + assert vel.z > 0 # Turning left (positive angular velocity) + # Verify the robot is actually moving with reasonable speed + total_linear = np.sqrt(vel.x**2 + vel.y**2) + assert total_linear > 0.1 # Some meaningful movement + # Since path is diagonal, X and Y should be similar magnitude + assert abs(vel.x - vel.y) < max(vel.x, vel.y) * 0.5 # Within 50% of each other diff --git a/dimos/perception/common/utils.py b/dimos/perception/common/utils.py index 10d05d9b4d..d7292fde13 100644 --- a/dimos/perception/common/utils.py +++ b/dimos/perception/common/utils.py @@ -19,6 +19,7 @@ from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger from dimos_lcm.vision_msgs import Detection3D, Detection2D, BoundingBox2D +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 import torch logger = setup_logger("dimos.perception.common.utils") @@ -121,13 +122,16 @@ def project_2d_points_to_3d( return points_3d -def colorize_depth(depth_img: np.ndarray, max_depth: float = 5.0) -> Optional[np.ndarray]: +def colorize_depth( + depth_img: np.ndarray, max_depth: float = 5.0, overlay_stats: bool = True +) -> Optional[np.ndarray]: """ - Normalize and colorize depth image using COLORMAP_JET. + Normalize and colorize depth image using COLORMAP_JET with optional statistics overlay. Args: depth_img: Depth image (H, W) in meters max_depth: Maximum depth value for normalization + overlay_stats: Whether to overlay depth statistics on the image Returns: Colorized depth image (H, W, 3) in RGB format, or None if input is None @@ -144,6 +148,122 @@ def colorize_depth(depth_img: np.ndarray, max_depth: float = 5.0) -> Optional[np # Make the depth image less bright by scaling down the values depth_rgb = (depth_rgb * 0.6).astype(np.uint8) + if overlay_stats and valid_mask.any(): + # Calculate statistics + valid_depths = depth_img[valid_mask] + min_depth = np.min(valid_depths) + max_depth_actual = np.max(valid_depths) + + # Get center depth + h, w = depth_img.shape + center_y, center_x = h // 2, w // 2 + # Sample a small region around center for robustness + center_region = depth_img[ + max(0, center_y - 2) : min(h, center_y + 3), max(0, center_x - 2) : min(w, center_x + 3) + ] + center_mask = np.isfinite(center_region) & (center_region > 0) + if center_mask.any(): + center_depth = np.median(center_region[center_mask]) + else: + center_depth = depth_img[center_y, center_x] if valid_mask[center_y, center_x] else 0.0 + + # Prepare text overlays + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 1 + line_type = cv2.LINE_AA + + # Text properties + text_color = (255, 255, 255) # White + bg_color = (0, 0, 0) # Black background + padding = 5 + + # Min depth (top-left) + min_text = f"Min: {min_depth:.2f}m" + (text_w, text_h), _ = cv2.getTextSize(min_text, font, font_scale, thickness) + cv2.rectangle( + depth_rgb, + (padding, padding), + (padding + text_w + 4, padding + text_h + 6), + bg_color, + -1, + ) + cv2.putText( + depth_rgb, + min_text, + (padding + 2, padding + text_h + 2), + font, + font_scale, + text_color, + thickness, + line_type, + ) + + # Max depth (top-right) + max_text = f"Max: {max_depth_actual:.2f}m" + (text_w, text_h), _ = cv2.getTextSize(max_text, font, font_scale, thickness) + cv2.rectangle( + depth_rgb, + (w - padding - text_w - 4, padding), + (w - padding, padding + text_h + 6), + bg_color, + -1, + ) + cv2.putText( + depth_rgb, + max_text, + (w - padding - text_w - 2, padding + text_h + 2), + font, + font_scale, + text_color, + thickness, + line_type, + ) + + # Center depth (center) + if center_depth > 0: + center_text = f"{center_depth:.2f}m" + (text_w, text_h), _ = cv2.getTextSize(center_text, font, font_scale, thickness) + center_text_x = center_x - text_w // 2 + center_text_y = center_y + text_h // 2 + + # Draw crosshair + cross_size = 10 + cross_color = (255, 255, 255) + cv2.line( + depth_rgb, + (center_x - cross_size, center_y), + (center_x + cross_size, center_y), + cross_color, + 1, + ) + cv2.line( + depth_rgb, + (center_x, center_y - cross_size), + (center_x, center_y + cross_size), + cross_color, + 1, + ) + + # Draw center depth text with background + cv2.rectangle( + depth_rgb, + (center_text_x - 2, center_text_y - text_h - 2), + (center_text_x + text_w + 2, center_text_y + 2), + bg_color, + -1, + ) + cv2.putText( + depth_rgb, + center_text, + (center_text_x, center_text_y), + font, + font_scale, + text_color, + thickness, + line_type, + ) + return depth_rgb @@ -492,3 +612,27 @@ def find_clicked_detection( return detections_3d[i] return None + + +def extract_pose_from_detection3d(detection3d: Detection3D): + """Extract PoseStamped from Detection3D message. + + Args: + detection3d: Detection3D message + + Returns: + Pose or None if no valid detection + """ + if not detection3d or not detection3d.bbox or not detection3d.bbox.center: + return None + + # Extract position + pos = detection3d.bbox.center.position + position = Vector3(pos.x, pos.y, pos.z) + + # Extract orientation + orient = detection3d.bbox.center.orientation + orientation = Quaternion(orient.x, orient.y, orient.z, orient.w) + + pose = Pose(position=position, orientation=orientation) + return pose diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index e4e96f443d..edd87134b1 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -13,64 +13,76 @@ # limitations under the License. import cv2 -from reactivex import Observable, interval -from reactivex import operators as ops import numpy as np +import time +import threading from typing import Dict, List, Optional from dimos.core import In, Out, Module, rpc +from dimos.msgs.std_msgs import Header from dimos.msgs.sensor_msgs import Image -from dimos.perception.common.ibvs import ObjectDistanceEstimator -from dimos.models.depth.metric3d import Metric3D -from dimos.perception.detection2d.utils import calculate_depth_from_bbox +from dimos.msgs.geometry_msgs import Vector3, Quaternion, Transform, Pose, PoseStamped +from dimos.protocol.tf import TF from dimos.utils.logging_config import setup_logger +# Import LCM messages +from dimos_lcm.vision_msgs import ( + Detection2D, + Detection2DArray, + Detection3D, + Detection3DArray, + ObjectHypothesisWithPose, +) +from dimos_lcm.sensor_msgs import CameraInfo +from dimos.utils.transform_utils import ( + yaw_towards_point, + optical_to_robot_frame, + euler_to_quaternion, +) +from dimos.manipulation.visual_servoing.utils import visualize_detections_3d + logger = setup_logger("dimos.perception.object_tracker") -class ObjectTrackingStream(Module): +class ObjectTracking(Module): """Module for object tracking with LCM input/output.""" # LCM inputs - video: In[Image] = None + color_image: In[Image] = None + depth: In[Image] = None + camera_info: In[CameraInfo] = None # LCM outputs - tracking_data: Out[Dict] = None + detection2darray: Out[Detection2DArray] = None + detection3darray: Out[Detection3DArray] = None + tracked_overlay: Out[Image] = None # Visualization output def __init__( self, - camera_intrinsics=None, - camera_pitch=0.0, - camera_height=1.0, - reid_threshold=5, - reid_fail_tolerance=10, - gt_depth_scale=1000.0, + camera_intrinsics: Optional[List[float]] = None, # [fx, fy, cx, cy] + reid_threshold: int = 10, + reid_fail_tolerance: int = 5, + frame_id: str = "camera_link", ): """ - Initialize an object tracking stream using OpenCV's CSRT tracker with ORB re-ID. + Initialize an object tracking module using OpenCV's CSRT tracker with ORB re-ID. Args: - camera_intrinsics: List in format [fx, fy, cx, cy] where: - - fx: Focal length in x direction (pixels) - - fy: Focal length in y direction (pixels) - - cx: Principal point x-coordinate (pixels) - - cy: Principal point y-coordinate (pixels) - camera_pitch: Camera pitch angle in radians (positive is up) - camera_height: Height of the camera from the ground in meters + camera_intrinsics: Optional [fx, fy, cx, cy] camera parameters. + If None, will use camera_info input. reid_threshold: Minimum good feature matches needed to confirm re-ID. reid_fail_tolerance: Number of consecutive frames Re-ID can fail before tracking is stopped. - gt_depth_scale: Ground truth depth scale factor for Metric3D model + frame_id: TF frame ID for the camera (default: "camera_link") """ # Call parent Module init super().__init__() self.camera_intrinsics = camera_intrinsics - self.camera_pitch = camera_pitch - self.camera_height = camera_height + self._camera_info_received = False self.reid_threshold = reid_threshold self.reid_fail_tolerance = reid_fail_tolerance - self.gt_depth_scale = gt_depth_scale + self.frame_id = frame_id self.tracker = None self.tracking_bbox = None # Stores (x, y, w, h) for tracker initialization @@ -78,173 +90,128 @@ def __init__( self.orb = cv2.ORB_create() self.bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False) self.original_des = None # Store original ORB descriptors - self.reid_threshold = reid_threshold - self.reid_fail_tolerance = reid_fail_tolerance + self.original_kps = None # Store original ORB keypoints self.reid_fail_count = 0 # Counter for consecutive re-id failures + self.last_good_matches = [] # Store good matches for visualization + self.last_roi_kps = None # Store last ROI keypoints for visualization + self.last_roi_bbox = None # Store last ROI bbox for visualization + self.reid_confirmed = False # Store current reid confirmation state - # Initialize distance estimator if camera parameters are provided - self.distance_estimator = None - if camera_intrinsics is not None: - # Convert [fx, fy, cx, cy] to 3x3 camera matrix - fx, fy, cx, cy = camera_intrinsics - K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + # For tracking latest frame data + self._latest_rgb_frame: Optional[np.ndarray] = None + self._latest_depth_frame: Optional[np.ndarray] = None + self._latest_camera_info: Optional[CameraInfo] = None - self.distance_estimator = ObjectDistanceEstimator( - K=K, camera_pitch=camera_pitch, camera_height=camera_height - ) + # Tracking thread control + self.tracking_thread: Optional[threading.Thread] = None + self.stop_tracking = threading.Event() + self.tracking_rate = 30.0 # Hz + self.tracking_period = 1.0 / self.tracking_rate - # Initialize depth model with error handling - try: - self.depth_model = Metric3D(gt_depth_scale) - if camera_intrinsics is not None: - self.depth_model.update_intrinsic(camera_intrinsics) - except RuntimeError as e: - logger.error(f"Failed to initialize Metric3D depth model: {e}") - if "CUDA" in str(e): - logger.error("This appears to be a CUDA initialization error. Please check:") - logger.error("- CUDA is properly installed") - logger.error("- GPU drivers are up to date") - logger.error("- CUDA_VISIBLE_DEVICES environment variable is set correctly") - raise # Re-raise the exception to fail initialization - except Exception as e: - logger.error(f"Unexpected error initializing Metric3D depth model: {e}") - raise + # Initialize TF publisher + self.tf = TF() - # For tracking latest frame data - self._latest_frame: Optional[np.ndarray] = None - self._process_interval = 0.1 # Process at 10Hz + # Store latest detections for RPC access + self._latest_detection2d: Optional[Detection2DArray] = None + self._latest_detection3d: Optional[Detection3DArray] = None + self._detection_event = threading.Event() @rpc def start(self): """Start the object tracking module and subscribe to LCM streams.""" - # Subscribe to video stream - def set_video(image_msg: Image): - if hasattr(image_msg, "data"): - self._latest_frame = image_msg.data - else: - logger.warning("Received image message without data attribute") - - self.video.subscribe(set_video) + # Subscribe to rgb image stream + def on_rgb(image_msg: Image): + self._latest_rgb_frame = image_msg.data + + self.color_image.subscribe(on_rgb) + + # Subscribe to depth stream + def on_depth(image_msg: Image): + self._latest_depth_frame = image_msg.data + + self.depth.subscribe(on_depth) + + # Subscribe to camera info stream + def on_camera_info(camera_info_msg: CameraInfo): + self._latest_camera_info = camera_info_msg + # Extract intrinsics from camera info K matrix + # K is a 3x3 matrix in row-major order: [fx, 0, cx, 0, fy, cy, 0, 0, 1] + self.camera_intrinsics = [ + camera_info_msg.K[0], + camera_info_msg.K[4], + camera_info_msg.K[2], + camera_info_msg.K[5], + ] + if not self._camera_info_received: + self._camera_info_received = True + logger.info( + f"Camera intrinsics received from camera_info: {self.camera_intrinsics}" + ) - # Start periodic processing - interval(self._process_interval).subscribe(lambda _: self._process_frame()) + self.camera_info.subscribe(on_camera_info) logger.info("ObjectTracking module started and subscribed to LCM streams") - def _process_frame(self): - """Process the latest frame if available.""" - if self._latest_frame is None: - return - - # TODO: Better implementation for handling track RPC init - if self.tracker is None or self.tracking_bbox is None: - return - - # Process frame through tracking pipeline - result = self._process_tracking(self._latest_frame) - - # Publish result to LCM - if result: - self.tracking_data.publish(result) - @rpc def track( self, bbox: List[float], - frame: Optional[np.ndarray] = None, - distance: Optional[float] = None, - size: Optional[float] = None, - ) -> bool: + ) -> Dict: """ - Set the initial bounding box for tracking. Features are extracted later. + Initialize tracking with a bounding box and process current frame. Args: bbox: Bounding box in format [x1, y1, x2, y2] - frame: Optional - Current frame for depth estimation and feature extraction - distance: Optional - Known distance to object (meters) - size: Optional - Known size of object (meters) Returns: - bool: True if intention to track is set (bbox is valid) + Dict containing tracking results with 2D and 3D detections """ - if frame is None: - frame = self._latest_frame + if self._latest_rgb_frame is None: + logger.warning("No RGB frame available for tracking") + + # Initialize tracking x1, y1, x2, y2 = map(int, bbox) w, h = x2 - x1, y2 - y1 if w <= 0 or h <= 0: logger.warning(f"Invalid initial bbox provided: {bbox}. Tracking not started.") - self.stop_track() # Ensure clean state - return False + # Set tracking parameters self.tracking_bbox = (x1, y1, w, h) # Store in (x, y, w, h) format self.tracker = cv2.legacy.TrackerCSRT_create() - self.tracking_initialized = False # Reset flag - self.original_des = None # Clear previous descriptors - self.reid_fail_count = 0 # Reset counter on new track + self.tracking_initialized = False + self.original_des = None + self.reid_fail_count = 0 logger.info(f"Tracking target set with bbox: {self.tracking_bbox}") - # Calculate depth only if distance and size not provided - depth_estimate = None - if frame is not None and distance is None and size is None: - depth_map = self.depth_model.infer_depth(frame) - depth_map = np.array(depth_map) - depth_estimate = calculate_depth_from_bbox(depth_map, bbox) - if depth_estimate is not None: - logger.info(f"Estimated depth for object: {depth_estimate:.2f}m") - - # Update distance estimator if needed - if self.distance_estimator is not None: - if size is not None: - self.distance_estimator.set_estimated_object_size(size) - elif distance is not None: - self.distance_estimator.estimate_object_size(bbox, distance) - elif depth_estimate is not None: - self.distance_estimator.estimate_object_size(bbox, depth_estimate) + # Extract initial features + roi = self._latest_rgb_frame[y1:y2, x1:x2] + if roi.size > 0: + self.original_kps, self.original_des = self.orb.detectAndCompute(roi, None) + if self.original_des is None: + logger.warning("No ORB features found in initial ROI.") + self.stop_track() + return {"status": "tracking_failed", "bbox": self.tracking_bbox} else: - logger.info("No distance or size provided. Cannot estimate object size.") - - return True # Indicate intention to track is set - - def calculate_depth_from_bbox(self, frame, bbox): - """ - Calculate the average depth of an object within a bounding box. - Uses the 25th to 75th percentile range to filter outliers. - - Args: - frame: The image frame - bbox: Bounding box in format [x1, y1, x2, y2] + logger.info(f"Initial ORB features extracted: {len(self.original_des)}") - Returns: - float: Average depth in meters, or None if depth estimation fails - """ - try: - # Get depth map for the entire frame - depth_map = self.depth_model.infer_depth(frame) - depth_map = np.array(depth_map) - - # Extract region of interest from the depth map - x1, y1, x2, y2 = map(int, bbox) - roi_depth = depth_map[y1:y2, x1:x2] - - if roi_depth.size == 0: - return None - - # Calculate 25th and 75th percentile to filter outliers - p25 = np.percentile(roi_depth, 25) - p75 = np.percentile(roi_depth, 75) - - # Filter depth values within this range - filtered_depth = roi_depth[(roi_depth >= p25) & (roi_depth <= p75)] + # Initialize the tracker + init_success = self.tracker.init(self._latest_rgb_frame, self.tracking_bbox) + if init_success: + self.tracking_initialized = True + logger.info("Tracker initialized successfully.") + else: + logger.error("Tracker initialization failed.") + self.stop_track() + else: + logger.error("Empty ROI during tracker initialization.") + self.stop_track() - # Calculate average depth (convert to meters) - if filtered_depth.size > 0: - return np.mean(filtered_depth) / 1000.0 # Convert mm to meters + # Start tracking thread + self._start_tracking_thread() - return None - except Exception as e: - logger.error(f"Error calculating depth from bbox: {e}") - return None + # Return initial tracking result + return {"status": "tracking_started", "bbox": self.tracking_bbox} def reid(self, frame, current_bbox) -> bool: """Check if features in current_bbox match stored original features.""" @@ -255,27 +222,74 @@ def reid(self, frame, current_bbox) -> bool: if roi.size == 0: return False # Empty ROI cannot match - _, des_current = self.orb.detectAndCompute(roi, None) + kps_current, des_current = self.orb.detectAndCompute(roi, None) if des_current is None or len(des_current) < 2: return False # Need at least 2 descriptors for knnMatch + # Store ROI keypoints and bbox for visualization + self.last_roi_kps = kps_current + self.last_roi_bbox = [x1, y1, x2, y2] + # Handle case where original_des has only 1 descriptor (cannot use knnMatch with k=2) if len(self.original_des) < 2: matches = self.bf.match(self.original_des, des_current) + self.last_good_matches = matches # Store all matches for visualization good_matches = len(matches) else: matches = self.bf.knnMatch(self.original_des, des_current, k=2) # Apply Lowe's ratio test robustly + good_matches_list = [] good_matches = 0 for match_pair in matches: if len(match_pair) == 2: m, n = match_pair if m.distance < 0.75 * n.distance: + good_matches_list.append(m) good_matches += 1 + self.last_good_matches = good_matches_list # Store good matches for visualization - # print(f"ReID: Good Matches={good_matches}, Threshold={self.reid_threshold}") # Debug return good_matches >= self.reid_threshold + def _start_tracking_thread(self): + """Start the tracking thread.""" + self.stop_tracking.clear() + self.tracking_thread = threading.Thread(target=self._tracking_loop, daemon=True) + self.tracking_thread.start() + logger.info("Started tracking thread") + + def _tracking_loop(self): + """Main tracking loop that runs in a separate thread.""" + while not self.stop_tracking.is_set() and self.tracking_initialized: + # Process tracking for current frame + self._process_tracking() + + # Sleep to maintain tracking rate + time.sleep(self.tracking_period) + + logger.info("Tracking loop ended") + + def _reset_tracking_state(self): + """Reset tracking state without stopping the thread.""" + self.tracker = None + self.tracking_bbox = None + self.tracking_initialized = False + self.original_des = None + self.original_kps = None + self.reid_fail_count = 0 # Reset counter + self.last_good_matches = [] + self.last_roi_kps = None + self.last_roi_bbox = None + self.reid_confirmed = False # Reset reid confirmation state + + # Publish empty detections to clear any visualizations + empty_2d = Detection2DArray(detections_length=0, header=Header(), detections=[]) + empty_3d = Detection3DArray(detections_length=0, header=Header(), detections=[]) + self._latest_detection2d = empty_2d + self._latest_detection3d = empty_3d + self._detection_event.clear() + self.detection2darray.publish(empty_2d) + self.detection3darray.publish(empty_3d) + @rpc def stop_track(self) -> bool: """ @@ -285,166 +299,296 @@ def stop_track(self) -> bool: Returns: bool: True if tracking was successfully stopped """ - self.tracker = None - self.tracking_bbox = None - self.tracking_initialized = False - self.original_des = None - self.reid_fail_count = 0 # Reset counter + # Reset tracking state first + self._reset_tracking_state() + + # Stop tracking thread if running (only if called from outside the thread) + if self.tracking_thread and self.tracking_thread.is_alive(): + # Check if we're being called from within the tracking thread + if threading.current_thread() != self.tracking_thread: + self.stop_tracking.set() + self.tracking_thread.join(timeout=1.0) + self.tracking_thread = None + else: + # If called from within thread, just set the stop flag + self.stop_tracking.set() + + logger.info("Tracking stopped") return True - def _process_tracking(self, frame): - """Process a single frame for tracking.""" - viz_frame = frame.copy() + @rpc + def is_tracking(self) -> bool: + """ + Check if the tracker is currently tracking an object successfully. + + Returns: + bool: True if tracking is active and REID is confirmed, False otherwise + """ + return self.tracking_initialized and self.reid_confirmed + + def _process_tracking(self): + """Process current frame for tracking and publish detections.""" + if self._latest_rgb_frame is None or self.tracker is None or not self.tracking_initialized: + return + + frame = self._latest_rgb_frame tracker_succeeded = False reid_confirmed_this_frame = False final_success = False - target_data = None current_bbox_x1y1x2y2 = None - if self.tracker is not None and self.tracking_bbox is not None: - if not self.tracking_initialized: - # Extract initial features and initialize tracker on first frame - x_init, y_init, w_init, h_init = self.tracking_bbox - roi = frame[y_init : y_init + h_init, x_init : x_init + w_init] - - if roi.size > 0: - _, self.original_des = self.orb.detectAndCompute(roi, None) - if self.original_des is None: - logger.warning( - "No ORB features found in initial ROI during stream processing." - ) - else: - logger.info(f"Initial ORB features extracted: {len(self.original_des)}") - - # Initialize the tracker - init_success = self.tracker.init(frame, self.tracking_bbox) - if init_success: - self.tracking_initialized = True - tracker_succeeded = True - reid_confirmed_this_frame = True - current_bbox_x1y1x2y2 = [ - x_init, - y_init, - x_init + w_init, - y_init + h_init, - ] - logger.info("Tracker initialized successfully.") - else: - logger.error("Tracker initialization failed in stream.") - self.stop_track() - else: - logger.error("Empty ROI during tracker initialization in stream.") - self.stop_track() - - else: # Tracker already initialized, perform update and re-id - tracker_succeeded, bbox_cv = self.tracker.update(frame) - if tracker_succeeded: - x, y, w, h = map(int, bbox_cv) - current_bbox_x1y1x2y2 = [x, y, x + w, y + h] - # Perform re-ID check - reid_confirmed_this_frame = self.reid(frame, current_bbox_x1y1x2y2) - - if reid_confirmed_this_frame: - self.reid_fail_count = 0 - else: - self.reid_fail_count += 1 - logger.warning( - f"Re-ID failed ({self.reid_fail_count}/{self.reid_fail_tolerance}). Continuing track..." - ) - - # Determine final success and stop tracking if needed + # Perform tracker update + tracker_succeeded, bbox_cv = self.tracker.update(frame) + if tracker_succeeded: + x, y, w, h = map(int, bbox_cv) + current_bbox_x1y1x2y2 = [x, y, x + w, y + h] + # Perform re-ID check + reid_confirmed_this_frame = self.reid(frame, current_bbox_x1y1x2y2) + self.reid_confirmed = reid_confirmed_this_frame # Store for is_tracking() RPC + + if reid_confirmed_this_frame: + self.reid_fail_count = 0 + else: + self.reid_fail_count += 1 + else: + self.reid_confirmed = False # No tracking if tracker failed + + # Determine final success if tracker_succeeded: if self.reid_fail_count >= self.reid_fail_tolerance: logger.warning( f"Re-ID failed consecutively {self.reid_fail_count} times. Target lost." ) final_success = False + self._reset_tracking_state() else: final_success = True else: final_success = False if self.tracking_initialized: logger.info("Tracker update failed. Stopping track.") + self._reset_tracking_state() + + if not reid_confirmed_this_frame: + return + + # Create detections if tracking succeeded + header = Header(self.frame_id) + detection2darray = Detection2DArray(detections_length=0, header=header, detections=[]) + detection3darray = Detection3DArray(detections_length=0, header=header, detections=[]) - # Post-processing based on final_success if final_success and current_bbox_x1y1x2y2 is not None: x1, y1, x2, y2 = current_bbox_x1y1x2y2 - viz_color = (0, 255, 0) if reid_confirmed_this_frame else (0, 165, 255) - cv2.rectangle(viz_frame, (x1, y1), (x2, y2), viz_color, 2) - - target_data = { - "target_id": 0, - "bbox": current_bbox_x1y1x2y2, - "confidence": 1.0, - "reid_confirmed": reid_confirmed_this_frame, - } - - dist_text = "Object Tracking" - if not reid_confirmed_this_frame: - dist_text += " (Re-ID Failed - Tolerated)" - - if ( - self.distance_estimator is not None - and self.distance_estimator.estimated_object_size is not None - ): - distance, angle = self.distance_estimator.estimate_distance_angle( - current_bbox_x1y1x2y2 - ) - if distance is not None: - target_data["distance"] = distance - target_data["angle"] = angle - dist_text = f"Object: {distance:.2f}m, {np.rad2deg(angle):.1f} deg" - if not reid_confirmed_this_frame: - dist_text += " (Re-ID Failed - Tolerated)" - - text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] - label_bg_y = max(y1 - text_size[1] - 5, 0) - cv2.rectangle(viz_frame, (x1, label_bg_y), (x1 + text_size[0], y1), (0, 0, 0), -1) - cv2.putText( - viz_frame, - dist_text, - (x1, y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (255, 255, 255), - 1, + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + width = float(x2 - x1) + height = float(y2 - y1) + + # Create Detection2D + detection_2d = Detection2D() + detection_2d.id = "0" + detection_2d.results_length = 1 + detection_2d.header = header + + # Create hypothesis + hypothesis = ObjectHypothesisWithPose() + hypothesis.hypothesis.class_id = "tracked_object" + hypothesis.hypothesis.score = 1.0 + detection_2d.results = [hypothesis] + + # Create bounding box + detection_2d.bbox.center.position.x = center_x + detection_2d.bbox.center.position.y = center_y + detection_2d.bbox.center.theta = 0.0 + detection_2d.bbox.size_x = width + detection_2d.bbox.size_y = height + + detection2darray = Detection2DArray() + detection2darray.detections_length = 1 + detection2darray.header = header + detection2darray.detections = [detection_2d] + + # Create Detection3D if depth is available + if self._latest_depth_frame is not None: + # Calculate 3D position using depth and camera intrinsics + depth_value = self._get_depth_from_bbox(current_bbox_x1y1x2y2) + if ( + depth_value is not None + and depth_value > 0 + and self.camera_intrinsics is not None + ): + fx, fy, cx, cy = self.camera_intrinsics + + # Convert pixel coordinates to 3D in optical frame + z_optical = depth_value + x_optical = (center_x - cx) * z_optical / fx + y_optical = (center_y - cy) * z_optical / fy + + # Create pose in optical frame + optical_pose = Pose() + optical_pose.position = Vector3(x_optical, y_optical, z_optical) + optical_pose.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) # Identity for now + + # Convert to robot frame + robot_pose = optical_to_robot_frame(optical_pose) + + # Calculate orientation: object facing towards camera (origin) + yaw = yaw_towards_point(robot_pose.position) + euler = Vector3(0.0, 0.0, yaw) # Only yaw, no roll/pitch + robot_pose.orientation = euler_to_quaternion(euler) + + # Estimate object size in meters + size_x = width * z_optical / fx + size_y = height * z_optical / fy + size_z = 0.1 # Default depth size + + # Create Detection3D + detection_3d = Detection3D() + detection_3d.id = "0" + detection_3d.results_length = 1 + detection_3d.header = header + + # Reuse hypothesis from 2D + detection_3d.results = [hypothesis] + + # Create 3D bounding box with robot frame pose + detection_3d.bbox.center = Pose() + detection_3d.bbox.center.position = robot_pose.position + detection_3d.bbox.center.orientation = robot_pose.orientation + detection_3d.bbox.size = Vector3(size_x, size_y, size_z) + + detection3darray = Detection3DArray() + detection3darray.detections_length = 1 + detection3darray.header = header + detection3darray.detections = [detection_3d] + + # Publish transform for tracked object + # The optical pose is in camera optical frame, so publish it relative to the camera frame + tracked_object_tf = Transform( + translation=robot_pose.position, + rotation=robot_pose.orientation, + frame_id=self.frame_id, # Use configured camera frame + child_frame_id=f"tracked_object", + ts=header.ts, + ) + self.tf.publish(tracked_object_tf) + + # Store latest detections for RPC access + self._latest_detection2d = detection2darray + self._latest_detection3d = detection3darray + + # Signal that new detections are available + if detection2darray.detections_length > 0 or detection3darray.detections_length > 0: + self._detection_event.set() + + # Publish detections + self.detection2darray.publish(detection2darray) + self.detection3darray.publish(detection3darray) + + # Create and publish visualization if tracking is active + if self.tracking_initialized and self._latest_rgb_frame is not None: + # Convert single detection to list for visualization + detections_3d = ( + detection3darray.detections if detection3darray.detections_length > 0 else [] + ) + detections_2d = ( + detection2darray.detections if detection2darray.detections_length > 0 else [] ) - elif self.tracking_initialized: - self.stop_track() + if detections_3d and detections_2d: + # Extract 2D bbox for visualization + det_2d = detections_2d[0] + bbox_2d = [] + if det_2d.bbox: + x1 = det_2d.bbox.center.position.x - det_2d.bbox.size_x / 2 + y1 = det_2d.bbox.center.position.y - det_2d.bbox.size_y / 2 + x2 = det_2d.bbox.center.position.x + det_2d.bbox.size_x / 2 + y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2 + bbox_2d = [[x1, y1, x2, y2]] + + # Create visualization + viz_image = visualize_detections_3d( + self._latest_rgb_frame, detections_3d, show_coordinates=True, bboxes_2d=bbox_2d + ) - return { - "frame": frame, - "viz_frame": viz_frame, - "targets": [target_data] if target_data else [], - } + # Overlay REID feature matches if available + if self.last_good_matches and self.last_roi_kps and self.last_roi_bbox: + viz_image = self._draw_reid_matches(viz_image) - @rpc - def get_tracking_data(self) -> Dict: - """Get the latest tracking data. + # Convert to Image message and publish + viz_msg = Image.from_numpy(viz_image) + self.tracked_overlay.publish(viz_msg) - Returns: - Dictionary containing tracking results - """ - if self._latest_frame is not None: - return self._process_tracking(self._latest_frame) - return {"frame": None, "viz_frame": None, "targets": []} + def _draw_reid_matches(self, image: np.ndarray) -> np.ndarray: + """Draw REID feature matches on the image.""" + viz_image = image.copy() - def create_stream(self, video_stream: Observable) -> Observable: - """ - Create an Observable stream of object tracking results from a video stream. - This method is maintained for backward compatibility. + x1, y1, x2, y2 = self.last_roi_bbox - Args: - video_stream: Observable that emits video frames + # Draw keypoints from current ROI in green + for kp in self.last_roi_kps: + pt = (int(kp.pt[0] + x1), int(kp.pt[1] + y1)) + cv2.circle(viz_image, pt, 3, (0, 255, 0), -1) - Returns: - Observable that emits dictionaries containing tracking results and visualizations - """ - return video_stream.pipe(ops.map(self._process_tracking)) + for match in self.last_good_matches: + current_kp = self.last_roi_kps[match.trainIdx] + pt_current = (int(current_kp.pt[0] + x1), int(current_kp.pt[1] + y1)) + + # Draw a larger circle for matched points in yellow + cv2.circle(viz_image, pt_current, 5, (0, 255, 255), 2) # Yellow for matched points + + # Draw match strength indicator (smaller circle with intensity based on distance) + # Lower distance = better match = brighter color + intensity = int(255 * (1.0 - min(match.distance / 100.0, 1.0))) + cv2.circle(viz_image, pt_current, 2, (intensity, intensity, 255), -1) + + text = f"REID Matches: {len(self.last_good_matches)}/{len(self.last_roi_kps) if self.last_roi_kps else 0}" + cv2.putText(viz_image, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) + + if len(self.last_good_matches) >= self.reid_threshold: + status_text = "REID: CONFIRMED" + status_color = (0, 255, 0) # Green + else: + status_text = f"REID: WEAK ({self.reid_fail_count}/{self.reid_fail_tolerance})" + status_color = (0, 165, 255) # Orange + + cv2.putText( + viz_image, status_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2 + ) + + return viz_image + + def _get_depth_from_bbox(self, bbox: List[int]) -> Optional[float]: + """Calculate depth from bbox using the 25th percentile of closest points.""" + if self._latest_depth_frame is None: + return None + + x1, y1, x2, y2 = bbox + + # Ensure bbox is within frame bounds + y1 = max(0, y1) + y2 = min(self._latest_depth_frame.shape[0], y2) + x1 = max(0, x1) + x2 = min(self._latest_depth_frame.shape[1], x2) + + # Extract depth values from the entire bbox + roi_depth = self._latest_depth_frame[y1:y2, x1:x2] + + # Get valid (finite and positive) depth values + valid_depths = roi_depth[np.isfinite(roi_depth) & (roi_depth > 0)] + + if len(valid_depths) > 0: + depth_25th_percentile = float(np.percentile(valid_depths, 25)) + return depth_25th_percentile + + return None @rpc def cleanup(self): """Clean up resources.""" self.stop_track() - # CUDA cleanup is now handled by WorkerPlugin in dimos.core + + # Ensure thread is stopped + if self.tracking_thread and self.tracking_thread.is_alive(): + self.stop_tracking.set() + self.tracking_thread.join(timeout=2.0) diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 188b9b81d9..aa9b843569 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -29,7 +29,6 @@ from dimos.core import In, Module, Out, rpc from dimos.msgs.sensor_msgs import Image from dimos.msgs.geometry_msgs import Vector3, Quaternion, Pose, PoseStamped -from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.logging_config import setup_logger from dimos.agents.memory.spatial_vector_db import SpatialVectorDB from dimos.agents.memory.image_embedding import ImageEmbeddingProvider @@ -52,7 +51,7 @@ class SpatialMemory(Module): # LCM inputs video: In[Image] = None - odom: In[Odometry] = None + odom: In[PoseStamped] = None def __init__( self, @@ -63,14 +62,12 @@ def __init__( min_time_threshold: float = 1.0, # Min time in seconds to record a new frame db_path: Optional[str] = None, # Path for ChromaDB persistence visual_memory_path: Optional[str] = None, # Path for saving/loading visual memory - new_memory: bool = False, # Whether to create a new memory from scratch + new_memory: bool = True, # Whether to create a new memory from scratch output_dir: Optional[str] = None, # Directory for storing visual memory data chroma_client: Any = None, # Optional ChromaDB client for persistence visual_memory: Optional[ "VisualMemory" ] = None, # Optional VisualMemory instance for storing images - video_stream: Optional[Observable] = None, # Video stream to process - get_pose: Optional[callable] = None, # Function that returns position and rotation ): """ Initialize the spatial perception system. @@ -170,7 +167,7 @@ def __init__( # Track latest data for processing self._latest_video_frame: Optional[np.ndarray] = None - self._latest_odom: Optional[Odometry] = None + self._latest_odom: Optional[PoseStamped] = None self._process_interval = 1 logger.info(f"SpatialMemory initialized with model {embedding_model}") @@ -187,7 +184,7 @@ def set_video(image_msg: Image): else: logger.warning("Received image message without data attribute") - def set_odom(odom_msg: Odometry): + def set_odom(odom_msg: PoseStamped): self._latest_odom = odom_msg self.video.subscribe(set_video) @@ -416,7 +413,8 @@ def process_combined_data(data): logger.info("No position or rotation data available, skipping frame") return None - position_v3 = Vector3(position_vec.x, position_vec.y, position_vec.z) + # position_vec is already a Vector3, no need to recreate it + position_v3 = position_vec if self.last_position is not None: distance_moved = np.linalg.norm( diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 80ad3e5c0e..b5abe1f114 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -43,8 +43,9 @@ def call( ) -> Optional[Callable[[], Any]]: ... # we bootstrap these from the call() implementation above - def call_sync(self, name: str, arguments: Args, rpc_timeout: float = None) -> Any: + def call_sync(self, name: str, arguments: Args, timeout: float = 1.0) -> Any: res = Empty + start_time = time.time() def receive_value(val): nonlocal res @@ -54,9 +55,9 @@ def receive_value(val): total_time = 0.0 while res is Empty: - if rpc_timeout is not None and total_time >= rpc_timeout: - raise TimeoutError(f"RPC call to {name} timed out after {rpc_timeout} seconds") - + if time.time() - start_time > timeout: + print(f"RPC {name} timed out") + return None time.sleep(0.05) total_time += 0.1 return res diff --git a/dimos/robot/agilex/piper_arm.py b/dimos/robot/agilex/piper_arm.py index 63dc419a78..2c917b71fb 100644 --- a/dimos/robot/agilex/piper_arm.py +++ b/dimos/robot/agilex/piper_arm.py @@ -19,12 +19,13 @@ from dimos import core from dimos.hardware.zed_camera import ZEDModule from dimos.manipulation.visual_servoing.manipulation_module import ManipulationModule -from dimos_lcm.sensor_msgs import Image +from dimos.msgs.sensor_msgs import Image from dimos.protocol import pubsub from dimos.skills.skills import SkillLibrary from dimos.types.robot_capabilities import RobotCapability from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.utils.logging_config import setup_logger +from dimos.robot.robot import Robot # Import LCM message types from dimos_lcm.sensor_msgs import CameraInfo @@ -32,10 +33,11 @@ logger = setup_logger("dimos.robot.agilex.piper_arm") -class PiperArmRobot: +class PiperArmRobot(Robot): """Piper Arm robot with ZED camera and manipulation capabilities.""" def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + super().__init__() self.dimos = None self.stereo_camera = None self.manipulation_interface = None @@ -91,9 +93,9 @@ async def start(self): # Print module info logger.info("Modules configured:") print("\nZED Module:") - print(self.stereo_camera.io().result()) + print(self.stereo_camera.io()) print("\nManipulation Module:") - print(self.manipulation_interface.io().result()) + print(self.manipulation_interface.io()) # Start modules logger.info("Starting modules...") @@ -106,14 +108,6 @@ async def start(self): logger.info("PiperArmRobot initialized and started") - def get_skills(self): - """Get the robot's skill library. - - Returns: - The robot's skill library for adding/managing skills - """ - return self.skill_library - def pick_and_place( self, pick_x: int, pick_y: int, place_x: Optional[int] = None, place_y: Optional[int] = None ): @@ -149,17 +143,6 @@ def handle_keyboard_command(self, key: str): logger.error("Manipulation module not initialized") return None - def has_capability(self, capability: RobotCapability) -> bool: - """Check if the robot has a specific capability. - - Args: - capability: The capability to check for - - Returns: - bool: True if the robot has the capability - """ - return capability in self.capabilities - def stop(self): """Stop all modules and clean up.""" logger.info("Stopping PiperArmRobot...") diff --git a/dimos/robot/agilex/run.py b/dimos/robot/agilex/run.py index c9e5a036d8..a2db03c898 100644 --- a/dimos/robot/agilex/run.py +++ b/dimos/robot/agilex/run.py @@ -31,7 +31,6 @@ from dimos.agents.claude_agent import ClaudeAgent from dimos.skills.manipulation.pick_and_place import PickAndPlace from dimos.skills.kill_skill import KillSkill -from dimos.skills.observe import Observe from dimos.web.robot_web_interface import RobotWebInterface from dimos.stream.audio.pipelines import stt, tts from dimos.utils.logging_config import setup_logger @@ -53,7 +52,6 @@ - **PickAndPlace**: Execute pick and place operations based on object and location descriptions - Pick only: "Pick up the red mug" - Pick and place: "Move the book to the shelf" -- **Observe**: Capture and analyze the current camera view - **KillSkill**: Stop any currently running skill ## Guidelines: @@ -70,7 +68,6 @@ You: "I'll place the toy on the table." [Execute PickAndPlace with object_query="toy", target_query="on the table"] - User: "What do you see?" - You: "Let me take a look at the current view." [Execute Observe] Remember: You're here to assist with manipulation tasks. Be helpful, precise, and always prioritize safe operation of the robot.""" @@ -109,12 +106,10 @@ def main(): # Set up skill library skills = robot.get_skills() skills.add(PickAndPlace) - skills.add(Observe) skills.add(KillSkill) # Create skill instances skills.create_instance("PickAndPlace", robot=robot) - skills.create_instance("Observe", robot=robot) skills.create_instance("KillSkill", robot=robot, skill_library=skills) logger.info(f"Skills registered: {[skill.__name__ for skill in skills.get_class_skills()]}") diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py new file mode 100644 index 0000000000..772a7d46bb --- /dev/null +++ b/dimos/robot/robot.py @@ -0,0 +1,59 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal robot interface for DIMOS robots.""" + +from abc import ABC +from typing import List + +from dimos.types.robot_capabilities import RobotCapability + + +class Robot(ABC): + """Minimal abstract base class for all DIMOS robots. + + This class provides the essential interface that all robot implementations + can share, with no required methods - just common properties and helpers. + """ + + def __init__(self): + """Initialize the robot with basic properties.""" + self.capabilities: List[RobotCapability] = [] + self.skill_library = None + + def has_capability(self, capability: RobotCapability) -> bool: + """Check if the robot has a specific capability. + + Args: + capability: The capability to check for + + Returns: + bool: True if the robot has the capability + """ + return capability in self.capabilities + + def get_skills(self): + """Get the robot's skill library. + + Returns: + The robot's skill library for managing skills + """ + return self.skill_library + + def cleanup(self): + """Clean up robot resources. + + Override this method to provide cleanup logic. + """ + pass diff --git a/dimos/robot/unitree_webrtc/camera_module.py b/dimos/robot/unitree_webrtc/camera_module.py new file mode 100644 index 0000000000..beff3561ba --- /dev/null +++ b/dimos/robot/unitree_webrtc/camera_module.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import threading +from typing import List, Optional + +import cv2 +import numpy as np + +from dimos.core import Module, In, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos_lcm.sensor_msgs import CameraInfo +from dimos.msgs.std_msgs import Header +from dimos.protocol.tf import TF +from dimos.utils.logging_config import setup_logger +from dimos.perception.common.utils import colorize_depth + +logger = setup_logger(__name__) + + +class UnitreeCameraModule(Module): + """ + Camera module for Unitree Go2 that processes RGB images to generate depth using Metric3D. + + Subscribes to: + - /video: RGB camera images from Unitree + + Publishes: + - /go2/color_image: RGB camera images + - /go2/depth_image: Depth images generated by Metric3D + - /go2/depth_colorized: Colorized depth images with statistics overlay + - /go2/camera_info: Camera calibration information + - /go2/camera_pose: Camera pose from TF lookup + """ + + # LCM inputs + video: In[Image] = None + + # LCM outputs + color_image: Out[Image] = None + depth_image: Out[Image] = None + depth_colorized: Out[Image] = None + camera_info: Out[CameraInfo] = None + camera_pose: Out[PoseStamped] = None + + def __init__( + self, + camera_intrinsics: List[float], + world_frame_id: str = "world", + camera_frame_id: str = "camera_link", + base_frame_id: str = "base_link", + gt_depth_scale: float = 2.0, + **kwargs, + ): + """ + Initialize Unitree Camera Module. + + Args: + camera_intrinsics: Camera intrinsics [fx, fy, cx, cy] + camera_frame_id: TF frame ID for camera + base_frame_id: TF frame ID for robot base + """ + super().__init__(**kwargs) + + if len(camera_intrinsics) != 4: + raise ValueError("Camera intrinsics must be [fx, fy, cx, cy]") + + self.camera_intrinsics = camera_intrinsics + self.camera_frame_id = camera_frame_id + self.base_frame_id = base_frame_id + self.world_frame_id = world_frame_id + + # Initialize components + from dimos.models.depth.metric3d import Metric3D + + self.metric3d = Metric3D(camera_intrinsics=self.camera_intrinsics) + self.gt_depth_scale = gt_depth_scale + self.tf = TF() + + # Processing state + self._running = False + self._latest_frame = None + self._last_image = None + self._last_timestamp = None + self._last_depth = None + + # Threading + self._processing_thread: Optional[threading.Thread] = None + self._stop_processing = threading.Event() + + logger.info(f"UnitreeCameraModule initialized with intrinsics: {camera_intrinsics}") + + @rpc + def start(self): + """Start the camera module.""" + if self._running: + logger.warning("Camera module already running") + return + + # Set running flag before starting + self._running = True + + # Subscribe to video input + self.video.subscribe(self._on_video) + + # Start processing thread + self._start_processing_thread() + + logger.info("Camera module started") + + @rpc + def stop(self): + """Stop the camera module.""" + if not self._running: + return + + self._running = False + self._stop_processing.set() + + # Wait for thread to finish + if self._processing_thread and self._processing_thread.is_alive(): + self._processing_thread.join(timeout=2.0) + + logger.info("Camera module stopped") + + def _on_video(self, msg: Image): + """Store latest video frame for processing.""" + if not self._running: + return + + # Simply store the latest frame - processing happens in main loop + self._latest_frame = msg + logger.debug( + f"Received video frame: format={msg.format}, shape={msg.data.shape if hasattr(msg.data, 'shape') else 'unknown'}" + ) + + def _start_processing_thread(self): + """Start the processing thread.""" + self._stop_processing.clear() + self._processing_thread = threading.Thread(target=self._main_processing_loop, daemon=True) + self._processing_thread.start() + logger.info("Started camera processing thread") + + def _main_processing_loop(self): + """Main processing loop that continuously processes latest frames.""" + logger.info("Starting main processing loop") + + while not self._stop_processing.is_set(): + # Process latest frame if available + if self._latest_frame is not None: + try: + msg = self._latest_frame + self._latest_frame = None # Clear to avoid reprocessing + # Store for publishing + self._last_image = msg.data + self._last_timestamp = msg.ts if msg.ts else time.time() + # Process depth + self._process_depth(self._last_image) + + except Exception as e: + logger.error(f"Error in main processing loop: {e}", exc_info=True) + else: + # Small sleep to avoid busy waiting + time.sleep(0.001) + + logger.info("Main processing loop stopped") + + def _process_depth(self, img_array: np.ndarray): + """Process depth estimation using Metric3D.""" + try: + logger.debug(f"Processing depth for image shape: {img_array.shape}") + + # Generate depth map + depth_array = self.metric3d.infer_depth(img_array) / self.gt_depth_scale + + self._last_depth = depth_array + logger.debug(f"Generated depth map shape: {depth_array.shape}") + + self._publish_synchronized_data() + + except Exception as e: + logger.error(f"Error processing depth: {e}", exc_info=True) + + def _publish_synchronized_data(self): + """Publish all data synchronously.""" + if not self._running: + return + + try: + # Create header + header = Header(self.camera_frame_id) + + logger.debug("Publishing synchronized camera data") + + # Publish color image + color_msg = Image( + data=self._last_image, + format=ImageFormat.RGB, + frame_id=header.frame_id, + ts=header.ts, + ) + self.color_image.publish(color_msg) + logger.debug(f"Published color image: shape={self._last_image.shape}") + + # Publish depth image + if self._last_depth is not None: + depth_msg = Image( + data=self._last_depth, + format=ImageFormat.DEPTH, + frame_id=header.frame_id, + ts=header.ts, + ) + self.depth_image.publish(depth_msg) + logger.debug(f"Published depth image: shape={self._last_depth.shape}") + + # Publish colorized depth image + depth_colorized_array = colorize_depth( + self._last_depth, max_depth=10.0, overlay_stats=True + ) + if depth_colorized_array is not None: + depth_colorized_msg = Image( + data=depth_colorized_array, + format=ImageFormat.RGB, + frame_id=header.frame_id, + ts=header.ts, + ) + self.depth_colorized.publish(depth_colorized_msg) + logger.debug( + f"Published colorized depth image: shape={depth_colorized_array.shape}" + ) + + # Publish camera info + self._publish_camera_info(header) + + # Publish camera pose + self._publish_camera_pose(header) + + except Exception as e: + logger.error(f"Error publishing synchronized data: {e}", exc_info=True) + + def _publish_camera_info(self, header: Header): + """Publish camera calibration information.""" + try: + # Extract intrinsics + fx, fy, cx, cy = self.camera_intrinsics + + # Get image dimensions from last image + height, width = self._last_image.shape[:2] + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion coefficients for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + msg = CameraInfo( + D_length=len(D), + header=header, + height=height, + width=width, + distortion_model="plumb_bob", + D=D, + K=K, + R=R, + P=P, + binning_x=0, + binning_y=0, + ) + + self.camera_info.publish(msg) + + except Exception as e: + logger.error(f"Error publishing camera info: {e}") + + def _publish_camera_pose(self, header: Header): + """Publish camera pose from TF lookup.""" + try: + # Look up transform from base_link to camera_link + transform = self.tf.get( + parent_frame=self.world_frame_id, + child_frame=self.camera_frame_id, + time_point=header.ts, + time_tolerance=1.0, + ) + + if transform: + # Create PoseStamped from transform + pose_msg = PoseStamped( + ts=header.ts, + frame_id=self.camera_frame_id, + position=transform.translation, + orientation=transform.rotation, + ) + self.camera_pose.publish(pose_msg) + else: + logger.warning( + f"Could not find transform from {self.base_frame_id} to {self.camera_frame_id}" + ) + + except Exception as e: + logger.error(f"Error publishing camera pose: {e}") + + @rpc + def get_camera_intrinsics(self) -> List[float]: + """Get camera intrinsics.""" + return self.camera_intrinsics + + def cleanup(self): + """Clean up resources on module destruction.""" + self.stop() + self.metric3d.cleanup() diff --git a/dimos/robot/unitree_webrtc/run.py b/dimos/robot/unitree_webrtc/run.py index 10d45bfb09..aca66ab654 100644 --- a/dimos/robot/unitree_webrtc/run.py +++ b/dimos/robot/unitree_webrtc/run.py @@ -28,8 +28,6 @@ from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 from dimos.agents.claude_agent import ClaudeAgent -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.observe import Observe from dimos.skills.kill_skill import KillSkill from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal, Explore from dimos.skills.unitree.unitree_speak import UnitreeSpeak @@ -91,8 +89,6 @@ def main(): # Set up skill library skills = robot.get_skills() - # skills.add(ObserveStream) - # skills.add(Observe) skills.add(KillSkill) skills.add(NavigateWithText) skills.add(GetPose) @@ -154,8 +150,6 @@ def main(): tts_node.consume_text(agent.get_response_observable()) # Create skill instances that need agent reference - skills.create_instance("ObserveStream", robot=robot, agent=agent) - skills.create_instance("Observe", robot=robot, agent=agent) logger.info("=" * 60) logger.info("Unitree Go2 Agent Ready!") diff --git a/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py b/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py index a9f2ce7d25..706130dae2 100644 --- a/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py +++ b/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py @@ -157,7 +157,7 @@ async def test_unitree_go2_navigation_stack(self): # Set navigation goal (non-blocking) try: - navigator.set_goal(target_pose, blocking=False) + navigator.set_goal(target_pose) logger.info("Navigation goal set") except Exception as e: logger.warning(f"Navigation goal setting failed: {e}") diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index a674d4d0b7..78d26427e8 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -57,15 +57,11 @@ def publish(_): # temporary, not sure if it belogs in mapper # used only for visualizations, not for any algo - occupancygrid = ( - OccupancyGrid.from_pointcloud( - self.to_lidar_message(), - resolution=self.cost_resolution, - min_height=0.15, - max_height=0.6, - ) - .inflate(0.1) - .gradient(max_distance=1.0) + occupancygrid = OccupancyGrid.from_pointcloud( + self.to_lidar_message(), + resolution=self.cost_resolution, + min_height=0.15, + max_height=0.6, ) self.global_costmap.publish(occupancygrid) diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 0578547760..6296b38b5c 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -20,32 +20,43 @@ import os import time import warnings -from typing import Callable, Optional +from typing import List, Optional from dimos import core from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3, Quaternion from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.msgs.sensor_msgs import Image +from dimos_lcm.std_msgs import String +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.vision_msgs import Detection2DArray, Detection3DArray from dimos.perception.spatial_perception import SpatialMemory from dimos.protocol import pubsub +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic from dimos.protocol.tf import TF from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule from dimos.navigation.global_planner import AstarPlanner from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills +from dimos.robot.unitree_webrtc.camera_module import UnitreeCameraModule from dimos.skills.skills import AbstractRobotSkill, SkillLibrary from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger from dimos.utils.testing import TimedSensorReplay +from dimos.utils.transform_utils import offset_distance +from dimos.perception.common.utils import extract_pose_from_detection3d +from dimos.perception.object_tracker import ObjectTracking from dimos_lcm.std_msgs import Bool +from dimos.robot.robot import Robot +from dimos.types.robot_capabilities import RobotCapability + logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) @@ -66,7 +77,7 @@ class FakeRTC: """Fake WebRTC connection for testing with recorded data.""" def __init__(self, *args, **kwargs): - data = get_data("unitree_office_walk") + get_data("unitree_office_walk") # Preload data for testing def connect(self): pass @@ -188,7 +199,7 @@ def publish_request(self, topic: str, data: dict): return self.connection.publish_request(topic, data) -class UnitreeGo2: +class UnitreeGo2(Robot): """Full Unitree Go2 robot with navigation and perception capabilities.""" def __init__( @@ -204,21 +215,28 @@ def __init__( Args: ip: Robot IP address (or None for fake connection) output_dir: Directory for saving outputs (default: assets/output) - enable_perception: Whether to enable spatial memory/perception websocket_port: Port for web visualization skill_library: Skill library instance playback: If True, use recorded data instead of real robot connection """ + super().__init__() self.ip = ip self.playback = playback or (ip is None) # Auto-enable playback if no IP provided self.output_dir = output_dir or os.path.join(os.getcwd(), "assets", "output") self.websocket_port = websocket_port + self.lcm = LCM() + + # Default camera intrinsics + self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] # Initialize skill library if skill_library is None: skill_library = MyUnitreeSkills() self.skill_library = skill_library + # Set capabilities + self.capabilities = [RobotCapability.LOCOMOTION, RobotCapability.VISION] + self.dimos = None self.connection = None self.mapper = None @@ -229,6 +247,8 @@ def __init__( self.websocket_vis = None self.foxglove_bridge = None self.spatial_memory_module = None + self.camera_module = None + self.object_tracker = None self._setup_directories() @@ -253,16 +273,19 @@ def _setup_directories(self): def start(self): """Start the robot system with all modules.""" - self.dimos = core.start(4) + self.dimos = core.start(8) self._deploy_connection() self._deploy_mapping() self._deploy_navigation() self._deploy_visualization() self._deploy_perception() + self._deploy_camera() self._start_modules() + self.lcm.start() + logger.info("UnitreeGo2 initialized and started") logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") @@ -289,12 +312,17 @@ def _deploy_navigation(self): """Deploy and configure navigation modules.""" self.global_planner = self.dimos.deploy(AstarPlanner) self.local_planner = self.dimos.deploy(HolonomicLocalPlanner) - self.navigator = self.dimos.deploy(BehaviorTreeNavigator, local_planner=self.local_planner) + self.navigator = self.dimos.deploy( + BehaviorTreeNavigator, + reset_local_planner=self.local_planner.reset, + check_goal_reached=self.local_planner.is_goal_reached, + ) self.frontier_explorer = self.dimos.deploy(WavefrontFrontierExplorer) self.navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) self.navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) self.navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) + self.navigator.navigation_state.transport = core.LCMTransport("/navigation_state", String) self.navigator.global_costmap.transport = core.LCMTransport( "/global_costmap", OccupancyGrid ) @@ -333,7 +361,8 @@ def _deploy_visualization(self): self.foxglove_bridge = FoxgloveBridge() def _deploy_perception(self): - """Deploy and configure the spatial memory module.""" + """Deploy and configure perception modules.""" + # Deploy spatial memory self.spatial_memory_module = self.dimos.deploy( SpatialMemory, collection_name=self.spatial_memory_collection, @@ -342,11 +371,65 @@ def _deploy_perception(self): output_dir=self.spatial_memory_dir, ) - self.spatial_memory_module.video.connect(self.connection.video) - self.spatial_memory_module.odom.connect(self.connection.odom) + self.spatial_memory_module.video.transport = core.LCMTransport("/go2/color_image", Image) + self.spatial_memory_module.odom.transport = core.LCMTransport( + "/go2/camera_pose", PoseStamped + ) logger.info("Spatial memory module deployed and connected") + # Deploy object tracker + self.object_tracker = self.dimos.deploy( + ObjectTracking, + camera_intrinsics=self.camera_intrinsics, + frame_id="camera_link", + ) + + # Set up transports + self.object_tracker.detection2darray.transport = core.LCMTransport( + "/go2/detection2d", Detection2DArray + ) + self.object_tracker.detection3darray.transport = core.LCMTransport( + "/go2/detection3d", Detection3DArray + ) + self.object_tracker.tracked_overlay.transport = core.LCMTransport( + "/go2/tracked_overlay", Image + ) + + logger.info("Object tracker module deployed") + + def _deploy_camera(self): + """Deploy and configure the camera module.""" + self.camera_module = self.dimos.deploy( + UnitreeCameraModule, + camera_intrinsics=self.camera_intrinsics, + camera_frame_id="camera_link", + base_frame_id="base_link", + ) + + # Set up transports + self.camera_module.color_image.transport = core.LCMTransport("/go2/color_image", Image) + self.camera_module.depth_image.transport = core.LCMTransport("/go2/depth_image", Image) + self.camera_module.depth_colorized.transport = core.LCMTransport( + "/go2/depth_colorized", Image + ) + self.camera_module.camera_info.transport = core.LCMTransport("/go2/camera_info", CameraInfo) + self.camera_module.camera_pose.transport = core.LCMTransport( + "/go2/camera_pose", PoseStamped + ) + + # Connect video input from connection module + self.camera_module.video.connect(self.connection.video) + + logger.info("Camera module deployed and connected") + + # Connect object tracker inputs after camera module is deployed + if self.object_tracker: + self.object_tracker.color_image.connect(self.camera_module.color_image) + self.object_tracker.depth.connect(self.camera_module.depth_image) + self.object_tracker.camera_info.connect(self.camera_module.camera_info) + logger.info("Object tracker connected to camera module") + def _start_modules(self): """Start all deployed modules in the correct order.""" self.connection.start() @@ -357,9 +440,9 @@ def _start_modules(self): self.frontier_explorer.start() self.websocket_vis.start() self.foxglove_bridge.start() - - if self.spatial_memory_module: - self.spatial_memory_module.start() + self.spatial_memory_module.start() + self.camera_module.start() + self.object_tracker.start() # Initialize skills after connection is established if self.skill_library is not None: @@ -371,6 +454,10 @@ def _start_modules(self): self.skill_library.init() self.skill_library.initialize_skills() + def get_single_rgb_frame(self, timeout: float = 2.0) -> Image: + topic = Topic("/go2/color_image", Image) + return self.lcm.wait_for_message(topic, timeout=timeout) + def move(self, vector: Vector3, duration: float = 0.0): """Send movement command to robot.""" self.connection.move(vector, duration) @@ -398,7 +485,22 @@ def navigate_to(self, pose: PoseStamped, blocking: bool = True): logger.info( f"Navigating to pose: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" ) - return self.navigator.set_goal(pose, blocking=blocking) + self.navigator.set_goal(pose) + time.sleep(1.0) + + if blocking: + while self.navigator.get_state() == NavigatorState.FOLLOWING_PATH: + time.sleep(0.25) + + time.sleep(1.0) + if not self.navigator.is_goal_reached(): + logger.info("Navigation was cancelled or failed") + return False + else: + logger.info("Navigation goal reached") + return True + + return True def stop_exploration(self) -> bool: """Stop autonomous exploration. @@ -425,14 +527,6 @@ def spatial_memory(self) -> Optional[SpatialMemory]: """ return self.spatial_memory_module - def get_skills(self): - """Get the robot's skill library. - - Returns: - The robot's skill library for adding/managing skills - """ - return self.skill_library - def get_odom(self) -> PoseStamped: """Get the robot's odometry. @@ -441,6 +535,64 @@ def get_odom(self) -> PoseStamped: """ return self.connection.get_odom() + def navigate_to_object(self, bbox: List[float], distance: float = 0.5, timeout: float = 30.0): + """Navigate to an object by tracking it and maintaining a specified distance. + + Args: + bbox: Bounding box of the object to track [x1, y1, x2, y2] + distance: Distance to maintain from the object (meters) + timeout: Total timeout for the navigation (seconds) + + Returns: + True if navigation completed successfully, False otherwise + """ + if not self.object_tracker: + logger.error("Object tracker not initialized") + return False + + logger.info(f"Starting object tracking with bbox: {bbox}") + self.object_tracker.track(bbox) + + start_time = time.time() + goal_set = False + + while time.time() - start_time < timeout: + if self.navigator.get_state() == NavigatorState.IDLE and goal_set: + logger.info("Waiting for goal result") + time.sleep(1.0) + if not self.navigator.is_goal_reached(): + logger.info("Goal cancelled, object tracking failed") + return False + else: + logger.info("Object tracking goal reached") + return True + + if not self.object_tracker.is_tracking(): + continue + + detection_topic = Topic("/go2/detection3d", Detection3DArray) + detection_msg = self.lcm.wait_for_message(detection_topic, timeout=1.0) + + if detection_msg and len(detection_msg.detections) > 0: + target_pose = extract_pose_from_detection3d(detection_msg.detections[0]) + + retracted_pose = offset_distance( + target_pose, distance, approach_vector=Vector3(-1, 0, 0) + ) + + goal_pose = PoseStamped( + frame_id=detection_msg.header.frame_id, + position=retracted_pose.position, + orientation=retracted_pose.orientation, + ) + self.navigator.set_goal(goal_pose) + goal_set = True + + time.sleep(0.25) + + logger.info("Object tracking timed out") + return False + def main(): """Main entry point.""" diff --git a/dimos/skills/manipulation/pick_and_place.py b/dimos/skills/manipulation/pick_and_place.py index 4306975d8d..15570d5373 100644 --- a/dimos/skills/manipulation/pick_and_place.py +++ b/dimos/skills/manipulation/pick_and_place.py @@ -26,7 +26,7 @@ import numpy as np from pydantic import Field -from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.skills.skills import AbstractRobotSkill from dimos.models.qwen.video_query import query_single_frame from dimos.utils.logging_config import setup_logger @@ -170,7 +170,7 @@ def parse_qwen_single_point_response(response: str) -> Optional[Tuple[int, int]] return None -class PickAndPlace(AbstractManipulationSkill): +class PickAndPlace(AbstractRobotSkill): """ A skill that performs pick and place operations using vision-language guidance. diff --git a/dimos/skills/navigation.py b/dimos/skills/navigation.py index e5ead5ab85..c6b51b2ddd 100644 --- a/dimos/skills/navigation.py +++ b/dimos/skills/navigation.py @@ -22,18 +22,16 @@ import os import time -import threading from typing import Optional, Tuple -from dimos.utils.threadpool import get_scheduler -from reactivex import operators as ops from pydantic import Field from dimos.skills.skills import AbstractRobotSkill from dimos.types.robot_location import RobotLocation from dimos.utils.logging_config import setup_logger from dimos.models.qwen.video_query import get_bbox_from_qwen_frame -from dimos.utils.transform_utils import distance_angle_to_goal_xy +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.utils.transform_utils import euler_to_quaternion logger = setup_logger("dimos.skills.semantic_map_skills") @@ -75,7 +73,7 @@ class NavigateWithText(AbstractRobotSkill): query: str = Field("", description="Text query to search for in the semantic map") limit: int = Field(1, description="Maximum number of results to return") - distance: float = Field(1.0, description="Desired distance to maintain from object in meters") + distance: float = Field(0.3, description="Desired distance to maintain from object in meters") skip_visual_search: bool = Field(False, description="Skip visual search for object in view") timeout: float = Field(40.0, description="Maximum time to spend navigating in seconds") @@ -88,12 +86,8 @@ def __init__(self, robot=None, **data): **data: Additional data for configuration """ super().__init__(robot=robot, **data) - self._stop_event = threading.Event() self._spatial_memory = None - self._scheduler = get_scheduler() # Use the shared DiMOS thread pool - self._navigation_disposable = None # Disposable returned by scheduler.schedule() - self._tracking_subscriber = None # For object tracking - self._similarity_threshold = 0.25 + self._similarity_threshold = 0.24 def _navigate_to_object(self): """ @@ -102,128 +96,58 @@ def _navigate_to_object(self): Returns: dict: Result dictionary with success status and details """ - # Stop any existing operation - self._stop_event.clear() + logger.info( + f"Attempting to navigate to visible object: {self.query} with desired distance {self.distance}m, timeout {self.timeout} seconds..." + ) + # Try to get a bounding box from Qwen + bbox = None try: - logger.warning( - f"Attempting to navigate to visible object: {self.query} with desired distance {self.distance}m, timeout {self.timeout} seconds..." - ) - - # Try to get a bounding box from Qwen - only try once - bbox = None - try: - # Use the robot's existing video stream instead of creating a new one - frame = self._robot.get_video_stream().pipe(ops.take(1)).run() - # Use the frame-based function - bbox, object_size = get_bbox_from_qwen_frame(frame, object_name=self.query) - except Exception as e: - logger.error(f"Error querying Qwen: {e}") + # Get a single frame from the robot's camera + frame = self._robot.get_single_rgb_frame() + if frame is None: + logger.error("Failed to get camera frame") return { "success": False, "failure_reason": "Perception", - "error": f"Could not detect {self.query} in view: {e}", - } - - if bbox is None or self._stop_event.is_set(): - logger.error(f"Failed to get bounding box for {self.query}") - return { - "success": False, - "failure_reason": "Perception", - "error": f"Could not find {self.query} in view", - } - - logger.info(f"Found {self.query} at {bbox} with size {object_size}") - - # Start the object tracker with the detected bbox - self._robot.object_tracker.track(bbox, frame=frame) - - # Get the first tracking data with valid distance and angle - start_time = time.time() - target_acquired = False - goal_x_robot = 0 - goal_y_robot = 0 - goal_angle = 0 - - while ( - time.time() - start_time < 10.0 - and not self._stop_event.is_set() - and not target_acquired - ): - # Get the latest tracking data - tracking_data = self._robot.object_tracking_stream.pipe(ops.take(1)).run() - - if tracking_data and tracking_data.get("targets") and tracking_data["targets"]: - target = tracking_data["targets"][0] - - if "distance" in target and "angle" in target: - # Convert target distance and angle to xy coordinates in robot frame - goal_distance = ( - target["distance"] - self.distance - ) # Subtract desired distance to stop short - goal_angle = -target["angle"] - logger.info(f"Target distance: {goal_distance}, Target angle: {goal_angle}") - - goal_x_robot, goal_y_robot = distance_angle_to_goal_xy( - goal_distance, goal_angle - ) - target_acquired = True - break - - else: - logger.warning("No valid target tracking data found.") - - else: - logger.warning("No valid target tracking data found.") - - time.sleep(0.1) - - if not target_acquired: - logger.error("Failed to acquire valid target tracking data") - return { - "success": False, - "failure_reason": "Perception", - "error": "Failed to track object", + "error": "Could not get camera frame", } + bbox = get_bbox_from_qwen_frame(frame.data, object_name=self.query) + except Exception as e: + logger.error(f"Error getting frame or bbox: {e}") + return { + "success": False, + "failure_reason": "Perception", + "error": f"Error getting frame or bbox: {e}", + } + if bbox is None: + logger.error(f"Failed to get bounding box for {self.query}") + return { + "success": False, + "failure_reason": "Perception", + "error": f"Could not find {self.query} in view", + } - logger.info( - f"Navigating to target at local coordinates: ({goal_x_robot:.2f}, {goal_y_robot:.2f}), angle: {goal_angle:.2f}" - ) + logger.info(f"Found {self.query} at {bbox}") - # Use navigate_to_goal_local instead of directly controlling the local planner - success = navigate_to_goal_local( - robot=self._robot, - goal_xy_robot=(goal_x_robot, goal_y_robot), - goal_theta=goal_angle, - distance=0.0, # We already accounted for desired distance - timeout=self.timeout, - stop_event=self._stop_event, - ) + # Use the robot's navigate_to_object method + success = self._robot.navigate_to_object(bbox, self.distance, self.timeout) - if success: - logger.info(f"Successfully navigated to {self.query}") - return { - "success": True, - "failure_reason": None, - "query": self.query, - "message": f"Successfully navigated to {self.query} in view", - } - else: - logger.warning( - f"Failed to reach {self.query} within timeout or operation was stopped" - ) - return { - "success": False, - "failure_reason": "Navigation", - "error": f"Failed to reach {self.query} within timeout", - } - - except Exception as e: - logger.error(f"Error in navigate to object: {e}") - return {"success": False, "failure_reason": "Code Error", "error": f"Error: {e}"} - finally: - # Clean up - self._robot.object_tracker.cleanup() + if success: + logger.info(f"Successfully navigated to {self.query}") + return { + "success": True, + "failure_reason": None, + "query": self.query, + "message": f"Successfully navigated to {self.query} in view", + } + else: + logger.warning(f"Failed to reach {self.query} within timeout") + return { + "success": False, + "failure_reason": "Navigation", + "error": f"Failed to reach {self.query} within timeout", + } def _navigate_using_semantic_map(self): """ @@ -235,10 +159,10 @@ def _navigate_using_semantic_map(self): logger.info(f"Querying semantic map for: '{self.query}'") try: - self._spatial_memory = self._robot.get_spatial_memory() + self._spatial_memory = self._robot.spatial_memory # Run the query - results = self._spatial_memory.query_by_text(self.query, limit=self.limit) + results = self._spatial_memory.query_by_text(self.query, self.limit) if not results: logger.warning(f"No results found for query: '{self.query}'") @@ -289,56 +213,40 @@ def _navigate_using_semantic_map(self): "error": f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})", } - # Reset the stop event before starting navigation - self._stop_event.clear() - - # The scheduler approach isn't working, switch to direct threading - # Define a navigation function that will run on a separate thread - def run_navigation(): - skill_library = self._robot.get_skills() - self.register_as_running("Navigate", skill_library) - - try: - logger.info( - f"Starting navigation to ({pos_x:.2f}, {pos_y:.2f}) with rotation {theta:.2f}" - ) - # Pass our stop_event to allow cancellation - result = False - try: - result = self._robot.global_planner.set_goal( - (pos_x, pos_y), goal_theta=theta, stop_event=self._stop_event - ) - except Exception as e: - logger.error(f"Error calling global_planner.set_goal: {e}") - - if result: - logger.info("Navigation completed successfully") - else: - logger.error("Navigation did not complete successfully") - return result - except Exception as e: - logger.error(f"Unexpected error in navigation thread: {e}") - return False - finally: - self.stop() - - # Cancel any existing navigation before starting a new one - # Signal stop to any running navigation - self._stop_event.set() - # Clear stop event for new navigation - self._stop_event.clear() - - # Run the navigation in the main thread - run_navigation() + # Create a PoseStamped for navigation + goal_pose = PoseStamped( + position=Vector3(pos_x, pos_y, 0), + orientation=euler_to_quaternion(Vector3(0, 0, theta)), + frame_id="world", + ) - return { - "success": True, - "query": self.query, - "position": (pos_x, pos_y), - "rotation": theta, - "similarity": similarity, - "metadata": metadata, - } + logger.info( + f"Starting navigation to ({pos_x:.2f}, {pos_y:.2f}) with rotation {theta:.2f}" + ) + + # Use the robot's navigate_to method + result = self._robot.navigate_to(goal_pose, blocking=True) + + if result: + logger.info("Navigation completed successfully") + return { + "success": True, + "query": self.query, + "position": (pos_x, pos_y), + "rotation": theta, + "similarity": similarity, + "metadata": metadata, + } + else: + logger.error("Navigation did not complete successfully") + return { + "success": False, + "query": self.query, + "position": (pos_x, pos_y), + "rotation": theta, + "similarity": similarity, + "error": "Navigation did not complete successfully", + } else: logger.warning(f"No valid position data found for query: '{self.query}'") return { @@ -397,21 +305,12 @@ def stop(self): """ logger.info("Stopping Navigate skill") - # Signal any running processes to stop via the shared event - self._stop_event.set() + # Cancel navigation + self._robot.cancel_navigation() skill_library = self._robot.get_skills() self.unregister_as_running("Navigate", skill_library) - # Dispose of any existing navigation task - if hasattr(self, "_navigation_disposable") and self._navigation_disposable: - logger.info("Disposing navigation task") - try: - self._navigation_disposable.dispose() - except Exception as e: - logger.error(f"Error disposing navigation task: {e}") - self._navigation_disposable = None - return "Navigate skill stopped successfully." @@ -478,7 +377,7 @@ def __call__(self): # If location_name is provided, remember this location if self.location_name: # Get the spatial memory instance - spatial_memory = self._robot.get_spatial_memory() + spatial_memory = self._robot.spatial_memory # Create a RobotLocation object location = RobotLocation( @@ -528,7 +427,6 @@ def __init__(self, robot=None, **data): **data: Additional data for configuration """ super().__init__(robot=robot, **data) - self._stop_event = threading.Event() def __call__(self): """ @@ -544,9 +442,6 @@ def __call__(self): logger.error(error_msg) return {"success": False, "error": error_msg} - # Reset stop event to make sure we don't immediately abort - self._stop_event.clear() - skill_library = self._robot.get_skills() self.register_as_running("NavigateToGoal", skill_library) @@ -557,11 +452,15 @@ def __call__(self): ) try: - # Use the global planner to set the goal and generate a path - result = self._robot.global_planner.set_goal( - self.position, goal_theta=self.rotation, stop_event=self._stop_event + # Create a PoseStamped for navigation + goal_pose = PoseStamped( + position=Vector3(self.position[0], self.position[1], 0), + orientation=euler_to_quaternion(Vector3(0, 0, self.rotation or 0)), ) + # Use the robot's navigate_to method + result = self._robot.navigate_to(goal_pose, blocking=True) + if result: logger.info("Navigation completed successfully") return { @@ -601,7 +500,7 @@ def stop(self): logger.info("Stopping NavigateToGoal") skill_library = self._robot.get_skills() self.unregister_as_running("NavigateToGoal", skill_library) - self._stop_event.set() + self._robot.cancel_navigation() return "Navigation stopped" @@ -615,7 +514,7 @@ class Explore(AbstractRobotSkill): Don't save GetPose locations when frontier exploring. Don't call any other skills except stop skill when needed. """ - timeout: float = Field(60.0, description="Maximum time (in seconds) allowed for exploration") + timeout: float = Field(240.0, description="Maximum time (in seconds) allowed for exploration") def __init__(self, robot=None, **data): """ @@ -626,7 +525,6 @@ def __init__(self, robot=None, **data): **data: Additional data for configuration """ super().__init__(robot=robot, **data) - self._stop_event = threading.Event() def __call__(self): """ @@ -642,9 +540,6 @@ def __call__(self): logger.error(error_msg) return {"success": False, "error": error_msg} - # Reset stop event to make sure we don't immediately abort - self._stop_event.clear() - skill_library = self._robot.get_skills() self.register_as_running("Explore", skill_library) @@ -660,13 +555,6 @@ def __call__(self): # Wait for exploration to complete or timeout start_time = time.time() while time.time() - start_time < self.timeout: - if self._stop_event.is_set(): - logger.info("Exploration stopped by user") - self._robot.stop_exploration() - return { - "success": False, - "message": "Exploration stopped by user", - } time.sleep(0.5) # Timeout reached, stop exploration @@ -703,7 +591,6 @@ def stop(self): logger.info("Stopping Explore") skill_library = self._robot.get_skills() self.unregister_as_running("Explore", skill_library) - self._stop_event.set() # Stop the robot's exploration if it's running try: diff --git a/dimos/skills/observe.py b/dimos/skills/observe.py deleted file mode 100644 index 8a934bf34d..0000000000 --- a/dimos/skills/observe.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Observer skill for an agent. - -This module provides a skill that sends a single image from any -Robot Data Stream to the Qwen VLM for inference and adds the response -to the agent's conversation history. -""" - -import time -from typing import Optional -import base64 -import cv2 -import numpy as np -import reactivex as rx -from reactivex import operators as ops -from pydantic import Field - -from dimos.skills.skills import AbstractRobotSkill -from dimos.agents.agent import LLMAgent -from dimos.models.qwen.video_query import query_single_frame -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.skills.observe") - - -class Observe(AbstractRobotSkill): - """ - A skill that captures a single frame from a Robot Video Stream, sends it to a VLM, - and adds the response to the agent's conversation history. - - This skill is used for visual reasoning, spatial understanding, or any queries involving visual information that require critical thinking. - """ - - query_text: str = Field( - "What do you see in this image? Describe the environment in detail.", - description="Query text to send to the VLM model with the image", - ) - - def __init__(self, robot=None, agent: Optional[LLMAgent] = None, **data): - """ - Initialize the Observe skill. - - Args: - robot: The robot instance - agent: The agent to store results in - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - self._agent = agent - self._model_name = "qwen2.5-vl-72b-instruct" - - # Get the video stream from the robot - self._video_stream = self._robot.video_stream - if self._video_stream is None: - logger.error("Failed to get video stream from robot") - - def __call__(self): - """ - Capture a single frame, process it with Qwen, and add the result to conversation history. - - Returns: - A message indicating the observation result - """ - super().__call__() - - if self._agent is None: - error_msg = "No agent provided to Observe skill" - logger.error(error_msg) - return error_msg - - if self._robot is None: - error_msg = "No robot instance provided to Observe skill" - logger.error(error_msg) - return error_msg - - if self._video_stream is None: - error_msg = "No video stream available" - logger.error(error_msg) - return error_msg - - try: - logger.info("Capturing frame for Qwen observation") - - # Get a single frame from the video stream - frame = self._get_frame_from_stream() - - if frame is None: - error_msg = "Failed to capture frame from video stream" - logger.error(error_msg) - return error_msg - - # Process the frame with Qwen - response = self._process_frame_with_qwen(frame) - - logger.info(f"Added Qwen observation to conversation history") - return f"Observation complete: {response}" - - except Exception as e: - error_msg = f"Error in Observe skill: {e}" - logger.error(error_msg) - return error_msg - - def _get_frame_from_stream(self): - """ - Get a single frame from the video stream. - - Returns: - A single frame from the video stream, or None if no frame is available - """ - if self._video_stream is None: - logger.error("Video stream is None") - return None - - frame = None - frame_subject = rx.subject.Subject() - - subscription = self._video_stream.pipe( - ops.take(1) # Take just one frame - ).subscribe( - on_next=lambda x: frame_subject.on_next(x), - on_error=lambda e: logger.error(f"Error getting frame: {e}"), - ) - - # Wait up to 5 seconds for a frame - timeout = 5.0 - start_time = time.time() - - def on_frame(f): - nonlocal frame - frame = f - - frame_subject.subscribe(on_frame) - - while frame is None and time.time() - start_time < timeout: - time.sleep(0.1) - - subscription.dispose() - return frame - - def _process_frame_with_qwen(self, frame): - """ - Process a frame with the Qwen model using query_single_frame. - - Args: - frame: The video frame to process (numpy array) - - Returns: - The response from Qwen - """ - logger.info(f"Processing frame with Qwen model: {self._model_name}") - - try: - # Ensure frame is in RGB format for Qwen - if isinstance(frame, np.ndarray): - # OpenCV uses BGR, convert to RGB if needed - if frame.shape[-1] == 3: # Check if it has color channels - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - else: - frame_rgb = frame - else: - raise ValueError("Frame must be a numpy array") - - # Query Qwen with the frame (direct function call) - response = query_single_frame( - frame_rgb, - self.query_text, - model_name=self._model_name, - ) - - logger.info(f"Qwen response received: {response[:100]}...") - return response - - except Exception as e: - logger.error(f"Error processing frame with Qwen: {e}") - raise diff --git a/dimos/skills/observe_stream.py b/dimos/skills/observe_stream.py deleted file mode 100644 index 1766ffe2aa..0000000000 --- a/dimos/skills/observe_stream.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Observer skill for an agent. - -This module provides a skill that periodically sends images from any -Robot Data Stream to an agent for inference. -""" - -import time -import threading -from typing import Optional -import base64 -import cv2 -import numpy as np -import reactivex as rx -from reactivex import operators as ops -from pydantic import Field - -from dimos.skills.skills import AbstractRobotSkill -from dimos.agents.agent import LLMAgent -from dimos.models.qwen.video_query import query_single_frame -from dimos.utils.threadpool import get_scheduler -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.skills.observe_stream") - - -class ObserveStream(AbstractRobotSkill): - """ - A skill that periodically Observes a Robot Video Stream and sends images to current instance of an agent for context. - - This skill runs in a non-halting manner, allowing other skills to run concurrently. - It can be used for continuous perception and passive monitoring, such as waiting for a person to enter a room - or to monitor changes in the environment. - """ - - timestep: float = Field( - 60.0, description="Time interval in seconds between observation queries" - ) - query_text: str = Field( - "What do you see in this image? Alert me if you see any people or important changes.", - description="Query text to send to agent with each image", - ) - max_duration: float = Field( - 0.0, description="Maximum duration to run the observer in seconds (0 for indefinite)" - ) - - def __init__(self, robot=None, agent: Optional[LLMAgent] = None, video_stream=None, **data): - """ - Initialize the ObserveStream skill. - - Args: - robot: The robot instance - agent: The agent to send queries to - **data: Additional data for configuration - """ - super().__init__(robot=robot, **data) - self._agent = agent - self._stop_event = threading.Event() - self._monitor_thread = None - self._scheduler = get_scheduler() - self._subscription = None - - # Get the video stream - # TODO: Use the video stream provided in the constructor for dynamic video_stream selection by the agent - self._video_stream = self._robot.video_stream - if self._video_stream is None: - logger.error("Failed to get video stream from robot") - return - - def __call__(self): - """ - Start the observing process in a separate thread using the threadpool. - - Returns: - A message indicating the observer has started - """ - super().__call__() - - if self._agent is None: - error_msg = "No agent provided to ObserveStream" - logger.error(error_msg) - return error_msg - - if self._robot is None: - error_msg = "No robot instance provided to ObserveStream" - logger.error(error_msg) - return error_msg - - self.stop() - - self._stop_event.clear() - - # Initialize start time for duration tracking - self._start_time = time.time() - - interval_observable = rx.interval(self.timestep, scheduler=self._scheduler).pipe( - ops.take_while(lambda _: not self._stop_event.is_set()) - ) - - # Subscribe to the interval observable - self._subscription = interval_observable.subscribe( - on_next=self._monitor_iteration, - on_error=lambda e: logger.error(f"Error in monitor observable: {e}"), - on_completed=lambda: logger.info("Monitor observable completed"), - ) - - skill_library = self._robot.get_skills() - self.register_as_running("ObserveStream", skill_library, self._subscription) - - logger.info(f"Observer started with timestep={self.timestep}s, query='{self.query_text}'") - return f"Observer started with timestep={self.timestep}s, query='{self.query_text}'" - - def _monitor_iteration(self, iteration): - """ - Execute a single observer iteration. - - Args: - iteration: The iteration number (provided by rx.interval) - """ - try: - if self.max_duration > 0: - elapsed_time = time.time() - self._start_time - if elapsed_time > self.max_duration: - logger.info(f"Observer reached maximum duration of {self.max_duration}s") - self.stop() - return - - logger.info(f"Observer iteration {iteration} executing") - - # Get a frame from the video stream - frame = self._get_frame_from_stream() - - if frame is not None: - self._process_frame(frame) - else: - logger.warning("Failed to get frame from video stream") - - except Exception as e: - logger.error(f"Error in monitor iteration {iteration}: {e}") - - def _get_frame_from_stream(self): - """ - Get a single frame from the video stream. - - Args: - video_stream: The ROS video stream observable - - Returns: - A single frame from the video stream, or None if no frame is available - """ - frame = None - - frame_subject = rx.subject.Subject() - - subscription = self._video_stream.pipe( - ops.take(1) # Take just one frame - ).subscribe( - on_next=lambda x: frame_subject.on_next(x), - on_error=lambda e: logger.error(f"Error getting frame: {e}"), - ) - - timeout = 5.0 # 5 seconds timeout - start_time = time.time() - - def on_frame(f): - nonlocal frame - frame = f - - frame_subject.subscribe(on_frame) - - while frame is None and time.time() - start_time < timeout: - time.sleep(0.1) - - subscription.dispose() - - return frame - - def _process_frame(self, frame): - """ - Process a frame with the Qwen VLM and add the response to conversation history. - - Args: - frame: The video frame to process - """ - logger.info("Processing frame with Qwen VLM") - - try: - # Ensure frame is in RGB format for Qwen - if isinstance(frame, np.ndarray): - # OpenCV uses BGR, convert to RGB if needed - if frame.shape[-1] == 3: # Check if it has color channels - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - else: - frame_rgb = frame - else: - raise ValueError("Frame must be a numpy array") - - # Use Qwen to process the frame - model_name = "qwen2.5-vl-72b-instruct" # Using the most capable model - response = query_single_frame(frame_rgb, self.query_text, model_name=model_name) - - logger.info(f"Qwen response received: {response[:100]}...") - - # Add the response to the conversation history - # self._agent.append_to_history( - # f"Observation: {response}", - # ) - response = self._agent.run_observable_query(f"Observation: {response}") - - logger.info("Added Qwen observation to conversation history") - - except Exception as e: - logger.error(f"Error processing frame with Qwen VLM: {e}") - - def stop(self): - """ - Stop the ObserveStream monitoring process. - - Returns: - A message indicating the observer has stopped - """ - if self._subscription is not None and not self._subscription.is_disposed: - logger.info("Stopping ObserveStream") - self._stop_event.set() - self._subscription.dispose() - self._subscription = None - - return "Observer stopped" - return "Observer was not running" diff --git a/dimos/utils/test_transform_utils.py b/dimos/utils/test_transform_utils.py index 94d15e9ee4..85128ac09c 100644 --- a/dimos/utils/test_transform_utils.py +++ b/dimos/utils/test_transform_utils.py @@ -570,6 +570,12 @@ def test_same_pose(self): distance = transform_utils.get_distance(pose1, pose2) assert np.isclose(distance, 0) + def test_vector_distance(self): + pose1 = Vector3(1, 2, 3) + pose2 = Vector3(4, 5, 6) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, np.sqrt(3**2 + 3**2 + 3**2)) + def test_distance_x_axis(self): pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) pose2 = Pose(Vector3(5, 0, 0), Quaternion(0, 0, 0, 1)) @@ -607,7 +613,7 @@ def test_retract_along_negative_z(self): # Default case: gripper approaches along -z axis # Positive distance moves away from the surface (opposite to approach direction) target_pose = Pose(Vector3(0, 0, 1), Quaternion(0, 0, 0, 1)) - retracted = transform_utils.retract_distance(target_pose, 0.5) + retracted = transform_utils.offset_distance(target_pose, 0.5) # Moving along -z approach vector with positive distance = retracting upward # Since approach is -z and we retract (positive distance), we move in +z @@ -627,7 +633,7 @@ def test_retract_with_rotation(self): q = r.as_quat() target_pose = Pose(Vector3(0, 0, 1), Quaternion(q[0], q[1], q[2], q[3])) - retracted = transform_utils.retract_distance(target_pose, 0.5) + retracted = transform_utils.offset_distance(target_pose, 0.5) # After 90 degree rotation around x, -z becomes +y assert np.isclose(retracted.position.x, 0) @@ -637,7 +643,7 @@ def test_retract_with_rotation(self): def test_retract_negative_distance(self): # Negative distance should move forward (toward the approach direction) target_pose = Pose(Vector3(0, 0, 1), Quaternion(0, 0, 0, 1)) - retracted = transform_utils.retract_distance(target_pose, -0.3) + retracted = transform_utils.offset_distance(target_pose, -0.3) # Moving along -z approach vector with negative distance = moving downward assert np.isclose(retracted.position.x, 0) @@ -651,7 +657,7 @@ def test_retract_arbitrary_pose(self): target_pose = Pose(Vector3(5, 3, 2), Quaternion(q[0], q[1], q[2], q[3])) distance = 1.0 - retracted = transform_utils.retract_distance(target_pose, distance) + retracted = transform_utils.offset_distance(target_pose, distance) # Verify the distance between original and retracted is as expected # (approximately, due to the approach vector direction) diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 0b93b9a0f3..5b49d285cc 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -324,7 +324,7 @@ def quaternion_to_euler(quaternion: Quaternion, degrees: bool = False) -> Vector return Vector3(euler[0], euler[1], euler[2]) -def get_distance(pose1: Pose, pose2: Pose) -> float: +def get_distance(pose1: Pose | Vector3, pose2: Pose | Vector3) -> float: """ Calculate Euclidean distance between two poses. @@ -335,18 +335,25 @@ def get_distance(pose1: Pose, pose2: Pose) -> float: Returns: Euclidean distance between the two poses in meters """ - dx = pose1.position.x - pose2.position.x - dy = pose1.position.y - pose2.position.y - dz = pose1.position.z - pose2.position.z + if hasattr(pose1, "position"): + pose1 = pose1.position + if hasattr(pose2, "position"): + pose2 = pose2.position + + dx = pose1.x - pose2.x + dy = pose1.y - pose2.y + dz = pose1.z - pose2.z return np.linalg.norm(np.array([dx, dy, dz])) -def retract_distance(target_pose: Pose, distance: float) -> Pose: +def offset_distance( + target_pose: Pose, distance: float, approach_vector: Vector3 = Vector3(0, 0, -1) +) -> Pose: """ Apply distance offset to target pose along its approach direction. - This is commonly used in grasping to retract the gripper by a certain distance + This is commonly used in grasping to offset the gripper by a certain distance along the approach vector before or after grasping. Args: @@ -363,7 +370,7 @@ def retract_distance(target_pose: Pose, distance: float) -> Pose: # Define the approach vector based on the target pose orientation # Assuming the gripper approaches along its local -z axis (common for downward grasps) # You can change this to [1, 0, 0] for x-axis or [0, 1, 0] for y-axis based on your gripper - approach_vector_local = np.array([0, 0, -1]) + approach_vector_local = np.array([approach_vector.x, approach_vector.y, approach_vector.z]) # Transform approach vector to world coordinates approach_vector_world = rotation_matrix @ approach_vector_local diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 02f56b8460..878f39eef8 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -19,10 +19,11 @@ """ import asyncio -import concurrent.futures import os import threading from typing import Any, Dict, Optional +import base64 +import numpy as np import socketio import uvicorn @@ -78,7 +79,8 @@ def __init__(self, port: int = 7779, **kwargs): self.server_thread: Optional[threading.Thread] = None self.sio: Optional[socketio.AsyncServer] = None self.app = None - self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self._broadcast_loop = None + self._broadcast_thread = None # Visualization state self.vis_state = { @@ -90,12 +92,31 @@ def __init__(self, port: int = 7779, **kwargs): logger.info(f"WebSocket visualization module initialized on port {port}") + def _start_broadcast_loop(self): + """Start the broadcast event loop in a background thread.""" + + def run_loop(): + self._broadcast_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._broadcast_loop) + try: + self._broadcast_loop.run_forever() + except Exception as e: + logger.error(f"Broadcast loop error: {e}") + finally: + self._broadcast_loop.close() + + self._broadcast_thread = threading.Thread(target=run_loop, daemon=True) + self._broadcast_thread.start() + @rpc def start(self): """Start the WebSocket server and subscribe to inputs.""" # Create the server self._create_server() + # Start the broadcast event loop in a background thread + self._start_broadcast_loop() + # Start the server in a background thread self.server_thread = threading.Thread(target=self._run_server, daemon=True) self.server_thread.start() @@ -110,8 +131,10 @@ def start(self): @rpc def stop(self): """Stop the WebSocket server.""" - if self._executor: - self._executor.shutdown(wait=True) + if self._broadcast_loop and not self._broadcast_loop.is_closed(): + self._broadcast_loop.call_soon_threadsafe(self._broadcast_loop.stop) + if self._broadcast_thread and self._broadcast_thread.is_alive(): + self._broadcast_thread.join(timeout=1.0) logger.info("WebSocket visualization module stopped") def _create_server(self): @@ -200,8 +223,7 @@ def _on_global_costmap(self, msg: OccupancyGrid): def _process_costmap(self, costmap: OccupancyGrid) -> Dict[str, Any]: """Convert OccupancyGrid to visualization format.""" - import base64 - import numpy as np + costmap = costmap.inflate(0.1).gradient(max_distance=1.0) # Convert grid data to base64 encoded string grid_bytes = costmap.grid.astype(np.float32).tobytes() @@ -236,12 +258,7 @@ def _update_state(self, new_data: Dict[str, Any]): self.vis_state.update(new_data) # Broadcast update asynchronously - def broadcast(): - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self.sio.emit("state_update", new_data)) - except Exception as e: - logger.error(f"Failed to broadcast state update: {e}") - - self._executor.submit(broadcast) + if self._broadcast_loop and not self._broadcast_loop.is_closed(): + asyncio.run_coroutine_threadsafe( + self.sio.emit("state_update", new_data), self._broadcast_loop + ) diff --git a/pyproject.toml b/pyproject.toml index 43604151da..80ad9f0f7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ dependencies = [ "dask[complete]==2025.5.1", # LCM / DimOS utilities - "dimos-lcm @ git+https://github.com/dimensionalOS/dimos-lcm.git@ba3445d16be75a7ade6fb2a516b39a3e44319d5c" + "dimos-lcm @ git+https://github.com/dimensionalOS/dimos-lcm.git@de4038871a4f166c3007ef6b6bc3ff83642219b2" ] diff --git a/tests/test_object_tracking_module.py b/tests/test_object_tracking_module.py new file mode 100755 index 0000000000..374315c184 --- /dev/null +++ b/tests/test_object_tracking_module.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for Object Tracking module with ZED camera.""" + +import asyncio +import cv2 + +from dimos import core +from dimos.hardware.zed_camera import ZEDModule +from dimos.perception.object_tracker import ObjectTracking +from dimos.protocol import pubsub +from dimos.utils.logging_config import setup_logger +from dimos.robot.foxglove_bridge import FoxgloveBridge + +# Import message types +from dimos.msgs.sensor_msgs import Image +from dimos_lcm.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic + +logger = setup_logger("test_object_tracking_module") + +# Suppress verbose Foxglove bridge warnings +import logging + +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) + + +class TrackingVisualization: + """Handles visualization and user interaction for object tracking.""" + + def __init__(self): + self.lcm = LCM() + self.latest_color = None + + # Mouse interaction state + self.selecting_bbox = False + self.bbox_start = None + self.current_bbox = None + self.tracking_active = False + + # Subscribe to color image topic only + self.color_topic = Topic("/zed/color_image", Image) + + def start(self): + """Start the visualization node.""" + self.lcm.start() + + # Subscribe to color image only + self.lcm.subscribe(self.color_topic, self._on_color_image) + + logger.info("Visualization started, subscribed to color image topic") + + def _on_color_image(self, msg: Image, _: str): + """Handle color image messages.""" + try: + # Convert dimos Image to OpenCV format (BGR) for display + self.latest_color = msg.to_opencv() + logger.debug(f"Received color image: {msg.width}x{msg.height}, format: {msg.format}") + except Exception as e: + logger.error(f"Error processing color image: {e}") + + def mouse_callback(self, event, x, y, _, param): + """Handle mouse events for bbox selection.""" + tracker_module = param.get("tracker") + + if event == cv2.EVENT_LBUTTONDOWN: + self.selecting_bbox = True + self.bbox_start = (x, y) + self.current_bbox = None + + elif event == cv2.EVENT_MOUSEMOVE and self.selecting_bbox: + # Update current selection for visualization + x1, y1 = self.bbox_start + self.current_bbox = [min(x1, x), min(y1, y), max(x1, x), max(y1, y)] + + elif event == cv2.EVENT_LBUTTONUP and self.selecting_bbox: + self.selecting_bbox = False + if self.bbox_start: + x1, y1 = self.bbox_start + x2, y2 = x, y + # Ensure valid bbox + bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] + + # Check if bbox is valid (has area) + if bbox[2] > bbox[0] and bbox[3] > bbox[1]: + # Call track RPC on the tracker module + if tracker_module: + result = tracker_module.track(bbox) + logger.info(f"Tracking initialized: {result}") + self.tracking_active = True + self.current_bbox = None + + def draw_interface(self, frame): + """Draw UI elements on the frame.""" + # Draw bbox selection if in progress + if self.selecting_bbox and self.current_bbox: + x1, y1, x2, y2 = self.current_bbox + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 255), 2) + + # Draw instructions + cv2.putText( + frame, + "Click and drag to select object", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + frame, + "Press 's' to stop tracking, 'q' to quit", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + + # Show tracking status + if self.tracking_active: + status = "Tracking Active" + color = (0, 255, 0) + else: + status = "No Target" + color = (0, 0, 255) + cv2.putText(frame, f"Status: {status}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) + + return frame + + +async def test_object_tracking_module(): + """Test object tracking with ZED camera module.""" + logger.info("Starting Object Tracking Module test") + + # Start Dimos + dimos = core.start(2) + + # Enable LCM auto-configuration + pubsub.lcm.autoconf() + + viz = None + tracker = None + zed = None + foxglove_bridge = None + + try: + # Deploy ZED module + logger.info("Deploying ZED module...") + zed = dimos.deploy( + ZEDModule, + camera_id=0, + resolution="HD720", + depth_mode="NEURAL", + fps=30, + enable_tracking=True, + publish_rate=15.0, + frame_id="zed_camera_link", + ) + + # Configure ZED LCM transports + zed.color_image.transport = core.LCMTransport("/zed/color_image", Image) + zed.depth_image.transport = core.LCMTransport("/zed/depth_image", Image) + zed.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) + zed.pose.transport = core.LCMTransport("/zed/pose", PoseStamped) + + # Start ZED to begin publishing + zed.start() + await asyncio.sleep(2) # Wait for camera to initialize + + # Deploy Object Tracking module + logger.info("Deploying Object Tracking module...") + tracker = dimos.deploy( + ObjectTracking, + camera_intrinsics=None, # Will get from camera_info topic + reid_threshold=5, + reid_fail_tolerance=10, + frame_id="zed_camera_link", + ) + + # Configure tracking LCM transports + tracker.color_image.transport = core.LCMTransport("/zed/color_image", Image) + tracker.depth.transport = core.LCMTransport("/zed/depth_image", Image) + tracker.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) + + # Configure output transports + from dimos_lcm.vision_msgs import Detection2DArray, Detection3DArray + + tracker.detection2darray.transport = core.LCMTransport( + "/detection2darray", Detection2DArray + ) + tracker.detection3darray.transport = core.LCMTransport( + "/detection3darray", Detection3DArray + ) + tracker.tracked_overlay.transport = core.LCMTransport("/tracked_overlay", Image) + + # Connect inputs + tracker.color_image.connect(zed.color_image) + tracker.depth.connect(zed.depth_image) + tracker.camera_info.connect(zed.camera_info) + + # Start tracker + tracker.start() + + # Create visualization + viz = TrackingVisualization() + viz.start() + + # Start Foxglove bridge for visualization + foxglove_bridge = FoxgloveBridge() + foxglove_bridge.start() + + # Give modules time to initialize + await asyncio.sleep(1) + + # Create OpenCV window and set mouse callback + cv2.namedWindow("Object Tracking") + cv2.setMouseCallback("Object Tracking", viz.mouse_callback, {"tracker": tracker}) + + logger.info("System ready. Click and drag to select an object to track.") + logger.info("Foxglove visualization available at http://localhost:8765") + + # Main visualization loop + while True: + # Get the color frame to display + if viz.latest_color is not None: + display_frame = viz.latest_color.copy() + else: + # Wait for frames + await asyncio.sleep(0.03) + continue + + # Draw UI elements + display_frame = viz.draw_interface(display_frame) + + # Show frame + cv2.imshow("Object Tracking", display_frame) + + # Handle keyboard input + key = cv2.waitKey(1) & 0xFF + if key == ord("q"): + logger.info("Quit requested") + break + elif key == ord("s"): + # Stop tracking + if tracker: + tracker.stop_track() + viz.tracking_active = False + logger.info("Tracking stopped") + + await asyncio.sleep(0.03) # ~30 FPS + + except Exception as e: + logger.error(f"Error in test: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + cv2.destroyAllWindows() + + if tracker: + tracker.cleanup() + if zed: + zed.stop() + if foxglove_bridge: + foxglove_bridge.stop() + + dimos.close() + logger.info("Test completed") + + +if __name__ == "__main__": + asyncio.run(test_object_tracking_module()) diff --git a/tests/test_observe_stream_skill.py b/tests/test_observe_stream_skill.py deleted file mode 100644 index 7f18789fb0..0000000000 --- a/tests/test_observe_stream_skill.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test for the monitor skill and kill skill. - -This script demonstrates how to use the monitor skill to periodically -send images from the robot's video stream to a Claude agent, and how -to use the kill skill to terminate the monitor skill. -""" - -import os -import time -import threading -from dotenv import load_dotenv -import reactivex as rx -from reactivex import operators as ops -import logging - -from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.kill_skill import KillSkill -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import setup_logger -import tests.test_header - -logger = setup_logger("tests.test_observe_stream_skill") - -load_dotenv() - - -def main(): - # Initialize the robot with mock connection for testing - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP", "192.168.123.161"), skills=MyUnitreeSkills(), mock_connection=True - ) - - agent_response_subject = rx.subject.Subject() - agent_response_stream = agent_response_subject.pipe(ops.share()) - - streams = {"unitree_video": robot.get_ros_video_stream()} - text_streams = { - "agent_responses": agent_response_stream, - } - - web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) - - agent = ClaudeAgent( - dev_name="test_agent", - input_query_stream=web_interface.query_stream, - skills=robot.get_skills(), - system_query="""You are an agent monitoring a robot's environment. - When you see an image, describe what you see and alert if you notice any people or important changes. - Be concise but thorough in your observations.""", - model_name="claude-3-7-sonnet-latest", - thinking_budget_tokens=10000, - ) - - agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - - robot_skills = robot.get_skills() - - robot_skills.add(ObserveStream) - robot_skills.add(KillSkill) - - robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) - robot_skills.create_instance("KillSkill", skill_library=robot_skills) - - web_interface_thread = threading.Thread(target=web_interface.run) - web_interface_thread.daemon = True - web_interface_thread.start() - - logger.info("Starting monitor skill...") - - memory_file = os.path.join(agent.output_dir, "memory.txt") - with open(memory_file, "a") as f: - f.write( - "SKILL CALL: ObserveStream(timestep=10.0, query_text='What do you see in this image? Alert me if you see any people.', max_duration=120.0)" - ) - - result = robot_skills.call( - "ObserveStream", - timestep=10.0, # 20 seconds between monitoring queries - query_text="What do you see in this image? Alert me if you see any people.", - max_duration=120.0, - ) # Run for 120 seconds - logger.info(f"Monitor skill result: {result}") - - logger.info(f"Running skills: {robot_skills.get_running_skills().keys()}") - - try: - logger.info("Observer running. Will stop after 35 seconds...") - time.sleep(20.0) - - logger.info(f"Running skills before kill: {robot_skills.get_running_skills().keys()}") - logger.info("Killing the observer skill...") - - memory_file = os.path.join(agent.output_dir, "memory.txt") - with open(memory_file, "a") as f: - f.write("\n\nSKILL CALL: KillSkill(skill_name='observer')\n\n") - - kill_result = robot_skills.call("KillSkill", skill_name="observer") - logger.info(f"Kill skill result: {kill_result}") - - logger.info(f"Running skills after kill: {robot_skills.get_running_skills().keys()}") - - # Keep test running until user interrupts - while True: - time.sleep(1.0) - except KeyboardInterrupt: - logger.info("Test interrupted by user") - - logger.info("Test completed") - - -if __name__ == "__main__": - main()