diff --git a/.envrc b/.envrc deleted file mode 100644 index a73ae4a035..0000000000 --- a/.envrc +++ /dev/null @@ -1,5 +0,0 @@ -if ! has nix_direnv_version || ! nix_direnv_version 3.0.6; then - source_url "https://raw.githubusercontent.com/nix-community/nix-direnv/3.0.6/direnvrc" "sha256-RYcUJaRMf8oF5LznDrlCXbkOQrywm0HDv1VjYGaJGdM=" -fi -use flake . -dotenv diff --git a/.envrc b/.envrc new file mode 120000 index 0000000000..6da2c886b2 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +.envrc.nix \ No newline at end of file diff --git a/.gitignore b/.gitignore index adc50a7ef6..12cb51509a 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,6 @@ FastSAM-x.pt yolo11n.pt /thread_monitor_report.csv + +# symlink one of .envrc.* if you'd like to use +.envrc diff --git a/dimos/msgs/sensor_msgs/test_image.py b/dimos/msgs/sensor_msgs/test_image.py index 0af559904a..6fa0b9d37b 100644 --- a/dimos/msgs/sensor_msgs/test_image.py +++ b/dimos/msgs/sensor_msgs/test_image.py @@ -118,6 +118,8 @@ def track_output(img): # Emit images at 100Hz to get ~5 per window from reactivex import from_iterable, interval + window_duration = 0.05 # 20Hz = 0.05s windows + source = from_iterable(mock_images).pipe( ops.zip(interval(0.01)), # 100Hz emission rate ops.map(lambda x: x[0]), # Extract just the image @@ -132,28 +134,17 @@ def track_output(img): # Only need 0.08s for 1 full window at 20Hz plus buffer time.sleep(0.08) - # Verify we got correct emissions - assert len(emitted_images) >= 1, f"Expected at least 1 emission, got {len(emitted_images)}" + # Verify we got correct emissions (items span across 2 windows due to timing) + # Items 1-4 arrive in first window (0-50ms), item 5 arrives in second window (50-100ms) + assert len(emitted_images) == 2, ( + f"Expected exactly 2 emissions (one per window), got {len(emitted_images)}" + ) # Group inputs by wall-clock windows and verify we got the sharpest - window_duration = 0.05 # 20Hz - - # Test just the first window - for window_idx in range(min(1, len(emitted_images))): - window_start = window_idx * window_duration - window_end = window_start + window_duration - - # Get all images that arrived during this wall-clock window - window_imgs = [ - img for wall_time, img in window_contents if window_start <= wall_time < window_end - ] - - if window_imgs: - max_sharp = max(img.sharpness for img in window_imgs) - emitted_sharp = emitted_images[window_idx].sharpness - - # Verify we emitted the sharpest - assert abs(emitted_sharp - max_sharp) < 0.0001, ( - f"Window {window_idx}: Emitted image (sharp={emitted_sharp}) " - f"is not the sharpest (max={max_sharp})" - ) + + # Verify each window emitted the sharpest image from that window + # First window (0-50ms): items 1-4 + assert emitted_images[0].sharpness == 0.3711 # Highest among first 4 items + + # Second window (50-100ms): only item 5 + assert emitted_images[1].sharpness == 0.3665 # Only item in second window diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index e3c21eba40..4daee48002 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -12,50 +12,99 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading import time -from typing import List, Optional, Tuple -from unittest.mock import MagicMock import numpy as np import pytest from PIL import Image, ImageDraw -from reactivex import operators as ops -from dimos import core from dimos.msgs.geometry_msgs import PoseStamped, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid, CostValues +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid from dimos.navigation.frontier_exploration.utils import costmap_to_pil_image from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -def create_test_costmap(width=100, height=100, resolution=0.1): - """Create a simple test costmap with free, occupied, and unknown regions.""" +@pytest.fixture +def explorer(): + """Create a WavefrontFrontierExplorer instance for testing.""" + explorer = WavefrontFrontierExplorer( + min_frontier_perimeter=0.3, # Smaller for faster tests + safe_distance=0.5, # Smaller for faster distance calculations + info_gain_threshold=0.02, + ) + yield explorer + # Cleanup after test + try: + explorer.cleanup() + except: + pass + + +@pytest.fixture +def quick_costmap(): + """Create a very small costmap for quick tests.""" + width, height = 20, 20 grid = np.full((height, width), CostValues.UNKNOWN, dtype=np.int8) - # Create a larger free space region with more complex shape + # Simple free space in center + grid[8:12, 8:12] = CostValues.FREE + + # Small extensions + grid[9:11, 6:8] = CostValues.FREE # Left + grid[9:11, 12:14] = CostValues.FREE # Right + + # One obstacle + grid[9:10, 9:10] = CostValues.OCCUPIED + + from dimos.msgs.geometry_msgs import Pose + + origin = Pose() + origin.position.x = -1.0 + origin.position.y = -1.0 + origin.position.z = 0.0 + origin.orientation.w = 1.0 + + occupancy_grid = OccupancyGrid( + grid=grid, resolution=0.1, origin=origin, frame_id="map", ts=time.time() + ) + + class MockLidar: + def __init__(self): + self.origin = Vector3(0.0, 0.0, 0.0) + + return occupancy_grid, MockLidar() + + +def create_test_costmap(width=40, height=40, resolution=0.1): + """Create a simple test costmap with free, occupied, and unknown regions. + + Default size reduced from 100x100 to 40x40 for faster tests. + """ + grid = np.full((height, width), CostValues.UNKNOWN, dtype=np.int8) + + # Create a smaller free space region with simple shape # Central room - grid[40:60, 40:60] = CostValues.FREE + grid[15:25, 15:25] = CostValues.FREE - # Corridors extending from central room - grid[45:55, 20:40] = CostValues.FREE # Left corridor - grid[45:55, 60:80] = CostValues.FREE # Right corridor - grid[20:40, 45:55] = CostValues.FREE # Top corridor - grid[60:80, 45:55] = CostValues.FREE # Bottom corridor + # Small corridors extending from central room + grid[18:22, 10:15] = CostValues.FREE # Left corridor + grid[18:22, 25:30] = CostValues.FREE # Right corridor + grid[10:15, 18:22] = CostValues.FREE # Top corridor + grid[25:30, 18:22] = CostValues.FREE # Bottom corridor - # Add some obstacles - grid[48:52, 48:52] = CostValues.OCCUPIED # Central obstacle - grid[35:38, 45:55] = CostValues.OCCUPIED # Top corridor obstacle - grid[62:65, 45:55] = CostValues.OCCUPIED # Bottom corridor obstacle + # Add fewer obstacles for faster processing + grid[19:21, 19:21] = CostValues.OCCUPIED # Central obstacle + grid[13:14, 18:22] = CostValues.OCCUPIED # Top corridor obstacle - # Create origin at bottom-left + # Create origin at bottom-left, adjusted for map size from dimos.msgs.geometry_msgs import Pose origin = Pose() - origin.position.x = -5.0 # Center the map - origin.position.y = -5.0 + # Center the map around (0, 0) in world coordinates + origin.position.x = -(width * resolution) / 2.0 + origin.position.y = -(height * resolution) / 2.0 origin.position.z = 0.0 origin.orientation.w = 1.0 @@ -71,10 +120,10 @@ def __init__(self): return occupancy_grid, MockLidar() -def test_frontier_detection_with_office_lidar(): +def test_frontier_detection_with_office_lidar(explorer, quick_costmap): """Test frontier detection using a test costmap.""" # Get test costmap - costmap, first_lidar = create_test_costmap() + costmap, first_lidar = quick_costmap # Verify we have a valid costmap assert costmap is not None, "Costmap should not be None" @@ -86,9 +135,6 @@ def test_frontier_detection_with_office_lidar(): print(f"Free percent: {costmap.free_percent:.1f}%") print(f"Occupied percent: {costmap.occupied_percent:.1f}%") - # Initialize frontier explorer with default parameters - explorer = WavefrontFrontierExplorer() - # Set robot pose near the center of free space in the costmap # We'll use the lidar origin as a reasonable robot position robot_pose = first_lidar.origin @@ -115,17 +161,12 @@ def test_frontier_detection_with_office_lidar(): else: print("No frontiers detected - map may be fully explored or parameters too restrictive") - explorer.cleanup() # TODO: this should be a in try-finally - -def test_exploration_goal_selection(): +def test_exploration_goal_selection(explorer): """Test the complete exploration goal selection pipeline.""" - # Get test costmap + # Get test costmap - use regular size for more realistic test costmap, first_lidar = create_test_costmap() - # Initialize frontier explorer with default parameters - explorer = WavefrontFrontierExplorer() - # Use lidar origin as robot position robot_pose = first_lidar.origin @@ -152,24 +193,22 @@ def test_exploration_goal_selection(): else: print("No exploration goal selected - map may be fully explored") - explorer.cleanup() # TODO: this should be a in try-finally - -def test_exploration_session_reset(): +def test_exploration_session_reset(explorer): """Test exploration session reset functionality.""" # Get test costmap costmap, first_lidar = create_test_costmap() - # Initialize explorer and select a goal - explorer = WavefrontFrontierExplorer() + # Use lidar origin as robot position robot_pose = first_lidar.origin # Select a goal to populate exploration state goal = explorer.get_exploration_goal(robot_pose, costmap) - # Verify state is populated - initial_explored_count = len(explorer.explored_goals) - initial_direction = explorer.exploration_direction + # Verify state is populated (skip if no goals available) + if goal: + initial_explored_count = len(explorer.explored_goals) + assert initial_explored_count > 0, "Should have at least one explored goal" # Reset exploration session explorer.reset_exploration_session() @@ -183,19 +222,13 @@ def test_exploration_session_reset(): assert explorer.no_gain_counter == 0, "No-gain counter should be reset" print("Exploration session reset successfully") - explorer.cleanup() # TODO: this should be a in try-finally -def test_frontier_ranking(): +def test_frontier_ranking(explorer): """Test frontier ranking and scoring logic.""" # Get test costmap costmap, first_lidar = create_test_costmap() - # Initialize explorer with custom parameters - explorer = WavefrontFrontierExplorer( - min_frontier_perimeter=0.5, safe_distance=0.5, info_gain_threshold=0.02 - ) - robot_pose = first_lidar.origin # Get first set of frontiers @@ -234,8 +267,6 @@ def test_frontier_ranking(): else: print("No frontiers found for ranking test") - explorer.cleanup() # TODO: this should be a in try-finally - def test_exploration_with_no_gain_detection(): """Test information gain detection and exploration termination.""" @@ -245,34 +276,35 @@ def test_exploration_with_no_gain_detection(): # Initialize explorer with low no-gain threshold for testing explorer = WavefrontFrontierExplorer(info_gain_threshold=0.01, num_no_gain_attempts=2) - robot_pose = first_lidar.origin - - # Select multiple goals to populate history - for i in range(6): - goal = explorer.get_exploration_goal(robot_pose, costmap1) - if goal: - print(f"Goal {i + 1}: ({goal.x:.2f}, {goal.y:.2f})") + try: + robot_pose = first_lidar.origin - # Now use same costmap repeatedly to trigger no-gain detection - initial_counter = explorer.no_gain_counter + # Select multiple goals to populate history + for i in range(6): + goal = explorer.get_exploration_goal(robot_pose, costmap1) + if goal: + print(f"Goal {i + 1}: ({goal.x:.2f}, {goal.y:.2f})") - # This should increment no-gain counter - goal = explorer.get_exploration_goal(robot_pose, costmap1) - assert explorer.no_gain_counter > initial_counter, "No-gain counter should increment" + # Now use same costmap repeatedly to trigger no-gain detection + initial_counter = explorer.no_gain_counter - # Continue until exploration stops - for _ in range(3): + # This should increment no-gain counter goal = explorer.get_exploration_goal(robot_pose, costmap1) - if goal is None: - break + assert explorer.no_gain_counter > initial_counter, "No-gain counter should increment" - # Should have stopped due to no information gain - assert goal is None, "Exploration should stop after no-gain threshold" - assert explorer.no_gain_counter == 0, "Counter should reset after stopping" + # Continue until exploration stops + for _ in range(3): + goal = explorer.get_exploration_goal(robot_pose, costmap1) + if goal is None: + break - print("No-gain detection test passed") + # Should have stopped due to no information gain + assert goal is None, "Exploration should stop after no-gain threshold" + assert explorer.no_gain_counter == 0, "Counter should reset after stopping" - explorer.cleanup() # TODO: this should be a in try-finally + print("No-gain detection test passed") + finally: + explorer.cleanup() @pytest.mark.vis @@ -284,78 +316,138 @@ def test_frontier_detection_visualization(): # Initialize frontier explorer with default parameters explorer = WavefrontFrontierExplorer() - # Use lidar origin as robot position - robot_pose = first_lidar.origin + try: + # Use lidar origin as robot position + robot_pose = first_lidar.origin + + # Detect all frontiers for visualization + all_frontiers = explorer.detect_frontiers(robot_pose, costmap) + + # Get selected goal + selected_goal = explorer.get_exploration_goal(robot_pose, costmap) + + print(f"Visualizing {len(all_frontiers)} frontier candidates") + if selected_goal: + print(f"Selected goal: ({selected_goal.x:.2f}, {selected_goal.y:.2f})") + + # Create visualization + image_scale_factor = 4 + base_image = costmap_to_pil_image(costmap, image_scale_factor) + + # Helper function to convert world coordinates to image coordinates + def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]: + grid_pos = costmap.world_to_grid(world_pos) + img_x = int(grid_pos.x * image_scale_factor) + img_y = int((costmap.height - grid_pos.y) * image_scale_factor) # Flip Y + return img_x, img_y + + # Draw visualization + draw = ImageDraw.Draw(base_image) + + # Draw frontier candidates as gray dots + for frontier in all_frontiers[:20]: # Limit to top 20 + x, y = world_to_image_coords(frontier) + radius = 6 + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(128, 128, 128), # Gray + outline=(64, 64, 64), + width=1, + ) - # Detect all frontiers for visualization - all_frontiers = explorer.detect_frontiers(robot_pose, costmap) + # Draw robot position as blue dot + robot_x, robot_y = world_to_image_coords(robot_pose) + robot_radius = 10 + draw.ellipse( + [ + robot_x - robot_radius, + robot_y - robot_radius, + robot_x + robot_radius, + robot_y + robot_radius, + ], + fill=(0, 0, 255), # Blue + outline=(0, 0, 128), + width=3, + ) - # Get selected goal - selected_goal = explorer.get_exploration_goal(robot_pose, costmap) + # Draw selected goal as red dot + if selected_goal: + goal_x, goal_y = world_to_image_coords(selected_goal) + goal_radius = 12 + draw.ellipse( + [ + goal_x - goal_radius, + goal_y - goal_radius, + goal_x + goal_radius, + goal_y + goal_radius, + ], + fill=(255, 0, 0), # Red + outline=(128, 0, 0), + width=3, + ) - print(f"Visualizing {len(all_frontiers)} frontier candidates") - if selected_goal: - print(f"Selected goal: ({selected_goal.x:.2f}, {selected_goal.y:.2f})") + # Display the image + base_image.show(title="Frontier Detection - Office Lidar") - # Create visualization - image_scale_factor = 4 - base_image = costmap_to_pil_image(costmap, image_scale_factor) + print("Visualization displayed. Close the image window to continue.") + finally: + explorer.cleanup() - # Helper function to convert world coordinates to image coordinates - def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]: - grid_pos = costmap.world_to_grid(world_pos) - img_x = int(grid_pos.x * image_scale_factor) - img_y = int((costmap.height - grid_pos.y) * image_scale_factor) # Flip Y - return img_x, img_y - # Draw visualization - draw = ImageDraw.Draw(base_image) +def test_performance_timing(): + """Test performance by timing frontier detection operations.""" + import time - # Draw frontier candidates as gray dots - for frontier in all_frontiers[:20]: # Limit to top 20 - x, y = world_to_image_coords(frontier) - radius = 6 - draw.ellipse( - [x - radius, y - radius, x + radius, y + radius], - fill=(128, 128, 128), # Gray - outline=(64, 64, 64), - width=1, - ) + # Test with different costmap sizes + sizes = [(20, 20), (40, 40), (60, 60)] + results = [] - # Draw robot position as blue dot - robot_x, robot_y = world_to_image_coords(robot_pose) - robot_radius = 10 - draw.ellipse( - [ - robot_x - robot_radius, - robot_y - robot_radius, - robot_x + robot_radius, - robot_y + robot_radius, - ], - fill=(0, 0, 255), # Blue - outline=(0, 0, 128), - width=3, - ) + for width, height in sizes: + # Create costmap of specified size + costmap, lidar = create_test_costmap(width, height) - # Draw selected goal as red dot - if selected_goal: - goal_x, goal_y = world_to_image_coords(selected_goal) - goal_radius = 12 - draw.ellipse( - [ - goal_x - goal_radius, - goal_y - goal_radius, - goal_x + goal_radius, - goal_y + goal_radius, - ], - fill=(255, 0, 0), # Red - outline=(128, 0, 0), - width=3, + # Create explorer with optimized parameters + explorer = WavefrontFrontierExplorer( + min_frontier_perimeter=0.3, + safe_distance=0.5, + info_gain_threshold=0.02, ) - # Display the image - base_image.show(title="Frontier Detection - Office Lidar") - - print("Visualization displayed. Close the image window to continue.") + try: + robot_pose = lidar.origin + + # Time frontier detection + start = time.time() + frontiers = explorer.detect_frontiers(robot_pose, costmap) + detect_time = time.time() - start + + # Time goal selection + start = time.time() + goal = explorer.get_exploration_goal(robot_pose, costmap) + goal_time = time.time() - start + + results.append( + { + "size": f"{width}x{height}", + "cells": width * height, + "detect_time": detect_time, + "goal_time": goal_time, + "frontiers": len(frontiers), + } + ) - explorer.cleanup() # TODO: this should be a in try-finally + print(f"\nSize {width}x{height}:") + print(f" Cells: {width * height}") + print(f" Frontier detection: {detect_time:.4f}s") + print(f" Goal selection: {goal_time:.4f}s") + print(f" Frontiers found: {len(frontiers)}") + finally: + explorer.cleanup() + + # Check that larger maps take more time (expected behavior) + # But verify times are reasonable + for result in results: + assert result["detect_time"] < 1.0, f"Detection too slow: {result['detect_time']}s" + assert result["goal_time"] < 1.5, f"Goal selection too slow: {result['goal_time']}s" + + print("\nPerformance test passed - all operations completed within time limits") diff --git a/dimos/perception/detection2d/__init__.py b/dimos/perception/detection2d/__init__.py index bdcf9ca827..6dc59e7366 100644 --- a/dimos/perception/detection2d/__init__.py +++ b/dimos/perception/detection2d/__init__.py @@ -1,3 +1,4 @@ +from dimos.perception.detection2d.detectors import * from dimos.perception.detection2d.module2D import ( Detection2DModule, ) @@ -5,4 +6,3 @@ Detection3DModule, ) from dimos.perception.detection2d.utils import * -from dimos.perception.detection2d.yolo_2d_det import * diff --git a/dimos/perception/detection2d/conftest.py b/dimos/perception/detection2d/conftest.py index b19cdcbd00..8ada4ec356 100644 --- a/dimos/perception/detection2d/conftest.py +++ b/dimos/perception/detection2d/conftest.py @@ -19,17 +19,20 @@ from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate from dimos_lcm.visualization_msgs.MarkerArray import MarkerArray +from dimos.core import LCMTransport from dimos.msgs.geometry_msgs import Transform -from dimos.msgs.sensor_msgs import CameraInfo -from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection2d.module2D import Detection2DModule from dimos.perception.detection2d.module3D import Detection3DModule from dimos.perception.detection2d.moduleDB import ObjectDBModule from dimos.perception.detection2d.type import ( Detection2D, Detection3D, + Detection3DPC, ImageDetections2D, ImageDetections3D, + ImageDetections3DPC, ) from dimos.protocol.tf import TF from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule @@ -47,7 +50,7 @@ class Moment(TypedDict, total=False): transforms: list[Transform] tf: TF annotations: Optional[ImageAnnotations] - detections: Optional[ImageDetections3D] + detections: Optional[ImageDetections3DPC] markers: Optional[MarkerArray] scene_update: Optional[SceneUpdate] @@ -57,7 +60,7 @@ class Moment2D(Moment): class Moment3D(Moment): - detections3d: ImageDetections3D + detections3dpc: ImageDetections3D @pytest.fixture @@ -102,6 +105,47 @@ def moment_provider(**kwargs) -> Moment: return moment_provider +@pytest.fixture +def publish_moment(): + def publisher(moment: Moment | Moment2D | Moment3D): + if moment.get("detections2d"): + # 2d annotations + annotations = LCMTransport("/annotations", ImageAnnotations) + annotations.publish(moment.get("detections2d").to_foxglove_annotations()) + + detections = LCMTransport("/detections", Detection2DArray) + detections.publish(moment.get("detections2d").to_ros_detection2d_array()) + + annotations.lcm.stop() + detections.lcm.stop() + + if moment.get("detections3dpc"): + scene_update = LCMTransport("/scene_update", SceneUpdate) + # 3d scene update + scene_update.publish(moment.get("detections3dpc").to_foxglove_scene_update()) + scene_update.lcm.stop() + + lidar = LCMTransport("/lidar", PointCloud2) + lidar.publish(moment.get("lidar_frame")) + lidar.lcm.stop() + + image = LCMTransport("/image", Image) + image.publish(moment.get("image_frame")) + image.lcm.stop() + + camera_info = LCMTransport("/camera_info", CameraInfo) + camera_info.publish(moment.get("camera_info")) + camera_info.lcm.stop() + + tf = moment.get("tf") + tf.publish(*moment.get("transforms")) + + # moduleDB.scene_update.transport = LCMTransport("/scene_update", SceneUpdate) + # moduleDB.target.transport = LCMTransport("/target", PoseStamped) + + return publisher + + @pytest.fixture def detection2d(get_moment_2d) -> Detection2D: moment = get_moment_2d(seek=10.0) @@ -110,11 +154,10 @@ def detection2d(get_moment_2d) -> Detection2D: @pytest.fixture -def detection3d(get_moment_3d) -> Detection3D: - moment = get_moment_3d(seek=10.0) - assert len(moment["detections3d"]) > 0, "No detections found in the moment" - print(moment["detections3d"]) - return moment["detections3d"][0] +def detection3dpc(get_moment_3dpc) -> Detection3DPC: + moment = get_moment_3dpc(seek=10.0) + assert len(moment["detections3dpc"]) > 0, "No detections found in the moment" + return moment["detections3dpc"][0] @pytest.fixture @@ -135,22 +178,22 @@ def moment_provider(**kwargs) -> Moment2D: @pytest.fixture -def get_moment_3d(get_moment_2d) -> Callable[[], Moment2D]: +def get_moment_3dpc(get_moment_2d) -> Callable[[], Moment2D]: module = None def moment_provider(**kwargs) -> Moment2D: nonlocal module moment = get_moment_2d(**kwargs) - module = Detection3DModule(camera_info=moment["camera_info"]) + if not module: + module = Detection3DModule(camera_info=moment["camera_info"]) camera_transform = moment["tf"].get("camera_optical", moment.get("lidar_frame").frame_id) if camera_transform is None: raise ValueError("No camera_optical transform in tf") - return { **moment, - "detections3d": module.process_frame( + "detections3dpc": module.process_frame( moment["detections2d"], moment["lidar_frame"], camera_transform ), } diff --git a/dimos/perception/detection2d/detectors/__init__.py b/dimos/perception/detection2d/detectors/__init__.py new file mode 100644 index 0000000000..287fff1a15 --- /dev/null +++ b/dimos/perception/detection2d/detectors/__init__.py @@ -0,0 +1,3 @@ +# from dimos.perception.detection2d.detectors.detic import Detic2DDetector +from dimos.perception.detection2d.detectors.types import Detector +from dimos.perception.detection2d.detectors.yolo import Yolo2DDetector diff --git a/dimos/perception/detection2d/config/custom_tracker.yaml b/dimos/perception/detection2d/detectors/config/custom_tracker.yaml similarity index 100% rename from dimos/perception/detection2d/config/custom_tracker.yaml rename to dimos/perception/detection2d/detectors/config/custom_tracker.yaml diff --git a/dimos/perception/detection2d/detic_2d_det.py b/dimos/perception/detection2d/detectors/detic.py similarity index 96% rename from dimos/perception/detection2d/detic_2d_det.py rename to dimos/perception/detection2d/detectors/detic.py index 44b77cb397..0b7b63276f 100644 --- a/dimos/perception/detection2d/detic_2d_det.py +++ b/dimos/perception/detection2d/detectors/detic.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import os import sys + +import numpy as np + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.detectors.types import Detector from dimos.perception.detection2d.utils import plot_results # Add Detic to Python path -detic_path = os.path.join(os.path.dirname(__file__), "..", "..", "models", "Detic") +from dimos.constants import DIMOS_PROJECT_ROOT + +detic_path = DIMOS_PROJECT_ROOT / "dimos/models/Detic" if detic_path not in sys.path: sys.path.append(detic_path) sys.path.append(os.path.join(detic_path, "third_party/CenterNet2")) @@ -154,7 +160,7 @@ def update(self, detections, masks): return result -class Detic2DDetector: +class Detic2DDetector(Detector): def __init__(self, model_path=None, device="cuda", vocabulary=None, threshold=0.5): """ Initialize the Detic detector with open vocabulary support. @@ -173,8 +179,8 @@ def __init__(self, model_path=None, device="cuda", vocabulary=None, threshold=0. # Import Detic modules from centernet.config import add_centernet_config from detic.config import add_detic_config - from detic.modeling.utils import reset_cls_test from detic.modeling.text.text_encoder import build_text_encoder + from detic.modeling.utils import reset_cls_test # Keep reference to these functions for later use self.reset_cls_test = reset_cls_test @@ -312,7 +318,7 @@ def _get_clip_embeddings(self, vocabulary, prompt="a "): emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() return emb - def process_image(self, image): + def process_image(self, image: Image): """ Process an image and return detection results. @@ -329,12 +335,12 @@ def process_image(self, image): - masks: list of segmentation masks (numpy arrays) """ # Run inference with Detic - outputs = self.predictor(image) + outputs = self.predictor(image.to_opencv()) instances = outputs["instances"].to("cpu") # Extract bounding boxes, classes, scores, and masks if len(instances) == 0: - return [], [], [], [], [], [] + return [], [], [], [], [] # , [] boxes = instances.pred_boxes.tensor.numpy() class_ids = instances.pred_classes.numpy() @@ -360,7 +366,7 @@ def process_image(self, image): filtered_masks.append(masks[i]) if not detections: - return [], [], [], [], [], [] + return [], [], [], [], [] # , [] # Update tracker with detections and correctly aligned masks track_results = self.tracker.update(detections, filtered_masks) @@ -387,7 +393,7 @@ def process_image(self, image): tracked_class_ids, tracked_scores, tracked_names, - tracked_masks, + # tracked_masks, ) def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, names): diff --git a/dimos/perception/detection2d/detectors/person/test_annotations.py b/dimos/perception/detection2d/detectors/person/test_annotations.py new file mode 100644 index 0000000000..c686c33bd9 --- /dev/null +++ b/dimos/perception/detection2d/detectors/person/test_annotations.py @@ -0,0 +1,70 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test person annotations work correctly.""" + +import sys + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.detectors.person.yolo import YoloPersonDetector +from dimos.utils.data import get_data + + +def test_person_annotations(): + """Test that Person annotations include keypoints and skeleton.""" + image = Image.from_file(get_data("cafe.jpg")) + detector = YoloPersonDetector() + people = detector.detect_people(image) + + assert len(people) > 0 + person = people[0] + + # Test text annotations + text_anns = person.to_text_annotation() + print(f"\nText annotations: {len(text_anns)}") + for i, ann in enumerate(text_anns): + print(f" {i}: {ann.text}") + assert len(text_anns) == 3 # confidence, name/track_id, keypoints count + assert any("keypoints:" in ann.text for ann in text_anns) + + # Test points annotations + points_anns = person.to_points_annotation() + print(f"\nPoints annotations: {len(points_anns)}") + + # Count different types (use actual LCM constants) + from dimos_lcm.foxglove_msgs.ImageAnnotations import PointsAnnotation + + bbox_count = sum(1 for ann in points_anns if ann.type == PointsAnnotation.LINE_LOOP) # 2 + keypoint_count = sum(1 for ann in points_anns if ann.type == PointsAnnotation.POINTS) # 1 + skeleton_count = sum(1 for ann in points_anns if ann.type == PointsAnnotation.LINE_LIST) # 4 + + print(f" - Bounding boxes: {bbox_count}") + print(f" - Keypoint circles: {keypoint_count}") + print(f" - Skeleton lines: {skeleton_count}") + + assert bbox_count >= 1 # At least the person bbox + assert keypoint_count >= 1 # At least some visible keypoints + assert skeleton_count >= 1 # At least some skeleton connections + + # Test full image annotations + img_anns = person.to_image_annotations() + assert img_anns.texts_length == len(text_anns) + assert img_anns.points_length == len(points_anns) + + print(f"\n✓ Person annotations working correctly!") + print(f" - {len(person.get_visible_keypoints(0.5))}/17 visible keypoints") + + +if __name__ == "__main__": + test_person_annotations() diff --git a/dimos/perception/detection2d/detectors/person/test_detection2d_conformance.py b/dimos/perception/detection2d/detectors/person/test_detection2d_conformance.py new file mode 100644 index 0000000000..f7c7cc088c --- /dev/null +++ b/dimos/perception/detection2d/detectors/person/test_detection2d_conformance.py @@ -0,0 +1,82 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.detectors.person.yolo import YoloPersonDetector +from dimos.perception.detection2d.type.person import Person +from dimos.utils.data import get_data + + +def test_person_detection2d_bbox_conformance(): + """Test that Person conforms to Detection2DBBox interface.""" + image = Image.from_file(get_data("cafe.jpg")) + detector = YoloPersonDetector() + people = detector.detect_people(image) + + assert len(people) > 0 + person = people[0] + + # Test Detection2DBBox methods + # Test bbox operations + assert hasattr(person, "bbox") + assert len(person.bbox) == 4 + assert all(isinstance(x, float) for x in person.bbox) + + # Test inherited properties + assert hasattr(person, "get_bbox_center") + center_bbox = person.get_bbox_center() + assert len(center_bbox) == 4 # center_x, center_y, width, height + + # Test volume calculation + volume = person.bbox_2d_volume() + assert volume > 0 + + # Test cropped image + cropped = person.cropped_image(padding=10) + assert isinstance(cropped, Image) + + # Test annotation methods + text_annotations = person.to_text_annotation() + assert len(text_annotations) == 3 # confidence, name/track_id, and keypoints count + + points_annotations = person.to_points_annotation() + # Should have: 1 bbox + 1 keypoints + multiple skeleton lines + assert len(points_annotations) > 1 + print(f" - Points annotations: {len(points_annotations)} (bbox + keypoints + skeleton)") + + # Test image annotations + annotations = person.to_image_annotations() + assert annotations.texts_length == 3 + assert annotations.points_length > 1 + + # Test ROS conversion + ros_det = person.to_ros_detection2d() + assert ros_det.bbox.size_x == person.width + assert ros_det.bbox.size_y == person.height + + # Test string representation + str_repr = str(person) + assert "Person" in str_repr + assert "person" in str_repr # name field + + print("\n✓ Person class fully conforms to Detection2DBBox interface") + print(f" - Detected {len(people)} people") + print(f" - First person confidence: {person.confidence:.3f}") + print(f" - Bbox volume: {volume:.1f}") + print(f" - Has {len(person.get_visible_keypoints(0.5))} visible keypoints") + + +if __name__ == "__main__": + test_person_detection2d_bbox_conformance() diff --git a/dimos/perception/detection2d/detectors/person/test_imagedetections2d.py b/dimos/perception/detection2d/detectors/person/test_imagedetections2d.py new file mode 100644 index 0000000000..89fd770aa6 --- /dev/null +++ b/dimos/perception/detection2d/detectors/person/test_imagedetections2d.py @@ -0,0 +1,55 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test ImageDetections2D with pose detections.""" + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.detectors.person.yolo import YoloPersonDetector +from dimos.perception.detection2d.type import ImageDetections2D +from dimos.utils.data import get_data + + +def test_image_detections_2d_with_person(): + """Test creating ImageDetections2D from person detector.""" + # Load image and detect people + image = Image.from_file(get_data("cafe.jpg")) + detector = YoloPersonDetector() + people = detector.detect_people(image) + + # Create ImageDetections2D using from_pose_detector + image_detections = ImageDetections2D.from_pose_detector(image, people) + + # Verify structure + assert image_detections.image is image + assert len(image_detections.detections) == len(people) + assert all(det in people for det in image_detections.detections) + + # Test image annotations (includes pose keypoints) + annotations = image_detections.to_foxglove_annotations() + print(f"\nImageDetections2D created with {len(people)} people") + print(f"Total text annotations: {annotations.texts_length}") + print(f"Total points annotations: {annotations.points_length}") + + # Points should include: bounding boxes + keypoints + skeleton lines + # At least 3 annotations per person (bbox, keypoints, skeleton) + assert annotations.points_length >= len(people) * 3 + + # Text annotations should include confidence, name/id, and keypoint count + assert annotations.texts_length >= len(people) * 3 + + print("\n✓ ImageDetections2D.from_pose_detector working correctly!") + + +if __name__ == "__main__": + test_image_detections_2d_with_person() diff --git a/dimos/perception/detection2d/detectors/person/test_yolo.py b/dimos/perception/detection2d/detectors/person/test_yolo.py new file mode 100644 index 0000000000..454997ca27 --- /dev/null +++ b/dimos/perception/detection2d/detectors/person/test_yolo.py @@ -0,0 +1,124 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.detectors.person.yolo import YoloPersonDetector +from dimos.perception.detection2d.type.person import Person +from dimos.utils.data import get_data + + +@pytest.fixture() +def detector(): + return YoloPersonDetector() + + +@pytest.fixture() +def test_image(): + return Image.from_file(get_data("cafe.jpg")) + + +@pytest.fixture() +def people(detector, test_image): + return detector.detect_people(test_image) + + +def test_person_detection(people): + """Test that we can detect people with pose keypoints.""" + assert len(people) > 0 + + # Check first person + person = people[0] + assert isinstance(person, Person) + assert person.confidence > 0 + assert len(person.bbox) == 4 # bbox is now a tuple + assert person.keypoints.shape == (17, 2) + assert person.keypoint_scores.shape == (17,) + + +def test_person_properties(people): + """Test Person object properties and methods.""" + person = people[0] + + # Test bounding box properties + assert person.width > 0 + assert person.height > 0 + assert len(person.center) == 2 + + # Test keypoint access + nose_xy, nose_conf = person.get_keypoint("nose") + assert nose_xy.shape == (2,) + assert 0 <= nose_conf <= 1 + + # Test visible keypoints + visible = person.get_visible_keypoints(threshold=0.5) + assert len(visible) > 0 + assert all(isinstance(name, str) for name, _, _ in visible) + assert all(xy.shape == (2,) for _, xy, _ in visible) + assert all(0 <= conf <= 1 for _, _, conf in visible) + + +def test_person_normalized_coords(people): + """Test normalized coordinates if available.""" + person = people[0] + + if person.keypoints_normalized is not None: + assert person.keypoints_normalized.shape == (17, 2) + # Check all values are in 0-1 range + assert (person.keypoints_normalized >= 0).all() + assert (person.keypoints_normalized <= 1).all() + + if person.bbox_normalized is not None: + assert person.bbox_normalized.shape == (4,) + assert (person.bbox_normalized >= 0).all() + assert (person.bbox_normalized <= 1).all() + + +def test_multiple_people(people): + """Test that multiple people can be detected.""" + print(f"\nDetected {len(people)} people in test image") + + for i, person in enumerate(people[:3]): # Show first 3 + print(f"\nPerson {i}:") + print(f" Confidence: {person.confidence:.3f}") + print(f" Size: {person.width:.1f} x {person.height:.1f}") + + visible = person.get_visible_keypoints(threshold=0.8) + print(f" High-confidence keypoints (>0.8): {len(visible)}") + for name, xy, conf in visible[:5]: + print(f" {name}: ({xy[0]:.1f}, {xy[1]:.1f}) conf={conf:.3f}") + + +def test_invalid_keypoint(test_image): + """Test error handling for invalid keypoint names.""" + # Create a dummy person + import numpy as np + + person = Person( + # Detection2DBBox fields + bbox=(0.0, 0.0, 100.0, 100.0), + track_id=0, + class_id=0, + confidence=0.9, + name="person", + ts=test_image.ts, + image=test_image, + # Person fields + keypoints=np.zeros((17, 2)), + keypoint_scores=np.zeros(17), + ) + + with pytest.raises(ValueError): + person.get_keypoint("invalid_keypoint") diff --git a/dimos/perception/detection2d/detectors/person/yolo.py b/dimos/perception/detection2d/detectors/person/yolo.py new file mode 100644 index 0000000000..fb4fe4769e --- /dev/null +++ b/dimos/perception/detection2d/detectors/person/yolo.py @@ -0,0 +1,138 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from ultralytics import YOLO +from ultralytics.engine.results import Boxes, Keypoints, Results + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.detectors.types import Detector +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.perception.detection2d.yolo.person") + + +# Type alias for YOLO person detection results +YoloPersonResults = List[Results] + +""" +YOLO Person Detection Results Structure: + +Each Results object in the list contains: + +1. boxes (Boxes object): + - boxes.xyxy: torch.Tensor [N, 4] - bounding boxes in [x1, y1, x2, y2] format + - boxes.xywh: torch.Tensor [N, 4] - boxes in [x_center, y_center, width, height] format + - boxes.conf: torch.Tensor [N] - confidence scores (0-1) + - boxes.cls: torch.Tensor [N] - class IDs (0 for person) + - boxes.xyxyn: torch.Tensor [N, 4] - normalized xyxy coordinates (0-1) + - boxes.xywhn: torch.Tensor [N, 4] - normalized xywh coordinates (0-1) + +2. keypoints (Keypoints object): + - keypoints.xy: torch.Tensor [N, 17, 2] - absolute x,y coordinates for 17 keypoints + - keypoints.conf: torch.Tensor [N, 17] - confidence/visibility scores for each keypoint + - keypoints.xyn: torch.Tensor [N, 17, 2] - normalized coordinates (0-1) + + Keypoint order (COCO format): + 0: nose, 1: left_eye, 2: right_eye, 3: left_ear, 4: right_ear, + 5: left_shoulder, 6: right_shoulder, 7: left_elbow, 8: right_elbow, + 9: left_wrist, 10: right_wrist, 11: left_hip, 12: right_hip, + 13: left_knee, 14: right_knee, 15: left_ankle, 16: right_ankle + +3. Other attributes: + - names: Dict[int, str] - class names mapping {0: 'person'} + - orig_shape: Tuple[int, int] - original image (height, width) + - speed: Dict[str, float] - timing info {'preprocess': ms, 'inference': ms, 'postprocess': ms} + - path: str - image path + - orig_img: np.ndarray - original image array + +Note: All tensor data is on GPU by default. Use .cpu() to move to CPU. +""" +from dimos.perception.detection2d.type.person import Person + + +class YoloPersonDetector(Detector): + def __init__(self, model_path="models_yolo", model_name="yolo11n-pose.pt"): + self.model = YOLO(get_data(model_path) / model_name, task="pose") + + def process_image(self, image: Image) -> YoloPersonResults: + """Process image and return YOLO person detection results. + + Returns: + List of Results objects, typically one per image. + Each Results object contains: + - boxes: Boxes with xyxy, xywh, conf, cls tensors + - keypoints: Keypoints with xy, conf, xyn tensors + - names: {0: 'person'} class mapping + - orig_shape: original image dimensions + - speed: inference timing + """ + return self.model(source=image.to_opencv()) + + def detect_people(self, image: Image) -> List[Person]: + """Process image and return list of Person objects. + + Returns: + List of Person objects with pose keypoints + """ + results = self.process_image(image) + + people = [] + for result in results: + if result.keypoints is None or result.boxes is None: + continue + + # Create Person object for each detection + num_detections = len(result.boxes.xyxy) + for i in range(num_detections): + person = Person.from_yolo(result, i, image) + people.append(person) + + return people + + +def main(): + image = Image.from_file(get_data("cafe.jpg")) + detector = YoloPersonDetector() + + # Get Person objects + people = detector.detect_people(image) + + print(f"Detected {len(people)} people") + for i, person in enumerate(people): + print(f"\nPerson {i}:") + print(f" Confidence: {person.confidence:.3f}") + print(f" Bounding box: {person.bbox}") + cx, cy = person.center + print(f" Center: ({cx:.1f}, {cy:.1f})") + print(f" Size: {person.width:.1f} x {person.height:.1f}") + + # Get specific keypoints + nose_xy, nose_conf = person.get_keypoint("nose") + print(f" Nose: {nose_xy} (conf: {nose_conf:.3f})") + + # Get all visible keypoints + visible = person.get_visible_keypoints(threshold=0.7) + print(f" Visible keypoints (>0.7): {len(visible)}") + for name, xy, conf in visible[:3]: # Show first 3 + print(f" {name}: {xy} (conf: {conf:.3f})") + + +if __name__ == "__main__": + main() diff --git a/dimos/perception/detection2d/detectors/types.py b/dimos/perception/detection2d/detectors/types.py new file mode 100644 index 0000000000..639fc09247 --- /dev/null +++ b/dimos/perception/detection2d/detectors/types.py @@ -0,0 +1,25 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.type import ( + InconvinientDetectionFormat, +) + + +class Detector(ABC): + @abstractmethod + def process_image(self, image: Image) -> InconvinientDetectionFormat: ... diff --git a/dimos/perception/detection2d/yolo_2d_det.py b/dimos/perception/detection2d/detectors/yolo.py similarity index 94% rename from dimos/perception/detection2d/yolo_2d_det.py rename to dimos/perception/detection2d/detectors/yolo.py index 02c4ee5325..2d8681f0ef 100644 --- a/dimos/perception/detection2d/yolo_2d_det.py +++ b/dimos/perception/detection2d/detectors/yolo.py @@ -18,6 +18,8 @@ import onnxruntime from ultralytics import YOLO +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.detectors.types import Detector from dimos.perception.detection2d.utils import ( extract_detection_results, filter_detections, @@ -30,7 +32,7 @@ logger = setup_logger("dimos.perception.detection2d.yolo_2d_det") -class Yolo2DDetector: +class Yolo2DDetector(Detector): def __init__(self, model_path="models_yolo", model_name="yolo11n.onnx", device="cpu"): """ Initialize the YOLO detector. @@ -49,12 +51,12 @@ def __init__(self, model_path="models_yolo", model_name="yolo11n.onnx", device=" if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18 onnxruntime.preload_dlls(cuda=True, cudnn=True) self.device = "cuda" - logger.info("Using CUDA for YOLO 2d detector") + logger.debug("Using CUDA for YOLO 2d detector") else: self.device = "cpu" - logger.info("Using CPU for YOLO 2d detector") + logger.debug("Using CPU for YOLO 2d detector") - def process_image(self, image): + def process_image(self, image: Image): """ Process an image and return detection results. @@ -70,7 +72,7 @@ def process_image(self, image): - names: list of class names """ results = self.model.track( - source=image, + source=image.to_opencv(), device=self.device, conf=0.5, iou=0.6, diff --git a/dimos/perception/detection2d/module2D.py b/dimos/perception/detection2d/module2D.py index 82fb181be9..d11875315f 100644 --- a/dimos/perception/detection2d/module2D.py +++ b/dimos/perception/detection2d/module2D.py @@ -11,10 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import json -import time -from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Callable, Optional @@ -31,24 +27,19 @@ from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection2d.detectors import Detector, Yolo2DDetector +from dimos.perception.detection2d.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection2d.type import ( - Detection2D, ImageDetections2D, - InconvinientDetectionFormat, ) -from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector +from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.reactive import backpressure -class Detector(ABC): - @abstractmethod - def process_image(self, image: np.ndarray) -> InconvinientDetectionFormat: ... - - @dataclass class Config: - detector: Optional[Callable[[Any], Detector]] = Yolo2DDetector - max_freq: float = 0.5 # hz + max_freq: float = 5 # hz + detector: Optional[Callable[[Any], Detector]] = lambda: Yolo2DDetector() vlmodel: VlModel = QwenVlModel @@ -65,10 +56,6 @@ class Detection2DModule(Module): detected_image_1: Out[Image] = None # type: ignore detected_image_2: Out[Image] = None # type: ignore - detected_image_0: Out[Image] = None # type: ignore - detected_image_1: Out[Image] = None # type: ignore - detected_image_2: Out[Image] = None # type: ignore - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.config: Config = Config(**kwargs) @@ -76,68 +63,26 @@ def __init__(self, *args, **kwargs): self.vlmodel = self.config.vlmodel() self.vlm_detections_subject = Subject() - def vlm_query(self, query: str) -> ImageDetections2D: - image = self.sharp_image_stream().pipe(ops.take(1)).run() - - full_query = f"""show me a bounding boxes in pixels for this query: `{query}` - - format should be: - `[ - [label, x1, y1, x2, y2] - ... - ]` - - (etc, multiple matches are possible) - - If there's no match return `[]`. Label is whatever you think is appropriate - - Only respond with the coordinates, no other text.""" - - response = self.vlmodel.query(image, full_query) - coords = json.loads(response) - - imageDetections = ImageDetections2D(image) - - for track_id, detection_list in enumerate(coords): - if len(detection_list) != 5: - continue - name = detection_list[0] - bbox = list(map(float, detection_list[1:])) - imageDetections.detections.append( - Detection2D( - bbox=bbox, - track_id=track_id, - class_id=-100, - confidence=1.0, - name=name, - ts=time.time(), - image=image, - ) - ) - - print("vlm detected", imageDetections) - # Emit the VLM detections to the subject - self.vlm_detections_subject.on_next(imageDetections) - - return imageDetections - def process_image_frame(self, image: Image) -> ImageDetections2D: - print("Processing image frame for detections", image) - return ImageDetections2D.from_detector( - image, self.detector.process_image(image.to_opencv()) - ) - - @functools.cache + # Use person detection specifically if it's a YoloPersonDetector + if isinstance(self.detector, YoloPersonDetector): + people = self.detector.detect_people(image) + return ImageDetections2D.from_pose_detector(image, people) + else: + # Fallback to generic dettection for other detectors + return ImageDetections2D.from_bbox_detector(image, self.detector.process_image(image)) + + @simple_mcache def sharp_image_stream(self) -> Observable[Image]: return backpressure( - self.image.observable().pipe( + self.image.pure_observable().pipe( sharpness_barrier(self.config.max_freq), ) ) - @functools.cache + @simple_mcache def detection_stream_2d(self) -> Observable[ImageDetections2D]: - # self.vlm_detections_subject + # return self.vlm_detections_subject # Regular detection stream from the detector regular_detections = self.sharp_image_stream().pipe(ops.map(self.process_image_frame)) # Merge with VL model detections @@ -145,25 +90,20 @@ def detection_stream_2d(self) -> Observable[ImageDetections2D]: @rpc def start(self): - # self.detection_stream_2d().subscribe( - # lambda det: self.detections.publish(det.to_ros_detection2d_array()) - # ) - - def publish_cropped_images(detections: ImageDetections2D): - for index, detection in enumerate(detections[:3]): - image_topic = getattr(self, "detected_image_" + str(index)) - image_topic.publish(detection.cropped_image()) + self.detection_stream_2d().subscribe( + lambda det: self.detections.publish(det.to_ros_detection2d_array()) + ) self.detection_stream_2d().subscribe( lambda det: self.annotations.publish(det.to_foxglove_annotations()) ) - def publish_cropped(detections: ImageDetections2D): + def publish_cropped_images(detections: ImageDetections2D): for index, detection in enumerate(detections[:3]): image_topic = getattr(self, "detected_image_" + str(index)) image_topic.publish(detection.cropped_image()) - self.detection_stream_2d().subscribe(publish_cropped) + self.detection_stream_2d().subscribe(publish_cropped_images) @rpc def stop(self): ... diff --git a/dimos/perception/detection2d/module3D.py b/dimos/perception/detection2d/module3D.py index 0ad3517bf5..66475d85a5 100644 --- a/dimos/perception/detection2d/module3D.py +++ b/dimos/perception/detection2d/module3D.py @@ -24,8 +24,9 @@ from dimos.perception.detection2d.type import ( ImageDetections2D, ImageDetections3D, + ImageDetections3DPC, ) -from dimos.perception.detection2d.type.detection3d import Detection3D +from dimos.perception.detection2d.type.detection3dpc import Detection3DPC from dimos.types.timestamped import align_timestamped from dimos.utils.reactive import backpressure @@ -40,7 +41,7 @@ class Detection3DModule(Detection2DModule): detected_pointcloud_1: Out[PointCloud2] = None # type: ignore detected_pointcloud_2: Out[PointCloud2] = None # type: ignore - detection_3d_stream: Observable[ImageDetections3D] = None + detection_3d_stream: Observable[ImageDetections3DPC] = None def __init__(self, camera_info: CameraInfo, *args, **kwargs): super().__init__(*args, **kwargs) @@ -55,10 +56,9 @@ def process_frame( if not transform: return ImageDetections3D(detections.image, []) - print("3d projection", detections, pointcloud, transform) detection3d_list = [] for detection in detections: - detection3d = Detection3D.from_2d( + detection3d = Detection3DPC.from_2d( detection, world_pointcloud=pointcloud, camera_info=self.camera_info, @@ -67,9 +67,7 @@ def process_frame( if detection3d is not None: detection3d_list.append(detection3d) - ret = ImageDetections3D(detections.image, detection3d_list) - print("3d projection finished", ret) - return ret + return ImageDetections3D(detections.image, detection3d_list) @rpc def start(self): diff --git a/dimos/perception/detection2d/moduleDB.py b/dimos/perception/detection2d/moduleDB.py index 052b65d6c7..456b1d8c87 100644 --- a/dimos/perception/detection2d/moduleDB.py +++ b/dimos/perception/detection2d/moduleDB.py @@ -37,18 +37,17 @@ class Object3D(Detection3D): best_detection: Detection3D = None center: Vector3 = None track_id: str = None - detections: List[Detection3D] + detections: int = 0 def to_repr_dict(self) -> Dict[str, Any]: return { "object_id": self.track_id, - "detections": len(self.detections), + "detections": self.detections, "center": "[" + ", ".join(list(map(lambda n: f"{n:1f}", self.center.to_list()))) + "]", } def __init__(self, track_id: str, detection: Optional[Detection3D] = None, *args, **kwargs): if detection is None: - self.detections = [] return self.ts = detection.ts self.track_id = track_id @@ -60,7 +59,7 @@ def __init__(self, track_id: str, detection: Optional[Detection3D] = None, *args self.transform = detection.transform self.center = detection.center self.frame_id = detection.frame_id - self.detections = [detection] + self.detections = self.detections + 1 self.best_detection = detection def __add__(self, detection: Detection3D) -> "Object3D": @@ -75,7 +74,7 @@ def __add__(self, detection: Detection3D) -> "Object3D": new_object.pointcloud = self.pointcloud + detection.pointcloud new_object.frame_id = self.frame_id new_object.center = (self.center + detection.center) / 2 - new_object.detections = self.detections + [detection] + new_object.detections = self.detections + 1 if detection.bbox_2d_volume() > self.bbox_2d_volume(): new_object.best_detection = detection @@ -89,13 +88,13 @@ def image(self) -> Image: return self.best_detection.image def scene_entity_label(self) -> str: - return f"{self.name} ({len(self.detections)})" + return f"{self.name} ({self.detections})" def agent_encode(self): return { "id": self.track_id, "name": self.name, - "detections": len(self.detections), + "detections": self.detections, "last_seen": f"{round((time.time() - self.ts))}s ago", # "position": self.to_pose().position.agent_encode(), } @@ -302,7 +301,7 @@ def lookup(self, label: str) -> List[Detection3D]: @rpc def start(self): - super().start() + Detection3DModule.start(self) def update_objects(imageDetections: ImageDetections3D): for detection in imageDetections.detections: @@ -339,8 +338,8 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": for obj in copy(self.objects).values(): # we need at least 3 detectieons to consider it a valid object # for this to be serious we need a ratio of detections within the window of observations - if obj.class_id != -100 and len(obj.detections) < 3: - continue + # if obj.class_id != -100 and obj.detections < 2: + # continue # print( # f"Object {obj.track_id}: {len(obj.detections)} detections, confidence {obj.confidence}" @@ -349,7 +348,7 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": scene_update.entities.append( obj.to_foxglove_scene_entity( - entity_id=f"object_{obj.name}_{obj.track_id}_{len(obj.detections)}" + entity_id=f"object_{obj.name}_{obj.track_id}_{obj.detections}" ) ) @@ -358,6 +357,3 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": def __len__(self): return len(self.objects.values()) - - def __iter__(self): - return iter(self.detections.values()) diff --git a/dimos/perception/detection2d/test_yolo_2d_det.py b/dimos/perception/detection2d/test_yolo_2d_det.py deleted file mode 100644 index c04152a1d7..0000000000 --- a/dimos/perception/detection2d/test_yolo_2d_det.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time - -import cv2 -import numpy as np -import pytest -import reactivex as rx -from reactivex import operators as ops -from reactivex.scheduler import ThreadPoolScheduler - -from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector -from dimos.stream.video_provider import VideoProvider - - -class TestYolo2DDetector: - def test_yolo_detector_initialization(self): - """Test YOLO detector initializes correctly with default model path.""" - try: - detector = Yolo2DDetector() - assert detector is not None - assert detector.model is not None - except Exception as e: - # If the model file doesn't exist, the test should still pass with a warning - pytest.skip(f"Skipping test due to model initialization error: {e}") - - def test_yolo_detector_process_image(self): - """Test YOLO detector can process video frames and return detection results.""" - # Create a dedicated scheduler for this test to avoid thread leaks - test_scheduler = ThreadPoolScheduler(max_workers=6) - try: - # Import data inside method to avoid pytest fixture confusion - from dimos.utils.data import get_data - - detector = Yolo2DDetector() - - video_path = get_data("assets") / "trimmed_video_office.mov" - - # Create video provider and directly get a video stream observable - assert os.path.exists(video_path), f"Test video not found: {video_path}" - video_provider = VideoProvider( - dev_name="test_video", video_source=video_path, pool_scheduler=test_scheduler - ) - # Process more frames for thorough testing - video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) - - # Use ReactiveX operators to process the stream - def process_frame(frame): - try: - # Process frame with YOLO - bboxes, track_ids, class_ids, confidences, names = detector.process_image(frame) - print( - f"YOLO results - boxes: {(bboxes)}, tracks: {len(track_ids)}, classes: {(class_ids)}, confidences: {(confidences)}, names: {(names)}" - ) - - return { - "frame": frame, - "bboxes": bboxes, - "track_ids": track_ids, - "class_ids": class_ids, - "confidences": confidences, - "names": names, - } - except Exception as e: - print(f"Exception in process_frame: {e}") - return {} - - # Create the detection stream using pipe and map operator - detection_stream = video_stream.pipe(ops.map(process_frame)) - - # Collect results from the stream - results = [] - - frames_processed = 0 - target_frames = 10 - - def on_next(result): - nonlocal frames_processed - if not result: - return - - results.append(result) - frames_processed += 1 - - # Stop after processing target number of frames - if frames_processed >= target_frames: - subscription.dispose() - - def on_error(error): - pytest.fail(f"Error in detection stream: {error}") - - def on_completed(): - pass - - # Subscribe and wait for results - subscription = detection_stream.subscribe( - on_next=on_next, on_error=on_error, on_completed=on_completed - ) - - timeout = 10.0 - start_time = time.time() - while frames_processed < target_frames and time.time() - start_time < timeout: - time.sleep(0.5) - - # Clean up subscription - subscription.dispose() - video_provider.dispose_all() - detector.stop() - # Shutdown the scheduler to clean up threads - test_scheduler.executor.shutdown(wait=True) - # Check that we got detection results - if len(results) == 0: - pytest.skip("Skipping test due to error: Failed to get any detection results") - - # Verify we have detection results with expected properties - assert len(results) > 0, "No detection results were received" - - # Print statistics about detections - total_detections = sum(len(r["bboxes"]) for r in results if r.get("bboxes")) - avg_detections = total_detections / len(results) if results else 0 - print(f"Total detections: {total_detections}, Average per frame: {avg_detections:.2f}") - - # Print most common detected objects - object_counts = {} - for r in results: - if r.get("names"): - for name in r["names"]: - if name: - object_counts[name] = object_counts.get(name, 0) + 1 - - if object_counts: - print("Detected objects:") - for obj, count in sorted(object_counts.items(), key=lambda x: x[1], reverse=True)[ - :5 - ]: - print(f" - {obj}: {count} times") - - # Analyze the first result - result = results[0] - - # Check that we have a frame - assert "frame" in result, "Result doesn't contain a frame" - assert isinstance(result["frame"], np.ndarray), "Frame is not a numpy array" - - # Check that detection results are valid - assert isinstance(result["bboxes"], list) - assert isinstance(result["track_ids"], list) - assert isinstance(result["class_ids"], list) - assert isinstance(result["confidences"], list) - assert isinstance(result["names"], list) - - # All result lists should be the same length - assert ( - len(result["bboxes"]) - == len(result["track_ids"]) - == len(result["class_ids"]) - == len(result["confidences"]) - == len(result["names"]) - ) - - # If we have detections, check that bbox format is valid - if result["bboxes"]: - assert len(result["bboxes"][0]) == 4, ( - "Bounding boxes should be in [x1, y1, x2, y2] format" - ) - - except Exception as e: - # Ensure cleanup happens even on exception - if "detector" in locals(): - detector.stop() - if "video_provider" in locals(): - video_provider.dispose_all() - pytest.skip(f"Skipping test due to error: {e}") - finally: - # Always shutdown the scheduler - test_scheduler.executor.shutdown(wait=True) - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/dimos/perception/detection2d/type/__init__.py b/dimos/perception/detection2d/type/__init__.py index fb7c435c0c..aee8597d5c 100644 --- a/dimos/perception/detection2d/type/__init__.py +++ b/dimos/perception/detection2d/type/__init__.py @@ -1,7 +1,16 @@ from dimos.perception.detection2d.type.detection2d import ( Detection2D, + Detection2DBBox, ImageDetections2D, InconvinientDetectionFormat, ) -from dimos.perception.detection2d.type.detection3d import Detection3D, ImageDetections3D +from dimos.perception.detection2d.type.detection3d import ( + Detection3D, + ImageDetections3D, +) +from dimos.perception.detection2d.type.detection3dpc import ( + Detection3DPC, + ImageDetections3DPC, +) from dimos.perception.detection2d.type.imageDetections import ImageDetections, TableStr +from dimos.perception.detection2d.type.person import Person diff --git a/dimos/perception/detection2d/type/detection2d.py b/dimos/perception/detection2d/type/detection2d.py index 5bbedfb6ae..48e1a5191d 100644 --- a/dimos/perception/detection2d/type/detection2d.py +++ b/dimos/perception/detection2d/type/detection2d.py @@ -15,8 +15,9 @@ from __future__ import annotations import hashlib +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Tuple from dimos_lcm.foxglove_msgs.ImageAnnotations import ( PointsAnnotation, @@ -43,8 +44,12 @@ from dimos.perception.detection2d.type.imageDetections import ImageDetections from dimos.types.timestamped import Timestamped, to_ros_stamp, to_timestamp +if TYPE_CHECKING: + from dimos.perception.detection2d.type.person import Person + Bbox = Tuple[float, float, float, float] CenteredBbox = Tuple[float, float, float, float] + # yolo and detic have bad output formats InconvinientDetectionFormat = Tuple[List[Bbox], List[int], List[int], List[float], List[str]] @@ -89,8 +94,16 @@ def better_detection_format(inconvinient_detections: InconvinientDetectionFormat ] -@dataclass class Detection2D(Timestamped): + @abstractmethod + def cropped_image(self, padding: int = 20) -> Image: ... + + @abstractmethod + def to_image_annotations(self) -> ImageAnnotations: ... + + +@dataclass +class Detection2DBBox(Detection2D): bbox: Bbox track_id: int class_id: int @@ -215,7 +228,7 @@ def to_text_annotation(self) -> List[TextAnnotation]: TextAnnotation( timestamp=to_ros_stamp(self.ts), position=Point2(x=x1, y=y1), - text=f"{self.name}_{self.track_id}", + text=f"{self.name}_{self.class_id}_{self.track_id}", font_size=font_size, text_color=Color(r=1.0, g=1.0, b=1.0, a=1), background_color=Color(r=0, g=0, b=0, a=1), @@ -320,10 +333,26 @@ def to_ros_detection2d(self) -> ROSDetection2D: class ImageDetections2D(ImageDetections[Detection2D]): @classmethod - def from_detector( + def from_bbox_detector( cls, image: Image, raw_detections: InconvinientDetectionFormat, **kwargs ) -> "ImageDetections2D": return cls( image=image, - detections=Detection2D.from_detector(raw_detections, image=image, ts=image.ts), + detections=Detection2DBBox.from_detector(raw_detections, image=image, ts=image.ts), + ) + + @classmethod + def from_pose_detector( + cls, image: Image, people: List["Person"], **kwargs + ) -> "ImageDetections2D": + """Create ImageDetections2D from a list of Person detections. + Args: + image: Source image + people: List of Person objects with pose keypoints + Returns: + ImageDetections2D containing the pose detections + """ + return cls( + image=image, + detections=people, # Person objects are already Detection2D subclasses ) diff --git a/dimos/perception/detection2d/type/detection3d.py b/dimos/perception/detection2d/type/detection3d.py index 8c0919b700..a203bb1a4b 100644 --- a/dimos/perception/detection2d/type/detection3d.py +++ b/dimos/perception/detection2d/type/detection3d.py @@ -28,199 +28,25 @@ from dimos.msgs.foxglove_msgs.Color import Color from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.perception.detection2d.type.detection2d import Detection2D +from dimos.perception.detection2d.type.detection2d import Detection2D, Detection2DBBox from dimos.perception.detection2d.type.imageDetections import ImageDetections from dimos.types.timestamped import to_ros_stamp -Detection3DFilter = Callable[ - [Detection2D, PointCloud2, CameraInfo, Transform], Optional["Detection3D"] -] - - -def height_filter(height=0.1) -> Detection3DFilter: - return lambda det, pc, ci, tf: pc.filter_by_height(height) - - -def statistical(nb_neighbors=40, std_ratio=0.5) -> Detection3DFilter: - def filter_func( - det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform - ) -> Optional[PointCloud2]: - try: - statistical, removed = pc.pointcloud.remove_statistical_outlier( - nb_neighbors=nb_neighbors, std_ratio=std_ratio - ) - return PointCloud2(statistical, pc.frame_id, pc.ts) - except Exception as e: - # print("statistical filter failed:", e) - return None - - return filter_func - - -def raycast() -> Detection3DFilter: - def filter_func( - det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform - ) -> Optional[PointCloud2]: - try: - camera_pos = tf.inverse().translation - camera_pos_np = camera_pos.to_numpy() - _, visible_indices = pc.pointcloud.hidden_point_removal(camera_pos_np, radius=100.0) - visible_pcd = pc.pointcloud.select_by_index(visible_indices) - return PointCloud2(visible_pcd, pc.frame_id, pc.ts) - except Exception as e: - # print("raycast filter failed:", e) - return None - - return filter_func - - -def radius_outlier(min_neighbors: int = 20, radius: float = 0.3) -> Detection3DFilter: - """ - Remove isolated points: keep only points that have at least `min_neighbors` - neighbors within `radius` meters (same units as your point cloud). - """ - - def filter_func( - det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform - ) -> Optional[PointCloud2]: - filtered_pcd, removed = pc.pointcloud.remove_radius_outlier( - nb_points=min_neighbors, radius=radius - ) - return PointCloud2(filtered_pcd, pc.frame_id, pc.ts) - - return filter_func - @dataclass -class Detection3D(Detection2D): - pointcloud: PointCloud2 +class Detection3D(Detection2DBBox): transform: Transform - frame_id: str = "unknown" + frame_id: str @classmethod def from_2d( cls, det: Detection2D, - world_pointcloud: PointCloud2, + distance: float, camera_info: CameraInfo, world_to_optical_transform: Transform, - # filters are to be adjusted based on the sensor noise characteristics if feeding - # sensor data directly - filters: list[Callable[[PointCloud2], PointCloud2]] = [ - # height_filter(0.1), - raycast(), - radius_outlier(), - statistical(), - ], ) -> Optional["Detection3D"]: - """Create a Detection3D from a 2D detection by projecting world pointcloud. - - This method handles: - 1. Projecting world pointcloud to camera frame - 2. Filtering points within the 2D detection bounding box - 3. Cleaning up the pointcloud (height filter, outlier removal) - 4. Hidden point removal from camera perspective - - Args: - det: The 2D detection - world_pointcloud: Full pointcloud in world frame - camera_info: Camera calibration info - world_to_camerlka_transform: Transform from world to camera frame - filters: List of functions to apply to the pointcloud for filtering - Returns: - Detection3D with filtered pointcloud, or None if no valid points - """ - # Extract camera parameters - fx, fy = camera_info.K[0], camera_info.K[4] - cx, cy = camera_info.K[2], camera_info.K[5] - image_width = camera_info.width - image_height = camera_info.height - - camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) - - # Convert pointcloud to numpy array - world_points = world_pointcloud.as_numpy() - - # Project points to camera frame - points_homogeneous = np.hstack([world_points, np.ones((world_points.shape[0], 1))]) - extrinsics_matrix = world_to_optical_transform.to_matrix() - points_camera = (extrinsics_matrix @ points_homogeneous.T).T - - # Filter out points behind the camera - valid_mask = points_camera[:, 2] > 0 - points_camera = points_camera[valid_mask] - world_points = world_points[valid_mask] - - if len(world_points) == 0: - return None - - # Project to 2D - points_2d_homogeneous = (camera_matrix @ points_camera[:, :3].T).T - points_2d = points_2d_homogeneous[:, :2] / points_2d_homogeneous[:, 2:3] - - # Filter points within image bounds - in_image_mask = ( - (points_2d[:, 0] >= 0) - & (points_2d[:, 0] < image_width) - & (points_2d[:, 1] >= 0) - & (points_2d[:, 1] < image_height) - ) - points_2d = points_2d[in_image_mask] - world_points = world_points[in_image_mask] - - if len(world_points) == 0: - return None - - # Extract bbox from Detection2D - x_min, y_min, x_max, y_max = det.bbox - - # Find points within this detection box (with small margin) - margin = 5 # pixels - in_box_mask = ( - (points_2d[:, 0] >= x_min - margin) - & (points_2d[:, 0] <= x_max + margin) - & (points_2d[:, 1] >= y_min - margin) - & (points_2d[:, 1] <= y_max + margin) - ) - - detection_points = world_points[in_box_mask] - - if detection_points.shape[0] == 0: - # print(f"No points found in detection bbox after projection. {det.name}") - return None - - # Create initial pointcloud for this detection - initial_pc = PointCloud2.from_numpy( - detection_points, - frame_id=world_pointcloud.frame_id, - timestamp=world_pointcloud.ts, - ) - - # Apply filters - each filter needs all 4 arguments - detection_pc = initial_pc - for filter_func in filters: - result = filter_func(det, detection_pc, camera_info, world_to_optical_transform) - if result is None: - return None - detection_pc = result - - # Final check for empty pointcloud - if len(detection_pc.pointcloud.points) == 0: - return None - - # Create Detection3D with filtered pointcloud - return Detection3D( - image=det.image, - bbox=det.bbox, - track_id=det.track_id, - class_id=det.class_id, - confidence=det.confidence, - name=det.name, - ts=det.ts, - pointcloud=detection_pc, - transform=world_to_optical_transform, - frame_id=world_pointcloud.frame_id, - ) + raise NotImplementedError() @functools.cached_property def center(self) -> Vector3: diff --git a/dimos/perception/detection2d/type/detection3dpc.py b/dimos/perception/detection2d/type/detection3dpc.py new file mode 100644 index 0000000000..44d242de9e --- /dev/null +++ b/dimos/perception/detection2d/type/detection3dpc.py @@ -0,0 +1,247 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, TypeVar + +import numpy as np +from dimos_lcm.sensor_msgs import CameraInfo +from lcm_msgs.builtin_interfaces import Duration +from lcm_msgs.foxglove_msgs import CubePrimitive, SceneEntity, SceneUpdate, TextPrimitive +from lcm_msgs.geometry_msgs import Point, Pose, Quaternion +from lcm_msgs.geometry_msgs import Vector3 as LCMVector3 + +from dimos.msgs.foxglove_msgs.Color import Color +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.perception.detection2d.type.detection2d import Detection2D +from dimos.perception.detection2d.type.detection3d import Detection3D +from dimos.perception.detection2d.type.imageDetections import ImageDetections +from dimos.types.timestamped import to_ros_stamp + +Detection3DPCFilter = Callable[ + [Detection2D, PointCloud2, CameraInfo, Transform], Optional["Detection3DPC"] +] + + +def height_filter(height=0.1) -> Detection3DPCFilter: + return lambda det, pc, ci, tf: pc.filter_by_height(height) + + +def statistical(nb_neighbors=40, std_ratio=0.5) -> Detection3DPCFilter: + def filter_func( + det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform + ) -> Optional[PointCloud2]: + try: + statistical, removed = pc.pointcloud.remove_statistical_outlier( + nb_neighbors=nb_neighbors, std_ratio=std_ratio + ) + return PointCloud2(statistical, pc.frame_id, pc.ts) + except Exception as e: + # print("statistical filter failed:", e) + return None + + return filter_func + + +def raycast() -> Detection3DPCFilter: + def filter_func( + det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform + ) -> Optional[PointCloud2]: + try: + camera_pos = tf.inverse().translation + camera_pos_np = camera_pos.to_numpy() + _, visible_indices = pc.pointcloud.hidden_point_removal(camera_pos_np, radius=100.0) + visible_pcd = pc.pointcloud.select_by_index(visible_indices) + return PointCloud2(visible_pcd, pc.frame_id, pc.ts) + except Exception as e: + # print("raycast filter failed:", e) + return None + + return filter_func + + +def radius_outlier(min_neighbors: int = 20, radius: float = 0.3) -> Detection3DPCFilter: + """ + Remove isolated points: keep only points that have at least `min_neighbors` + neighbors within `radius` meters (same units as your point cloud). + """ + + def filter_func( + det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform + ) -> Optional[PointCloud2]: + filtered_pcd, removed = pc.pointcloud.remove_radius_outlier( + nb_points=min_neighbors, radius=radius + ) + return PointCloud2(filtered_pcd, pc.frame_id, pc.ts) + + return filter_func + + +@dataclass +class Detection3DPC(Detection3D): + pointcloud: PointCloud2 + + @classmethod + def from_2d( + cls, + det: Detection2D, + world_pointcloud: PointCloud2, + camera_info: CameraInfo, + world_to_optical_transform: Transform, + # filters are to be adjusted based on the sensor noise characteristics if feeding + # sensor data directly + filters: list[Callable[[PointCloud2], PointCloud2]] = [ + # height_filter(0.1), + raycast(), + radius_outlier(), + statistical(), + ], + ) -> Optional["Detection3D"]: + """Create a Detection3D from a 2D detection by projecting world pointcloud. + + This method handles: + 1. Projecting world pointcloud to camera frame + 2. Filtering points within the 2D detection bounding box + 3. Cleaning up the pointcloud (height filter, outlier removal) + 4. Hidden point removal from camera perspective + + Args: + det: The 2D detection + world_pointcloud: Full pointcloud in world frame + camera_info: Camera calibration info + world_to_camerlka_transform: Transform from world to camera frame + filters: List of functions to apply to the pointcloud for filtering + Returns: + Detection3D with filtered pointcloud, or None if no valid points + """ + # Extract camera parameters + fx, fy = camera_info.K[0], camera_info.K[4] + cx, cy = camera_info.K[2], camera_info.K[5] + image_width = camera_info.width + image_height = camera_info.height + + camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + + # Convert pointcloud to numpy array + world_points = world_pointcloud.as_numpy() + + # Project points to camera frame + points_homogeneous = np.hstack([world_points, np.ones((world_points.shape[0], 1))]) + extrinsics_matrix = world_to_optical_transform.to_matrix() + points_camera = (extrinsics_matrix @ points_homogeneous.T).T + + # Filter out points behind the camera + valid_mask = points_camera[:, 2] > 0 + points_camera = points_camera[valid_mask] + world_points = world_points[valid_mask] + + if len(world_points) == 0: + return None + + # Project to 2D + points_2d_homogeneous = (camera_matrix @ points_camera[:, :3].T).T + points_2d = points_2d_homogeneous[:, :2] / points_2d_homogeneous[:, 2:3] + + # Filter points within image bounds + in_image_mask = ( + (points_2d[:, 0] >= 0) + & (points_2d[:, 0] < image_width) + & (points_2d[:, 1] >= 0) + & (points_2d[:, 1] < image_height) + ) + points_2d = points_2d[in_image_mask] + world_points = world_points[in_image_mask] + + if len(world_points) == 0: + return None + + # Extract bbox from Detection2D + x_min, y_min, x_max, y_max = det.bbox + + # Find points within this detection box (with small margin) + margin = 5 # pixels + in_box_mask = ( + (points_2d[:, 0] >= x_min - margin) + & (points_2d[:, 0] <= x_max + margin) + & (points_2d[:, 1] >= y_min - margin) + & (points_2d[:, 1] <= y_max + margin) + ) + + detection_points = world_points[in_box_mask] + + if detection_points.shape[0] == 0: + # print(f"No points found in detection bbox after projection. {det.name}") + return None + + # Create initial pointcloud for this detection + initial_pc = PointCloud2.from_numpy( + detection_points, + frame_id=world_pointcloud.frame_id, + timestamp=world_pointcloud.ts, + ) + + # Apply filters - each filter needs all 4 arguments + detection_pc = initial_pc + for filter_func in filters: + result = filter_func(det, detection_pc, camera_info, world_to_optical_transform) + if result is None: + return None + detection_pc = result + + # Final check for empty pointcloud + if len(detection_pc.pointcloud.points) == 0: + return None + + # Create Detection3D with filtered pointcloud + return cls( + image=det.image, + bbox=det.bbox, + track_id=det.track_id, + class_id=det.class_id, + confidence=det.confidence, + name=det.name, + ts=det.ts, + pointcloud=detection_pc, + transform=world_to_optical_transform, + frame_id=world_pointcloud.frame_id, + ) + + +class ImageDetections3DPC(ImageDetections[Detection3DPC]): + """Specialized class for 3D detections in an image.""" + + def to_foxglove_scene_update(self) -> "SceneUpdate": + """Convert all detections to a Foxglove SceneUpdate message. + + Returns: + SceneUpdate containing SceneEntity objects for all detections + """ + + # Create SceneUpdate message with all detections + scene_update = SceneUpdate() + scene_update.deletions_length = 0 + scene_update.deletions = [] + scene_update.entities = [] + + # Process each detection + for i, detection in enumerate(self.detections): + entity = detection.to_foxglove_scene_entity(entity_id=f"detection_{detection.name}_{i}") + scene_update.entities.append(entity) + + scene_update.entities_length = len(scene_update.entities) + return scene_update diff --git a/dimos/perception/detection2d/type/person.py b/dimos/perception/detection2d/type/person.py new file mode 100644 index 0000000000..b61045f48c --- /dev/null +++ b/dimos/perception/detection2d/type/person.py @@ -0,0 +1,267 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +# Import for type checking only to avoid circular imports +from typing import TYPE_CHECKING, List, Optional, Tuple + +import numpy as np +from dimos_lcm.foxglove_msgs.ImageAnnotations import PointsAnnotation, TextAnnotation +from dimos_lcm.foxglove_msgs.Point2 import Point2 + +from dimos.msgs.foxglove_msgs.Color import Color +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d.type.detection2d import Bbox, Detection2DBBox +from dimos.types.timestamped import to_ros_stamp + +if TYPE_CHECKING: + from ultralytics.engine.results import Results + + +@dataclass +class Person(Detection2DBBox): + """Represents a detected person with pose keypoints.""" + + # Pose keypoints - additional fields beyond Detection2DBBox + keypoints: np.ndarray # [17, 2] - x,y coordinates + keypoint_scores: np.ndarray # [17] - confidence scores + + # Optional normalized coordinates + bbox_normalized: Optional[np.ndarray] = None # [x1, y1, x2, y2] in 0-1 range + keypoints_normalized: Optional[np.ndarray] = None # [17, 2] in 0-1 range + + # Image dimensions for context + image_width: Optional[int] = None + image_height: Optional[int] = None + + # Keypoint names (class attribute) + KEYPOINT_NAMES = [ + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", + ] + + @classmethod + def from_yolo(cls, result: "Results", person_idx: int, image: Image) -> "Person": + """Create Person instance from YOLO results. + + Args: + result: Single Results object from YOLO + person_idx: Index of the person in the detection results + image: Original image for the detection + """ + # Extract bounding box as tuple for Detection2DBBox + bbox_array = result.boxes.xyxy[person_idx].cpu().numpy() + + bbox: Bbox = ( + float(bbox_array[0]), + float(bbox_array[1]), + float(bbox_array[2]), + float(bbox_array[3]), + ) + + bbox_norm = ( + result.boxes.xyxyn[person_idx].cpu().numpy() if hasattr(result.boxes, "xyxyn") else None + ) + + confidence = float(result.boxes.conf[person_idx].cpu()) + class_id = int(result.boxes.cls[person_idx].cpu()) + + # Extract keypoints + keypoints = result.keypoints.xy[person_idx].cpu().numpy() + keypoint_scores = result.keypoints.conf[person_idx].cpu().numpy() + keypoints_norm = ( + result.keypoints.xyn[person_idx].cpu().numpy() + if hasattr(result.keypoints, "xyn") + else None + ) + + # Get image dimensions + height, width = result.orig_shape + + return cls( + # Detection2DBBox fields + bbox=bbox, + track_id=person_idx, # Use person index as track_id for now + class_id=class_id, + confidence=confidence, + name="person", + ts=image.ts, + image=image, + # Person specific fields + keypoints=keypoints, + keypoint_scores=keypoint_scores, + bbox_normalized=bbox_norm, + keypoints_normalized=keypoints_norm, + image_width=width, + image_height=height, + ) + + def get_keypoint(self, name: str) -> Tuple[np.ndarray, float]: + """Get specific keypoint by name. + Returns: + Tuple of (xy_coordinates, confidence_score) + """ + if name not in self.KEYPOINT_NAMES: + raise ValueError(f"Invalid keypoint name: {name}. Must be one of {self.KEYPOINT_NAMES}") + + idx = self.KEYPOINT_NAMES.index(name) + return self.keypoints[idx], self.keypoint_scores[idx] + + def get_visible_keypoints(self, threshold: float = 0.5) -> List[Tuple[str, np.ndarray, float]]: + """Get all keypoints above confidence threshold. + Returns: + List of tuples: (keypoint_name, xy_coordinates, confidence) + """ + visible = [] + for i, (name, score) in enumerate(zip(self.KEYPOINT_NAMES, self.keypoint_scores)): + if score > threshold: + visible.append((name, self.keypoints[i], score)) + return visible + + @property + def width(self) -> float: + """Get width of bounding box.""" + x1, _, x2, _ = self.bbox + return x2 - x1 + + @property + def height(self) -> float: + """Get height of bounding box.""" + _, y1, _, y2 = self.bbox + return y2 - y1 + + @property + def center(self) -> Tuple[float, float]: + """Get center point of bounding box.""" + x1, y1, x2, y2 = self.bbox + return ((x1 + x2) / 2, (y1 + y2) / 2) + + def to_points_annotation(self) -> List[PointsAnnotation]: + """Override to include keypoint visualizations along with bounding box.""" + annotations = [] + + # First add the bounding box from parent class + annotations.extend(super().to_points_annotation()) + + # Add keypoints as circles + visible_keypoints = self.get_visible_keypoints(threshold=0.3) + + # Create points for visible keypoints + if visible_keypoints: + keypoint_points = [] + for name, xy, conf in visible_keypoints: + keypoint_points.append(Point2(float(xy[0]), float(xy[1]))) + + # Add keypoints as circles + annotations.append( + PointsAnnotation( + timestamp=to_ros_stamp(self.ts), + outline_color=Color(r=0.0, g=1.0, b=0.0, a=1.0), # Green outline + fill_color=Color(r=0.0, g=1.0, b=0.0, a=0.5), # Semi-transparent green + thickness=2.0, + points_length=len(keypoint_points), + points=keypoint_points, + type=PointsAnnotation.POINTS, # Draw as individual points/circles + ) + ) + + # Add skeleton connections (COCO skeleton) + skeleton_connections = [ + # Face + (0, 1), + (0, 2), + (1, 3), + (2, 4), # nose to eyes, eyes to ears + # Arms + (5, 6), # shoulders + (5, 7), + (7, 9), # left arm + (6, 8), + (8, 10), # right arm + # Torso + (5, 11), + (6, 12), + (11, 12), # shoulders to hips, hip to hip + # Legs + (11, 13), + (13, 15), # left leg + (12, 14), + (14, 16), # right leg + ] + + # Draw skeleton lines between connected keypoints + for start_idx, end_idx in skeleton_connections: + if ( + start_idx < len(self.keypoint_scores) + and end_idx < len(self.keypoint_scores) + and self.keypoint_scores[start_idx] > 0.3 + and self.keypoint_scores[end_idx] > 0.3 + ): + start_point = Point2( + float(self.keypoints[start_idx][0]), float(self.keypoints[start_idx][1]) + ) + end_point = Point2( + float(self.keypoints[end_idx][0]), float(self.keypoints[end_idx][1]) + ) + + annotations.append( + PointsAnnotation( + timestamp=to_ros_stamp(self.ts), + outline_color=Color(r=0.0, g=0.8, b=1.0, a=0.8), # Cyan + thickness=1.5, + points_length=2, + points=[start_point, end_point], + type=PointsAnnotation.LINE_LIST, + ) + ) + + return annotations + + def to_text_annotation(self) -> List[TextAnnotation]: + """Override to include pose information in text annotations.""" + # Get base annotations from parent + annotations = super().to_text_annotation() + + # Add pose-specific info + visible_count = len(self.get_visible_keypoints(threshold=0.5)) + x1, y1, x2, y2 = self.bbox + + annotations.append( + TextAnnotation( + timestamp=to_ros_stamp(self.ts), + position=Point2(x=x1, y=y2 + 40), # Below confidence text + text=f"keypoints: {visible_count}/17", + font_size=18, + text_color=Color(r=0.0, g=1.0, b=0.0, a=1), + background_color=Color(r=0, g=0, b=0, a=0.7), + ) + ) + + return annotations diff --git a/dimos/perception/detection2d/type/test_detection2d.py b/dimos/perception/detection2d/type/test_detection2d.py index 515bdee339..3bf37c0fb6 100644 --- a/dimos/perception/detection2d/type/test_detection2d.py +++ b/dimos/perception/detection2d/type/test_detection2d.py @@ -14,7 +14,8 @@ import pytest -def test_detection_basic_properties(detection2d): +def test_detection2d(detection2d): + # def test_detection_basic_properties(detection2d): """Test basic detection properties.""" assert detection2d.track_id >= 0 assert detection2d.class_id >= 0 @@ -22,8 +23,7 @@ def test_detection_basic_properties(detection2d): assert detection2d.name is not None assert detection2d.ts > 0 - -def test_bounding_box_format(detection2d): + # def test_bounding_box_format(detection2d): """Test bounding box format and validity.""" bbox = detection2d.bbox assert len(bbox) == 4, "Bounding box should have 4 values" @@ -34,8 +34,7 @@ def test_bounding_box_format(detection2d): assert x1 >= 0, "x1 should be non-negative" assert y1 >= 0, "y1 should be non-negative" - -def test_bbox_2d_volume(detection2d): + # def test_bbox_2d_volume(detection2d): """Test bounding box volume calculation.""" volume = detection2d.bbox_2d_volume() assert volume > 0, "Bounding box volume should be positive" @@ -45,8 +44,7 @@ def test_bbox_2d_volume(detection2d): expected_volume = (x2 - x1) * (y2 - y1) assert volume == pytest.approx(expected_volume, abs=0.001) - -def test_bbox_center_calculation(detection2d): + # def test_bbox_center_calculation(detection2d): """Test bounding box center calculation.""" center_bbox = detection2d.get_bbox_center() assert len(center_bbox) == 4, "Center bbox should have 4 values" @@ -60,8 +58,7 @@ def test_bbox_center_calculation(detection2d): assert width == pytest.approx(x2 - x1, abs=0.001) assert height == pytest.approx(y2 - y1, abs=0.001) - -def test_cropped_image(detection2d): + # def test_cropped_image(detection2d): """Test cropped image generation.""" padding = 20 cropped = detection2d.cropped_image(padding=padding) @@ -73,8 +70,7 @@ def test_cropped_image(detection2d): assert cropped.height == 260 assert cropped.shape == (260, 192, 3) - -def test_to_ros_bbox(detection2d): + # def test_to_ros_bbox(detection2d): """Test ROS bounding box conversion.""" ros_bbox = detection2d.to_ros_bbox() diff --git a/dimos/perception/detection2d/type/test_detection3d.py b/dimos/perception/detection2d/type/test_detection3d.py index c8215b9601..642e6c7542 100644 --- a/dimos/perception/detection2d/type/test_detection3d.py +++ b/dimos/perception/detection2d/type/test_detection3d.py @@ -12,130 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import pytest +import time +from dimos.perception.detection2d.type.detection3d import Detection3D -def test_oriented_bounding_box(detection3d): - """Test oriented bounding box calculation and values.""" - obb = detection3d.get_oriented_bounding_box() - assert obb is not None, "Oriented bounding box should not be None" - # Verify OBB center values - assert obb.center[0] == pytest.approx(-3.36002, abs=0.1) - assert obb.center[1] == pytest.approx(-0.196446, abs=0.1) - assert obb.center[2] == pytest.approx(0.220184, abs=0.1) +def test_guess_projection(get_moment_2d, publish_moment): + moment = get_moment_2d(seek=10.0) + for key, value in moment.items(): + print(key, "====================================") + print(value) - # Verify OBB extent values - assert obb.extent[0] == pytest.approx(0.531275, abs=0.1) - assert obb.extent[1] == pytest.approx(0.461054, abs=0.1) - assert obb.extent[2] == pytest.approx(0.155, abs=0.1) + camera_info = moment.get("camera_info") + detection2d = moment.get("detections2d")[0] + tf = moment.get("tf") + transform = tf.get("camera_optical", "world", detection2d.ts, 5.0) + # for stash + # detection3d = Detection3D.from_2d(detection2d, 1.5, camera_info, transform) + # print(detection3d) -def test_bounding_box_dimensions(detection3d): - """Test bounding box dimension calculation.""" - dims = detection3d.get_bounding_box_dimensions() - assert len(dims) == 3, "Bounding box dimensions should have 3 values" - assert dims[0] == pytest.approx(0.350, abs=0.1) - assert dims[1] == pytest.approx(0.250, abs=0.1) - assert dims[2] == pytest.approx(0.550, abs=0.1) - - -def test_axis_aligned_bounding_box(detection3d): - """Test axis-aligned bounding box calculation.""" - aabb = detection3d.get_bounding_box() - assert aabb is not None, "Axis-aligned bounding box should not be None" - - # Verify AABB min values - assert aabb.min_bound[0] == pytest.approx(-3.575, abs=0.1) - assert aabb.min_bound[1] == pytest.approx(-0.375, abs=0.1) - assert aabb.min_bound[2] == pytest.approx(-0.075, abs=0.1) - - # Verify AABB max values - assert aabb.max_bound[0] == pytest.approx(-3.075, abs=0.1) - assert aabb.max_bound[1] == pytest.approx(-0.125, abs=0.1) - assert aabb.max_bound[2] == pytest.approx(0.475, abs=0.1) - - -def test_point_cloud_properties(detection3d): - """Test point cloud data and boundaries.""" - pc_points = detection3d.pointcloud.points() - assert len(pc_points) in [69, 70] - assert detection3d.pointcloud.frame_id == "world", ( - f"Expected frame_id 'world', got '{detection3d.pointcloud.frame_id}'" - ) - - # Extract xyz coordinates from points - points = np.array([[pt[0], pt[1], pt[2]] for pt in pc_points]) - - min_pt = np.min(points, axis=0) - max_pt = np.max(points, axis=0) - center = np.mean(points, axis=0) - - # Verify point cloud boundaries - assert min_pt[0] == pytest.approx(-3.575, abs=0.1) - assert min_pt[1] == pytest.approx(-0.375, abs=0.1) - assert min_pt[2] == pytest.approx(-0.075, abs=0.1) - - assert max_pt[0] == pytest.approx(-3.075, abs=0.1) - assert max_pt[1] == pytest.approx(-0.125, abs=0.1) - assert max_pt[2] == pytest.approx(0.475, abs=0.1) - - assert center[0] == pytest.approx(-3.326, abs=0.1) - assert center[1] == pytest.approx(-0.202, abs=0.1) - assert center[2] == pytest.approx(0.160, abs=0.1) - - -def test_foxglove_scene_entity_generation(detection3d): - """Test Foxglove scene entity creation and structure.""" - entity = detection3d.to_foxglove_scene_entity("test_entity_123") - - # Verify entity metadata - assert entity.id == "1", f"Expected entity ID '1', got '{entity.id}'" - assert entity.frame_id == "world", f"Expected frame_id 'world', got '{entity.frame_id}'" - assert entity.cubes_length == 1, f"Expected 1 cube, got {entity.cubes_length}" - assert entity.texts_length == 1, f"Expected 1 text, got {entity.texts_length}" - - -def test_foxglove_cube_properties(detection3d): - """Test Foxglove cube primitive properties.""" - entity = detection3d.to_foxglove_scene_entity("test_entity_123") - cube = entity.cubes[0] - - # Verify position - assert cube.pose.position.x == pytest.approx(-3.325, abs=0.1) - assert cube.pose.position.y == pytest.approx(-0.250, abs=0.1) - assert cube.pose.position.z == pytest.approx(0.200, abs=0.1) - - # Verify size - assert cube.size.x == pytest.approx(0.350, abs=0.1) - assert cube.size.y == pytest.approx(0.250, abs=0.1) - assert cube.size.z == pytest.approx(0.550, abs=0.1) - - # Verify color (green with alpha) - assert cube.color.r == pytest.approx(0.08235294117647059, abs=0.1) - assert cube.color.g == pytest.approx(0.7176470588235294, abs=0.1) - assert cube.color.b == pytest.approx(0.28627450980392155, abs=0.1) - assert cube.color.a == pytest.approx(0.2, abs=0.1) - - -def test_foxglove_text_label(detection3d): - """Test Foxglove text label properties.""" - entity = detection3d.to_foxglove_scene_entity("test_entity_123") - text = entity.texts[0] - - assert text.text == "1/suitcase (81%)", f"Expected text '1/suitcase (81%)', got '{text.text}'" - assert text.pose.position.x == pytest.approx(-3.325, abs=0.1) - assert text.pose.position.y == pytest.approx(-0.250, abs=0.1) - assert text.pose.position.z == pytest.approx(0.575, abs=0.1) - assert text.font_size == 20.0, f"Expected font size 20.0, got {text.font_size}" - - -def test_detection_pose(detection3d): - """Test detection pose and frame information.""" - assert detection3d.pose.x == pytest.approx(-3.327, abs=0.1) - assert detection3d.pose.y == pytest.approx(-0.202, abs=0.1) - assert detection3d.pose.z == pytest.approx(0.160, abs=0.1) - assert detection3d.pose.frame_id == "world", ( - f"Expected frame_id 'world', got '{detection3d.pose.frame_id}'" - ) + # foxglove bridge needs 2 messages per topic to pass to foxglove + publish_moment(moment) + time.sleep(0.1) + publish_moment(moment) diff --git a/dimos/perception/detection2d/type/test_detection3dpc.py b/dimos/perception/detection2d/type/test_detection3dpc.py new file mode 100644 index 0000000000..a25e27d458 --- /dev/null +++ b/dimos/perception/detection2d/type/test_detection3dpc.py @@ -0,0 +1,135 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + + +def test_detection3dpc(detection3dpc): + # def test_oriented_bounding_box(detection3dpc): + """Test oriented bounding box calculation and values.""" + obb = detection3dpc.get_oriented_bounding_box() + assert obb is not None, "Oriented bounding box should not be None" + + # Verify OBB center values + assert obb.center[0] == pytest.approx(-3.36002, abs=0.1) + assert obb.center[1] == pytest.approx(-0.196446, abs=0.1) + assert obb.center[2] == pytest.approx(0.220184, abs=0.1) + + # Verify OBB extent values + assert obb.extent[0] == pytest.approx(0.531275, abs=0.1) + assert obb.extent[1] == pytest.approx(0.461054, abs=0.1) + assert obb.extent[2] == pytest.approx(0.155, abs=0.1) + + # def test_bounding_box_dimensions(detection3dpc): + """Test bounding box dimension calculation.""" + dims = detection3dpc.get_bounding_box_dimensions() + assert len(dims) == 3, "Bounding box dimensions should have 3 values" + assert dims[0] == pytest.approx(0.350, abs=0.1) + assert dims[1] == pytest.approx(0.250, abs=0.1) + assert dims[2] == pytest.approx(0.550, abs=0.1) + + # def test_axis_aligned_bounding_box(detection3dpc): + """Test axis-aligned bounding box calculation.""" + aabb = detection3dpc.get_bounding_box() + assert aabb is not None, "Axis-aligned bounding box should not be None" + + # Verify AABB min values + assert aabb.min_bound[0] == pytest.approx(-3.575, abs=0.1) + assert aabb.min_bound[1] == pytest.approx(-0.375, abs=0.1) + assert aabb.min_bound[2] == pytest.approx(-0.075, abs=0.1) + + # Verify AABB max values + assert aabb.max_bound[0] == pytest.approx(-3.075, abs=0.1) + assert aabb.max_bound[1] == pytest.approx(-0.125, abs=0.1) + assert aabb.max_bound[2] == pytest.approx(0.475, abs=0.1) + + # def test_point_cloud_properties(detection3dpc): + """Test point cloud data and boundaries.""" + pc_points = detection3dpc.pointcloud.points() + assert len(pc_points) in [69, 70] + assert detection3dpc.pointcloud.frame_id == "world", ( + f"Expected frame_id 'world', got '{detection3dpc.pointcloud.frame_id}'" + ) + + # Extract xyz coordinates from points + points = np.array([[pt[0], pt[1], pt[2]] for pt in pc_points]) + + min_pt = np.min(points, axis=0) + max_pt = np.max(points, axis=0) + center = np.mean(points, axis=0) + + # Verify point cloud boundaries + assert min_pt[0] == pytest.approx(-3.575, abs=0.1) + assert min_pt[1] == pytest.approx(-0.375, abs=0.1) + assert min_pt[2] == pytest.approx(-0.075, abs=0.1) + + assert max_pt[0] == pytest.approx(-3.075, abs=0.1) + assert max_pt[1] == pytest.approx(-0.125, abs=0.1) + assert max_pt[2] == pytest.approx(0.475, abs=0.1) + + assert center[0] == pytest.approx(-3.326, abs=0.1) + assert center[1] == pytest.approx(-0.202, abs=0.1) + assert center[2] == pytest.approx(0.160, abs=0.1) + + # def test_foxglove_scene_entity_generation(detection3dpc): + """Test Foxglove scene entity creation and structure.""" + entity = detection3dpc.to_foxglove_scene_entity("test_entity_123") + + # Verify entity metadata + assert entity.id == "1", f"Expected entity ID '1', got '{entity.id}'" + assert entity.frame_id == "world", f"Expected frame_id 'world', got '{entity.frame_id}'" + assert entity.cubes_length == 1, f"Expected 1 cube, got {entity.cubes_length}" + assert entity.texts_length == 1, f"Expected 1 text, got {entity.texts_length}" + + # def test_foxglove_cube_properties(detection3dpc): + """Test Foxglove cube primitive properties.""" + entity = detection3dpc.to_foxglove_scene_entity("test_entity_123") + cube = entity.cubes[0] + + # Verify position + assert cube.pose.position.x == pytest.approx(-3.325, abs=0.1) + assert cube.pose.position.y == pytest.approx(-0.250, abs=0.1) + assert cube.pose.position.z == pytest.approx(0.200, abs=0.1) + + # Verify size + assert cube.size.x == pytest.approx(0.350, abs=0.1) + assert cube.size.y == pytest.approx(0.250, abs=0.1) + assert cube.size.z == pytest.approx(0.550, abs=0.1) + + # Verify color (green with alpha) + assert cube.color.r == pytest.approx(0.08235294117647059, abs=0.1) + assert cube.color.g == pytest.approx(0.7176470588235294, abs=0.1) + assert cube.color.b == pytest.approx(0.28627450980392155, abs=0.1) + assert cube.color.a == pytest.approx(0.2, abs=0.1) + + # def test_foxglove_text_label(detection3dpc): + """Test Foxglove text label properties.""" + entity = detection3dpc.to_foxglove_scene_entity("test_entity_123") + text = entity.texts[0] + + assert text.text == "1/suitcase (81%)", f"Expected text '1/suitcase (81%)', got '{text.text}'" + assert text.pose.position.x == pytest.approx(-3.325, abs=0.1) + assert text.pose.position.y == pytest.approx(-0.250, abs=0.1) + assert text.pose.position.z == pytest.approx(0.575, abs=0.1) + assert text.font_size == 20.0, f"Expected font size 20.0, got {text.font_size}" + + # def test_detection_pose(detection3dpc): + """Test detection pose and frame information.""" + assert detection3dpc.pose.x == pytest.approx(-3.327, abs=0.1) + assert detection3dpc.pose.y == pytest.approx(-0.202, abs=0.1) + assert detection3dpc.pose.z == pytest.approx(0.160, abs=0.1) + assert detection3dpc.pose.frame_id == "world", ( + f"Expected frame_id 'world', got '{detection3dpc.pose.frame_id}'" + ) diff --git a/dimos/perception/detection2d/type/test_object3d.py b/dimos/perception/detection2d/type/test_object3d.py index 4fac1dcf10..b7933e86d5 100644 --- a/dimos/perception/detection2d/type/test_object3d.py +++ b/dimos/perception/detection2d/type/test_object3d.py @@ -21,24 +21,8 @@ from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule -def test_object_db_module_populated(object_db_module): - """Test that ObjectDBModule is properly populated.""" - assert len(object_db_module.objects) > 0, "Database should contain objects" - assert object_db_module.cnt > 0, "Object counter should be greater than 0" - - -def test_object_db_module_objects_structure(all_objects): - """Test the structure of objects in the database.""" - for obj in all_objects: - assert isinstance(obj, Object3D) - assert hasattr(obj, "track_id") - assert hasattr(obj, "detections") - assert hasattr(obj, "best_detection") - assert hasattr(obj, "center") - assert len(obj.detections) >= 1 - - -def test_object3d_properties(first_object): +def test_first_object(first_object): + # def test_object3d_properties(first_object): """Test basic properties of an Object3D.""" assert first_object.track_id is not None assert isinstance(first_object.track_id, str) @@ -49,30 +33,7 @@ def test_object3d_properties(first_object): assert first_object.frame_id is not None assert first_object.best_detection is not None - -def test_object3d_multiple_detections(all_objects): - """Test objects that have been built from multiple detections.""" - # Find objects with multiple detections - multi_detection_objects = [obj for obj in all_objects if len(obj.detections) > 1] - - if multi_detection_objects: - obj = multi_detection_objects[0] - - # Test that confidence is the max of all detections - max_conf = max(d.confidence for d in obj.detections) - assert obj.confidence == max_conf - - # Test that timestamp is the max (most recent) - max_ts = max(d.ts for d in obj.detections) - assert obj.ts == max_ts - - # Test that best_detection has the largest bbox volume - best_volume = obj.best_detection.bbox_2d_volume() - for det in obj.detections: - assert det.bbox_2d_volume() <= best_volume - - -def test_object3d_center(first_object): + # def test_object3d_center(first_object): """Test Object3D center calculation.""" assert first_object.center is not None assert hasattr(first_object.center, "x") @@ -94,24 +55,22 @@ def test_object3d_repr_dict(first_object): assert "center" in repr_dict assert repr_dict["object_id"] == first_object.track_id - assert repr_dict["detections"] == len(first_object.detections) + assert repr_dict["detections"] == first_object.detections # Center should be formatted as string with coordinates assert isinstance(repr_dict["center"], str) assert repr_dict["center"].startswith("[") assert repr_dict["center"].endswith("]") - -def test_object3d_scene_entity_label(first_object): + # def test_object3d_scene_entity_label(first_object): """Test scene entity label generation.""" label = first_object.scene_entity_label() assert isinstance(label, str) assert first_object.name in label - assert f"({len(first_object.detections)})" in label - + assert f"({first_object.detections})" in label -def test_object3d_agent_encode(first_object): + # def test_object3d_agent_encode(first_object): """Test agent encoding.""" encoded = first_object.agent_encode() @@ -123,17 +82,52 @@ def test_object3d_agent_encode(first_object): assert encoded["id"] == first_object.track_id assert encoded["name"] == first_object.name - assert encoded["detections"] == len(first_object.detections) + assert encoded["detections"] == first_object.detections assert encoded["last_seen"].endswith("s ago") - -def test_object3d_image_property(first_object): + # def test_object3d_image_property(first_object): """Test image property returns best_detection's image.""" assert first_object.image is not None assert first_object.image is first_object.best_detection.image -def test_object3d_addition(object_db_module): +def test_all_objeects(all_objects): + # def test_object3d_multiple_detections(all_objects): + """Test objects that have been built from multiple detections.""" + # Find objects with multiple detections + multi_detection_objects = [obj for obj in all_objects if obj.detections > 1] + + if multi_detection_objects: + obj = multi_detection_objects[0] + + # Since detections is now a counter, we can only test that we have multiple detections + # and that best_detection exists + assert obj.detections > 1 + assert obj.best_detection is not None + assert obj.confidence is not None + assert obj.ts > 0 + + # Test that best_detection has reasonable properties + assert obj.best_detection.bbox_2d_volume() > 0 + + # def test_object_db_module_objects_structure(all_objects): + """Test the structure of objects in the database.""" + for obj in all_objects: + assert isinstance(obj, Object3D) + assert hasattr(obj, "track_id") + assert hasattr(obj, "detections") + assert hasattr(obj, "best_detection") + assert hasattr(obj, "center") + assert obj.detections >= 1 + + +def test_objectdb_module(object_db_module): + # def test_object_db_module_populated(object_db_module): + """Test that ObjectDBModule is properly populated.""" + assert len(object_db_module.objects) > 0, "Database should contain objects" + assert object_db_module.cnt > 0, "Object counter should be greater than 0" + + # def test_object3d_addition(object_db_module): """Test Object3D addition operator.""" # Get existing objects from the database objects = list(object_db_module.objects.values()) @@ -151,11 +145,10 @@ def test_object3d_addition(object_db_module): combined = obj + det2 assert combined.track_id == "test_track_combined" - assert len(combined.detections) == 2 + assert combined.detections == 2 - # The combined object should have properties from both detections - assert det1 in combined.detections - assert det2 in combined.detections + # Since detections is now a counter, we can't check if specific detections are in the list + # We can only verify the count and that best_detection is properly set # Best detection should be determined by the Object3D logic assert combined.best_detection is not None @@ -164,8 +157,7 @@ def test_object3d_addition(object_db_module): assert hasattr(combined, "center") assert combined.center is not None - -def test_image_detections3d_scene_update(object_db_module): + # def test_image_detections3d_scene_update(object_db_module): """Test ImageDetections3D to Foxglove scene update conversion.""" # Get some detections objects = list(object_db_module.objects.values()) diff --git a/dimos/robot/unitree_webrtc/unitree_b1/connection.py b/dimos/robot/unitree_webrtc/unitree_b1/connection.py index 657ad8f6e4..73d24bdc3c 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/connection.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/connection.py @@ -23,11 +23,12 @@ import time from typing import Optional -from dimos.core import In, Out, Module, rpc -from dimos.msgs.geometry_msgs import Twist, TwistStamped, PoseStamped +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.std_msgs import Int32 from dimos.utils.logging_config import setup_logger + from .b1_command import B1Command # Setup logger with DEBUG level for troubleshooting @@ -354,7 +355,7 @@ def cleanup(self): self.stop() -class TestB1ConnectionModule(B1ConnectionModule): +class MockB1ConnectionModule(B1ConnectionModule): """Test connection module that prints commands instead of sending UDP.""" def __init__(self, ip: str = "127.0.0.1", port: int = 9090, *args, **kwargs): diff --git a/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py b/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py index e595a1adde..a9451acdf0 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py @@ -22,20 +22,21 @@ # should be used and tested. Additionally, tests should always use `try-finally` # to clean up even if the test fails. -import time import threading +import time -from .connection import TestB1ConnectionModule from dimos.msgs.geometry_msgs import TwistStamped, Vector3 from dimos.msgs.std_msgs.Int32 import Int32 +from .connection import MockB1ConnectionModule + class TestB1Connection: """Test suite for B1 connection module with Timer implementation.""" def test_watchdog_actually_zeros_commands(self): """Test that watchdog thread zeros commands after timeout.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True conn.watchdog_running = True conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) @@ -76,7 +77,7 @@ def test_watchdog_actually_zeros_commands(self): def test_watchdog_resets_on_new_command(self): """Test that watchdog timeout resets when new command arrives.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True conn.watchdog_running = True conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) @@ -124,7 +125,7 @@ def test_watchdog_resets_on_new_command(self): def test_watchdog_thread_efficiency(self): """Test that watchdog uses only one thread regardless of command rate.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True conn.watchdog_running = True conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) @@ -158,7 +159,7 @@ def test_watchdog_thread_efficiency(self): def test_watchdog_with_send_loop_blocking(self): """Test that watchdog still works if send loop blocks.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) # Mock the send loop to simulate blocking original_send_loop = conn._send_loop @@ -205,7 +206,7 @@ def blocking_send_loop(): def test_continuous_commands_prevent_timeout(self): """Test that continuous commands prevent watchdog timeout.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True conn.watchdog_running = True conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) @@ -240,7 +241,7 @@ def test_continuous_commands_prevent_timeout(self): def test_watchdog_timing_accuracy(self): """Test that watchdog zeros commands at approximately 200ms.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True conn.watchdog_running = True conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) @@ -281,7 +282,7 @@ def test_watchdog_timing_accuracy(self): def test_mode_changes_with_watchdog(self): """Test that mode changes work correctly with watchdog.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True conn.watchdog_running = True conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) @@ -321,7 +322,7 @@ def test_mode_changes_with_watchdog(self): def test_watchdog_stops_movement_when_commands_stop(self): """Verify watchdog zeros commands when packets stop being sent.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True conn.watchdog_running = True conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) @@ -379,7 +380,7 @@ def test_watchdog_stops_movement_when_commands_stop(self): def test_rapid_command_thread_safety(self): """Test thread safety with rapid commands from multiple threads.""" - conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True conn.watchdog_running = True conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) diff --git a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py index 77cf3ce19c..bef0fafbfa 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py @@ -20,21 +20,21 @@ Uses standard Twist interface for velocity commands. """ -import os import logging +import os from typing import Optional from dimos import core -from dimos.msgs.geometry_msgs import Twist, TwistStamped, PoseStamped +from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.msgs.std_msgs import Int32 +from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.protocol.pubsub.lcmpubsub import LCM from dimos.robot.robot import Robot -from dimos.robot.ros_bridge import ROSBridge, BridgeDirection +from dimos.robot.ros_bridge import BridgeDirection, ROSBridge from dimos.robot.unitree_webrtc.unitree_b1.connection import ( B1ConnectionModule, - TestB1ConnectionModule, + MockB1ConnectionModule, ) from dimos.skills.skills import SkillLibrary from dimos.types.robot_capabilities import RobotCapability @@ -109,7 +109,7 @@ def start(self): logger.info("Deploying connection module...") if self.test_mode: - self.connection = self.dimos.deploy(TestB1ConnectionModule, self.ip, self.port) + self.connection = self.dimos.deploy(MockB1ConnectionModule, self.ip, self.port) else: self.connection = self.dimos.deploy(B1ConnectionModule, self.ip, self.port) diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index a4617774ef..052c596d2e 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -113,6 +113,16 @@ def __init__(self, ts: float, data: str): self.data = data +@pytest.fixture +def test_scheduler(): + """Fixture that provides a ThreadPoolScheduler and cleans it up after the test.""" + scheduler = ThreadPoolScheduler(max_workers=6) + yield scheduler + # Cleanup after test + scheduler.executor.shutdown(wait=True) + time.sleep(0.2) # Give threads time to finish cleanup + + @pytest.fixture def sample_items(): return [ @@ -268,64 +278,301 @@ def test_time_window_collection(): assert window.end_ts == 5.5 -@pytest.mark.tofix -def test_timestamp_alignment(): - # Create a dedicated scheduler for this test to avoid thread leaks - test_scheduler = ThreadPoolScheduler(max_workers=6) - try: - speed = 5.0 +def test_timestamp_alignment(test_scheduler): + speed = 5.0 - # ensure that lfs package is downloaded - get_data("unitree_office_walk") + # ensure that lfs package is downloaded + get_data("unitree_office_walk") - raw_frames = [] + raw_frames = [] - def spy(image): - raw_frames.append(image.ts) - print(image.ts) - return image + def spy(image): + raw_frames.append(image.ts) + print(image.ts) + return image - # sensor reply of raw video frames - video_raw = ( - testing.TimedSensorReplay( - "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() - ) - .stream(speed) - .pipe(ops.take(30)) + # sensor reply of raw video frames + video_raw = ( + testing.TimedSensorReplay( + "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() ) + .stream(speed) + .pipe(ops.take(30)) + ) - processed_frames = [] + processed_frames = [] - def process_video_frame(frame): - processed_frames.append(frame.ts) - time.sleep(0.5 / speed) - return frame + def process_video_frame(frame): + processed_frames.append(frame.ts) + time.sleep(0.5 / speed) + return frame - # fake reply of some 0.5s processor of video frames that drops messages - fake_video_processor = backpressure( - video_raw.pipe(ops.map(spy)), scheduler=test_scheduler - ).pipe(ops.map(process_video_frame)) + # fake reply of some 0.5s processor of video frames that drops messages + # Pass the scheduler to backpressure to manage threads properly + fake_video_processor = backpressure( + video_raw.pipe(ops.map(spy)), scheduler=test_scheduler + ).pipe(ops.map(process_video_frame)) - aligned_frames = ( - align_timestamped(fake_video_processor, video_raw).pipe(ops.to_list()).run() + aligned_frames = align_timestamped(fake_video_processor, video_raw).pipe(ops.to_list()).run() + + assert len(raw_frames) == 30 + assert len(processed_frames) > 2 + assert len(aligned_frames) > 2 + + # Due to async processing, the last frame might not be aligned before completion + assert len(aligned_frames) >= len(processed_frames) - 1 + + for value in aligned_frames: + [primary, secondary] = value + diff = abs(primary.ts - secondary.ts) + print( + f"Aligned pair: primary={primary.ts:.6f}, secondary={secondary.ts:.6f}, diff={diff:.6f}s" ) + assert diff <= 0.05 + + assert len(aligned_frames) > 3 + + +def test_timestamp_alignment_primary_first(): + """Test alignment when primary messages arrive before secondary messages.""" + from reactivex import Subject + + primary_subject = Subject() + secondary_subject = Subject() + + results = [] + + # Set up alignment with a 2-second buffer + aligned = align_timestamped( + primary_subject, secondary_subject, buffer_size=2.0, match_tolerance=0.1 + ) + + # Subscribe to collect results + aligned.subscribe(lambda x: results.append(x)) + + # Send primary messages first + primary1 = SimpleTimestamped(1.0, "primary1") + primary2 = SimpleTimestamped(2.0, "primary2") + primary3 = SimpleTimestamped(3.0, "primary3") + + primary_subject.on_next(primary1) + primary_subject.on_next(primary2) + primary_subject.on_next(primary3) + + # At this point, no results should be emitted (no secondaries yet) + assert len(results) == 0 + + # Send secondary messages that match primary1 and primary2 + secondary1 = SimpleTimestamped(1.05, "secondary1") # Matches primary1 + secondary2 = SimpleTimestamped(2.02, "secondary2") # Matches primary2 + + secondary_subject.on_next(secondary1) + assert len(results) == 1 # primary1 should now be matched + assert results[0][0].data == "primary1" + assert results[0][1].data == "secondary1" + + secondary_subject.on_next(secondary2) + assert len(results) == 2 # primary2 should now be matched + assert results[1][0].data == "primary2" + assert results[1][1].data == "secondary2" + + # Send a secondary that's too far from primary3 + secondary_far = SimpleTimestamped(3.5, "secondary_far") # Too far from primary3 + secondary_subject.on_next(secondary_far) + # At this point primary3 is removed as unmatchable since secondary progressed past it + assert len(results) == 2 # primary3 should not match (outside tolerance) + + # Send a new primary that can match with the future secondary + primary4 = SimpleTimestamped(3.45, "primary4") + primary_subject.on_next(primary4) + assert len(results) == 3 # Should match with secondary_far + assert results[2][0].data == "primary4" + assert results[2][1].data == "secondary_far" + + # Complete the streams + primary_subject.on_completed() + secondary_subject.on_completed() + + +def test_timestamp_alignment_multiple_secondaries(): + """Test alignment with multiple secondary observables.""" + from reactivex import Subject + + primary_subject = Subject() + secondary1_subject = Subject() + secondary2_subject = Subject() + + results = [] + + # Set up alignment with two secondary streams + aligned = align_timestamped( + primary_subject, + secondary1_subject, + secondary2_subject, + buffer_size=1.0, + match_tolerance=0.05, + ) + + # Subscribe to collect results + aligned.subscribe(lambda x: results.append(x)) + + # Send a primary message + primary1 = SimpleTimestamped(1.0, "primary1") + primary_subject.on_next(primary1) + + # No results yet (waiting for both secondaries) + assert len(results) == 0 + + # Send first secondary + sec1_msg1 = SimpleTimestamped(1.01, "sec1_msg1") + secondary1_subject.on_next(sec1_msg1) + + # Still no results (waiting for secondary2) + assert len(results) == 0 + + # Send second secondary + sec2_msg1 = SimpleTimestamped(1.02, "sec2_msg1") + secondary2_subject.on_next(sec2_msg1) + + # Now we should have a result + assert len(results) == 1 + assert results[0][0].data == "primary1" + assert results[0][1].data == "sec1_msg1" + assert results[0][2].data == "sec2_msg1" + + # Test partial match (one secondary missing) + primary2 = SimpleTimestamped(2.0, "primary2") + primary_subject.on_next(primary2) + + # Send only one secondary + sec1_msg2 = SimpleTimestamped(2.01, "sec1_msg2") + secondary1_subject.on_next(sec1_msg2) + + # No result yet + assert len(results) == 1 + + # Send a secondary2 that's too far + sec2_far = SimpleTimestamped(2.1, "sec2_far") # Outside tolerance + secondary2_subject.on_next(sec2_far) + + # Still no result (secondary2 is outside tolerance) + assert len(results) == 1 + + # Complete the streams + primary_subject.on_completed() + secondary1_subject.on_completed() + secondary2_subject.on_completed() + + +def test_timestamp_alignment_delayed_secondary(): + """Test alignment when secondary messages arrive late but still within tolerance.""" + from reactivex import Subject + + primary_subject = Subject() + secondary_subject = Subject() + + results = [] + + # Set up alignment with a 2-second buffer + aligned = align_timestamped( + primary_subject, secondary_subject, buffer_size=2.0, match_tolerance=0.1 + ) + + # Subscribe to collect results + aligned.subscribe(lambda x: results.append(x)) + + # Send primary messages + primary1 = SimpleTimestamped(1.0, "primary1") + primary2 = SimpleTimestamped(2.0, "primary2") + primary3 = SimpleTimestamped(3.0, "primary3") + + primary_subject.on_next(primary1) + primary_subject.on_next(primary2) + primary_subject.on_next(primary3) + + # No results yet + assert len(results) == 0 + + # Send delayed secondaries (in timestamp order) + secondary1 = SimpleTimestamped(1.05, "secondary1") # Matches primary1 + secondary_subject.on_next(secondary1) + assert len(results) == 1 # primary1 matched + assert results[0][0].data == "primary1" + assert results[0][1].data == "secondary1" + + secondary2 = SimpleTimestamped(2.02, "secondary2") # Matches primary2 + secondary_subject.on_next(secondary2) + assert len(results) == 2 # primary2 matched + assert results[1][0].data == "primary2" + assert results[1][1].data == "secondary2" + + # Now send a secondary that's past primary3's match window + secondary_future = SimpleTimestamped(3.2, "secondary_future") # Too far from primary3 + secondary_subject.on_next(secondary_future) + # At this point, primary3 should be removed as unmatchable + assert len(results) == 2 # No new matches + + # Send a new primary that can match with secondary_future + primary4 = SimpleTimestamped(3.15, "primary4") + primary_subject.on_next(primary4) + assert len(results) == 3 # Should match immediately + assert results[2][0].data == "primary4" + assert results[2][1].data == "secondary_future" + + # Complete the streams + primary_subject.on_completed() + secondary_subject.on_completed() + + +def test_timestamp_alignment_buffer_cleanup(): + """Test that old buffered primaries are cleaned up.""" + import time as time_module + + from reactivex import Subject + + primary_subject = Subject() + secondary_subject = Subject() + + results = [] + + # Set up alignment with a 0.5-second buffer + aligned = align_timestamped( + primary_subject, secondary_subject, buffer_size=0.5, match_tolerance=0.05 + ) + + # Subscribe to collect results + aligned.subscribe(lambda x: results.append(x)) + + # Use real timestamps for this test + now = time_module.time() + + # Send an old primary + old_primary = Timestamped(now - 1.0) # 1 second ago + old_primary.data = "old" + primary_subject.on_next(old_primary) + + # Send a recent secondary to trigger cleanup + recent_secondary = Timestamped(now) + recent_secondary.data = "recent" + secondary_subject.on_next(recent_secondary) + + # Old primary should not match (outside buffer window) + assert len(results) == 0 + + # Send a matching pair within buffer + new_primary = Timestamped(now + 0.1) + new_primary.data = "new_primary" + new_secondary = Timestamped(now + 0.11) + new_secondary.data = "new_secondary" + + primary_subject.on_next(new_primary) + secondary_subject.on_next(new_secondary) + + # Should have one match + assert len(results) == 1 + assert results[0][0].data == "new_primary" + assert results[0][1].data == "new_secondary" - assert len(raw_frames) == 30 - assert len(processed_frames) > 2 - assert len(aligned_frames) > 2 - - # Due to async processing, the last frame might not be aligned before completion - assert len(aligned_frames) >= len(processed_frames) - 1 - - for value in aligned_frames: - [primary, secondary] = value - diff = abs(primary.ts - secondary.ts) - print( - f"Aligned pair: primary={primary.ts:.6f}, secondary={secondary.ts:.6f}, diff={diff:.6f}s" - ) - assert diff <= 0.05 - finally: - # Always shutdown the scheduler to clean up threads - test_scheduler.executor.shutdown(wait=True) - # Give threads time to finish cleanup - time.sleep(0.2) + # Complete the streams + primary_subject.on_completed() + secondary_subject.on_completed() diff --git a/dimos/types/test_weaklist.py b/dimos/types/test_weaklist.py index fd0b05d74f..c4dfe27616 100644 --- a/dimos/types/test_weaklist.py +++ b/dimos/types/test_weaklist.py @@ -15,18 +15,20 @@ """Tests for WeakList implementation.""" import gc + import pytest + from dimos.types.weaklist import WeakList -class TestObject: +class SampleObject: """Simple test object.""" def __init__(self, value): self.value = value def __repr__(self): - return f"TestObject({self.value})" + return f"SampleObject({self.value})" def test_weaklist_basic_operations(): @@ -34,9 +36,9 @@ def test_weaklist_basic_operations(): wl = WeakList() # Add objects - obj1 = TestObject(1) - obj2 = TestObject(2) - obj3 = TestObject(3) + obj1 = SampleObject(1) + obj2 = SampleObject(2) + obj3 = SampleObject(3) wl.append(obj1) wl.append(obj2) @@ -49,16 +51,16 @@ def test_weaklist_basic_operations(): # Check contains assert obj1 in wl assert obj2 in wl - assert TestObject(4) not in wl + assert SampleObject(4) not in wl def test_weaklist_auto_removal(): """Test that objects are automatically removed when garbage collected.""" wl = WeakList() - obj1 = TestObject(1) - obj2 = TestObject(2) - obj3 = TestObject(3) + obj1 = SampleObject(1) + obj2 = SampleObject(2) + obj3 = SampleObject(3) wl.append(obj1) wl.append(obj2) @@ -79,8 +81,8 @@ def test_weaklist_explicit_remove(): """Test explicit removal of objects.""" wl = WeakList() - obj1 = TestObject(1) - obj2 = TestObject(2) + obj1 = SampleObject(1) + obj2 = SampleObject(2) wl.append(obj1) wl.append(obj2) @@ -93,16 +95,16 @@ def test_weaklist_explicit_remove(): # Try to remove non-existent object with pytest.raises(ValueError): - wl.remove(TestObject(3)) + wl.remove(SampleObject(3)) def test_weaklist_indexing(): """Test index access.""" wl = WeakList() - obj1 = TestObject(1) - obj2 = TestObject(2) - obj3 = TestObject(3) + obj1 = SampleObject(1) + obj2 = SampleObject(2) + obj3 = SampleObject(3) wl.append(obj1) wl.append(obj2) @@ -121,8 +123,8 @@ def test_weaklist_clear(): """Test clearing the list.""" wl = WeakList() - obj1 = TestObject(1) - obj2 = TestObject(2) + obj1 = SampleObject(1) + obj2 = SampleObject(2) wl.append(obj1) wl.append(obj2) @@ -138,7 +140,7 @@ def test_weaklist_iteration_during_modification(): """Test that iteration works even if objects are deleted during iteration.""" wl = WeakList() - objects = [TestObject(i) for i in range(5)] + objects = [SampleObject(i) for i in range(5)] for obj in objects: wl.append(obj) @@ -151,7 +153,7 @@ def test_weaklist_iteration_during_modification(): seen_values.append(obj.value) if obj.value == 2: # Delete another object (not the current one) - del objects[3] # Delete TestObject(3) + del objects[3] # Delete SampleObject(3) gc.collect() # The object with value 3 gets garbage collected during iteration diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index 0e5427d0b6..412ba08c03 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -11,16 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from collections import defaultdict from datetime import datetime, timezone -from typing import Generic, Iterable, Optional, Tuple, TypedDict, TypeVar, Union +from typing import Generic, Iterable, List, Optional, Tuple, TypeVar, Union from dimos_lcm.builtin_interfaces import Time as ROSTime +from reactivex import create +from reactivex.disposable import CompositeDisposable # from dimos_lcm.std_msgs import Time as ROSTime from reactivex.observable import Observable from sortedcontainers import SortedKeyList +from dimos.types.weaklist import WeakList +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.timestampAlignment") + # any class that carries a timestamp should inherit from this # this allows us to work with timeseries in consistent way, allign messages, replay etc # aditional functionality will come to this class soon @@ -243,100 +250,160 @@ def _prune_old_messages(self, current_ts: float) -> None: if keep_idx > 0: del self._items[:keep_idx] + def remove_by_timestamp(self, timestamp: float) -> bool: + """Remove an item with the given timestamp. Returns True if item was found and removed.""" + idx = self._items.bisect_key_left(timestamp) -def align_timestamped( - primary_observable: Observable[PRIMARY], - secondary_observable: Observable[SECONDARY], - buffer_size: float = 1.0, # seconds - match_tolerance: float = 0.05, # seconds -) -> Observable[Tuple[PRIMARY, SECONDARY]]: - from reactivex import create - from reactivex.disposable import CompositeDisposable + if idx < len(self._items) and self._items[idx].ts == timestamp: + del self._items[idx] + return True + return False - def subscribe(observer, scheduler=None): - secondary_collection: TimestampedBufferCollection[SECONDARY] = TimestampedBufferCollection( - buffer_size - ) - # Subscribe to secondary to populate the buffer with proper error/complete handling - secondary_sub = secondary_observable.subscribe( - on_next=secondary_collection.add, - on_error=lambda e: None, # Silently ignore errors from secondary - on_completed=lambda: None, # Silently ignore completion from secondary - ) + def remove(self, item: T) -> bool: + """Remove a timestamped item from the collection. Returns True if item was found and removed.""" + return self.remove_by_timestamp(item.ts) - def on_primary(primary_item: PRIMARY): - secondary_item = secondary_collection.find_closest( - primary_item.ts, tolerance=match_tolerance - ) - if secondary_item is not None: - observer.on_next((primary_item, secondary_item)) - # Subscribe to primary and emit aligned pairs - primary_sub = primary_observable.subscribe( - on_next=on_primary, on_error=observer.on_error, on_completed=observer.on_completed - ) +class MatchContainer(Timestamped, Generic[PRIMARY, SECONDARY]): + """ + This class stores a primary item along with its partial matches to secondary items, + tracking which secondaries are still missing to avoid redundant searches. + """ - # Return cleanup disposable - return CompositeDisposable(secondary_sub, primary_sub) + def __init__(self, primary: PRIMARY, matches: List[Optional[SECONDARY]]): + super().__init__(primary.ts) + self.primary = primary + self.matches = matches # Direct list with None for missing matches - return create(subscribe) + def message_received(self, secondary_idx: int, secondary_item: SECONDARY): + """Process a secondary message and check if it matches this primary.""" + if self.matches[secondary_idx] is None: + self.matches[secondary_idx] = secondary_item + def is_complete(self) -> bool: + """Check if all secondary matches have been found.""" + return all(match is not None for match in self.matches) -def align_timestamped_multiple( + def get_tuple(self) -> Tuple[PRIMARY, ...]: + """Get the result tuple for emission.""" + return (self.primary, *self.matches) + + +def align_timestamped( primary_observable: Observable[PRIMARY], *secondary_observables: Observable[SECONDARY], buffer_size: float = 1.0, # seconds - match_tolerance: float = 0.05, # seconds + match_tolerance: float = 0.1, # seconds ) -> Observable[Tuple[PRIMARY, ...]]: - """Align a primary observable with multiple secondary observables. + """Align a primary observable with one or more secondary observables. Args: primary_observable: The primary stream to align against - *secondary_observables: Secondary streams to align - buffer_size: Time window to keep secondary messages in seconds + *secondary_observables: One or more secondary streams to align + buffer_size: Time window to keep messages in seconds match_tolerance: Maximum time difference for matching in seconds Returns: - Observable that emits tuples of (primary_item, secondary1, secondary2, ...) - where each secondary item is the closest match from the corresponding + If single secondary observable: Observable that emits tuples of (primary_item, secondary_item) + If multiple secondary observables: Observable that emits tuples of (primary_item, secondary1, secondary2, ...) + Each secondary item is the closest match from the corresponding secondary observable, or None if no match within tolerance. """ - from reactivex import create def subscribe(observer, scheduler=None): - from reactivex.disposable import CompositeDisposable - - # Create a buffer collection for each secondary observable - secondary_collections: list[TimestampedBufferCollection[SECONDARY]] = [ + # Create a timed buffer collection for each secondary observable + secondary_collections: List[TimestampedBufferCollection[SECONDARY]] = [ TimestampedBufferCollection(buffer_size) for _ in secondary_observables ] - # Subscribe to all secondary observables with proper error/complete handling + # WeakLists to track subscribers to each secondary observable + secondary_stakeholders = defaultdict(WeakList) + + # Buffer for unmatched MatchContainers - automatically expires old items + primary_buffer: TimestampedBufferCollection[MatchContainer[PRIMARY, SECONDARY]] = ( + TimestampedBufferCollection(buffer_size) + ) + + # Subscribe to all secondary observables secondary_subs = [] + + def has_secondary_progressed_past(secondary_ts: float, primary_ts: float) -> bool: + """Check if secondary stream has progressed past the primary + tolerance.""" + return secondary_ts > primary_ts + match_tolerance + + def remove_stakeholder(stakeholder: MatchContainer): + """Remove a stakeholder from all tracking structures.""" + primary_buffer.remove(stakeholder) + for weak_list in secondary_stakeholders.values(): + weak_list.discard(stakeholder) + + def on_secondary(i: int, secondary_item: SECONDARY): + # Add the secondary item to its collection + secondary_collections[i].add(secondary_item) + + # Check all stakeholders for this secondary stream + for stakeholder in secondary_stakeholders[i]: + # If the secondary stream has progressed past this primary, + # we won't be able to match it anymore + if has_secondary_progressed_past(secondary_item.ts, stakeholder.ts): + logger.debug(f"secondary progressed, giving up {stakeholder.ts}") + + remove_stakeholder(stakeholder) + continue + + # Check if this secondary is within tolerance of the primary + if abs(stakeholder.ts - secondary_item.ts) <= match_tolerance: + stakeholder.message_received(i, secondary_item) + + # If all secondaries matched, emit result + if stakeholder.is_complete(): + logger.debug(f"Emitting deferred match {stakeholder.ts}") + observer.on_next(stakeholder.get_tuple()) + remove_stakeholder(stakeholder) + for i, secondary_obs in enumerate(secondary_observables): - sub = secondary_obs.subscribe( - on_next=secondary_collections[i].add, - on_error=lambda e: None, # Silently ignore errors from secondary - on_completed=lambda: None, # Silently ignore completion from secondary + secondary_subs.append( + secondary_obs.subscribe( + lambda x, idx=i: on_secondary(idx, x), on_error=observer.on_error + ) ) - secondary_subs.append(sub) def on_primary(primary_item: PRIMARY): - # Find closest match from each secondary collection - secondary_items = [] - for collection in secondary_collections: - secondary_item = collection.find_closest(primary_item.ts, tolerance=match_tolerance) - secondary_items.append(secondary_item) - - # Emit the aligned tuple (flatten into single tuple) - observer.on_next((primary_item, *secondary_items)) - - # Subscribe to primary and emit aligned tuples + # Try to find matches in existing secondary collections + matches = [None] * len(secondary_observables) + + for i, collection in enumerate(secondary_collections): + closest = collection.find_closest(primary_item.ts, tolerance=match_tolerance) + if closest is not None: + matches[i] = closest + else: + # Check if this secondary stream has already progressed past this primary + if collection.end_ts is not None and has_secondary_progressed_past( + collection.end_ts, primary_item.ts + ): + # This secondary won't match, so don't buffer this primary + return + + # If all matched, emit immediately without creating MatchContainer + if all(match is not None for match in matches): + logger.debug(f"Immadiate match {primary_item.ts}") + result = (primary_item, *matches) + observer.on_next(result) + else: + logger.debug(f"Deferred match attempt {primary_item.ts}") + match_container = MatchContainer(primary_item, matches) + primary_buffer.add(match_container) + + for i, match in enumerate(matches): + if match is None: + secondary_stakeholders[i].append(match_container) + + # Subscribe to primary observable primary_sub = primary_observable.subscribe( - on_next=on_primary, on_error=observer.on_error, on_completed=observer.on_completed + on_primary, on_error=observer.on_error, on_completed=observer.on_completed ) - # Return cleanup disposable + # Return a CompositeDisposable for proper cleanup return CompositeDisposable(primary_sub, *secondary_subs) return create(subscribe) diff --git a/dimos/utils/decorators/decorators.py b/dimos/utils/decorators/decorators.py index 13ca5844a8..c54e3530e1 100644 --- a/dimos/utils/decorators/decorators.py +++ b/dimos/utils/decorators/decorators.py @@ -100,3 +100,46 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def simple_mcache(method: Callable) -> Callable: + """ + Decorator to cache the result of a method call on the instance. + + The cached value is stored as an attribute on the instance with the name + `_cached_`. Subsequent calls to the method will return the + cached value instead of recomputing it. + + Thread-safe: Uses a lock per instance to ensure the cached value is + computed only once even in multi-threaded environments. + + Args: + method: The method to be decorated. + + Returns: + The decorated method with caching behavior. + """ + + attr_name = f"_cached_{method.__name__}" + lock_name = f"_lock_{method.__name__}" + + @wraps(method) + def getter(self): + # Get or create the lock for this instance + if not hasattr(self, lock_name): + # This is a one-time operation, race condition here is acceptable + # as worst case we create multiple locks but only one gets stored + setattr(self, lock_name, threading.Lock()) + + lock = getattr(self, lock_name) + + if hasattr(self, attr_name): + return getattr(self, attr_name) + + with lock: + # Check again inside the lock + if not hasattr(self, attr_name): + setattr(self, attr_name, method(self)) + return getattr(self, attr_name) + + return getter diff --git a/dimos/utils/reactive.py b/dimos/utils/reactive.py index 8ab8fe66ae..74c7044648 100644 --- a/dimos/utils/reactive.py +++ b/dimos/utils/reactive.py @@ -214,8 +214,8 @@ def quality_barrier(quality_func: Callable[[T], float], target_frequency: float) def _quality_barrier(source: Observable[T]) -> Observable[T]: return source.pipe( - # Create time-based windows - ops.window_with_time(window_duration), + # Create non-overlapping time-based windows + ops.window_with_time(window_duration, window_duration), # For each window, find the highest quality item ops.flat_map( lambda window: window.pipe( diff --git a/dimos/utils/test_reactive.py b/dimos/utils/test_reactive.py index 21f2bd7894..1a9b759ab3 100644 --- a/dimos/utils/test_reactive.py +++ b/dimos/utils/test_reactive.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import time +from typing import Any, Callable, TypeVar + import numpy as np +import pytest import reactivex as rx from reactivex import operators as ops -from reactivex.scheduler import ThreadPoolScheduler -from typing import Callable, TypeVar, Any from reactivex.disposable import Disposable +from reactivex.scheduler import ThreadPoolScheduler + from dimos.utils.reactive import ( backpressure, - getter_streaming, - getter_ondemand, callback_to_observable, + getter_ondemand, + getter_streaming, ) @@ -175,7 +177,7 @@ def test_getter_streaming_nonblocking(): 0.1, "nonblocking getter init shouldn't block", ) - min_time(getter, 0.2, "Expected for first value call to block if cache is empty") + min_time(getter, 0.1, "Expected for first value call to block if cache is empty") assert getter() == 0 time.sleep(0.5) diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index a7c1541d87..c5984cf3fd 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -17,6 +17,7 @@ import os import pickle import re +import shutil import time from pathlib import Path from typing import Any, Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union @@ -107,13 +108,13 @@ def stream( class SensorStorage(Generic[T]): - """Generic sensor data storage utility. + """Generic sensor data storage utility + . + Creates a directory in the test data directory and stores pickled sensor data. - Creates a directory in the test data directory and stores pickled sensor data. - - Args: - name: The name of the storage directory - autocast: Optional function that takes data and returns a processed result before storage. + Args: + name: The name of the storage directory + autocast: Optional function that takes data and returns a processed result before storage. """ def __init__(self, name: str, autocast: Optional[Callable[[T], Any]] = None): @@ -136,6 +137,10 @@ def __init__(self, name: str, autocast: Optional[Callable[[T], Any]] = None): # Create the directory self.root_dir.mkdir(parents=True, exist_ok=True) + def consume_stream(self, observable: Observable[Union[T, Any]]) -> None: + """Consume an observable stream of sensor data without saving.""" + return observable.subscribe(self.save_one) + def save_stream(self, observable: Observable[Union[T, Any]]) -> Observable[int]: """Save an observable stream of sensor data to pickle files.""" return observable.pipe(ops.map(lambda frame: self.save_one(frame))) @@ -296,50 +301,77 @@ def stream( def _subscribe(observer, scheduler=None): from reactivex.disposable import CompositeDisposable, Disposable - scheduler = scheduler or TimeoutScheduler() # default thread-based + scheduler = scheduler or TimeoutScheduler() + disp = CompositeDisposable() + is_disposed = False iterator = self.iterate_ts( seek=seek, duration=duration, from_timestamp=from_timestamp, loop=loop ) + # Get first message try: - prev_ts, first_data = next(iterator) + first_ts, first_data = next(iterator) except StopIteration: observer.on_completed() return Disposable() - # Emit the first sample immediately + # Establish timing reference + start_local_time = time.time() + start_replay_time = first_ts + + # Emit first sample immediately observer.on_next(first_data) - disp = CompositeDisposable() - completed = [False] # Use list to allow mutation in nested function + # Pre-load next message + try: + next_message = next(iterator) + except StopIteration: + observer.on_completed() + return disp - def emit_next(prev_timestamp): - if completed[0]: + def schedule_emission(message): + nonlocal next_message, is_disposed + + if is_disposed: return + ts, data = message + + # Pre-load the following message while we have time try: - ts, data = next(iterator) + next_message = next(iterator) except StopIteration: - completed[0] = True - observer.on_completed() - return + next_message = None - delay = max(0.0, ts - prev_timestamp) / speed + # Calculate absolute emission time + target_time = start_local_time + (ts - start_replay_time) / speed + delay = max(0.0, target_time - time.time()) - def _action(sc, _state=None): - if not completed[0]: - observer.on_next(data) - emit_next(ts) # schedule the following sample + def emit(): + if is_disposed: + return + observer.on_next(data) + if next_message is not None: + schedule_emission(next_message) + else: + observer.on_completed() + # Dispose of the scheduler to clean up threads + if hasattr(scheduler, "dispose"): + scheduler.dispose() - # Schedule the next emission relative to previous timestamp - disp.add(scheduler.schedule_relative(delay, _action)) + disp.add(scheduler.schedule_relative(delay, lambda sc, _: emit())) - emit_next(prev_ts) + schedule_emission(next_message) + # Create a custom disposable that properly cleans up def dispose(): - completed[0] = True + nonlocal is_disposed + is_disposed = True disp.dispose() + # Ensure scheduler is disposed to clean up any threads + if hasattr(scheduler, "dispose"): + scheduler.dispose() return Disposable(dispose) diff --git a/pyproject.toml b/pyproject.toml index 670f98a0b7..7979293411 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,6 @@ dependencies = [ [project.scripts] lcmspy = "dimos.utils.cli.lcmspy.run_lcmspy:main" -lcm-recorder = "dimos.utils.cli.recorder.run_recorder:main" foxglove-bridge = "dimos.utils.cli.foxglove_bridge.run_foxglove_bridge:main" skillspy = "dimos.utils.cli.skillspy.skillspy:main" agentspy = "dimos.utils.cli.agentspy.agentspy:main"