diff --git a/.envrc b/.envrc index e22018404a..a73ae4a035 100644 --- a/.envrc +++ b/.envrc @@ -2,3 +2,4 @@ if ! has nix_direnv_version || ! nix_direnv_version 3.0.6; then source_url "https://raw.githubusercontent.com/nix-community/nix-direnv/3.0.6/direnvrc" "sha256-RYcUJaRMf8oF5LznDrlCXbkOQrywm0HDv1VjYGaJGdM=" fi use flake . +dotenv diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 8bd7a030b2..c20293f60e 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -28,6 +28,7 @@ from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler from dimos.utils.logging_config import setup_logger +from dimos.navigation.bt_navigator.navigator import NavigatorState from reactivex.disposable import Disposable, CompositeDisposable logger = setup_logger(__file__) @@ -162,13 +163,41 @@ def _navigate_to_object(self, query: str) -> Optional[str]: logger.info(f"Found {query} at {bbox}") - success = self._robot.navigate_to_object(bbox) + # Start tracking - BBoxNavigationModule automatically generates goals + self._robot.object_tracker.track(bbox) - if not success: - logger.warning(f"Failed to navigate to '{query}' at {bbox}") - return None - - return "Successfully navigated to object from query '{query}'." + start_time = time.time() + timeout = 30.0 + goal_set = False + + while time.time() - start_time < timeout: + # Check if navigator finished + if self._robot.navigator.get_state() == NavigatorState.IDLE and goal_set: + logger.info("Waiting for goal result") + time.sleep(1.0) + if not self._robot.navigator.is_goal_reached(): + logger.info(f"Goal cancelled, tracking '{query}' failed") + self._robot.object_tracker.stop_track() + return None + else: + logger.info(f"Reached '{query}'") + self._robot.object_tracker.stop_track() + return f"Successfully arrived at '{query}'" + + # If goal set and tracking lost, just continue (tracker will resume or timeout) + if goal_set and not self._robot.object_tracker.is_tracking(): + continue + + # BBoxNavigationModule automatically sends goals when tracker publishes + # Just check if we have any detections to mark goal_set + if self._robot.object_tracker.is_tracking(): + goal_set = True + + time.sleep(0.25) + + logger.warning(f"Navigation to '{query}' timed out after {timeout}s") + self._robot.object_tracker.stop_track() + return None def _get_bbox_for_current_frame(self, query: str) -> Optional[BBox]: if self._latest_image is None: diff --git a/dimos/agents2/skills/test_navigation.py b/dimos/agents2/skills/test_navigation.py index f497ea0953..f90f8a2d19 100644 --- a/dimos/agents2/skills/test_navigation.py +++ b/dimos/agents2/skills/test_navigation.py @@ -36,7 +36,20 @@ def test_take_a_look_around(fake_robot, create_navigation_agent, mocker): def test_go_to_object(fake_robot, create_navigation_agent, mocker): - fake_robot.navigate_to_object.return_value = True + fake_robot.object_tracker = mocker.MagicMock() + fake_robot.object_tracker.is_tracking.side_effect = [True, True, True, True] # Tracking active + fake_robot.navigator = mocker.MagicMock() + + # Simulate navigation states: FOLLOWING_PATH -> IDLE (goal reached) + from dimos.navigation.bt_navigator.navigator import NavigatorState + + fake_robot.navigator.get_state.side_effect = [ + NavigatorState.FOLLOWING_PATH, + NavigatorState.FOLLOWING_PATH, + NavigatorState.IDLE, + ] + fake_robot.navigator.is_goal_reached.return_value = True + mocker.patch( "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_by_tagged_location", return_value=None, @@ -45,12 +58,14 @@ def test_go_to_object(fake_robot, create_navigation_agent, mocker): "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_using_semantic_map", return_value=None, ) + mocker.patch("dimos.agents2.skills.navigation.time.sleep") + agent = create_navigation_agent(fixture="test_go_to_object.json") agent.query("go to the chair") - fake_robot.navigate_to_object.assert_called_once() - actual_bbox = fake_robot.navigate_to_object.call_args[0][0] + fake_robot.object_tracker.track.assert_called_once() + actual_bbox = fake_robot.object_tracker.track.call_args[0][0] expected_bbox = (82, 51, 163, 159) for actual_val, expected_val in zip(actual_bbox, expected_bbox): @@ -58,6 +73,8 @@ def test_go_to_object(fake_robot, create_navigation_agent, mocker): f"BBox {actual_bbox} not within ±5 of {expected_bbox}" ) + fake_robot.object_tracker.stop_track.assert_called_once() + def test_go_to_semantic_location(fake_robot, create_navigation_agent, mocker): mocker.patch( diff --git a/dimos/msgs/vision_msgs/Detection2DArray.py b/dimos/msgs/vision_msgs/Detection2DArray.py index 004f8fd9b3..133893b9f0 100644 --- a/dimos/msgs/vision_msgs/Detection2DArray.py +++ b/dimos/msgs/vision_msgs/Detection2DArray.py @@ -17,3 +17,6 @@ class Detection2DArray(LCMDetection2DArray): msg_name = "vision_msgs.Detection2DArray" + + # for _get_field_type() to work when decoding in _decode_one() + __annotations__ = LCMDetection2DArray.__annotations__ diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py new file mode 100644 index 0000000000..aaafc32ac7 --- /dev/null +++ b/dimos/navigation/bbox_navigation.py @@ -0,0 +1,66 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core import Module, In, Out, rpc +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs import PoseStamped, Vector3, Quaternion +from dimos_lcm.sensor_msgs import CameraInfo +from dimos.utils.logging_config import setup_logger +import logging + +logger = setup_logger(__name__, level=logging.DEBUG) + + +class BBoxNavigationModule(Module): + """Minimal module that converts 2D bbox center to navigation goals.""" + + detection2d: In[Detection2DArray] = None + camera_info: In[CameraInfo] = None + goal_request: Out[PoseStamped] = None + + def __init__(self, goal_distance: float = 1.0): + super().__init__() + self.goal_distance = goal_distance + self.camera_intrinsics = None + + @rpc + def start(self): + self.camera_info.subscribe( + lambda msg: setattr(self, "camera_intrinsics", [msg.K[0], msg.K[4], msg.K[2], msg.K[5]]) + ) + self.detection2d.subscribe(self._on_detection) + + def _on_detection(self, det: Detection2DArray): + if det.detections_length == 0 or not self.camera_intrinsics: + return + fx, fy, cx, cy = self.camera_intrinsics + center_x, center_y = ( + det.detections[0].bbox.center.position.x, + det.detections[0].bbox.center.position.y, + ) + x, y, z = ( + (center_x - cx) / fx * self.goal_distance, + (center_y - cy) / fy * self.goal_distance, + self.goal_distance, + ) + goal = PoseStamped( + position=Vector3(z, -x, -y), + orientation=Quaternion(0, 0, 0, 1), + frame_id=det.header.frame_id, + ) + logger.debug( + f"BBox center: ({center_x:.1f}, {center_y:.1f}) → " + f"Goal pose: ({z:.2f}, {-x:.2f}, {-y:.2f}) in frame '{det.header.frame_id}'" + ) + self.goal_request.publish(goal) diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py new file mode 100644 index 0000000000..481b69e1ac --- /dev/null +++ b/dimos/perception/object_tracker_2d.py @@ -0,0 +1,296 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import time +import threading +from typing import Dict, List, Optional +import logging + +from dimos.core import In, Out, Module, rpc +from dimos.msgs.std_msgs import Header +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.utils.logging_config import setup_logger + +# Import LCM messages +from dimos_lcm.vision_msgs import ( + BoundingBox2D, + Detection2D, + ObjectHypothesis, + ObjectHypothesisWithPose, + Point2D, + Pose2D, +) + +logger = setup_logger("dimos.perception.object_tracker_2d", level=logging.INFO) + + +class ObjectTracker2D(Module): + """Pure 2D object tracking module using OpenCV's CSRT tracker.""" + + color_image: In[Image] = None + + detection2darray: Out[Detection2DArray] = None + tracked_overlay: Out[Image] = None # Visualization output + + def __init__( + self, + frame_id: str = "camera_link", + ): + """ + Initialize 2D object tracking module using OpenCV's CSRT tracker. + + Args: + frame_id: TF frame ID for the camera (default: "camera_link") + """ + super().__init__() + + self.frame_id = frame_id + + # Tracker state + self.tracker = None + self.tracking_bbox = None # Stores (x, y, w, h) + self.tracking_initialized = False + + # Stuck detection + self._last_bbox = None + self._stuck_count = 0 + self._max_stuck_frames = 10 # Higher threshold for stationary objects + + # Frame management + self._frame_lock = threading.Lock() + self._latest_rgb_frame: Optional[np.ndarray] = None + self._frame_arrival_time: Optional[float] = None + + # Tracking thread control + self.tracking_thread: Optional[threading.Thread] = None + self.stop_tracking_event = threading.Event() + self.tracking_rate = 5.0 # Hz + self.tracking_period = 1.0 / self.tracking_rate + + # Store latest detection for RPC access + self._latest_detection2d: Optional[Detection2DArray] = None + + @rpc + def start(self): + """Start the object tracking module and subscribe to video stream.""" + + def on_frame(frame_msg: Image): + arrival_time = time.perf_counter() + with self._frame_lock: + self._latest_rgb_frame = frame_msg.data + self._frame_arrival_time = arrival_time + + self.color_image.subscribe(on_frame) + logger.info("ObjectTracker2D module started") + + @rpc + def track(self, bbox: List[float]) -> Dict: + """ + Initialize tracking with a bounding box. + + Args: + bbox: Bounding box in format [x1, y1, x2, y2] + + Returns: + Dict containing tracking status + """ + if self._latest_rgb_frame is None: + logger.warning("No RGB frame available for tracking") + return {"status": "no_frame"} + + # Initialize tracking + x1, y1, x2, y2 = map(int, bbox) + w, h = x2 - x1, y2 - y1 + if w <= 0 or h <= 0: + logger.warning(f"Invalid initial bbox provided: {bbox}. Tracking not started.") + return {"status": "invalid_bbox"} + + self.tracking_bbox = (x1, y1, w, h) + self.tracker = cv2.legacy.TrackerCSRT_create() + self.tracking_initialized = False + logger.info(f"Tracking target set with bbox: {self.tracking_bbox}") + + # Convert RGB to BGR for CSRT (OpenCV expects BGR) + frame_bgr = cv2.cvtColor(self._latest_rgb_frame, cv2.COLOR_RGB2BGR) + init_success = self.tracker.init(frame_bgr, self.tracking_bbox) + if init_success: + self.tracking_initialized = True + logger.info("Tracker initialized successfully.") + else: + logger.error("Tracker initialization failed.") + self.stop_track() + return {"status": "init_failed"} + + # Start tracking thread + self._start_tracking_thread() + + return {"status": "tracking_started", "bbox": self.tracking_bbox} + + def _start_tracking_thread(self): + """Start the tracking thread.""" + self.stop_tracking_event.clear() + self.tracking_thread = threading.Thread(target=self._tracking_loop, daemon=True) + self.tracking_thread.start() + logger.info("Started tracking thread") + + def _tracking_loop(self): + """Main tracking loop that runs in a separate thread.""" + while not self.stop_tracking_event.is_set() and self.tracking_initialized: + self._process_tracking() + time.sleep(self.tracking_period) + logger.info("Tracking loop ended") + + def _reset_tracking_state(self): + """Reset tracking state without stopping the thread.""" + self.tracker = None + self.tracking_bbox = None + self.tracking_initialized = False + self._last_bbox = None + self._stuck_count = 0 + + # Publish empty detection + empty_2d = Detection2DArray( + detections_length=0, header=Header(time.time(), self.frame_id), detections=[] + ) + self._latest_detection2d = empty_2d + self.detection2darray.publish(empty_2d) + + @rpc + def stop_track(self) -> bool: + """ + Stop tracking the current object. + + Returns: + bool: True if tracking was successfully stopped + """ + self._reset_tracking_state() + + # Stop tracking thread if running + if self.tracking_thread and self.tracking_thread.is_alive(): + if threading.current_thread() != self.tracking_thread: + self.stop_tracking_event.set() + self.tracking_thread.join(timeout=1.0) + self.tracking_thread = None + else: + self.stop_tracking_event.set() + + logger.info("Tracking stopped") + return True + + @rpc + def is_tracking(self) -> bool: + """ + Check if the tracker is currently tracking an object. + + Returns: + bool: True if tracking is active + """ + return self.tracking_initialized + + def _process_tracking(self): + """Process current frame for tracking and publish 2D detections.""" + if self.tracker is None or not self.tracking_initialized: + return + + # Get frame copy + with self._frame_lock: + if self._latest_rgb_frame is None: + return + frame = self._latest_rgb_frame.copy() + + # Convert RGB to BGR for CSRT (OpenCV expects BGR) + frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + tracker_succeeded, bbox_cv = self.tracker.update(frame_bgr) + + if not tracker_succeeded: + logger.info("Tracker update failed. Stopping track.") + self._reset_tracking_state() + return + + # Extract bbox + x, y, w, h = map(int, bbox_cv) + current_bbox_x1y1x2y2 = [x, y, x + w, y + h] + x1, y1, x2, y2 = current_bbox_x1y1x2y2 + + # Check if tracker is stuck + if self._last_bbox is not None: + if (x1, y1, x2, y2) == self._last_bbox: + self._stuck_count += 1 + if self._stuck_count >= self._max_stuck_frames: + logger.warning(f"Tracker stuck for {self._stuck_count} frames. Stopping track.") + self._reset_tracking_state() + return + else: + self._stuck_count = 0 + + self._last_bbox = (x1, y1, x2, y2) + + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + width = float(x2 - x1) + height = float(y2 - y1) + + # Create 2D detection header + header = Header(time.time(), self.frame_id) + + # Create Detection2D with all fields in constructors + detection_2d = Detection2D( + id="0", + results_length=1, + header=header, + bbox=BoundingBox2D( + center=Pose2D(position=Point2D(x=center_x, y=center_y), theta=0.0), + size_x=width, + size_y=height, + ), + results=[ + ObjectHypothesisWithPose( + hypothesis=ObjectHypothesis(class_id="tracked_object", score=1.0) + ) + ], + ) + + detection2darray = Detection2DArray( + detections_length=1, header=header, detections=[detection_2d] + ) + + # Store and publish + self._latest_detection2d = detection2darray + self.detection2darray.publish(detection2darray) + + # Create visualization + viz_image = self._draw_visualization(frame, current_bbox_x1y1x2y2) + viz_copy = viz_image.copy() # Force copy needed to prevent frame reuse + viz_msg = Image.from_numpy(viz_copy, format=ImageFormat.RGB) + self.tracked_overlay.publish(viz_msg) + + def _draw_visualization(self, image: np.ndarray, bbox: List[int]) -> np.ndarray: + """Draw tracking visualization.""" + viz_image = image.copy() + x1, y1, x2, y2 = bbox + cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) + cv2.putText(viz_image, "TRACKING", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + return viz_image + + @rpc + def cleanup(self): + """Clean up resources.""" + self.stop_track() + if self.tracking_thread and self.tracking_thread.is_alive(): + self.stop_tracking_event.set() + self.tracking_thread.join(timeout=2.0) diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py new file mode 100644 index 0000000000..a5dc96bae9 --- /dev/null +++ b/dimos/perception/object_tracker_3d.py @@ -0,0 +1,309 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import List, Optional + +from dimos.core import In, Out, rpc +from dimos.msgs.std_msgs import Header +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.msgs.geometry_msgs import Vector3, Quaternion, Transform, Pose +from dimos.perception.object_tracker_2d import ObjectTracker2D +from dimos.protocol.tf import TF +from dimos.types.timestamped import align_timestamped +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import ( + yaw_towards_point, + optical_to_robot_frame, + euler_to_quaternion, +) +from dimos.manipulation.visual_servoing.utils import visualize_detections_3d + +# Import LCM messages +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.vision_msgs import Detection3D, ObjectHypothesisWithPose + +logger = setup_logger("dimos.perception.object_tracker_3d") + + +class ObjectTracker3D(ObjectTracker2D): + """3D object tracking module extending ObjectTracker2D with depth capabilities.""" + + # Additional inputs (2D tracker already has color_image) + depth: In[Image] = None + camera_info: In[CameraInfo] = None + + # Additional outputs (2D tracker already has detection2darray and tracked_overlay) + detection3darray: Out[Detection3DArray] = None + + def __init__(self, **kwargs): + """ + Initialize 3D object tracking module. + + Args: + **kwargs: Arguments passed to parent ObjectTracker2D + """ + super().__init__(**kwargs) + + # Additional state for 3D tracking + self.camera_intrinsics = None + self._latest_depth_frame: Optional[np.ndarray] = None + self._latest_camera_info: Optional[CameraInfo] = None + self._aligned_frames_subscription = None + + # TF publisher for tracked object + self.tf = TF() + + # Store latest 3D detection + self._latest_detection3d: Optional[Detection3DArray] = None + + @rpc + def start(self): + """Start the 3D tracking module with depth stream alignment.""" + + # Subscribe to aligned RGB and depth streams + def on_aligned_frames(frames_tuple): + rgb_msg, depth_msg = frames_tuple + with self._frame_lock: + self._latest_rgb_frame = rgb_msg.data + + depth_data = depth_msg.data + # Convert from millimeters to meters if depth is DEPTH16 format + if depth_msg.format == ImageFormat.DEPTH16: + depth_data = depth_data.astype(np.float32) / 1000.0 + self._latest_depth_frame = depth_data + + # Create aligned observable for RGB and depth + aligned_frames = align_timestamped( + self.color_image.observable(), + self.depth.observable(), + buffer_size=2.0, # 2 second buffer + match_tolerance=0.5, # 500ms tolerance + ) + self._aligned_frames_subscription = aligned_frames.subscribe(on_aligned_frames) + + # Subscribe to camera info + def on_camera_info(camera_info_msg: CameraInfo): + self._latest_camera_info = camera_info_msg + # Extract intrinsics: K is [fx, 0, cx, 0, fy, cy, 0, 0, 1] + self.camera_intrinsics = [ + camera_info_msg.K[0], + camera_info_msg.K[4], + camera_info_msg.K[2], + camera_info_msg.K[5], + ] + + self.camera_info.subscribe(on_camera_info) + + logger.info("ObjectTracker3D module started with aligned frame subscription") + + def _process_tracking(self): + """Override to add 3D detection creation after 2D tracking.""" + # Call parent 2D tracking + super()._process_tracking() + + # Enhance with 3D if we have depth and a valid 2D detection + if ( + self._latest_detection2d + and self._latest_detection2d.detections_length > 0 + and self._latest_depth_frame is not None + and self.camera_intrinsics is not None + ): + detection_3d = self._create_detection3d_from_2d(self._latest_detection2d) + if detection_3d: + self._latest_detection3d = detection_3d + self.detection3darray.publish(detection_3d) + + # Update visualization with 3D info + with self._frame_lock: + if self._latest_rgb_frame is not None: + frame = self._latest_rgb_frame.copy() + + # Extract 2D bbox for visualization + det_2d = self._latest_detection2d.detections[0] + x1 = det_2d.bbox.center.position.x - det_2d.bbox.size_x / 2 + y1 = det_2d.bbox.center.position.y - det_2d.bbox.size_y / 2 + x2 = det_2d.bbox.center.position.x + det_2d.bbox.size_x / 2 + y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2 + bbox_2d = [[x1, y1, x2, y2]] + + # Create 3D visualization + viz_image = visualize_detections_3d( + frame, detection_3d.detections, show_coordinates=True, bboxes_2d=bbox_2d + ) + + # Overlay Re-ID matches + if self.last_good_matches and self.last_roi_kps and self.last_roi_bbox: + viz_image = self._draw_reid_overlay(viz_image) + + viz_msg = Image.from_numpy(viz_image) + self.tracked_overlay.publish(viz_msg) + + def _create_detection3d_from_2d( + self, detection2d: Detection2DArray + ) -> Optional[Detection3DArray]: + """Create 3D detection from 2D detection using depth.""" + if detection2d.detections_length == 0: + return None + + det_2d = detection2d.detections[0] + + # Get bbox center + center_x = det_2d.bbox.center.position.x + center_y = det_2d.bbox.center.position.y + width = det_2d.bbox.size_x + height = det_2d.bbox.size_y + + # Convert to bbox coordinates + x1 = int(center_x - width / 2) + y1 = int(center_y - height / 2) + x2 = int(center_x + width / 2) + y2 = int(center_y + height / 2) + + # Get depth value + depth_value = self._get_depth_from_bbox([x1, y1, x2, y2], self._latest_depth_frame) + + if depth_value is None or depth_value <= 0: + return None + + fx, fy, cx, cy = self.camera_intrinsics + + # Convert pixel coordinates to 3D in optical frame + z_optical = depth_value + x_optical = (center_x - cx) * z_optical / fx + y_optical = (center_y - cy) * z_optical / fy + + # Create pose in optical frame + optical_pose = Pose() + optical_pose.position = Vector3(x_optical, y_optical, z_optical) + optical_pose.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) + + # Convert to robot frame + robot_pose = optical_to_robot_frame(optical_pose) + + # Calculate orientation: object facing towards camera + yaw = yaw_towards_point(robot_pose.position) + euler = Vector3(0.0, 0.0, yaw) + robot_pose.orientation = euler_to_quaternion(euler) + + # Estimate object size in meters + size_x = width * z_optical / fx + size_y = height * z_optical / fy + size_z = 0.1 # Default depth size + + # Create Detection3D + header = Header(self.frame_id) + detection_3d = Detection3D() + detection_3d.id = "0" + detection_3d.results_length = 1 + detection_3d.header = header + + # Create hypothesis + hypothesis = ObjectHypothesisWithPose() + hypothesis.hypothesis.class_id = "tracked_object" + hypothesis.hypothesis.score = 1.0 + detection_3d.results = [hypothesis] + + # Create 3D bounding box + detection_3d.bbox.center = Pose() + detection_3d.bbox.center.position = robot_pose.position + detection_3d.bbox.center.orientation = robot_pose.orientation + detection_3d.bbox.size = Vector3(size_x, size_y, size_z) + + detection3darray = Detection3DArray() + detection3darray.detections_length = 1 + detection3darray.header = header + detection3darray.detections = [detection_3d] + + # Publish TF for tracked object + tracked_object_tf = Transform( + translation=robot_pose.position, + rotation=robot_pose.orientation, + frame_id=self.frame_id, + child_frame_id="tracked_object", + ts=header.ts, + ) + self.tf.publish(tracked_object_tf) + + return detection3darray + + def _get_depth_from_bbox(self, bbox: List[int], depth_frame: np.ndarray) -> Optional[float]: + """ + Calculate depth from bbox using the 25th percentile of closest points. + + Args: + bbox: Bounding box coordinates [x1, y1, x2, y2] + depth_frame: Depth frame to extract depth values from + + Returns: + Depth value or None if not available + """ + if depth_frame is None: + return None + + x1, y1, x2, y2 = bbox + + # Ensure bbox is within frame bounds + y1 = max(0, y1) + y2 = min(depth_frame.shape[0], y2) + x1 = max(0, x1) + x2 = min(depth_frame.shape[1], x2) + + # Extract depth values from the bbox + roi_depth = depth_frame[y1:y2, x1:x2] + + # Get valid (finite and positive) depth values + valid_depths = roi_depth[np.isfinite(roi_depth) & (roi_depth > 0)] + + if len(valid_depths) > 0: + return float(np.percentile(valid_depths, 25)) + + return None + + def _draw_reid_overlay(self, image: np.ndarray) -> np.ndarray: + """Draw Re-ID feature matches on visualization.""" + import cv2 + + viz_image = image.copy() + x1, y1, _x2, _y2 = self.last_roi_bbox + + # Draw keypoints + for kp in self.last_roi_kps: + pt = (int(kp.pt[0] + x1), int(kp.pt[1] + y1)) + cv2.circle(viz_image, pt, 3, (0, 255, 0), -1) + + # Draw matches + for match in self.last_good_matches: + current_kp = self.last_roi_kps[match.trainIdx] + pt_current = (int(current_kp.pt[0] + x1), int(current_kp.pt[1] + y1)) + cv2.circle(viz_image, pt_current, 5, (0, 255, 255), 2) + + intensity = int(255 * (1.0 - min(match.distance / 100.0, 1.0))) + cv2.circle(viz_image, pt_current, 2, (intensity, intensity, 255), -1) + + # Draw match count + text = f"REID: {len(self.last_good_matches)}/{len(self.last_roi_kps)}" + cv2.putText(viz_image, text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) + + return viz_image + + @rpc + def cleanup(self): + """Clean up resources.""" + super().cleanup() + + if self._aligned_frames_subscription: + self._aligned_frames_subscription.dispose() + self._aligned_frames_subscription = None diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py index ac70c7a020..7cdd50cf0b 100644 --- a/dimos/robot/robot.py +++ b/dimos/robot/robot.py @@ -68,12 +68,6 @@ class UnitreeRobot(Robot): @abstractmethod def get_odom(self) -> PoseStamped: ... - @abstractmethod - def navigate_to(self, pose: PoseStamped, blocking: bool = True) -> None: ... - - @abstractmethod - def navigate_to_object(self, pose: PoseStamped, blocking: bool = True) -> None: ... - @abstractmethod def explore(self) -> bool: ... diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 1db6bef9a3..3c05062149 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -20,24 +20,23 @@ import os import time import warnings -from typing import List, Optional +from typing import Optional from reactivex import Observable from dimos import core -from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE, DEFAULT_CAPACITY_DEPTH_IMAGE +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core import In, Module, Out, rpc from dimos.mapping.types import LatLon from dimos.msgs.std_msgs import Header from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3, Quaternion from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.msgs.sensor_msgs import Image -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.msgs.vision_msgs import Detection2DArray from dimos_lcm.std_msgs import String from dimos_lcm.sensor_msgs import CameraInfo from dimos.perception.spatial_perception import SpatialMemory from dimos.perception.common.utils import ( - extract_pose_from_detection3d, load_camera_info, load_camera_info_opencv, rectify_image, @@ -57,13 +56,12 @@ from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.robot.unitree_webrtc.depth_module import DepthModule from dimos.skills.skills import AbstractRobotSkill, SkillLibrary from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger from dimos.utils.testing import TimedSensorReplay -from dimos.utils.transform_utils import offset_distance -from dimos.perception.object_tracker import ObjectTracking +from dimos.perception.object_tracker_2d import ObjectTracker2D +from dimos.navigation.bbox_navigation import BBoxNavigationModule from dimos_lcm.std_msgs import Bool from dimos.robot.robot import UnitreeRobot from dimos.types.robot_capabilities import RobotCapability @@ -350,7 +348,6 @@ def __init__( self.websocket_vis = None self.foxglove_bridge = None self.spatial_memory_module = None - self.depth_module = None self.object_tracker = None self.utilization_module = None @@ -385,7 +382,7 @@ def _setup_directories(self): def start(self): """Start the robot system with all modules.""" - self.dimos = core.start(8) + self.dimos = core.start(8, memory_limit="8GiB") self._deploy_connection() self._deploy_mapping() @@ -412,7 +409,6 @@ def stop(self): # self.websocket_vis.stop() # self.foxglove_bridge.stop() self.spatial_memory_module.stop() - # self.depth_module.stop() # self.object_tracker.stop() self.utilization_module.stop() self.dimos.close_all() @@ -516,7 +512,7 @@ def _deploy_foxglove_bridge(self): self.foxglove_bridge = FoxgloveBridge( shm_channels=[ "/go2/color_image#sensor_msgs.Image", - "/go2/depth_image#sensor_msgs.Image", + "/go2/tracked_overlay#sensor_msgs.Image", ] ) @@ -540,48 +536,43 @@ def _deploy_perception(self): logger.info("Spatial memory module deployed and connected") - # Deploy object tracker + # Deploy 2D object tracker self.object_tracker = self.dimos.deploy( - ObjectTracking, + ObjectTracker2D, frame_id="camera_link", ) + # Deploy bbox navigation module + self.bbox_navigator = self.dimos.deploy(BBoxNavigationModule, goal_distance=1.0) + self.utilization_module = self.dimos.deploy(UtilizationModule) - # Set up transports + # Set up transports for object tracker self.object_tracker.detection2darray.transport = core.LCMTransport( "/go2/detection2d", Detection2DArray ) - self.object_tracker.detection3darray.transport = core.LCMTransport( - "/go2/detection3d", Detection3DArray - ) - self.object_tracker.tracked_overlay.transport = core.LCMTransport( - "/go2/tracked_overlay", Image + self.object_tracker.tracked_overlay.transport = core.pSHMTransport( + "/go2/tracked_overlay", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE ) - logger.info("Object tracker module deployed") + # Set up transports for bbox navigator + self.bbox_navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) + + logger.info("Object tracker and bbox navigator modules deployed") def _deploy_camera(self): """Deploy and configure the camera module.""" - gt_depth_scale = 1.0 if self.connection_type == "mujoco" else 0.5 - self.depth_module = self.dimos.deploy(DepthModule, gt_depth_scale=gt_depth_scale) - - # Set up transports - self.depth_module.color_image.transport = core.pSHMTransport( - "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE - ) - self.depth_module.depth_image.transport = core.pSHMTransport( - "/go2/depth_image", default_capacity=DEFAULT_CAPACITY_DEPTH_IMAGE - ) - self.depth_module.camera_info.transport = core.LCMTransport("/go2/camera_info", CameraInfo) - - logger.info("Camera module deployed and connected") - # Connect object tracker inputs after camera module is deployed + # Connect object tracker inputs if self.object_tracker: self.object_tracker.color_image.connect(self.connection.video) - self.object_tracker.depth.connect(self.depth_module.depth_image) - self.object_tracker.camera_info.connect(self.connection.camera_info) - logger.info("Object tracker connected to camera module") + logger.info("Object tracker connected to camera") + + # Connect bbox navigator inputs + if self.bbox_navigator: + self.bbox_navigator.detection2d.connect(self.object_tracker.detection2darray) + self.bbox_navigator.camera_info.connect(self.connection.camera_info) + self.bbox_navigator.goal_request.connect(self.navigator.goal_request) + logger.info("BBox navigator connected") def _start_modules(self): """Start all deployed modules in the correct order.""" @@ -594,8 +585,8 @@ def _start_modules(self): # self.websocket_vis.start() self.foxglove_bridge.start() self.spatial_memory_module.start() - self.depth_module.start() self.object_tracker.start() + self.bbox_navigator.start() self.utilization_module.start() # Initialize skills after connection is established @@ -703,64 +694,6 @@ def get_odom(self) -> PoseStamped: """ return self.connection.get_odom() - def navigate_to_object(self, bbox: List[float], distance: float = 0.5, timeout: float = 30.0): - """Navigate to an object by tracking it and maintaining a specified distance. - - Args: - bbox: Bounding box of the object to track [x1, y1, x2, y2] - distance: Distance to maintain from the object (meters) - timeout: Total timeout for the navigation (seconds) - - Returns: - True if navigation completed successfully, False otherwise - """ - if not self.object_tracker: - logger.error("Object tracker not initialized") - return False - - logger.info(f"Starting object tracking with bbox: {bbox}") - self.object_tracker.track(bbox) - - start_time = time.time() - goal_set = False - - while time.time() - start_time < timeout: - if self.navigator.get_state() == NavigatorState.IDLE and goal_set: - logger.info("Waiting for goal result") - time.sleep(1.0) - if not self.navigator.is_goal_reached(): - logger.info("Goal cancelled, object tracking failed") - return False - else: - logger.info("Object tracking goal reached") - return True - - if goal_set and not self.object_tracker.is_tracking(): - continue - - detection_topic = Topic("/go2/detection3d", Detection3DArray) - detection_msg = self.lcm.wait_for_message(detection_topic, timeout=1.0) - - if detection_msg and len(detection_msg.detections) > 0: - target_pose = extract_pose_from_detection3d(detection_msg.detections[0]) - - retracted_pose = offset_distance( - target_pose, distance, approach_vector=Vector3(-1, 0, 0) - ) - - goal_pose = PoseStamped( - frame_id=detection_msg.header.frame_id, - position=retracted_pose.position, - orientation=retracted_pose.orientation, - ) - self.navigator.set_goal(goal_pose) - goal_set = True - - time.sleep(0.25) - - logger.info("Object tracking timed out") - return False - def main(): """Main entry point."""