diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index f9fc841ef9..0b7755e2e3 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -54,7 +54,7 @@ def __getattr__(self, name: str): if name in self.rpcs: return lambda *args, **kwargs: self.rpc.call_sync( - f"{self.remote_name}/{name}", (args, kwargs), timeout=2.0 + f"{self.remote_name}/{name}", (args, kwargs) ) # return super().__getattr__(name) diff --git a/dimos/core/test_rpcstress.py b/dimos/core/test_rpcstress.py new file mode 100644 index 0000000000..8f7a0dac40 --- /dev/null +++ b/dimos/core/test_rpcstress.py @@ -0,0 +1,177 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +from dimos.core import In, Module, Out, rpc + + +class Counter(Module): + current_count: int = 0 + + count_stream: Out[int] = None + + def __init__(self): + super().__init__() + self.current_count = 0 + + @rpc + def increment(self): + """Increment the counter and publish the new value.""" + self.current_count += 1 + self.count_stream.publish(self.current_count) + return self.current_count + + +class CounterValidator(Module): + """Calls counter.increment() as fast as possible and validates no numbers are skipped.""" + + count_in: In[int] = None + + def __init__(self, increment_func): + super().__init__() + self.increment_func = increment_func + self.last_seen = 0 + self.missing_numbers = [] + self.running = False + self.call_thread = None + self.call_count = 0 + self.total_latency = 0.0 + self.call_start_time = None + self.waiting_for_response = False + + @rpc + def start(self): + """Start the validator.""" + self.count_in.subscribe(self._on_count_received) + self.running = True + self.call_thread = threading.Thread(target=self._call_loop) + self.call_thread.start() + + @rpc + def stop(self): + """Stop the validator.""" + self.running = False + if self.call_thread: + self.call_thread.join() + + def _on_count_received(self, count: int): + """Check if we received all numbers in sequence and trigger next call.""" + # Calculate round trip time + if self.call_start_time: + latency = time.time() - self.call_start_time + self.total_latency += latency + + if count != self.last_seen + 1: + for missing in range(self.last_seen + 1, count): + self.missing_numbers.append(missing) + print(f"[VALIDATOR] Missing number detected: {missing}") + self.last_seen = count + + # Signal that we can make the next call + self.waiting_for_response = False + + def _call_loop(self): + """Call increment only after receiving response from previous call.""" + while self.running: + if not self.waiting_for_response: + try: + self.waiting_for_response = True + self.call_start_time = time.time() + result = self.increment_func() + call_time = time.time() - self.call_start_time + self.call_count += 1 + if self.call_count % 100 == 0: + avg_latency = ( + self.total_latency / self.call_count if self.call_count > 0 else 0 + ) + print( + f"[VALIDATOR] Made {self.call_count} calls, last result: {result}, RPC call time: {call_time * 1000:.2f}ms, avg RTT: {avg_latency * 1000:.2f}ms" + ) + except Exception as e: + print(f"[VALIDATOR] Error calling increment: {e}") + self.waiting_for_response = False + time.sleep(0.001) # Small delay on error + else: + # Don't sleep - busy wait for maximum speed + pass + + @rpc + def get_stats(self): + """Get validation statistics.""" + avg_latency = self.total_latency / self.call_count if self.call_count > 0 else 0 + return { + "call_count": self.call_count, + "last_seen": self.last_seen, + "missing_count": len(self.missing_numbers), + "missing_numbers": self.missing_numbers[:10] if self.missing_numbers else [], + "avg_rtt_ms": avg_latency * 1000, + "calls_per_sec": self.call_count / 10.0 if self.call_count > 0 else 0, + } + + +if __name__ == "__main__": + import dimos.core as core + from dimos.core import pLCMTransport + + # Start dimos with 2 workers + client = core.start(2) + + # Deploy counter module + counter = client.deploy(Counter) + counter.count_stream.transport = pLCMTransport("/counter_stream") + + # Deploy validator module with increment function + validator = client.deploy(CounterValidator, counter.increment) + validator.count_in.transport = pLCMTransport("/counter_stream") + + # Connect validator to counter's output + validator.count_in.connect(counter.count_stream) + + # Start modules + validator.start() + + print("[MAIN] Counter and validator started. Running for 10 seconds...") + + # Test direct RPC speed for comparison + print("\n[MAIN] Testing direct RPC call speed for 1 second...") + start = time.time() + direct_count = 0 + while time.time() - start < 1.0: + counter.increment() + direct_count += 1 + print(f"[MAIN] Direct RPC calls per second: {direct_count}") + + # Run for 10 seconds + time.sleep(10) + + # Get stats before stopping + stats = validator.get_stats() + print(f"\n[MAIN] Final statistics:") + print(f" - Total calls made: {stats['call_count']}") + print(f" - Last number seen: {stats['last_seen']}") + print(f" - Missing numbers: {stats['missing_count']}") + print(f" - Average RTT: {stats['avg_rtt_ms']:.2f}ms") + print(f" - Calls per second: {stats['calls_per_sec']:.1f}") + if stats["missing_numbers"]: + print(f" - First missing numbers: {stats['missing_numbers']}") + + # Stop modules + validator.stop() + + # Shutdown dimos + client.shutdown() + + print("[MAIN] Test complete.") diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 90bd851222..fb57cfcd3e 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -37,6 +37,7 @@ class ImageFormat(Enum): GRAY = "GRAY" # 8-bit Grayscale GRAY16 = "GRAY16" # 16-bit Grayscale DEPTH = "DEPTH" # 32-bit Float Depth + DEPTH16 = "DEPTH16" # 16-bit Integer Depth (millimeters) @dataclass @@ -169,6 +170,8 @@ def to_opencv(self) -> np.ndarray: return self.data elif self.format == ImageFormat.DEPTH: return self.data # Depth images are already in the correct format + elif self.format == ImageFormat.DEPTH16: + return self.data # 16-bit depth images are already in the correct format else: raise ValueError(f"Unsupported format conversion: {self.format}") @@ -373,6 +376,11 @@ def _get_lcm_encoding(self) -> str: return "32FC1" elif self.dtype == np.float64: return "64FC1" + elif self.format == ImageFormat.DEPTH16: + if self.dtype == np.uint16: + return "16UC1" # 16-bit unsigned depth + elif self.dtype == np.int16: + return "16SC1" # 16-bit signed depth raise ValueError( f"Cannot determine LCM encoding for format={self.format}, dtype={self.dtype}" @@ -393,6 +401,9 @@ def _parse_encoding(encoding: str) -> dict: "32FC1": {"format": ImageFormat.DEPTH, "dtype": np.float32, "channels": 1}, "32FC3": {"format": ImageFormat.RGB, "dtype": np.float32, "channels": 3}, "64FC1": {"format": ImageFormat.DEPTH, "dtype": np.float64, "channels": 1}, + # 16-bit depth encodings + "16UC1": {"format": ImageFormat.DEPTH16, "dtype": np.uint16, "channels": 1}, + "16SC1": {"format": ImageFormat.DEPTH16, "dtype": np.int16, "channels": 1}, } if encoding not in encoding_map: diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py index 8a81af0356..d64971c4a3 100644 --- a/dimos/navigation/bt_navigator/navigator.py +++ b/dimos/navigation/bt_navigator/navigator.py @@ -218,16 +218,28 @@ def _transform_goal_to_odom_frame(self, goal: PoseStamped) -> Optional[PoseStamp return goal try: - transform = self.tf.get( - parent_frame=odom_frame, - child_frame=goal.frame_id, - time_point=goal.ts, - time_tolerance=1.0, - ) + transform = None + max_retries = 3 + + for attempt in range(max_retries): + transform = self.tf.get( + parent_frame=odom_frame, + child_frame=goal.frame_id, + ) + + if transform: + break - if not transform: - logger.error(f"Could not find transform from '{goal.frame_id}' to '{odom_frame}'") - return None + if attempt < max_retries - 1: + logger.warning( + f"Transform attempt {attempt + 1}/{max_retries} failed, retrying..." + ) + time.sleep(1.0) + else: + logger.error( + f"Could not find transform from '{goal.frame_id}' to '{odom_frame}' after {max_retries} attempts" + ) + return None pose = apply_transform(goal, transform) transformed_goal = PoseStamped( diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 9f94c7f79c..4bfae6c45e 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -107,7 +107,7 @@ def __init__( lookahead_distance: float = 5.0, max_explored_distance: float = 10.0, info_gain_threshold: float = 0.03, - num_no_gain_attempts: int = 4, + num_no_gain_attempts: int = 2, goal_timeout: float = 15.0, **kwargs, ): @@ -639,7 +639,8 @@ def get_exploration_goal( logger.info( f"No information gain for {self.no_gain_counter} consecutive attempts" ) - self.reset_exploration_session() + self.no_gain_counter = 0 # Reset counter when stopping due to no gain + self.stop_exploration() return None else: self.no_gain_counter = 0 @@ -724,6 +725,7 @@ def stop_exploration(self) -> bool: return False self.exploration_active = False + self.no_gain_counter = 0 # Reset counter when exploration stops self.stop_event.set() if self.exploration_thread and self.exploration_thread.is_alive(): diff --git a/dimos/navigation/global_planner/planner.py b/dimos/navigation/global_planner/planner.py index 47622f9cce..984873f67a 100644 --- a/dimos/navigation/global_planner/planner.py +++ b/dimos/navigation/global_planner/planner.py @@ -204,7 +204,7 @@ def plan(self, goal: Pose) -> Optional[Path]: # Get current position from odometry robot_pos = self.latest_odom.position - costmap = self.latest_costmap.inflate(0.1).gradient(max_distance=1.0) + costmap = self.latest_costmap.inflate(0.2).gradient(max_distance=1.5) # Run A* planning path = astar(costmap, goal.position, robot_pos) diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index edd87134b1..c4bedf9ac9 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -20,7 +20,7 @@ from dimos.core import In, Out, Module, rpc from dimos.msgs.std_msgs import Header -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs import Image, ImageFormat from dimos.msgs.geometry_msgs import Vector3, Quaternion, Transform, Pose, PoseStamped from dimos.protocol.tf import TF from dimos.utils.logging_config import setup_logger @@ -40,6 +40,7 @@ euler_to_quaternion, ) from dimos.manipulation.visual_servoing.utils import visualize_detections_3d +from dimos.types.timestamped import align_timestamped logger = setup_logger("dimos.perception.object_tracker") @@ -96,11 +97,14 @@ def __init__( self.last_roi_kps = None # Store last ROI keypoints for visualization self.last_roi_bbox = None # Store last ROI bbox for visualization self.reid_confirmed = False # Store current reid confirmation state + self.tracking_frame_count = 0 # Count frames since tracking started + self.reid_warmup_frames = 3 # Number of frames before REID starts - # For tracking latest frame data + self._frame_lock = threading.Lock() self._latest_rgb_frame: Optional[np.ndarray] = None self._latest_depth_frame: Optional[np.ndarray] = None self._latest_camera_info: Optional[CameraInfo] = None + self._aligned_frames_subscription = None # Tracking thread control self.tracking_thread: Optional[threading.Thread] = None @@ -120,19 +124,28 @@ def __init__( def start(self): """Start the object tracking module and subscribe to LCM streams.""" - # Subscribe to rgb image stream - def on_rgb(image_msg: Image): - self._latest_rgb_frame = image_msg.data - - self.color_image.subscribe(on_rgb) - - # Subscribe to depth stream - def on_depth(image_msg: Image): - self._latest_depth_frame = image_msg.data - - self.depth.subscribe(on_depth) + # Subscribe to aligned rgb and depth streams + def on_aligned_frames(frames_tuple): + rgb_msg, depth_msg = frames_tuple + with self._frame_lock: + self._latest_rgb_frame = rgb_msg.data + + depth_data = depth_msg.data + # Convert from millimeters to meters if depth is DEPTH16 format + if depth_msg.format == ImageFormat.DEPTH16: + depth_data = depth_data.astype(np.float32) / 1000.0 + self._latest_depth_frame = depth_data + + # Create aligned observable for RGB and depth + aligned_frames = align_timestamped( + self.color_image.observable(), + self.depth.observable(), + buffer_size=2.0, # 2 second buffer + match_tolerance=0.05, # 50ms tolerance + ) + self._aligned_frames_subscription = aligned_frames.subscribe(on_aligned_frames) - # Subscribe to camera info stream + # Subscribe to camera info stream separately (doesn't need alignment) def on_camera_info(camera_info_msg: CameraInfo): self._latest_camera_info = camera_info_msg # Extract intrinsics from camera info K matrix @@ -151,7 +164,7 @@ def on_camera_info(camera_info_msg: CameraInfo): self.camera_info.subscribe(on_camera_info) - logger.info("ObjectTracking module started and subscribed to LCM streams") + logger.info("ObjectTracking module started with aligned frame subscription") @rpc def track( @@ -189,9 +202,7 @@ def track( if roi.size > 0: self.original_kps, self.original_des = self.orb.detectAndCompute(roi, None) if self.original_des is None: - logger.warning("No ORB features found in initial ROI.") - self.stop_track() - return {"status": "tracking_failed", "bbox": self.tracking_bbox} + logger.warning("No ORB features found in initial ROI. REID will be disabled.") else: logger.info(f"Initial ORB features extracted: {len(self.original_des)}") @@ -199,6 +210,7 @@ def track( init_success = self.tracker.init(self._latest_rgb_frame, self.tracking_bbox) if init_success: self.tracking_initialized = True + self.tracking_frame_count = 0 # Reset frame counter logger.info("Tracker initialized successfully.") else: logger.error("Tracker initialization failed.") @@ -215,8 +227,12 @@ def track( def reid(self, frame, current_bbox) -> bool: """Check if features in current_bbox match stored original features.""" + # During warm-up period, always return True + if self.tracking_frame_count < self.reid_warmup_frames: + return True + if self.original_des is None: - return True # Cannot re-id if no original features + return False x1, y1, x2, y2 = map(int, current_bbox) roi = frame[y1:y2, x1:x2] if roi.size == 0: @@ -280,6 +296,7 @@ def _reset_tracking_state(self): self.last_roi_kps = None self.last_roi_bbox = None self.reid_confirmed = False # Reset reid confirmation state + self.tracking_frame_count = 0 # Reset frame counter # Publish empty detections to clear any visualizations empty_2d = Detection2DArray(detections_length=0, header=Header(), detections=[]) @@ -328,10 +345,15 @@ def is_tracking(self) -> bool: def _process_tracking(self): """Process current frame for tracking and publish detections.""" - if self._latest_rgb_frame is None or self.tracker is None or not self.tracking_initialized: + if self.tracker is None or not self.tracking_initialized: return - frame = self._latest_rgb_frame + # Get local copies of frames under lock + with self._frame_lock: + if self._latest_rgb_frame is None or self._latest_depth_frame is None: + return + frame = self._latest_rgb_frame.copy() + depth_frame = self._latest_depth_frame.copy() tracker_succeeded = False reid_confirmed_this_frame = False final_success = False @@ -369,7 +391,9 @@ def _process_tracking(self): logger.info("Tracker update failed. Stopping track.") self._reset_tracking_state() - if not reid_confirmed_this_frame: + self.tracking_frame_count += 1 + + if not reid_confirmed_this_frame and self.tracking_frame_count >= self.reid_warmup_frames: return # Create detections if tracking succeeded @@ -409,9 +433,9 @@ def _process_tracking(self): detection2darray.detections = [detection_2d] # Create Detection3D if depth is available - if self._latest_depth_frame is not None: + if depth_frame is not None: # Calculate 3D position using depth and camera intrinsics - depth_value = self._get_depth_from_bbox(current_bbox_x1y1x2y2) + depth_value = self._get_depth_from_bbox(current_bbox_x1y1x2y2, depth_frame) if ( depth_value is not None and depth_value > 0 @@ -486,7 +510,7 @@ def _process_tracking(self): self.detection3darray.publish(detection3darray) # Create and publish visualization if tracking is active - if self.tracking_initialized and self._latest_rgb_frame is not None: + if self.tracking_initialized: # Convert single detection to list for visualization detections_3d = ( detection3darray.detections if detection3darray.detections_length > 0 else [] @@ -508,7 +532,7 @@ def _process_tracking(self): # Create visualization viz_image = visualize_detections_3d( - self._latest_rgb_frame, detections_3d, show_coordinates=True, bboxes_2d=bbox_2d + frame, detections_3d, show_coordinates=True, bboxes_2d=bbox_2d ) # Overlay REID feature matches if available @@ -545,7 +569,12 @@ def _draw_reid_matches(self, image: np.ndarray) -> np.ndarray: text = f"REID Matches: {len(self.last_good_matches)}/{len(self.last_roi_kps) if self.last_roi_kps else 0}" cv2.putText(viz_image, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) - if len(self.last_good_matches) >= self.reid_threshold: + if self.tracking_frame_count < self.reid_warmup_frames: + status_text = ( + f"REID: WARMING UP ({self.tracking_frame_count}/{self.reid_warmup_frames})" + ) + status_color = (255, 255, 0) # Yellow + elif len(self.last_good_matches) >= self.reid_threshold: status_text = "REID: CONFIRMED" status_color = (0, 255, 0) # Green else: @@ -558,21 +587,29 @@ def _draw_reid_matches(self, image: np.ndarray) -> np.ndarray: return viz_image - def _get_depth_from_bbox(self, bbox: List[int]) -> Optional[float]: - """Calculate depth from bbox using the 25th percentile of closest points.""" - if self._latest_depth_frame is None: + def _get_depth_from_bbox(self, bbox: List[int], depth_frame: np.ndarray) -> Optional[float]: + """Calculate depth from bbox using the 25th percentile of closest points. + + Args: + bbox: Bounding box coordinates [x1, y1, x2, y2] + depth_frame: Depth frame to extract depth values from + + Returns: + Depth value or None if not available + """ + if depth_frame is None: return None x1, y1, x2, y2 = bbox # Ensure bbox is within frame bounds y1 = max(0, y1) - y2 = min(self._latest_depth_frame.shape[0], y2) + y2 = min(depth_frame.shape[0], y2) x1 = max(0, x1) - x2 = min(self._latest_depth_frame.shape[1], x2) + x2 = min(depth_frame.shape[1], x2) # Extract depth values from the entire bbox - roi_depth = self._latest_depth_frame[y1:y2, x1:x2] + roi_depth = depth_frame[y1:y2, x1:x2] # Get valid (finite and positive) depth values valid_depths = roi_depth[np.isfinite(roi_depth) & (roi_depth > 0)] @@ -592,3 +629,8 @@ def cleanup(self): if self.tracking_thread and self.tracking_thread.is_alive(): self.stop_tracking.set() self.tracking_thread.join(timeout=2.0) + + # Unsubscribe from aligned frames + if self._aligned_frames_subscription: + self._aligned_frames_subscription.dispose() + self._aligned_frames_subscription = None diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index aa9b843569..1d1390bb6e 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Any import numpy as np +import cv2 from reactivex import Observable, disposable, just, interval from reactivex import operators as ops from datetime import datetime @@ -180,7 +181,9 @@ def start(self): def set_video(image_msg: Image): # Convert Image message to numpy array if hasattr(image_msg, "data"): - self._latest_video_frame = image_msg.data + frame = image_msg.data + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + self._latest_video_frame = frame else: logger.warning("Received image message without data attribute") diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index baace4be94..9e3fe8d1bd 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import threading import time from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, overload @@ -42,25 +43,19 @@ def call( self, name: str, arguments: Args, cb: Optional[Callable] ) -> Optional[Callable[[], Any]]: ... - # we bootstrap these from the call() implementation above - def call_sync(self, name: str, arguments: Args, timeout: float = 1.0) -> Any: - res = Empty - start_time = time.time() + # we expect to crash if we don't get a return value after 10 seconds + # but callers can override this timeout for extra long functions + def call_sync(self, name: str, arguments: Args, rpc_timeout: Optional[float] = 10.0) -> Any: + event = threading.Event() def receive_value(val): - nonlocal res - res = val + event.result = val # attach to event + event.set() self.call(name, arguments, receive_value) - - total_time = 0.0 - while res is Empty: - if time.time() - start_time > timeout: - print(f"RPC {name} timed out") - return None - time.sleep(0.05) - total_time += 0.1 - return res + if not event.wait(rpc_timeout): + raise TimeoutError(f"RPC call to '{name}' timed out after {rpc_timeout} seconds") + return event.result async def call_async(self, name: str, arguments: Args) -> Any: loop = asyncio.get_event_loop() diff --git a/dimos/protocol/rpc/test_lcmrpc_timeout.py b/dimos/protocol/rpc/test_lcmrpc_timeout.py new file mode 100644 index 0000000000..e7375ff8d4 --- /dev/null +++ b/dimos/protocol/rpc/test_lcmrpc_timeout.py @@ -0,0 +1,164 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +import pytest + +from dimos.protocol.rpc.lcmrpc import LCMRPC +from dimos.protocol.service.lcmservice import autoconf + + +@pytest.fixture(scope="session", autouse=True) +def setup_lcm_autoconf(): + """Setup LCM autoconf once for the entire test session""" + autoconf() + yield + + +@pytest.fixture +def lcm_server(): + """Fixture that provides started LCMRPC server""" + server = LCMRPC() + server.start() + + yield server + + server.stop() + + +@pytest.fixture +def lcm_client(): + """Fixture that provides started LCMRPC client""" + client = LCMRPC() + client.start() + + yield client + + client.stop() + + +def test_lcmrpc_timeout_no_reply(lcm_server, lcm_client): + """Test that RPC calls timeout when no reply is received""" + server = lcm_server + client = lcm_client + + # Track if the function was called + function_called = threading.Event() + + # Serve a function that never responds + def never_responds(a: int, b: int): + # Signal that the function was called + function_called.set() + # Simulating a server that receives the request but never sends a reply + time.sleep(1) # Long sleep to ensure timeout happens first + return a + b + + server.serve_rpc(never_responds, "slow_add") + + # Test with call_sync and explicit timeout + start_time = time.time() + + # Should raise TimeoutError when timeout occurs + with pytest.raises(TimeoutError, match="RPC call to 'slow_add' timed out after 0.1 seconds"): + client.call_sync("slow_add", ([1, 2], {}), rpc_timeout=0.1) + + elapsed = time.time() - start_time + + # Should timeout after ~0.1 seconds + assert elapsed < 0.3, f"Timeout took too long: {elapsed}s" + + # Verify the function was actually called + assert function_called.wait(0.5), "Server function was never called" + + +def test_lcmrpc_timeout_nonexistent_service(lcm_client): + """Test that RPC calls timeout when calling a non-existent service""" + client = lcm_client + + # Call a service that doesn't exist + start_time = time.time() + + # Should raise TimeoutError when timeout occurs + with pytest.raises( + TimeoutError, match="RPC call to 'nonexistent/service' timed out after 0.1 seconds" + ): + client.call_sync("nonexistent/service", ([1, 2], {}), rpc_timeout=0.1) + + elapsed = time.time() - start_time + + # Should timeout after ~0.1 seconds + assert elapsed < 0.3, f"Timeout took too long: {elapsed}s" + + +def test_lcmrpc_callback_with_timeout(lcm_server, lcm_client): + """Test that callback-based RPC calls handle timeouts properly""" + server = lcm_server + client = lcm_client + # Track if the function was called + function_called = threading.Event() + + # Serve a function that never responds + def never_responds(a: int, b: int): + function_called.set() + time.sleep(1) + return a + b + + server.serve_rpc(never_responds, "slow_add") + + callback_called = threading.Event() + received_value = [] + + def callback(value): + received_value.append(value) + callback_called.set() + + # Make the call with callback + unsub = client.call("slow_add", ([1, 2], {}), callback) + + # Wait for a short time - callback should not be called + callback_called.wait(0.2) + assert not callback_called.is_set(), "Callback should not have been called" + assert len(received_value) == 0 + + # Verify the server function was actually called + assert function_called.wait(0.5), "Server function was never called" + + # Clean up - unsubscribe if possible + if unsub: + unsub() + + +def test_lcmrpc_normal_operation(lcm_server, lcm_client): + """Sanity check that normal RPC calls still work""" + server = lcm_server + client = lcm_client + + def quick_add(a: int, b: int): + return a + b + + server.serve_rpc(quick_add, "add") + + # Normal call should work quickly + start_time = time.time() + result = client.call_sync("add", ([5, 3], {}), rpc_timeout=0.5) + elapsed = time.time() - start_time + + assert result == 8 + assert elapsed < 0.2, f"Normal call took too long: {elapsed}s" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index c06998bed0..1c7bb32101 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -93,23 +93,14 @@ def _prune_old_transforms(self) -> None: self._items.pop(0) def get( - self, time_point: Optional[float] = None, time_tolerance: Optional[float] = None + self, time_point: Optional[float] = None, time_tolerance: float = 1.0 ) -> Optional[Transform]: """Get transform at specified time or latest if no time given.""" if time_point is None: # Return the latest transform return self[-1] if len(self) > 0 else None - # Find closest transform within tolerance - closest = self.find_closest(time_point) - if closest is None: - return None - - if time_tolerance is not None: - if abs(closest.ts - time_point) > time_tolerance: - return None - - return closest + return self.find_closest(time_point, time_tolerance) def __str__(self) -> str: if not self._items: diff --git a/dimos/robot/unitree_webrtc/camera_module.py b/dimos/robot/unitree_webrtc/camera_module.py index ef07c1be1f..4378024bf1 100644 --- a/dimos/robot/unitree_webrtc/camera_module.py +++ b/dimos/robot/unitree_webrtc/camera_module.py @@ -28,7 +28,6 @@ from dimos.msgs.std_msgs import Header from dimos.protocol.tf import TF from dimos.utils.logging_config import setup_logger -from dimos.perception.common.utils import colorize_depth logger = setup_logger(__name__) @@ -38,23 +37,19 @@ class UnitreeCameraModule(Module): Camera module for Unitree Go2 that processes RGB images to generate depth using Metric3D. Subscribes to: - - /video: RGB camera images from Unitree + - /go2/color_image: RGB camera images from Unitree Publishes: - - /go2/color_image: RGB camera images - /go2/depth_image: Depth images generated by Metric3D - - /go2/depth_colorized: Colorized depth images with statistics overlay - /go2/camera_info: Camera calibration information - /go2/camera_pose: Camera pose from TF lookup """ # LCM inputs - video: In[Image] = None + color_image: In[Image] = None # LCM outputs - color_image: Out[Image] = None depth_image: Out[Image] = None - depth_colorized: Out[Image] = None camera_info: Out[CameraInfo] = None camera_pose: Out[PoseStamped] = None @@ -117,7 +112,7 @@ def start(self): self._running = True # Subscribe to video input - self.video.subscribe(self._on_video) + self.color_image.subscribe(self._on_video) # Start processing thread self._start_processing_thread() @@ -213,42 +208,20 @@ def _publish_synchronized_data(self): logger.debug("Publishing synchronized camera data") - # Publish color image - color_msg = Image( - data=self._last_image, - format=ImageFormat.RGB, - frame_id=header.frame_id, - ts=header.ts, - ) - self.color_image.publish(color_msg) - logger.debug(f"Published color image: shape={self._last_image.shape}") - # Publish depth image if self._last_depth is not None: + # Convert depth to uint16 (millimeters) for more efficient storage + # Clamp to valid range [0, 65.535] meters before converting + depth_clamped = np.clip(self._last_depth, 0, 65.535) + depth_uint16 = (depth_clamped * 1000).astype(np.uint16) depth_msg = Image( - data=self._last_depth, - format=ImageFormat.DEPTH, + data=depth_uint16, + format=ImageFormat.DEPTH16, # Use DEPTH16 format for uint16 depth frame_id=header.frame_id, ts=header.ts, ) self.depth_image.publish(depth_msg) - logger.debug(f"Published depth image: shape={self._last_depth.shape}") - - # Publish colorized depth image - depth_colorized_array = colorize_depth( - self._last_depth, max_depth=10.0, overlay_stats=True - ) - if depth_colorized_array is not None: - depth_colorized_msg = Image( - data=depth_colorized_array, - format=ImageFormat.RGB, - frame_id=header.frame_id, - ts=header.ts, - ) - self.depth_colorized.publish(depth_colorized_msg) - logger.debug( - f"Published colorized depth image: shape={depth_colorized_array.shape}" - ) + logger.debug(f"Published depth image (uint16): shape={depth_uint16.shape}") # Publish camera info self._publish_camera_info(header) diff --git a/dimos/robot/unitree_webrtc/run.py b/dimos/robot/unitree_webrtc/run.py index aca66ab654..b127bfc9a2 100644 --- a/dimos/robot/unitree_webrtc/run.py +++ b/dimos/robot/unitree_webrtc/run.py @@ -79,6 +79,7 @@ def main(): # Create robot instance robot = UnitreeGo2( ip=os.getenv("ROBOT_IP"), + connection_type=os.getenv("CONNECTION_TYPE", "webrtc"), ) robot.start() diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 13f59ec20b..c45bbd6fde 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -307,7 +307,7 @@ def _deploy_connection(self): self.connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) - self.connection.video.transport = core.LCMTransport("/video", Image) + self.connection.video.transport = core.LCMTransport("/go2/color_image", Image) self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", Vector3) def _deploy_mapping(self): @@ -429,17 +429,11 @@ def _deploy_camera(self): # Set up transports self.camera_module.color_image.transport = core.LCMTransport("/go2/color_image", Image) self.camera_module.depth_image.transport = core.LCMTransport("/go2/depth_image", Image) - self.camera_module.depth_colorized.transport = core.LCMTransport( - "/go2/depth_colorized", Image - ) self.camera_module.camera_info.transport = core.LCMTransport("/go2/camera_info", CameraInfo) self.camera_module.camera_pose.transport = core.LCMTransport( "/go2/camera_pose", PoseStamped ) - # Connect video input from connection module - self.camera_module.video.connect(self.connection.video) - logger.info("Camera module deployed and connected") # Connect object tracker inputs after camera module is deployed @@ -527,6 +521,7 @@ def stop_exploration(self) -> bool: Returns: True if exploration was stopped """ + self.navigator.cancel_goal() return self.frontier_explorer.stop_exploration() def cancel_navigation(self) -> bool: @@ -586,7 +581,7 @@ def navigate_to_object(self, bbox: List[float], distance: float = 0.5, timeout: logger.info("Object tracking goal reached") return True - if not self.object_tracker.is_tracking(): + if goal_set and not self.object_tracker.is_tracking(): continue detection_topic = Topic("/go2/detection3d", Detection3DArray) diff --git a/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py b/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py new file mode 100644 index 0000000000..da0beb252c --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_go2_nav_only.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import logging +import os +import time +import warnings +from typing import Optional + +from dimos import core +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3, Quaternion +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.sensor_msgs import Image +from dimos_lcm.std_msgs import String, Bool +from dimos.protocol import pubsub +from dimos.protocol.pubsub.lcmpubsub import LCM +from dimos.protocol.tf import TF +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.navigation.global_planner import AstarPlanner +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState +from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer +from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay +from dimos.robot.robot import Robot +from dimos.types.robot_capabilities import RobotCapability + + +logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2_nav_only", level=logging.INFO) + +# Suppress verbose loggers +logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("websockets.server").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) +logging.getLogger("root").setLevel(logging.WARNING) + +# Suppress warnings +warnings.filterwarnings("ignore", message="coroutine.*was never awaited") +warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") + + +class FakeRTC: + """Fake WebRTC connection for testing with recorded data.""" + + def __init__(self, *args, **kwargs): + get_data("unitree_office_walk") # Preload data for testing + + def connect(self): + pass + + def standup(self): + print("standup suppressed") + + def liedown(self): + print("liedown suppressed") + + @functools.cache + def lidar_stream(self): + print("lidar stream start") + lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) + return lidar_store.stream() + + @functools.cache + def odom_stream(self): + print("odom stream start") + odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + return odom_store.stream() + + @functools.cache + def video_stream(self): + print("video stream start") + video_store = TimedSensorReplay( + "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() + ) + return video_store.stream() + + def move(self, vector: Vector3, duration: float = 0.0): + pass + + def publish_request(self, topic: str, data: dict): + """Fake publish request for testing.""" + return {"status": "ok", "message": "Fake publish"} + + +class ConnectionModule(Module): + """Module that handles robot sensor data and movement commands.""" + + movecmd: In[Vector3] = None + odom: Out[PoseStamped] = None + lidar: Out[LidarMessage] = None + video: Out[Image] = None + ip: str + connection_type: str = "webrtc" + + _odom: PoseStamped = None + _lidar: LidarMessage = None + + def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwargs): + self.ip = ip + self.connection_type = connection_type + self.tf = TF() + self.connection = None + Module.__init__(self, *args, **kwargs) + + @rpc + def start(self): + """Start the connection and subscribe to sensor streams.""" + match self.connection_type: + case "webrtc": + self.connection = UnitreeWebRTCConnection(self.ip) + case "fake": + self.connection = FakeRTC(self.ip) + case "mujoco": + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection() + self.connection.start() + case _: + raise ValueError(f"Unknown connection type: {self.connection_type}") + + # Connect sensor streams to outputs + self.connection.lidar_stream().subscribe(self.lidar.publish) + self.connection.odom_stream().subscribe(self._publish_tf) + self.connection.video_stream().subscribe(self.video.publish) + self.movecmd.subscribe(self.move) + + def _publish_tf(self, msg): + self._odom = msg + self.odom.publish(msg) + self.tf.publish(Transform.from_pose("base_link", msg)) + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=time.time(), + ) + self.tf.publish(camera_link) + + @rpc + def get_odom(self) -> Optional[PoseStamped]: + """Get the robot's odometry. + + Returns: + The robot's odometry + """ + return self._odom + + @rpc + def move(self, vector: Vector3, duration: float = 0.0): + """Send movement command to robot.""" + self.connection.move(vector, duration) + + @rpc + def standup(self): + """Make the robot stand up.""" + return self.connection.standup() + + @rpc + def liedown(self): + """Make the robot lie down.""" + return self.connection.liedown() + + @rpc + def publish_request(self, topic: str, data: dict): + """Publish a request to the WebRTC connection. + Args: + topic: The RTC topic to publish to + data: The data dictionary to publish + Returns: + The result of the publish request + """ + return self.connection.publish_request(topic, data) + + +class UnitreeGo2NavOnly(Robot): + """Minimal Unitree Go2 robot with only navigation and visualization capabilities.""" + + def __init__( + self, + ip: str, + websocket_port: int = 7779, + connection_type: Optional[str] = "webrtc", + ): + """Initialize the navigation-only robot system. + + Args: + ip: Robot IP address (or None for fake connection) + websocket_port: Port for web visualization + connection_type: webrtc, fake, or mujoco + """ + super().__init__() + self.ip = ip + self.connection_type = connection_type or "webrtc" + if ip is None and self.connection_type == "webrtc": + self.connection_type = "fake" # Auto-enable playback if no IP provided + self.websocket_port = websocket_port + self.lcm = LCM() + + # Set capabilities - navigation only + self.capabilities = [RobotCapability.LOCOMOTION] + + self.dimos = None + self.connection = None + self.mapper = None + self.global_planner = None + self.local_planner = None + self.navigator = None + self.frontier_explorer = None + self.websocket_vis = None + + def start(self): + """Start the robot system with navigation modules only.""" + self.dimos = core.start(8) + + self._deploy_connection() + self._deploy_mapping() + self._deploy_navigation() + self._deploy_visualization() + + self._start_modules() + + self.lcm.start() + + logger.info("UnitreeGo2NavOnly initialized and started") + logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") + + def _deploy_connection(self): + """Deploy and configure the connection module.""" + self.connection = self.dimos.deploy( + ConnectionModule, self.ip, connection_type=self.connection_type + ) + + self.connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) + self.connection.video.transport = core.LCMTransport("/go2/color_image", Image) + self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", Vector3) + + def _deploy_mapping(self): + """Deploy and configure the mapping module.""" + min_height = 0.3 if self.connection_type == "mujoco" else 0.15 + self.mapper = self.dimos.deploy( + Map, voxel_size=0.5, global_publish_interval=2.5, min_height=min_height + ) + + self.mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) + self.mapper.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) + self.mapper.local_costmap.transport = core.LCMTransport("/local_costmap", OccupancyGrid) + + self.mapper.lidar.connect(self.connection.lidar) + + def _deploy_navigation(self): + """Deploy and configure navigation modules.""" + self.global_planner = self.dimos.deploy(AstarPlanner) + self.local_planner = self.dimos.deploy(HolonomicLocalPlanner) + self.navigator = self.dimos.deploy( + BehaviorTreeNavigator, + reset_local_planner=self.local_planner.reset, + check_goal_reached=self.local_planner.is_goal_reached, + ) + self.frontier_explorer = self.dimos.deploy(WavefrontFrontierExplorer) + + self.navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) + self.navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) + self.navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) + self.navigator.navigation_state.transport = core.LCMTransport("/navigation_state", String) + self.navigator.global_costmap.transport = core.LCMTransport( + "/global_costmap", OccupancyGrid + ) + self.global_planner.path.transport = core.LCMTransport("/global_path", Path) + self.local_planner.cmd_vel.transport = core.LCMTransport("/cmd_vel", Vector3) + self.frontier_explorer.goal_request.transport = core.LCMTransport( + "/goal_request", PoseStamped + ) + self.frontier_explorer.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) + self.frontier_explorer.explore_cmd.transport = core.LCMTransport("/explore_cmd", Bool) + self.frontier_explorer.stop_explore_cmd.transport = core.LCMTransport( + "/stop_explore_cmd", Bool + ) + + self.global_planner.target.connect(self.navigator.goal) + + self.global_planner.global_costmap.connect(self.mapper.global_costmap) + self.global_planner.odom.connect(self.connection.odom) + + self.local_planner.path.connect(self.global_planner.path) + self.local_planner.local_costmap.connect(self.mapper.local_costmap) + self.local_planner.odom.connect(self.connection.odom) + + self.connection.movecmd.connect(self.local_planner.cmd_vel) + + self.navigator.odom.connect(self.connection.odom) + + self.frontier_explorer.costmap.connect(self.mapper.global_costmap) + self.frontier_explorer.odometry.connect(self.connection.odom) + + def _deploy_visualization(self): + """Deploy and configure visualization modules.""" + self.websocket_vis = self.dimos.deploy(WebsocketVisModule, port=self.websocket_port) + self.websocket_vis.click_goal.transport = core.LCMTransport("/goal_request", PoseStamped) + + self.websocket_vis.robot_pose.connect(self.connection.odom) + self.websocket_vis.path.connect(self.global_planner.path) + self.websocket_vis.global_costmap.connect(self.mapper.global_costmap) + + def _start_modules(self): + """Start all deployed modules in the correct order.""" + self.connection.start() + self.mapper.start() + self.global_planner.start() + self.local_planner.start() + self.navigator.start() + self.frontier_explorer.start() + self.websocket_vis.start() + + def move(self, vector: Vector3, duration: float = 0.0): + """Send movement command to robot.""" + self.connection.move(vector, duration) + + def explore(self) -> bool: + """Start autonomous frontier exploration. + + Returns: + True if exploration started successfully + """ + return self.frontier_explorer.explore() + + def navigate_to(self, pose: PoseStamped, blocking: bool = True): + """Navigate to a target pose. + + Args: + pose: Target pose to navigate to + blocking: If True, block until goal is reached. If False, return immediately. + + Returns: + If blocking=True: True if navigation was successful, False otherwise + If blocking=False: True if goal was accepted, False otherwise + """ + + logger.info( + f"Navigating to pose: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" + ) + self.navigator.set_goal(pose) + time.sleep(1.0) + + if blocking: + while self.navigator.get_state() == NavigatorState.FOLLOWING_PATH: + time.sleep(0.25) + + time.sleep(1.0) + if not self.navigator.is_goal_reached(): + logger.info("Navigation was cancelled or failed") + return False + else: + logger.info("Navigation goal reached") + return True + + return True + + def stop_exploration(self) -> bool: + """Stop autonomous exploration. + + Returns: + True if exploration was stopped + """ + self.navigator.cancel_goal() + return self.frontier_explorer.stop_exploration() + + def cancel_navigation(self) -> bool: + """Cancel the current navigation goal. + + Returns: + True if goal was cancelled + """ + return self.navigator.cancel_goal() + + def get_odom(self) -> PoseStamped: + """Get the robot's odometry. + + Returns: + The robot's odometry + """ + return self.connection.get_odom() + + +def main(): + """Main entry point.""" + ip = os.getenv("ROBOT_IP") + connection_type = os.getenv("CONNECTION_TYPE", "webrtc") + + pubsub.lcm.autoconf() + + robot = UnitreeGo2NavOnly(ip=ip, websocket_port=7779, connection_type=connection_type) + robot.start() + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Shutting down...") + + +if __name__ == "__main__": + main() diff --git a/dimos/skills/navigation.py b/dimos/skills/navigation.py index c6b51b2ddd..bb3741ff31 100644 --- a/dimos/skills/navigation.py +++ b/dimos/skills/navigation.py @@ -23,7 +23,7 @@ import os import time from typing import Optional, Tuple - +import cv2 from pydantic import Field from dimos.skills.skills import AbstractRobotSkill @@ -31,7 +31,7 @@ from dimos.utils.logging_config import setup_logger from dimos.models.qwen.video_query import get_bbox_from_qwen_frame from dimos.msgs.geometry_msgs import PoseStamped, Vector3 -from dimos.utils.transform_utils import euler_to_quaternion +from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler logger = setup_logger("dimos.skills.semantic_map_skills") @@ -87,7 +87,7 @@ def __init__(self, robot=None, **data): """ super().__init__(robot=robot, **data) self._spatial_memory = None - self._similarity_threshold = 0.24 + self._similarity_threshold = 0.23 def _navigate_to_object(self): """ @@ -104,7 +104,8 @@ def _navigate_to_object(self): bbox = None try: # Get a single frame from the robot's camera - frame = self._robot.get_single_rgb_frame() + frame = self._robot.get_single_rgb_frame().data + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) if frame is None: logger.error("Failed to get camera frame") return { @@ -112,7 +113,7 @@ def _navigate_to_object(self): "failure_reason": "Perception", "error": "Could not get camera frame", } - bbox = get_bbox_from_qwen_frame(frame.data, object_name=self.query) + bbox = get_bbox_from_qwen_frame(frame, object_name=self.query) except Exception as e: logger.error(f"Error getting frame or bbox: {e}") return { @@ -357,11 +358,11 @@ def __call__(self): try: # Get the current pose using the robot's get_pose method - pose_data = self._robot.get_pose() + pose_data = self._robot.get_odom() # Extract position and rotation from the new dictionary format - position = pose_data["position"] - rotation = pose_data["rotation"] + position = pose_data.position + rotation = quaternion_to_euler(pose_data.orientation) # Format the response result = { diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index 7f043750ea..d723421c6a 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -12,11 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time from datetime import datetime, timezone import pytest - -from dimos.types.timestamped import Timestamped, TimestampedCollection, to_datetime, to_ros_stamp +from reactivex import operators as ops + +from dimos.msgs.sensor_msgs import Image +from dimos.types.timestamped import ( + Timestamped, + TimestampedBufferCollection, + TimestampedCollection, + align_timestamped, + to_datetime, + to_ros_stamp, +) +from dimos.utils import testing +from dimos.utils.data import get_data +from dimos.utils.reactive import backpressure def test_timestamped_dt_method(): @@ -145,19 +158,24 @@ def test_find_closest(collection): assert collection.find_closest(3.0).data == "third" # Between items (closer to left) - assert collection.find_closest(1.5).data == "first" + assert collection.find_closest(1.5, tolerance=1.0).data == "first" # Between items (closer to right) - assert collection.find_closest(3.5).data == "third" + assert collection.find_closest(3.5, tolerance=1.0).data == "third" # Exactly in the middle (should pick the later one due to >= comparison) - assert collection.find_closest(4.0).data == "fifth" # 4.0 is equidistant from 3.0 and 5.0 + assert ( + collection.find_closest(4.0, tolerance=1.0).data == "fifth" + ) # 4.0 is equidistant from 3.0 and 5.0 # Before all items - assert collection.find_closest(0.0).data == "first" + assert collection.find_closest(0.0, tolerance=1.0).data == "first" # After all items - assert collection.find_closest(10.0).data == "seventh" + assert collection.find_closest(10.0, tolerance=4.0).data == "seventh" + + # low tolerance, should return None + assert collection.find_closest(10.0, tolerance=2.0) is None def test_find_before_after(collection): @@ -223,4 +241,85 @@ def test_single_item_collection(): single = TimestampedCollection([SimpleTimestamped(5.0, "only")]) assert single.duration() == 0.0 assert single.time_range() == (5.0, 5.0) - assert single.find_closest(100.0).data == "only" + + +def test_time_window_collection(): + # Create a collection with a 2-second window + window = TimestampedBufferCollection[SimpleTimestamped](window_duration=2.0) + + # Add messages at different timestamps + window.add(SimpleTimestamped(1.0, "msg1")) + window.add(SimpleTimestamped(2.0, "msg2")) + window.add(SimpleTimestamped(3.0, "msg3")) + + # At this point, all messages should be present (within 2s window) + assert len(window) == 3 + + # Add a message at t=4.0, should keep messages from t=2.0 onwards + window.add(SimpleTimestamped(4.0, "msg4")) + assert len(window) == 3 # msg1 should be dropped + assert window[0].data == "msg2" # oldest is now msg2 + assert window[-1].data == "msg4" # newest is msg4 + + # Add a message at t=5.5, should drop msg2 and msg3 + window.add(SimpleTimestamped(5.5, "msg5")) + assert len(window) == 2 # only msg4 and msg5 remain + assert window[0].data == "msg4" + assert window[1].data == "msg5" + + # Verify time range + assert window.start_ts == 4.0 + assert window.end_ts == 5.5 + + +def test_timestamp_alignment(): + speed = 5.0 + + # ensure that lfs package is downloaded + get_data("unitree_office_walk") + + raw_frames = [] + + def spy(image): + raw_frames.append(image.ts) + print(image.ts) + return image + + # sensor reply of raw video frames + video_raw = ( + testing.TimedSensorReplay( + "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() + ) + .stream(speed) + .pipe(ops.take(30)) + ) + + processed_frames = [] + + def process_video_frame(frame): + processed_frames.append(frame.ts) + print("PROCESSING", frame.ts) + time.sleep(0.5 / speed) + return frame + + # fake reply of some 0.5s processor of video frames that drops messages + fake_video_processor = backpressure(video_raw.pipe(ops.map(spy))).pipe( + ops.map(process_video_frame) + ) + + aligned_frames = align_timestamped(fake_video_processor, video_raw).pipe(ops.to_list()).run() + + assert len(raw_frames) == 30 + assert len(processed_frames) > 2 + assert len(aligned_frames) > 2 + + # Due to async processing, the last frame might not be aligned before completion + assert len(aligned_frames) >= len(processed_frames) - 1 + + for value in aligned_frames: + [primary, secondary] = value + diff = abs(primary.ts - secondary.ts) + print( + f"Aligned pair: primary={primary.ts:.6f}, secondary={secondary.ts:.6f}, diff={diff:.6f}s" + ) + assert diff <= 0.05 diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index 858a2bdaad..6446c5167b 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -16,6 +16,7 @@ from datetime import datetime, timezone from typing import Generic, Iterable, Optional, Tuple, TypedDict, TypeVar, Union +from reactivex.observable import Observable from sortedcontainers import SortedList # any class that carries a timestamp should inherit from this @@ -102,26 +103,42 @@ def add(self, item: T) -> None: """Add a timestamped item to the collection.""" self._items.add(item) - def find_closest(self, timestamp: float) -> Optional[T]: + def find_closest(self, timestamp: float, tolerance: Optional[float] = None) -> Optional[T]: """Find the timestamped object closest to the given timestamp.""" if not self._items: return None - # Find insertion point using binary search on timestamps + # Use binary search to find insertion point timestamps = [item.ts for item in self._items] idx = bisect.bisect_left(timestamps, timestamp) - # Check boundaries - if idx == 0: - return self._items[0] - if idx == len(self._items): - return self._items[-1] + # Check exact match + if idx < len(self._items) and self._items[idx].ts == timestamp: + return self._items[idx] - # Compare distances to neighbors - left_diff = abs(timestamp - self._items[idx - 1].ts) - right_diff = abs(self._items[idx].ts - timestamp) + # Find candidates: item before and after + candidates = [] - return self._items[idx - 1] if left_diff < right_diff else self._items[idx] + # Item before + if idx > 0: + candidates.append((idx - 1, abs(self._items[idx - 1].ts - timestamp))) + + # Item after + if idx < len(self._items): + candidates.append((idx, abs(self._items[idx].ts - timestamp))) + + if not candidates: + return None + + # Find closest + # When distances are equal, prefer the later item (higher index) + closest_idx, closest_distance = min(candidates, key=lambda x: (x[1], -x[0])) + + # Check tolerance if provided + if tolerance is not None and closest_distance > tolerance: + return None + + return self._items[closest_idx] def find_before(self, timestamp: float) -> Optional[T]: """Find the last item before the given timestamp.""" @@ -178,3 +195,77 @@ def __iter__(self): def __getitem__(self, idx: int) -> T: return self._items[idx] + + +PRIMARY = TypeVar("PRIMARY", bound=Timestamped) +SECONDARY = TypeVar("SECONDARY", bound=Timestamped) + + +class TimestampedBufferCollection(TimestampedCollection[T]): + """A timestamped collection that maintains a sliding time window, dropping old messages.""" + + def __init__(self, window_duration: float, items: Optional[Iterable[T]] = None): + """ + Initialize with a time window duration in seconds. + + Args: + window_duration: Maximum age of messages to keep in seconds + items: Optional initial items + """ + super().__init__(items) + self.window_duration = window_duration + + def add(self, item: T) -> None: + """Add a timestamped item and remove any items outside the time window.""" + super().add(item) + self._prune_old_messages(item.ts) + + def _prune_old_messages(self, current_ts: float) -> None: + """Remove messages older than window_duration from the given timestamp.""" + cutoff_ts = current_ts - self.window_duration + + # Find the index of the first item that should be kept + timestamps = [item.ts for item in self._items] + keep_idx = bisect.bisect_left(timestamps, cutoff_ts) + + # Remove old items + if keep_idx > 0: + # Create new SortedList with items to keep + self._items = SortedList(self._items[keep_idx:], key=lambda x: x.ts) + + +def align_timestamped( + primary_observable: Observable[PRIMARY], + secondary_observable: Observable[SECONDARY], + buffer_size: float = 1.0, # seconds + match_tolerance: float = 0.05, # seconds +) -> Observable[Tuple[PRIMARY, SECONDARY]]: + from reactivex import create + + def subscribe(observer, scheduler=None): + secondary_collection: TimestampedBufferCollection[SECONDARY] = TimestampedBufferCollection( + buffer_size + ) + # Subscribe to secondary to populate the buffer + secondary_sub = secondary_observable.subscribe(secondary_collection.add) + + def on_primary(primary_item: PRIMARY): + secondary_item = secondary_collection.find_closest( + primary_item.ts, tolerance=match_tolerance + ) + if secondary_item is not None: + observer.on_next((primary_item, secondary_item)) + + # Subscribe to primary and emit aligned pairs + primary_sub = primary_observable.subscribe( + on_next=on_primary, on_error=observer.on_error, on_completed=observer.on_completed + ) + + # Return cleanup function + def dispose(): + secondary_sub.dispose() + primary_sub.dispose() + + return dispose + + return create(subscribe)