diff --git a/assets/drone_foxglove_lcm_dashboard.json b/assets/drone_foxglove_lcm_dashboard.json new file mode 100644 index 0000000000..cfcd8afb47 --- /dev/null +++ b/assets/drone_foxglove_lcm_dashboard.json @@ -0,0 +1,381 @@ +{ + "configById": { + "RawMessages!3zk027p": { + "diffEnabled": false, + "diffMethod": "custom", + "diffTopicPath": "", + "showFullMessageForDiff": false, + "topicPath": "/drone/telemetry", + "fontSize": 12 + }, + "RawMessages!ra9m3n": { + "diffEnabled": false, + "diffMethod": "custom", + "diffTopicPath": "", + "showFullMessageForDiff": false, + "topicPath": "/drone/status", + "fontSize": 12 + }, + "RawMessages!2rdgzs9": { + "diffEnabled": false, + "diffMethod": "custom", + "diffTopicPath": "", + "showFullMessageForDiff": false, + "topicPath": "/drone/odom", + "fontSize": 12 + }, + "3D!18i6zy7": { + "layers": { + "845139cb-26bc-40b3-8161-8ab60af4baf5": { + "visible": true, + "frameLocked": true, + "label": "Grid", + "instanceId": "845139cb-26bc-40b3-8161-8ab60af4baf5", + "layerId": "foxglove.Grid", + "lineWidth": 0.5, + "position": [ + 0, + 0, + 0 + ], + "rotation": [ + 0, + 0, + 0 + ], + "order": 1, + "size": 30, + "divisions": 30, + "color": "#248eff57" + }, + "ff758451-8c06-4419-a995-e93c825eb8be": { + "visible": true, + "frameLocked": true, + "label": "Grid", + "instanceId": "ff758451-8c06-4419-a995-e93c825eb8be", + "layerId": "foxglove.Grid", + "frameId": "base_link", + "size": 3, + "divisions": 3, + "lineWidth": 1.5, + "color": "#24fff4ff", + "position": [ + 0, + 0, + 0 + ], + "rotation": [ + 0, + 0, + 0 + ], + "order": 2 + } + }, + "cameraState": { + "perspective": true, + "distance": 35.161738318180966, + "phi": 54.90139603020621, + "thetaOffset": -55.91718358847429, + "targetOffset": [ + -1.0714086708240587, + -1.3106525624032879, + 2.481084387307447e-16 + ], + "target": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": true, + "ignoreColladaUpAxis": false, + "syncCamera": false, + "transforms": { + "visible": true + } + }, + "transforms": {}, + "topics": { + "/lidar": { + "stixelsEnabled": false, + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "circle", + "pointSize": 10, + "explicitAlpha": 1, + "decayTime": 0, + "cubeSize": 0.1, + "minValue": -0.3, + "cubeOutline": false + }, + "/odom": { + "visible": true, + "axisScale": 1 + }, + "/video": { + "visible": false + }, + "/global_map": { + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointSize": 10, + "decayTime": 0, + "pointShape": "cube", + "cubeOutline": false, + "cubeSize": 0.08, + "gradient": [ + "#06011dff", + "#d1e2e2ff" + ], + "stixelsEnabled": false, + "explicitAlpha": 1, + "minValue": -0.2 + }, + "/global_path": { + "visible": true, + "type": "line", + "arrowScale": [ + 1, + 0.15, + 0.15 + ], + "lineWidth": 0.132, + "gradient": [ + "#6bff7cff", + "#0081ffff" + ] + }, + "/global_target": { + "visible": true + }, + "/pt": { + "visible": false + }, + "/global_costmap": { + "visible": true, + "maxColor": "#8d3939ff", + "frameLocked": false, + "unknownColor": "#80808000", + "colorMode": "custom", + "alpha": 0.517, + "minColor": "#1e00ff00" + }, + "/global_gradient": { + "visible": true, + "maxColor": "#690066ff", + "unknownColor": "#30b89a00", + "minColor": "#00000000", + "colorMode": "custom", + "alpha": 0.3662, + "frameLocked": false, + "drawBehind": false + }, + "/global_cost_field": { + "visible": false, + "maxColor": "#ff0000ff", + "unknownColor": "#80808000" + }, + "/global_passable": { + "visible": false, + "maxColor": "#ffffff00", + "minColor": "#ff0000ff", + "unknownColor": "#80808000" + } + }, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/estimate", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": {}, + "foxglovePanelTitle": "test", + "followTf": "world" + }, + "Image!3mnp456": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": true + }, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/drone/color_image", + "colorMode": "gradient", + "calibrationTopic": "/drone/camera_info" + }, + "foxglovePanelTitle": "/video" + }, + "Image!1gtgk2x": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": true + }, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/drone/depth_colorized", + "colorMode": "gradient", + "calibrationTopic": "/drone/camera_info" + }, + "foxglovePanelTitle": "/video" + }, + "Plot!a1gj37": { + "paths": [ + { + "timestampMethod": "receiveTime", + "value": "/drone/odom.pose.position.x", + "enabled": true, + "color": "#4e98e2" + }, + { + "timestampMethod": "receiveTime", + "value": "/drone/odom.pose.orientation.y", + "enabled": true, + "color": "#f5774d" + }, + { + "timestampMethod": "receiveTime", + "value": "/drone/odom.pose.position.z", + "enabled": true, + "color": "#f7df71" + } + ], + "showXAxisLabels": true, + "showYAxisLabels": true, + "showLegend": true, + "legendDisplay": "floating", + "showPlotValuesInLegend": false, + "isSynced": true, + "xAxisVal": "timestamp", + "sidebarDimension": 240 + } + }, + "globalVariables": {}, + "userNodes": {}, + "playbackConfig": { + "speed": 1 + }, + "drawerConfig": { + "tracks": [] + }, + "layout": { + "direction": "row", + "first": { + "first": { + "first": "RawMessages!3zk027p", + "second": "RawMessages!ra9m3n", + "direction": "column", + "splitPercentage": 69.92084432717678 + }, + "second": "RawMessages!2rdgzs9", + "direction": "column", + "splitPercentage": 70.97625329815304 + }, + "second": { + "first": "3D!18i6zy7", + "second": { + "first": "Image!3mnp456", + "second": { + "first": "Image!1gtgk2x", + "second": "Plot!a1gj37", + "direction": "column" + }, + "direction": "column", + "splitPercentage": 36.93931398416886 + }, + "direction": "row", + "splitPercentage": 52.45307143723201 + }, + "splitPercentage": 39.13203076769059 + } +} diff --git a/data/.lfs/drone.tar.gz b/data/.lfs/drone.tar.gz new file mode 100644 index 0000000000..2973c649cd --- /dev/null +++ b/data/.lfs/drone.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd73f988eee8fd7b99d6c0bf6a905c2f43a6145a4ef33e9eef64bee5f53e04dd +size 709946060 diff --git a/dimos/agents2/skills/google_maps_skill_container.py b/dimos/agents2/skills/google_maps_skill_container.py index 09eee6e490..ba1af7831d 100644 --- a/dimos/agents2/skills/google_maps_skill_container.py +++ b/dimos/agents2/skills/google_maps_skill_container.py @@ -15,6 +15,8 @@ import json from typing import Any +from reactivex.disposable import Disposable + from dimos.core.core import rpc from dimos.core.skill_module import SkillModule from dimos.core.stream import In @@ -35,6 +37,8 @@ class GoogleMapsSkillContainer(SkillModule): def __init__(self) -> None: super().__init__() self._client = GoogleMaps() + self._started = True + self._max_valid_distance = 20000 # meters @rpc def start(self) -> None: @@ -80,9 +84,10 @@ def where_am_i(self, context_radius: int = 200) -> str: return result.model_dump_json() @skill() - def get_gps_position_for_queries(self, *queries: str) -> str: - """Get the GPS position (latitude/longitude) - + def get_gps_position_for_queries(self, queries: list[str]) -> str: + """Get the GPS position (latitude/longitude) from Google Maps for know landmarks or searchable locations. + This includes anything that wouldn't be viewable on a physical OSM map including intersections (5th and Natoma) + landmarks (Dolores park), or locations (Tempest bar) Example: get_gps_position_for_queries(['Fort Mason', 'Lafayette Park']) diff --git a/dimos/agents2/skills/osm.py b/dimos/agents2/skills/osm.py index 52cd0137cd..a0dd9b37e2 100644 --- a/dimos/agents2/skills/osm.py +++ b/dimos/agents2/skills/osm.py @@ -47,14 +47,14 @@ def _on_gps_location(self, location: LatLon) -> None: self._latest_location = location @skill() - def street_map_query(self, query_sentence: str) -> str: + def map_query(self, query_sentence: str) -> str: """This skill uses a vision language model to find something on the map based on the query sentence. You can query it with something like "Where can I find a coffee shop?" and it returns the latitude and longitude. Example: - street_map_query("Where can I find a coffee shop?") + map_query("Where can I find a coffee shop?") Args: query_sentence (str): The query sentence. diff --git a/dimos/core/stream.py b/dimos/core/stream.py index 37a6fce766..2eab6de710 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -220,7 +220,7 @@ def __reduce__(self): # type: ignore[no-untyped-def] @property def transport(self) -> Transport[T]: - if not self._transport: + if not self._transport and self.connection: self._transport = self.connection.transport # type: ignore[union-attr] return self._transport diff --git a/dimos/mapping/osm/current_location_map.py b/dimos/mapping/osm/current_location_map.py index 7d6c132486..b34fb6c443 100644 --- a/dimos/mapping/osm/current_location_map.py +++ b/dimos/mapping/osm/current_location_map.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from PIL import Image as PILImage, ImageDraw from dimos.mapping.osm.osm import MapImage, get_osm_map from dimos.mapping.osm.query import query_for_one_position, query_for_one_position_and_context @@ -31,7 +32,7 @@ def __init__(self, vl_model: VlModel) -> None: self._vl_model = vl_model self._position = None self._map_image = None - self._zoom_level = 19 + self._zoom_level = 15 self._n_tiles = 6 # What ratio of the width is considered the center. 1.0 means the entire map is the center. self._center_width = 0.4 @@ -68,6 +69,24 @@ def _fetch_new_map(self) -> None: ) self._map_image = get_osm_map(self._position, self._zoom_level, self._n_tiles) # type: ignore[arg-type] + # Add position marker + import numpy as np + + assert self._map_image is not None + assert self._position is not None + pil_image = PILImage.fromarray(self._map_image.image.data) + draw = ImageDraw.Draw(pil_image) + x, y = self._map_image.latlon_to_pixel(self._position) + radius = 20 + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(255, 0, 0), + outline=(0, 0, 0), + width=3, + ) + + self._map_image.image.data[:] = np.array(pil_image) + def _position_is_too_far_off_center(self) -> bool: x, y = self._map_image.latlon_to_pixel(self._position) # type: ignore[arg-type, union-attr] width = self._map_image.image.width # type: ignore[union-attr] @@ -75,3 +94,20 @@ def _position_is_too_far_off_center(self) -> bool: size_max = width * (0.5 + self._center_width / 2) return x < size_min or x > size_max or y < size_min or y > size_max + + def save_current_map_image(self, filepath: str = "osm_debug_map.png") -> str: + """Save the current OSM map image to a file for debugging. + + Args: + filepath: Path where to save the image + + Returns: + The filepath where the image was saved + """ + if not self._map_image: + self._get_current_map() # type: ignore[no-untyped-call] + + if self._map_image is not None: + self._map_image.image.save(filepath) + logger.info(f"Saved OSM map image to {filepath}") + return filepath diff --git a/dimos/robot/drone/README.md b/dimos/robot/drone/README.md new file mode 100644 index 0000000000..fbd7ddf2ae --- /dev/null +++ b/dimos/robot/drone/README.md @@ -0,0 +1,289 @@ +# DimOS Drone Module + +Complete integration for DJI drones via RosettaDrone MAVLink bridge with visual servoing and autonomous tracking capabilities. + +## Quick Start + +### Test the System +```bash +# Test with replay mode (no hardware needed) +python dimos/robot/drone/drone.py --replay + +# Real drone - indoor (IMU odometry) +python dimos/robot/drone/drone.py + +# Real drone - outdoor (GPS odometry) +python dimos/robot/drone/drone.py --outdoor +``` + +### Python API Usage +```python +from dimos.robot.drone.drone import Drone + +# Connect to drone +drone = Drone(connection_string='udp:0.0.0.0:14550', outdoor=True) # Use outdoor=True for GPS +drone.start() + +# Basic operations +drone.arm() +drone.takeoff(altitude=5.0) +drone.move(Vector3(1.0, 0, 0), duration=2.0) # Forward 1m/s for 2s + +# Visual tracking +drone.tracking.track_object("person", duration=120) # Track for 2 minutes + +# Land and cleanup +drone.land() +drone.stop() +``` + +## Installation + +### Python Package +```bash +# Install DimOS with drone support +pip install -e .[drone] +``` + +### System Dependencies +```bash +# GStreamer for video streaming +sudo apt-get install -y gstreamer1.0-tools gstreamer1.0-plugins-base \ + gstreamer1.0-plugins-good gstreamer1.0-plugins-bad \ + gstreamer1.0-libav python3-gi python3-gi-cairo + +# LCM for communication +sudo apt-get install liblcm-dev +``` + +### Environment Setup +```bash +export DRONE_IP=0.0.0.0 # Listen on all interfaces +export DRONE_VIDEO_PORT=5600 +export DRONE_MAVLINK_PORT=14550 +``` + +## RosettaDrone Setup (Critical) + +RosettaDrone is an Android app that bridges DJI SDK to MAVLink protocol. Without it, the drone cannot communicate with DimOS. + +### Option 1: Pre-built APK +1. Download latest release: https://github.com/RosettaDrone/rosettadrone/releases +2. Install on Android device connected to DJI controller +3. Configure in app: + - MAVLink Target IP: Your computer's IP + - MAVLink Port: 14550 + - Video Port: 5600 + - Enable video streaming + +### Option 2: Build from Source + +#### Prerequisites +- Android Studio +- DJI Developer Account: https://developer.dji.com/ +- Git + +#### Build Steps +```bash +# Clone repository +git clone https://github.com/RosettaDrone/rosettadrone.git +cd rosettadrone + +# Build with Gradle +./gradlew assembleRelease + +# APK will be in: app/build/outputs/apk/release/ +``` + +#### Configure DJI API Key +1. Register app at https://developer.dji.com/user/apps + - Package name: `sq.rogue.rosettadrone` +2. Add key to `app/src/main/AndroidManifest.xml`: +```xml + +``` + +#### Install APK +```bash +adb install -r app/build/outputs/apk/release/rosettadrone-release.apk +``` + +### Hardware Connection +``` +DJI Drone ← Wireless → DJI Controller ← USB → Android Device ← WiFi → DimOS Computer +``` + +1. Connect Android to DJI controller via USB +2. Start RosettaDrone app +3. Wait for "DJI Connected" status +4. Verify "MAVLink Active" shows in app + +## Architecture + +### Module Structure +``` +drone.py # Main orchestrator +├── connection_module.py # MAVLink communication & skills +├── camera_module.py # Video processing & depth estimation +├── tracking_module.py # Visual servoing & object tracking +├── mavlink_connection.py # Low-level MAVLink protocol +└── dji_video_stream.py # GStreamer video capture +``` + +### Communication Flow +``` +DJI Drone → RosettaDrone → MAVLink UDP → connection_module → LCM Topics + → Video UDP → dji_video_stream → tracking_module +``` + +### LCM Topics +- `/drone/odom` - Position and orientation +- `/drone/status` - Armed state, battery +- `/drone/video` - Camera frames +- `/drone/tracking/cmd_vel` - Tracking velocity commands +- `/drone/tracking/overlay` - Visualization with tracking box + +## Visual Servoing & Tracking + +### Object Tracking +```python +# Track specific object +result = drone.tracking.track_object("red flag", duration=60) + +# Track nearest/most prominent object +result = drone.tracking.track_object(None, duration=60) + +# Stop tracking +drone.tracking.stop_tracking() +``` + +### PID Tuning +Configure in `drone.py` initialization: +```python +# Indoor (gentle, precise) +x_pid_params=(0.001, 0.0, 0.0001, (-0.5, 0.5), None, 30) + +# Outdoor (aggressive, wind-resistant) +x_pid_params=(0.003, 0.0001, 0.0002, (-1.0, 1.0), None, 10) +``` + +Parameters: `(Kp, Ki, Kd, (min_output, max_output), integral_limit, deadband_pixels)` + +### Visual Servoing Flow +1. Qwen model detects object → bounding box +2. CSRT tracker initialized on bbox +3. PID controller computes velocity from pixel error +4. Velocity commands sent via LCM stream +5. Connection module converts to MAVLink commands + +## Available Skills + +### Movement & Control +- `move(vector, duration)` - Move with velocity vector +- `takeoff(altitude)` - Takeoff to altitude +- `land()` - Land at current position +- `arm()/disarm()` - Arm/disarm motors +- `fly_to(lat, lon, alt)` - Fly to GPS coordinates + +### Perception +- `observe()` - Get current camera frame +- `follow_object(description, duration)` - Follow object with servoing + +### Tracking Module +- `track_object(name, duration)` - Track and follow object +- `stop_tracking()` - Stop current tracking +- `get_status()` - Get tracking status + +## Testing + +### Unit Tests +```bash +pytest -s dimos/robot/drone/ +``` + +### Replay Mode (No Hardware) +```python +# Use recorded data for testing +drone = Drone(connection_string='replay') +drone.start() +# All operations work with recorded data +``` + +## Troubleshooting + +### No MAVLink Connection +- Check Android and computer are on same network +- Verify IP address in RosettaDrone matches computer +- Test with: `nc -lu 14550` (should see data) +- Check firewall: `sudo ufw allow 14550/udp` + +### No Video Stream +- Enable video in RosettaDrone settings +- Test with: `nc -lu 5600` (should see data) +- Verify GStreamer installed: `gst-launch-1.0 --version` + +### Tracking Issues +- Increase lighting for better detection +- Adjust PID gains for environment +- Check `max_lost_frames` in tracking module +- Monitor with Foxglove on `ws://localhost:8765` + +### Wrong Movement Direction +- Don't modify coordinate conversions +- Verify with: `pytest test_drone.py::test_ned_to_ros_coordinate_conversion` +- Check camera orientation assumptions + +## Advanced Features + +### Coordinate Systems +- **MAVLink/NED**: X=North, Y=East, Z=Down +- **ROS/DimOS**: X=Forward, Y=Left, Z=Up +- Automatic conversion handled internally + +### Depth Estimation +Camera module can generate depth maps using Metric3D: +```python +# Depth published to /drone/depth and /drone/pointcloud +# Requires GPU with 8GB+ VRAM +``` + +### Foxglove Visualization +Connect Foxglove Studio to `ws://localhost:8765` to see: +- Live video with tracking overlay +- 3D drone position +- Telemetry plots +- Transform tree + +## Network Ports +- **14550**: MAVLink UDP +- **5600**: Video stream UDP +- **8765**: Foxglove WebSocket +- **7667**: LCM messaging + +## Development + +### Adding New Skills +Add to `connection_module.py` with `@skill()` decorator: +```python +@skill() +def my_skill(self, param: float) -> str: + """Skill description for LLM.""" + # Implementation + return "Result" +``` + +### Modifying PID Control +Edit gains in `drone.py` `_deploy_tracking()`: +- Increase Kp for faster response +- Add Ki for steady-state error +- Increase Kd for damping +- Adjust limits for max velocity + +## Safety Notes +- Always test in simulator or with propellers removed first +- Set conservative PID gains initially +- Implement geofencing for outdoor flights +- Monitor battery voltage continuously +- Have manual override ready diff --git a/dimos/robot/drone/__init__.py b/dimos/robot/drone/__init__.py new file mode 100644 index 0000000000..5fb7c2fc7b --- /dev/null +++ b/dimos/robot/drone/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic drone module for MAVLink-based drones.""" + +from .camera_module import DroneCameraModule +from .connection_module import DroneConnectionModule +from .drone import Drone +from .mavlink_connection import MavlinkConnection + +__all__ = ["Drone", "DroneCameraModule", "DroneConnectionModule", "MavlinkConnection"] diff --git a/dimos/robot/drone/camera_module.py b/dimos/robot/drone/camera_module.py new file mode 100644 index 0000000000..c2ccef606b --- /dev/null +++ b/dimos/robot/drone/camera_module.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Camera module for drone with depth estimation.""" + +import threading +import time +from typing import Any + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header +from dimos.perception.common.utils import colorize_depth +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class DroneCameraModule(Module): + """ + Camera module for drone that processes RGB images to generate depth using Metric3D. + + Subscribes to: + - /video: RGB camera images from drone + + Publishes: + - /drone/color_image: RGB camera images + - /drone/depth_image: Depth images from Metric3D + - /drone/depth_colorized: Colorized depth + - /drone/camera_info: Camera calibration + - /drone/camera_pose: Camera pose from TF + """ + + # Inputs + video: In[Image] + + # Outputs + color_image: Out[Image] + depth_image: Out[Image] + depth_colorized: Out[Image] + camera_info: Out[CameraInfo] + camera_pose: Out[PoseStamped] + + def __init__( + self, + camera_intrinsics: list[float], + world_frame_id: str = "world", + camera_frame_id: str = "camera_link", + base_frame_id: str = "base_link", + gt_depth_scale: float = 2.0, + **kwargs: Any, + ) -> None: + """Initialize drone camera module. + + Args: + camera_intrinsics: [fx, fy, cx, cy] + camera_frame_id: TF frame for camera + base_frame_id: TF frame for drone base + gt_depth_scale: Depth scale factor + """ + super().__init__(**kwargs) + + if len(camera_intrinsics) != 4: + raise ValueError("Camera intrinsics must be [fx, fy, cx, cy]") + + self.camera_intrinsics = camera_intrinsics + self.camera_frame_id = camera_frame_id + self.base_frame_id = base_frame_id + self.world_frame_id = world_frame_id + self.gt_depth_scale = gt_depth_scale + + # Metric3D for depth + self.metric3d: Any = None # Lazy-loaded Metric3D model + + # Processing state + self._running = False + self._latest_frame: Image | None = None + self._processing_thread: threading.Thread | None = None + self._stop_processing = threading.Event() + + logger.info(f"DroneCameraModule initialized with intrinsics: {camera_intrinsics}") + + @rpc + def start(self) -> bool: + """Start the camera module.""" + if self._running: + logger.warning("Camera module already running") + return True + + # Start processing thread for depth (which will init Metric3D and handle video) + self._running = True + self._stop_processing.clear() + self._processing_thread = threading.Thread(target=self._processing_loop, daemon=True) + self._processing_thread.start() + + logger.info("Camera module started") + return True + + def _on_video_frame(self, frame: Image) -> None: + """Handle incoming video frame.""" + if not self._running: + return + + # Publish color image immediately + self.color_image.publish(frame) + + # Store for depth processing + self._latest_frame = frame + + def _processing_loop(self) -> None: + """Process depth estimation in background.""" + # Initialize Metric3D in the background thread + if self.metric3d is None: + try: + from dimos.models.depth.metric3d import Metric3D + + self.metric3d = Metric3D(camera_intrinsics=self.camera_intrinsics) + logger.info("Metric3D initialized") + except Exception as e: + logger.warning(f"Metric3D not available: {e}") + self.metric3d = None + + # Subscribe to video once connection is available + subscribed = False + while not subscribed and not self._stop_processing.is_set(): + try: + if self.video.connection is not None: + self.video.subscribe(self._on_video_frame) + subscribed = True + logger.info("Subscribed to video input") + else: + time.sleep(0.1) + except Exception as e: + logger.debug(f"Waiting for video connection: {e}") + time.sleep(0.1) + + logger.info("Depth processing loop started") + + _reported_error = False + + while not self._stop_processing.is_set(): + if self._latest_frame is not None and self.metric3d is not None: + try: + frame = self._latest_frame + self._latest_frame = None + + # Get numpy array from Image + img_array = frame.data + + # Generate depth + depth_array = self.metric3d.infer_depth(img_array) / self.gt_depth_scale + + # Create header + header = Header(self.camera_frame_id) + + # Publish depth + depth_msg = Image( + data=depth_array, + format=ImageFormat.DEPTH, + frame_id=header.frame_id, + ts=header.ts, + ) + self.depth_image.publish(depth_msg) + + # Publish colorized depth + depth_colorized_array = colorize_depth( + depth_array, max_depth=10.0, overlay_stats=True + ) + if depth_colorized_array is not None: + depth_colorized_msg = Image( + data=depth_colorized_array, + format=ImageFormat.RGB, + frame_id=header.frame_id, + ts=header.ts, + ) + self.depth_colorized.publish(depth_colorized_msg) + + # Publish camera info + self._publish_camera_info(header, img_array.shape) + + # Publish camera pose + self._publish_camera_pose(header) + + except Exception as e: + if not _reported_error: + _reported_error = True + logger.error(f"Error processing depth: {e}") + else: + time.sleep(0.01) + + logger.info("Depth processing loop stopped") + + def _publish_camera_info(self, header: Header, shape: tuple[int, ...]) -> None: + """Publish camera calibration info.""" + try: + fx, fy, cx, cy = self.camera_intrinsics + height, width = shape[:2] + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + msg = CameraInfo( + D_length=len(D), + header=header, + height=height, + width=width, + distortion_model="plumb_bob", + D=D, + K=K, + R=R, + P=P, + binning_x=0, + binning_y=0, + ) + + self.camera_info.publish(msg) + + except Exception as e: + logger.error(f"Error publishing camera info: {e}") + + def _publish_camera_pose(self, header: Header) -> None: + """Publish camera pose from TF.""" + try: + transform = self.tf.get( + parent_frame=self.world_frame_id, + child_frame=self.camera_frame_id, + time_point=header.ts, + time_tolerance=1.0, + ) + + if transform: + pose_msg = PoseStamped( + ts=header.ts, + frame_id=self.camera_frame_id, + position=transform.translation, + orientation=transform.rotation, + ) + self.camera_pose.publish(pose_msg) + + except Exception as e: + logger.error(f"Error publishing camera pose: {e}") + + @rpc + def stop(self) -> None: + """Stop the camera module.""" + if not self._running: + return + + self._running = False + self._stop_processing.set() + + # Wait for thread + if self._processing_thread and self._processing_thread.is_alive(): + self._processing_thread.join(timeout=2.0) + + # Cleanup Metric3D + if self.metric3d: + self.metric3d.cleanup() + + logger.info("Camera module stopped") diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py new file mode 100644 index 0000000000..51bbe59299 --- /dev/null +++ b/dimos/robot/drone/connection_module.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DimOS module wrapper for drone connection.""" + +from collections.abc import Generator +import json +import threading +import time +from typing import Any + +from dimos_lcm.std_msgs import String # type: ignore[import-untyped] +from reactivex.disposable import CompositeDisposable, Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.mapping.types import LatLon +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol.skill.skill import skill +from dimos.protocol.skill.type import Output +from dimos.robot.drone.dji_video_stream import DJIDroneVideoStream +from dimos.robot.drone.mavlink_connection import MavlinkConnection +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def _add_disposable(composite: CompositeDisposable, item: Disposable | Any) -> None: + if isinstance(item, Disposable): + composite.add(item) + elif callable(item): + composite.add(Disposable(item)) + + +class DroneConnectionModule(Module): + """Module that handles drone sensor data and movement commands.""" + + # Inputs + movecmd: In[Vector3] + movecmd_twist: In[Twist] # Twist commands from tracking/navigation + gps_goal: In[LatLon] + tracking_status: In[Any] + + # Outputs + odom: Out[PoseStamped] + gps_location: Out[LatLon] + status: Out[Any] # JSON status + telemetry: Out[Any] # Full telemetry JSON + video: Out[Image] + follow_object_cmd: Out[Any] + + # Parameters + connection_string: str + + # Internal state + _odom: PoseStamped | None = None + _status: dict[str, Any] = {} + _latest_video_frame: Image | None = None + _latest_telemetry: dict[str, Any] | None = None + _latest_status: dict[str, Any] | None = None + _latest_status_lock: threading.RLock + + def __init__( + self, + connection_string: str = "udp:0.0.0.0:14550", + video_port: int = 5600, + outdoor: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: + """Initialize drone connection module. + + Args: + connection_string: MAVLink connection string + video_port: UDP port for video stream + outdoor: Use GPS only mode (no velocity integration) + """ + self.connection_string = connection_string + self.video_port = video_port + self.outdoor = outdoor + self.connection: MavlinkConnection | None = None + self.video_stream: DJIDroneVideoStream | None = None + self._latest_video_frame = None + self._latest_telemetry = None + self._latest_status = None + self._latest_status_lock = threading.RLock() + self._running = False + self._telemetry_thread: threading.Thread | None = None + Module.__init__(self, *args, **kwargs) + + @rpc + def start(self) -> bool: + """Start the connection and subscribe to sensor streams.""" + # Check for replay mode + if self.connection_string == "replay": + from dimos.robot.drone.dji_video_stream import FakeDJIVideoStream + from dimos.robot.drone.mavlink_connection import FakeMavlinkConnection + + self.connection = FakeMavlinkConnection("replay") + self.video_stream = FakeDJIVideoStream(port=self.video_port) + else: + self.connection = MavlinkConnection(self.connection_string, outdoor=self.outdoor) + self.connection.connect() + + self.video_stream = DJIDroneVideoStream(port=self.video_port) + + if not self.connection.connected: + logger.error("Failed to connect to drone") + return False + + # Start video stream (already created above) + if self.video_stream.start(): + logger.info("Video stream started") + # Subscribe to video, store latest frame and publish it + _add_disposable( + self._disposables, + self.video_stream.get_stream().subscribe(self._store_and_publish_frame), + ) + # # TEMPORARY - DELETE AFTER RECORDING + # from dimos.utils.testing import TimedSensorStorage + # self._video_storage = TimedSensorStorage("drone/video") + # self._video_subscription = self._video_storage.save_stream(self.video_stream.get_stream()).subscribe() + # logger.info("Recording video to data/drone/video/") + else: + logger.warning("Video stream failed to start") + + # Subscribe to drone streams + _add_disposable( + self._disposables, self.connection.odom_stream().subscribe(self._publish_tf) + ) + _add_disposable( + self._disposables, self.connection.status_stream().subscribe(self._publish_status) + ) + _add_disposable( + self._disposables, self.connection.telemetry_stream().subscribe(self._publish_telemetry) + ) + + # Subscribe to movement commands + _add_disposable(self._disposables, self.movecmd.subscribe(self.move)) + + # Subscribe to Twist movement commands + if self.movecmd_twist.transport: + _add_disposable(self._disposables, self.movecmd_twist.subscribe(self._on_move_twist)) + + if self.gps_goal.transport: + _add_disposable(self._disposables, self.gps_goal.subscribe(self._on_gps_goal)) + + if self.tracking_status.transport: + _add_disposable( + self._disposables, self.tracking_status.subscribe(self._on_tracking_status) + ) + + # Start telemetry update thread + import threading + + self._running = True + self._telemetry_thread = threading.Thread(target=self._telemetry_loop, daemon=True) + self._telemetry_thread.start() + + logger.info("Drone connection module started") + return True + + def _store_and_publish_frame(self, frame: Image) -> None: + """Store the latest video frame and publish it.""" + self._latest_video_frame = frame + self.video.publish(frame) + + def _publish_tf(self, msg: PoseStamped) -> None: + """Publish odometry and TF transforms.""" + self._odom = msg + + # Publish odometry + self.odom.publish(msg) + + # Publish base_link transform + base_link = Transform( + translation=msg.position, + rotation=msg.orientation, + frame_id="world", + child_frame_id="base_link", + ts=msg.ts if hasattr(msg, "ts") else time.time(), + ) + self.tf.publish(base_link) + + # Publish camera_link transform (camera mounted on front of drone, no gimbal factored in yet) + camera_link = Transform( + translation=Vector3(0.1, 0.0, -0.05), # 10cm forward, 5cm down + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # No rotation relative to base + frame_id="base_link", + child_frame_id="camera_link", + ts=time.time(), + ) + self.tf.publish(camera_link) + + def _publish_status(self, status: dict[str, Any]) -> None: + """Publish drone status as JSON string.""" + self._status = status + + status_str = String(json.dumps(status)) + self.status.publish(status_str) + + def _publish_telemetry(self, telemetry: dict[str, Any]) -> None: + """Publish full telemetry as JSON string.""" + telemetry_str = String(json.dumps(telemetry)) + self.telemetry.publish(telemetry_str) + self._latest_telemetry = telemetry + + if "GLOBAL_POSITION_INT" in telemetry: + tel = telemetry["GLOBAL_POSITION_INT"] + self.gps_location.publish(LatLon(lat=tel["lat"], lon=tel["lon"])) + + def _telemetry_loop(self) -> None: + """Continuously update telemetry at 30Hz.""" + frame_count = 0 + while self._running: + try: + # Update telemetry from drone + if self.connection is not None: + self.connection.update_telemetry(timeout=0.01) + + # Publish default odometry if we don't have real data yet + if frame_count % 10 == 0: # Every ~3Hz + if self._odom is None: + # Publish default pose + default_pose = PoseStamped( + position=Vector3(0, 0, 0), + orientation=Quaternion(0, 0, 0, 1), + frame_id="world", + ts=time.time(), + ) + self._publish_tf(default_pose) + logger.debug("Publishing default odometry") + + frame_count += 1 + time.sleep(0.033) # ~30Hz + except Exception as e: + logger.debug(f"Telemetry update error: {e}") + time.sleep(0.1) + + @rpc + def get_odom(self) -> PoseStamped | None: + """Get current odometry. + + Returns: + Current pose or None + """ + return self._odom + + @rpc + def get_status(self) -> dict[str, Any]: + """Get current drone status. + + Returns: + Status dictionary + """ + return self._status.copy() + + @skill() + def move(self, vector: Vector3, duration: float = 0.0) -> None: + """Send movement command to drone. + + Args: + vector: Velocity vector [x, y, z] in m/s + duration: How long to move (0 = continuous) + """ + if self.connection: + # Convert dict/list to Vector3 + if isinstance(vector, dict): + vector = Vector3(vector.get("x", 0), vector.get("y", 0), vector.get("z", 0)) + elif isinstance(vector, (list, tuple)): + vector = Vector3( + vector[0] if len(vector) > 0 else 0, + vector[1] if len(vector) > 1 else 0, + vector[2] if len(vector) > 2 else 0, + ) + self.connection.move(vector, duration) + + @skill() + def takeoff(self, altitude: float = 3.0) -> bool: + """Takeoff to specified altitude. + + Args: + altitude: Target altitude in meters + + Returns: + True if takeoff initiated + """ + if self.connection: + return self.connection.takeoff(altitude) + return False + + @skill() + def land(self) -> bool: + """Land the drone. + + Returns: + True if land command sent + """ + if self.connection: + return self.connection.land() + return False + + @skill() + def arm(self) -> bool: + """Arm the drone. + + Returns: + True if armed successfully + """ + if self.connection: + return self.connection.arm() + return False + + @skill() + def disarm(self) -> bool: + """Disarm the drone. + + Returns: + True if disarmed successfully + """ + if self.connection: + return self.connection.disarm() + return False + + @skill() + def set_mode(self, mode: str) -> bool: + """Set flight mode. + + Args: + mode: Flight mode name + + Returns: + True if mode set successfully + """ + if self.connection: + return self.connection.set_mode(mode) + return False + + def move_twist(self, twist: Twist, duration: float = 0.0, lock_altitude: bool = True) -> bool: + """Move using ROS-style Twist commands. + + Args: + twist: Twist message with linear velocities + duration: How long to move (0 = single command) + lock_altitude: If True, ignore Z velocity + + Returns: + True if command sent successfully + """ + if self.connection: + return self.connection.move_twist(twist, duration, lock_altitude) + return False + + @skill() + def is_flying_to_target(self) -> bool: + """Check if drone is currently flying to a GPS target. + + Returns: + True if flying to target, False otherwise + """ + if self.connection and hasattr(self.connection, "is_flying_to_target"): + return self.connection.is_flying_to_target + return False + + @skill() + def fly_to(self, lat: float, lon: float, alt: float) -> str: + """Fly drone to GPS coordinates (blocking operation). + + Args: + lat: Latitude in degrees + lon: Longitude in degrees + alt: Altitude in meters (relative to home) + + Returns: + String message indicating success or failure reason + """ + if self.connection: + return self.connection.fly_to(lat, lon, alt) + return "Failed: No connection to drone" + + @skill() + def follow_object( + self, object_description: str, duration: float = 120.0 + ) -> Generator[str, None, None]: + """Follow an object with visual servoing. + + Example: + + follow_object(object_description="red car", duration=120) + + Args: + object_description (str): A short and clear description of the object. + duration (float, optional): How long to track for. Defaults to 120.0. + """ + msg = {"object_description": object_description, "duration": duration} + self.follow_object_cmd.publish(String(json.dumps(msg))) + + yield "Started trying to track. First, trying to find the object." + + start_time = time.time() + + started_tracking = False + + while time.time() - start_time < duration: + time.sleep(0.01) + with self._latest_status_lock: + if not self._latest_status: + continue + match self._latest_status.get("status"): + case "not_found" | "failed": + yield f"The '{object_description}' object has not been found. Stopped tracking." + break + case "tracking": + # Only return tracking once. + if not started_tracking: + started_tracking = True + yield f"The '{object_description}' object is now being followed." + case "lost": + yield f"The '{object_description}' object has been lost. Stopped tracking." + break + case "stopped": + yield f"Tracking '{object_description}' has stopped." + break + else: + yield f"Stopped tracking '{object_description}'" + + def _on_move_twist(self, msg: Twist) -> None: + """Handle Twist movement commands from tracking/navigation. + + Args: + msg: Twist message with linear and angular velocities + """ + if self.connection: + # Use move_twist to properly handle Twist messages + self.connection.move_twist(msg, duration=0, lock_altitude=True) + + def _on_gps_goal(self, cmd: LatLon) -> None: + if self._latest_telemetry is None or self.connection is None: + return + current_alt = self._latest_telemetry.get("GLOBAL_POSITION_INT", {}).get("relative_alt", 0) + self.connection.fly_to(cmd.lat, cmd.lon, current_alt) + + def _on_tracking_status(self, status: String) -> None: + with self._latest_status_lock: + self._latest_status = json.loads(status.data) + + @rpc + def stop(self) -> None: + """Stop the module.""" + # Stop the telemetry loop + self._running = False + + # Wait for telemetry thread to finish + if self._telemetry_thread and self._telemetry_thread.is_alive(): + self._telemetry_thread.join(timeout=2.0) + + # Stop video stream + if self.video_stream: + self.video_stream.stop() + + # Disconnect from drone + if self.connection: + self.connection.disconnect() + + logger.info("Drone connection module stopped") + + # Call parent stop to clean up Module infrastructure (event loop, LCM, disposables, etc.) + super().stop() + + @skill(output=Output.image) + def observe(self) -> Image | None: + """Returns the latest video frame from the drone camera. Use this skill for any visual world queries. + + This skill provides the current camera view for perception tasks. + Returns None if no frame has been captured yet. + """ + return self._latest_video_frame diff --git a/dimos/robot/drone/dji_video_stream.py b/dimos/robot/drone/dji_video_stream.py new file mode 100644 index 0000000000..de59770996 --- /dev/null +++ b/dimos/robot/drone/dji_video_stream.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Video streaming using GStreamer appsink for proper frame extraction.""" + +import functools +import subprocess +import threading +import time +from typing import Any + +import numpy as np +from reactivex import Observable, Subject + +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class DJIDroneVideoStream: + """Capture drone video using GStreamer appsink.""" + + def __init__(self, port: int = 5600, width: int = 640, height: int = 360) -> None: + self.port = port + self.width = width + self.height = height + self._video_subject: Subject[Image] = Subject() + self._process: subprocess.Popen[bytes] | None = None + self._stop_event = threading.Event() + + def start(self) -> bool: + """Start video capture using gst-launch with appsink.""" + try: + # Use appsink to get properly formatted frames + # The ! at the end tells appsink to emit data on stdout in a parseable format + cmd = [ + "gst-launch-1.0", + "-q", + "udpsrc", + f"port={self.port}", + "!", + "application/x-rtp,encoding-name=H264,payload=96", + "!", + "rtph264depay", + "!", + "h264parse", + "!", + "avdec_h264", + "!", + "videoscale", + "!", + f"video/x-raw,width={self.width},height={self.height}", + "!", + "videoconvert", + "!", + "video/x-raw,format=RGB", + "!", + "filesink", + "location=/dev/stdout", + "buffer-mode=2", # Unbuffered output + ] + + logger.info(f"Starting video capture on UDP port {self.port}") + logger.debug(f"Pipeline: {' '.join(cmd)}") + + self._process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0 + ) + + self._stop_event.clear() + + # Start capture thread + self._capture_thread = threading.Thread(target=self._capture_loop, daemon=True) + self._capture_thread.start() + + # Start error monitoring + self._error_thread = threading.Thread(target=self._error_monitor, daemon=True) + self._error_thread.start() + + logger.info("Video stream started") + return True + + except Exception as e: + logger.error(f"Failed to start video stream: {e}") + return False + + def _capture_loop(self) -> None: + """Read frames with fixed size.""" + channels = 3 + frame_size = self.width * self.height * channels + + logger.info( + f"Capturing frames: {self.width}x{self.height} RGB ({frame_size} bytes per frame)" + ) + + frame_count = 0 + total_bytes = 0 + + while not self._stop_event.is_set(): + try: + # Read exactly one frame worth of data + frame_data = b"" + bytes_needed = frame_size + + while bytes_needed > 0 and not self._stop_event.is_set(): + if self._process is None or self._process.stdout is None: + break + chunk = self._process.stdout.read(bytes_needed) + if not chunk: + logger.warning("No data from GStreamer") + time.sleep(0.1) + break + frame_data += chunk + bytes_needed -= len(chunk) + + if len(frame_data) == frame_size: + # We have a complete frame + total_bytes += frame_size + + # Convert to numpy array + frame = np.frombuffer(frame_data, dtype=np.uint8) + frame = frame.reshape((self.height, self.width, channels)) + + # Create Image message (RGB format - matches GStreamer pipeline output) + img_msg = Image.from_numpy(frame, format=ImageFormat.RGB) + + # Publish + self._video_subject.on_next(img_msg) + + frame_count += 1 + if frame_count == 1: + logger.debug(f"First frame captured! Shape: {frame.shape}") + elif frame_count % 100 == 0: + logger.debug( + f"Captured {frame_count} frames ({total_bytes / 1024 / 1024:.1f} MB)" + ) + + except Exception as e: + if not self._stop_event.is_set(): + logger.error(f"Error in capture loop: {e}") + time.sleep(0.1) + + def _error_monitor(self) -> None: + """Monitor GStreamer stderr.""" + while not self._stop_event.is_set() and self._process is not None: + if self._process.stderr is None: + break + line = self._process.stderr.readline() + if line: + msg = line.decode("utf-8").strip() + if "ERROR" in msg or "WARNING" in msg: + logger.warning(f"GStreamer: {msg}") + else: + logger.debug(f"GStreamer: {msg}") + + def stop(self) -> None: + """Stop video stream.""" + self._stop_event.set() + + if self._process: + self._process.terminate() + try: + self._process.wait(timeout=2) + except subprocess.TimeoutExpired: + self._process.kill() + self._process = None + + logger.info("Video stream stopped") + + def get_stream(self) -> Subject[Image]: + """Get the video stream observable.""" + return self._video_subject + + +class FakeDJIVideoStream(DJIDroneVideoStream): + """Replay video for testing.""" + + def __init__(self, port: int = 5600) -> None: + super().__init__(port) + from dimos.utils.data import get_data + + # Ensure data is available + get_data("drone") + + def start(self) -> bool: + """Start replay of recorded video.""" + self._stop_event.clear() + logger.info("Video replay started") + return True + + @functools.cache + def get_stream(self) -> Observable[Image]: # type: ignore[override] + """Get the replay stream directly.""" + from dimos.utils.testing import TimedSensorReplay + + logger.info("Creating video replay stream") + video_store: Any = TimedSensorReplay("drone/video") + stream: Observable[Image] = video_store.stream() + return stream + + def stop(self) -> None: + """Stop replay.""" + self._stop_event.set() + logger.info("Video replay stopped") diff --git a/dimos/robot/drone/drone.py b/dimos/robot/drone/drone.py new file mode 100644 index 0000000000..7816d6a9aa --- /dev/null +++ b/dimos/robot/drone/drone.py @@ -0,0 +1,501 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Main Drone robot class for DimOS.""" + +import functools +import logging +import os +import time +from typing import Any + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +from dimos_lcm.std_msgs import String # type: ignore[import-untyped] +from reactivex import Observable + +from dimos import core +from dimos.agents2.skills.google_maps_skill_container import GoogleMapsSkillContainer +from dimos.agents2.skills.osm import OsmSkill +from dimos.mapping.types import LatLon +from dimos.msgs.geometry_msgs import PoseStamped, Twist, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.robot.drone.camera_module import DroneCameraModule +from dimos.robot.drone.connection_module import DroneConnectionModule +from dimos.robot.drone.drone_tracking_module import DroneTrackingModule +from dimos.robot.foxglove_bridge import FoxgloveBridge + +# LCM not needed in orchestrator - modules handle communication +from dimos.robot.robot import Robot +from dimos.types.robot_capabilities import RobotCapability +from dimos.utils.logging_config import setup_logger +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule + +logger = setup_logger() + + +class Drone(Robot): + """Generic MAVLink-based drone with video and depth capabilities.""" + + def __init__( + self, + connection_string: str = "udp:0.0.0.0:14550", + video_port: int = 5600, + camera_intrinsics: list[float] | None = None, + output_dir: str | None = None, + outdoor: bool = False, + ) -> None: + """Initialize drone robot. + + Args: + connection_string: MAVLink connection string + video_port: UDP port for video stream + camera_intrinsics: Camera intrinsics [fx, fy, cx, cy] + output_dir: Directory for outputs + outdoor: Use GPS only mode (no velocity integration) + """ + super().__init__() + + self.connection_string = connection_string + self.video_port = video_port + self.output_dir = output_dir or os.path.join(os.getcwd(), "assets", "output") + self.outdoor = outdoor + + if camera_intrinsics is None: + # Assuming 1920x1080 with typical FOV + self.camera_intrinsics = [1000.0, 1000.0, 960.0, 540.0] + else: + self.camera_intrinsics = camera_intrinsics + + self.capabilities = [ + RobotCapability.LOCOMOTION, # Aerial locomotion + RobotCapability.VISION, + ] + + self.dimos: core.DimosCluster | None = None + self.connection: DroneConnectionModule | None = None + self.camera: DroneCameraModule | None = None + self.tracking: DroneTrackingModule | None = None + self.foxglove_bridge: FoxgloveBridge | None = None + self.websocket_vis: WebsocketVisModule | None = None + + self._setup_directories() + + def _setup_directories(self) -> None: + """Setup output directories.""" + os.makedirs(self.output_dir, exist_ok=True) + logger.info(f"Drone outputs will be saved to: {self.output_dir}") + + def start(self) -> None: + """Start the drone system with all modules.""" + logger.info("Starting Drone robot system...") + + # Start DimOS cluster + self.dimos = core.start(4) + + # Deploy modules + self._deploy_connection() + self._deploy_camera() + self._deploy_tracking() + self._deploy_visualization() + self._deploy_navigation() + + # Start modules + self._start_modules() + + logger.info("Drone system initialized and started") + logger.info("Foxglove visualization available at http://localhost:8765") + + def _deploy_connection(self) -> None: + """Deploy and configure connection module.""" + assert self.dimos is not None + logger.info("Deploying connection module...") + + self.connection = self.dimos.deploy( # type: ignore[attr-defined] + DroneConnectionModule, + # connection_string="replay", + connection_string=self.connection_string, + video_port=self.video_port, + outdoor=self.outdoor, + ) + + # Configure LCM transports + self.connection.odom.transport = core.LCMTransport("/drone/odom", PoseStamped) + self.connection.gps_location.transport = core.pLCMTransport("/gps_location") + self.connection.gps_goal.transport = core.pLCMTransport("/gps_goal") + self.connection.status.transport = core.LCMTransport("/drone/status", String) + self.connection.telemetry.transport = core.LCMTransport("/drone/telemetry", String) + self.connection.video.transport = core.LCMTransport("/drone/video", Image) + self.connection.follow_object_cmd.transport = core.LCMTransport( + "/drone/follow_object_cmd", String + ) + self.connection.movecmd.transport = core.LCMTransport("/drone/cmd_vel", Vector3) + self.connection.movecmd_twist.transport = core.LCMTransport( + "/drone/tracking/cmd_vel", Twist + ) + + logger.info("Connection module deployed") + + def _deploy_camera(self) -> None: + """Deploy and configure camera module.""" + assert self.dimos is not None + assert self.connection is not None + logger.info("Deploying camera module...") + + self.camera = self.dimos.deploy( # type: ignore[attr-defined] + DroneCameraModule, camera_intrinsics=self.camera_intrinsics + ) + + # Configure LCM transports + self.camera.color_image.transport = core.LCMTransport("/drone/color_image", Image) + self.camera.depth_image.transport = core.LCMTransport("/drone/depth_image", Image) + self.camera.depth_colorized.transport = core.LCMTransport("/drone/depth_colorized", Image) + self.camera.camera_info.transport = core.LCMTransport("/drone/camera_info", CameraInfo) + self.camera.camera_pose.transport = core.LCMTransport("/drone/camera_pose", PoseStamped) + + # Connect video from connection module to camera module + self.camera.video.connect(self.connection.video) + + logger.info("Camera module deployed") + + def _deploy_tracking(self) -> None: + """Deploy and configure tracking module.""" + assert self.dimos is not None + assert self.connection is not None + logger.info("Deploying tracking module...") + + self.tracking = self.dimos.deploy( # type: ignore[attr-defined] + DroneTrackingModule, + outdoor=self.outdoor, + ) + + self.tracking.tracking_overlay.transport = core.LCMTransport( + "/drone/tracking_overlay", Image + ) + self.tracking.tracking_status.transport = core.LCMTransport( + "/drone/tracking_status", String + ) + self.tracking.cmd_vel.transport = core.LCMTransport("/drone/tracking/cmd_vel", Twist) + + self.tracking.video_input.connect(self.connection.video) + self.tracking.follow_object_cmd.connect(self.connection.follow_object_cmd) + + self.connection.movecmd_twist.connect(self.tracking.cmd_vel) + self.connection.tracking_status.connect(self.tracking.tracking_status) + + logger.info("Tracking module deployed") + + def _deploy_visualization(self) -> None: + """Deploy and configure visualization modules.""" + assert self.dimos is not None + assert self.connection is not None + self.websocket_vis = self.dimos.deploy(WebsocketVisModule) # type: ignore[attr-defined] + # self.websocket_vis.click_goal.transport = core.LCMTransport("/goal_request", PoseStamped) + self.websocket_vis.gps_goal.transport = core.pLCMTransport("/gps_goal") + # self.websocket_vis.explore_cmd.transport = core.LCMTransport("/explore_cmd", Bool) + # self.websocket_vis.stop_explore_cmd.transport = core.LCMTransport("/stop_explore_cmd", Bool) + self.websocket_vis.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) + + self.websocket_vis.odom.connect(self.connection.odom) + self.websocket_vis.gps_location.connect(self.connection.gps_location) + # self.websocket_vis.path.connect(self.global_planner.path) + # self.websocket_vis.global_costmap.connect(self.mapper.global_costmap) + + self.foxglove_bridge = FoxgloveBridge() + + def _deploy_navigation(self) -> None: + assert self.websocket_vis is not None + assert self.connection is not None + # Connect In (subscriber) to Out (publisher) + self.connection.gps_goal.connect(self.websocket_vis.gps_goal) + + def _start_modules(self) -> None: + """Start all deployed modules.""" + assert self.connection is not None + assert self.camera is not None + assert self.tracking is not None + assert self.websocket_vis is not None + assert self.foxglove_bridge is not None + logger.info("Starting modules...") + + # Start connection first + result = self.connection.start() + if not result: + logger.warning("Connection module failed to start (no drone connected?)") + + # Start camera + result = self.camera.start() + if not result: + logger.warning("Camera module failed to start") + + result = self.tracking.start() + if result: + logger.info("Tracking module started successfully") + else: + logger.warning("Tracking module failed to start") + + self.websocket_vis.start() + + # Start Foxglove + self.foxglove_bridge.start() + + logger.info("All modules started") + + # Robot control methods + + def get_odom(self) -> PoseStamped | None: + """Get current odometry. + + Returns: + Current pose or None + """ + if self.connection is None: + return None + result: PoseStamped | None = self.connection.get_odom() + return result + + @functools.cached_property + def gps_position_stream(self) -> Observable[LatLon]: + assert self.connection is not None + return self.connection.gps_location.transport.pure_observable() + + def get_status(self) -> dict[str, Any]: + """Get drone status. + + Returns: + Status dictionary + """ + if self.connection is None: + return {} + result: dict[str, Any] = self.connection.get_status() + return result + + def move(self, vector: Vector3, duration: float = 0.0) -> None: + """Send movement command. + + Args: + vector: Velocity vector [x, y, z] in m/s + duration: How long to move (0 = continuous) + """ + if self.connection is None: + return + self.connection.move(vector, duration) + + def takeoff(self, altitude: float = 3.0) -> bool: + """Takeoff to altitude. + + Args: + altitude: Target altitude in meters + + Returns: + True if takeoff initiated + """ + if self.connection is None: + return False + result: bool = self.connection.takeoff(altitude) + return result + + def land(self) -> bool: + """Land the drone. + + Returns: + True if land command sent + """ + if self.connection is None: + return False + result: bool = self.connection.land() + return result + + def arm(self) -> bool: + """Arm the drone. + + Returns: + True if armed successfully + """ + if self.connection is None: + return False + result: bool = self.connection.arm() + return result + + def disarm(self) -> bool: + """Disarm the drone. + + Returns: + True if disarmed successfully + """ + if self.connection is None: + return False + result: bool = self.connection.disarm() + return result + + def set_mode(self, mode: str) -> bool: + """Set flight mode. + + Args: + mode: Mode name (STABILIZE, GUIDED, LAND, RTL, etc.) + + Returns: + True if mode set successfully + """ + if self.connection is None: + return False + result: bool = self.connection.set_mode(mode) + return result + + def fly_to(self, lat: float, lon: float, alt: float) -> str: + """Fly to GPS coordinates. + + Args: + lat: Latitude in degrees + lon: Longitude in degrees + alt: Altitude in meters (relative to home) + + Returns: + String message indicating success or failure + """ + if self.connection is None: + return "Failed: No connection" + result: str = self.connection.fly_to(lat, lon, alt) + return result + + def cleanup(self) -> None: + self.stop() + + def stop(self) -> None: + """Stop the drone system.""" + logger.info("Stopping drone system...") + + if self.connection: + self.connection.stop() + + if self.camera: + self.camera.stop() + + if self.foxglove_bridge: + self.foxglove_bridge.stop() + + if self.dimos: + self.dimos.close_all() # type: ignore[attr-defined] + + logger.info("Drone system stopped") + + +def main() -> None: + """Main entry point for drone system.""" + import argparse + + parser = argparse.ArgumentParser(description="DimOS Drone System") + parser.add_argument("--replay", action="store_true", help="Use recorded data for testing") + + parser.add_argument( + "--outdoor", + action="store_true", + help="Outdoor mode - use GPS only, no velocity integration", + ) + args = parser.parse_args() + + # Configure logging + setup_logger(level=logging.INFO) + + # Suppress verbose loggers + logging.getLogger("distributed").setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) + + if args.replay: + connection = "replay" + print("\n🔄 REPLAY MODE - Using drone replay data") + else: + connection = os.getenv("DRONE_CONNECTION", "udp:0.0.0.0:14550") + video_port = int(os.getenv("DRONE_VIDEO_PORT", "5600")) + + print(f""" +╔══════════════════════════════════════════╗ +║ DimOS Mavlink Drone Runner ║ +╠══════════════════════════════════════════╣ +║ MAVLink: {connection:<30} ║ +║ Video: UDP port {video_port:<22}║ +║ Foxglove: http://localhost:8765 ║ +╚══════════════════════════════════════════╝ + """) + + pubsub.lcm.autoconf() # type: ignore[attr-defined] + + drone = Drone(connection_string=connection, video_port=video_port, outdoor=args.outdoor) + + drone.start() + + print("\n✓ Drone system started successfully!") + print("\nLCM Topics:") + print(" • /drone/odom - Odometry (PoseStamped)") + print(" • /drone/status - Status (String/JSON)") + print(" • /drone/telemetry - Full telemetry (String/JSON)") + print(" • /drone/color_image - RGB Video (Image)") + print(" • /drone/depth_image - Depth estimation (Image)") + print(" • /drone/depth_colorized - Colorized depth (Image)") + print(" • /drone/camera_info - Camera calibration") + print(" • /drone/cmd_vel - Movement commands (Vector3)") + print(" • /drone/tracking_overlay - Object tracking visualization (Image)") + print(" • /drone/tracking_status - Tracking status (String/JSON)") + + from dimos.agents2 import Agent # type: ignore[attr-defined] + from dimos.agents2.cli.human import HumanInput + from dimos.agents2.spec import Model, Provider # type: ignore[attr-defined] + + assert drone.dimos is not None + human_input = drone.dimos.deploy(HumanInput) # type: ignore[attr-defined] + google_maps = drone.dimos.deploy(GoogleMapsSkillContainer) # type: ignore[attr-defined] + osm_skill = drone.dimos.deploy(OsmSkill) # type: ignore[attr-defined] + + google_maps.gps_location.transport = core.pLCMTransport("/gps_location") + osm_skill.gps_location.transport = core.pLCMTransport("/gps_location") + + agent = Agent( + system_prompt="""You are controlling a DJI drone with MAVLink interface. + You have access to drone control skills you are already flying so only run move_twist, set_mode, and fly_to. + When the user gives commands, use the appropriate skills to control the drone. + Always confirm actions and report results. Send fly_to commands only at above 200 meters altitude to be safe. + Here are some GPS locations to remember + 6th and Natoma intersection: 37.78019978319006, -122.40770815020853, + 454 Natoma (Office): 37.780967465525244, -122.40688342010769 + 5th and mission intersection: 37.782598539339695, -122.40649441875473 + 6th and mission intersection: 37.781007204789354, -122.40868447123661""", + model=Model.GPT_4O, # type: ignore[attr-defined] + provider=Provider.OPENAI, # type: ignore[attr-defined] + ) + + agent.register_skills(drone.connection) + agent.register_skills(human_input) + agent.register_skills(google_maps) + agent.register_skills(osm_skill) + agent.run_implicit_skill("human") + + agent.start() + agent.loop_thread() + + # Testing + # from dimos_lcm.geometry_msgs import Twist,Vector3 + # twist = Twist() + # twist.linear = Vector3(-0.5, 0.5, 0.5) + # drone.connection.move_twist(twist, duration=2.0, lock_altitude=True) + # time.sleep(10) + # drone.tracking.track_object("water bottle") + while True: + time.sleep(1) + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/drone/drone_tracking_module.py b/dimos/robot/drone/drone_tracking_module.py new file mode 100644 index 0000000000..8d1f4c6ac8 --- /dev/null +++ b/dimos/robot/drone/drone_tracking_module.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Drone tracking module with visual servoing for object following.""" + +import json +import threading +import time +from typing import Any + +import cv2 +from dimos_lcm.std_msgs import String # type: ignore[import-untyped] +import numpy as np + +from dimos.core import In, Module, Out, rpc +from dimos.models.qwen.video_query import get_bbox_from_qwen_frame +from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.robot.drone.drone_visual_servoing_controller import ( + DroneVisualServoingController, + PIDParams, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +INDOOR_PID_PARAMS: PIDParams = (0.001, 0.0, 0.0001, (-1.0, 1.0), None, 30) +OUTDOOR_PID_PARAMS: PIDParams = (0.05, 0.0, 0.0003, (-5.0, 5.0), None, 10) +INDOOR_MAX_VELOCITY = 1.0 # m/s safety cap for indoor mode + + +class DroneTrackingModule(Module): + """Module for drone object tracking with visual servoing control.""" + + # Inputs + video_input: In[Image] + follow_object_cmd: In[Any] + + # Outputs + tracking_overlay: Out[Image] # Visualization with bbox and crosshairs + tracking_status: Out[Any] # JSON status updates + cmd_vel: Out[Twist] # Velocity commands for drone control + + def __init__( + self, + outdoor: bool = False, + x_pid_params: PIDParams | None = None, + y_pid_params: PIDParams | None = None, + z_pid_params: PIDParams | None = None, + ) -> None: + """Initialize the drone tracking module. + + Args: + outdoor: If True, use aggressive outdoor PID params (5 m/s max). + If False (default), use conservative indoor params (1 m/s max). + x_pid_params: PID parameters for forward/backward control. + If None, uses preset based on outdoor flag. + y_pid_params: PID parameters for left/right strafe control. + If None, uses preset based on outdoor flag. + z_pid_params: Optional PID parameters for altitude control. + """ + super().__init__() + + default_params = OUTDOOR_PID_PARAMS if outdoor else INDOOR_PID_PARAMS + x_pid_params = x_pid_params if x_pid_params is not None else default_params + y_pid_params = y_pid_params if y_pid_params is not None else default_params + + self._outdoor = outdoor + self._max_velocity = None if outdoor else INDOOR_MAX_VELOCITY + + self.servoing_controller = DroneVisualServoingController( + x_pid_params=x_pid_params, y_pid_params=y_pid_params, z_pid_params=z_pid_params + ) + + # Tracking state + self._tracking_active = False + self._tracking_thread: threading.Thread | None = None + self._current_object: str | None = None + self._latest_frame: Image | None = None + self._frame_lock = threading.Lock() + + # Subscribe to video input when transport is set + # (will be done by connection module) + + def _on_new_frame(self, frame: Image) -> None: + """Handle new video frame.""" + with self._frame_lock: + self._latest_frame = frame + + def _on_follow_object_cmd(self, cmd: String) -> None: + msg = json.loads(cmd.data) + self.track_object(msg["object_description"], msg["duration"]) + + def _get_latest_frame(self) -> np.ndarray[Any, np.dtype[Any]] | None: + """Get the latest video frame as numpy array.""" + with self._frame_lock: + if self._latest_frame is None: + return None + # Convert Image to numpy array + data: np.ndarray[Any, np.dtype[Any]] = self._latest_frame.data + return data + + @rpc + def start(self) -> bool: + """Start the tracking module and subscribe to video input.""" + if self.video_input.transport: + self.video_input.subscribe(self._on_new_frame) + logger.info("DroneTrackingModule started - subscribed to video input") + else: + logger.warning("DroneTrackingModule: No video input transport configured") + + if self.follow_object_cmd.transport: + self.follow_object_cmd.subscribe(self._on_follow_object_cmd) + + return True + + @rpc + def stop(self) -> None: + self._stop_tracking() + super().stop() + + @rpc + def track_object(self, object_name: str | None = None, duration: float = 120.0) -> str: + """Track and follow an object using visual servoing. + + Args: + object_name: Name of object to track, or None for most prominent + duration: Maximum tracking duration in seconds + + Returns: + String status message + """ + if self._tracking_active: + return "Already tracking an object" + + # Get current frame + frame = self._get_latest_frame() + if frame is None: + return "Error: No video frame available" + + logger.info(f"Starting track_object for {object_name or 'any object'}") + + try: + # Detect object with Qwen + logger.info("Detecting object with Qwen...") + bbox = get_bbox_from_qwen_frame(frame, object_name) + + if bbox is None: + msg = f"No object detected{' for: ' + object_name if object_name else ''}" + logger.warning(msg) + self._publish_status({"status": "not_found", "object": self._current_object}) + return msg + + logger.info(f"Object detected at bbox: {bbox}") + + # Initialize CSRT tracker (use legacy for OpenCV 4) + try: + tracker = cv2.legacy.TrackerCSRT_create() # type: ignore[attr-defined] + except AttributeError: + tracker = cv2.TrackerCSRT_create() # type: ignore[attr-defined] + + # Convert bbox format from [x1, y1, x2, y2] to [x, y, w, h] + x1, y1, x2, y2 = bbox + x, y, w, h = x1, y1, x2 - x1, y2 - y1 + + # Initialize tracker + success = tracker.init(frame, (x, y, w, h)) + if not success: + self._publish_status({"status": "failed", "object": self._current_object}) + return "Failed to initialize tracker" + + self._current_object = object_name or "object" + self._tracking_active = True + + # Start tracking in thread (non-blocking - caller should poll get_status()) + self._tracking_thread = threading.Thread( + target=self._visual_servoing_loop, args=(tracker, duration), daemon=True + ) + self._tracking_thread.start() + + return f"Tracking started for {self._current_object}. Poll get_status() for updates." + + except Exception as e: + logger.error(f"Tracking error: {e}") + self._stop_tracking() + return f"Tracking failed: {e!s}" + + def _visual_servoing_loop(self, tracker: Any, duration: float) -> None: + """Main visual servoing control loop. + + Args: + tracker: OpenCV tracker instance + duration: Maximum duration in seconds + """ + start_time = time.time() + frame_count = 0 + lost_track_count = 0 + max_lost_frames = 100 + + logger.info("Starting visual servoing loop") + + try: + while self._tracking_active and (time.time() - start_time < duration): + # Get latest frame + frame = self._get_latest_frame() + if frame is None: + time.sleep(0.01) + continue + + frame_count += 1 + + # Update tracker + success, bbox = tracker.update(frame) + + if not success: + lost_track_count += 1 + logger.warning(f"Lost track (count: {lost_track_count})") + + if lost_track_count >= max_lost_frames: + logger.error("Lost track of object") + self._publish_status( + {"status": "lost", "object": self._current_object, "frame": frame_count} + ) + break + continue + else: + lost_track_count = 0 + + # Calculate object center + x, y, w, h = bbox + current_x = x + w / 2 + current_y = y + h / 2 + + # Get frame dimensions + frame_height, frame_width = frame.shape[:2] + center_x = frame_width / 2 + center_y = frame_height / 2 + + # Compute velocity commands + vx, vy, vz = self.servoing_controller.compute_velocity_control( + target_x=current_x, + target_y=current_y, + center_x=center_x, + center_y=center_y, + dt=0.033, # ~30Hz + lock_altitude=True, + ) + + # Clamp velocity for indoor safety + if self._max_velocity is not None: + vx = max(-self._max_velocity, min(self._max_velocity, vx)) + vy = max(-self._max_velocity, min(self._max_velocity, vy)) + + # Publish velocity command via LCM + if self.cmd_vel.transport: + twist = Twist() + twist.linear = Vector3(vx, vy, 0) + twist.angular = Vector3(0, 0, 0) # No rotation for now + self.cmd_vel.publish(twist) + + # Publish visualization if transport is set + if self.tracking_overlay.transport: + overlay = self._draw_tracking_overlay( + frame, (int(x), int(y), int(w), int(h)), (int(current_x), int(current_y)) + ) + overlay_msg = Image.from_numpy(overlay, format=ImageFormat.BGR) + self.tracking_overlay.publish(overlay_msg) + + # Publish status + self._publish_status( + { + "status": "tracking", + "object": self._current_object, + "bbox": [int(x), int(y), int(w), int(h)], + "center": [int(current_x), int(current_y)], + "error": [int(current_x - center_x), int(current_y - center_y)], + "velocity": [float(vx), float(vy), float(vz)], + "frame": frame_count, + } + ) + + # Control loop rate + time.sleep(0.033) # ~30Hz + + except Exception as e: + logger.error(f"Error in servoing loop: {e}") + finally: + # Stop movement by publishing zero velocity + if self.cmd_vel.transport: + stop_twist = Twist() + stop_twist.linear = Vector3(0, 0, 0) + stop_twist.angular = Vector3(0, 0, 0) + self.cmd_vel.publish(stop_twist) + self._tracking_active = False + logger.info(f"Visual servoing loop ended after {frame_count} frames") + + def _draw_tracking_overlay( + self, + frame: np.ndarray[Any, np.dtype[Any]], + bbox: tuple[int, int, int, int], + center: tuple[int, int], + ) -> np.ndarray[Any, np.dtype[Any]]: + """Draw tracking visualization overlay. + + Args: + frame: Current video frame + bbox: Bounding box (x, y, w, h) + center: Object center (x, y) + + Returns: + Frame with overlay drawn + """ + overlay = frame.copy() + x, y, w, h = bbox + + # Draw tracking box (green) + cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 255, 0), 2) + + # Draw object center (red crosshair) + cv2.drawMarker(overlay, center, (0, 0, 255), cv2.MARKER_CROSS, 20, 2) + + # Draw desired center (blue crosshair) + frame_h, frame_w = frame.shape[:2] + frame_center = (frame_w // 2, frame_h // 2) + cv2.drawMarker(overlay, frame_center, (255, 0, 0), cv2.MARKER_CROSS, 20, 2) + + # Draw line from object to desired center + cv2.line(overlay, center, frame_center, (255, 255, 0), 1) + + # Add status text + status_text = f"Tracking: {self._current_object}" + cv2.putText(overlay, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + + # Add error text + error_x = center[0] - frame_center[0] + error_y = center[1] - frame_center[1] + error_text = f"Error: ({error_x}, {error_y})" + cv2.putText( + overlay, error_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 1 + ) + + return overlay + + def _publish_status(self, status: dict[str, Any]) -> None: + """Publish tracking status as JSON. + + Args: + status: Status dictionary + """ + if self.tracking_status.transport: + status_msg = String(json.dumps(status)) + self.tracking_status.publish(status_msg) + + def _stop_tracking(self) -> None: + """Stop tracking and clean up.""" + self._tracking_active = False + if self._tracking_thread and self._tracking_thread.is_alive(): + self._tracking_thread.join(timeout=1) + + # Send stop command via LCM + if self.cmd_vel.transport: + stop_twist = Twist() + stop_twist.linear = Vector3(0, 0, 0) + stop_twist.angular = Vector3(0, 0, 0) + self.cmd_vel.publish(stop_twist) + + self._publish_status({"status": "stopped", "object": self._current_object}) + + self._current_object = None + logger.info("Tracking stopped") + + @rpc + def stop_tracking(self) -> str: + """Stop current tracking operation.""" + self._stop_tracking() + return "Tracking stopped" + + @rpc + def get_status(self) -> dict[str, Any]: + """Get current tracking status. + + Returns: + Status dictionary + """ + return { + "active": self._tracking_active, + "object": self._current_object, + "has_frame": self._latest_frame is not None, + } diff --git a/dimos/robot/drone/drone_visual_servoing_controller.py b/dimos/robot/drone/drone_visual_servoing_controller.py new file mode 100644 index 0000000000..68e39d5d7f --- /dev/null +++ b/dimos/robot/drone/drone_visual_servoing_controller.py @@ -0,0 +1,103 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal visual servoing controller for drone with downward-facing camera.""" + +from typing import TypeAlias + +from dimos.utils.simple_controller import PIDController + +# Type alias for PID parameters tuple +PIDParams: TypeAlias = tuple[float, float, float, tuple[float, float], float | None, int] + + +class DroneVisualServoingController: + """Minimal visual servoing for downward-facing drone camera using velocity-only control.""" + + def __init__( + self, + x_pid_params: PIDParams, + y_pid_params: PIDParams, + z_pid_params: PIDParams | None = None, + ) -> None: + """ + Initialize drone visual servoing controller. + + Args: + x_pid_params: (kp, ki, kd, output_limits, integral_limit, deadband) for forward/back + y_pid_params: (kp, ki, kd, output_limits, integral_limit, deadband) for left/right + z_pid_params: Optional params for altitude control + """ + self.x_pid = PIDController(*x_pid_params) # type: ignore[no-untyped-call] + self.y_pid = PIDController(*y_pid_params) # type: ignore[no-untyped-call] + self.z_pid = PIDController(*z_pid_params) if z_pid_params else None # type: ignore[no-untyped-call] + + def compute_velocity_control( + self, + target_x: float, + target_y: float, # Target position in image (pixels or normalized) + center_x: float = 0.0, + center_y: float = 0.0, # Desired position (usually image center) + target_z: float | None = None, + desired_z: float | None = None, # Optional altitude control + dt: float = 0.1, + lock_altitude: bool = True, + ) -> tuple[float, float, float]: + """ + Compute velocity commands to center target in camera view. + + For downward camera: + - Image X error -> Drone Y velocity (left/right strafe) + - Image Y error -> Drone X velocity (forward/backward) + + Args: + target_x: Target X position in image + target_y: Target Y position in image + center_x: Desired X position (default 0) + center_y: Desired Y position (default 0) + target_z: Current altitude (optional) + desired_z: Desired altitude (optional) + dt: Time step + lock_altitude: If True, vz will always be 0 + + Returns: + tuple: (vx, vy, vz) velocities in m/s + """ + # Compute errors (positive = target is to the right/below center) + error_x = target_x - center_x # Lateral error in image + error_y = target_y - center_y # Forward error in image + + # PID control (swap axes for downward camera) + # For downward camera: object below center (positive error_y) = object is behind drone + # Need to negate: positive error_y should give negative vx (move backward) + vy = self.y_pid.update(error_x, dt) # type: ignore[no-untyped-call] # Image X -> Drone Y (strafe) + vx = -self.x_pid.update(error_y, dt) # type: ignore[no-untyped-call] # Image Y -> Drone X (NEGATED) + + # Optional altitude control + vz = 0.0 + if not lock_altitude and self.z_pid and target_z is not None and desired_z is not None: + error_z = target_z - desired_z + vz = self.z_pid.update(error_z, dt) # type: ignore[no-untyped-call] + + return vx, vy, vz + + def reset(self) -> None: + """Reset all PID controllers.""" + self.x_pid.integral = 0.0 + self.x_pid.prev_error = 0.0 + self.y_pid.integral = 0.0 + self.y_pid.prev_error = 0.0 + if self.z_pid: + self.z_pid.integral = 0.0 + self.z_pid.prev_error = 0.0 diff --git a/dimos/robot/drone/mavlink_connection.py b/dimos/robot/drone/mavlink_connection.py new file mode 100644 index 0000000000..92bbcc0ec8 --- /dev/null +++ b/dimos/robot/drone/mavlink_connection.py @@ -0,0 +1,1109 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MAVLink-based drone connection for DimOS.""" + +import functools +import logging +import time +from typing import Any + +from pymavlink import mavutil # type: ignore[import-untyped] +from reactivex import Subject + +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Twist, Vector3 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.INFO) + + +class MavlinkConnection: + """MAVLink connection for drone control.""" + + def __init__( + self, + connection_string: str = "udp:0.0.0.0:14550", + outdoor: bool = False, + max_velocity: float = 5.0, + ) -> None: + """Initialize drone connection. + + Args: + connection_string: MAVLink connection string + outdoor: Use GPS only mode (no velocity integration) + max_velocity: Maximum velocity in m/s + """ + self.connection_string = connection_string + self.outdoor = outdoor + self.max_velocity = max_velocity + self.mavlink: Any = None # MAVLink connection object + self.connected = False + self.telemetry: dict[str, Any] = {} + + self._odom_subject: Subject[PoseStamped] = Subject() + self._status_subject: Subject[dict[str, Any]] = Subject() + self._telemetry_subject: Subject[dict[str, Any]] = Subject() + self._raw_mavlink_subject: Subject[dict[str, Any]] = Subject() + + # Velocity tracking for smoothing + self.prev_vx = 0.0 + self.prev_vy = 0.0 + self.prev_vz = 0.0 + + # Flag to prevent concurrent fly_to commands + self.flying_to_target = False + + def connect(self) -> bool: + """Connect to drone via MAVLink.""" + try: + logger.info(f"Connecting to {self.connection_string}") + self.mavlink = mavutil.mavlink_connection(self.connection_string) + self.mavlink.wait_heartbeat(timeout=30) + self.connected = True + logger.info(f"Connected to system {self.mavlink.target_system}") + + self.update_telemetry() + return True + except Exception as e: + logger.error(f"Connection failed: {e}") + return False + + def update_telemetry(self, timeout: float = 0.1) -> None: + """Update telemetry data from available messages.""" + if not self.connected: + return + + end_time = time.time() + timeout + while time.time() < end_time: + msg = self.mavlink.recv_match(blocking=False) + if not msg: + time.sleep(0.001) + continue + msg_type = msg.get_type() + msg_dict = msg.to_dict() + if msg_type == "HEARTBEAT": + bool(msg_dict.get("base_mode", 0) & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED) + # print("HEARTBEAT:", msg_dict, "ARMED:", armed) + # print("MESSAGE", msg_dict) + # print("MESSAGE TYPE", msg_type) + # self._raw_mavlink_subject.on_next(msg_dict) + + self.telemetry[msg_type] = msg_dict + + # Apply unit conversions for known fields + if msg_type == "GLOBAL_POSITION_INT": + msg_dict["lat"] = msg_dict.get("lat", 0) / 1e7 + msg_dict["lon"] = msg_dict.get("lon", 0) / 1e7 + msg_dict["alt"] = msg_dict.get("alt", 0) / 1000.0 + msg_dict["relative_alt"] = msg_dict.get("relative_alt", 0) / 1000.0 + msg_dict["vx"] = msg_dict.get("vx", 0) / 100.0 # cm/s to m/s + msg_dict["vy"] = msg_dict.get("vy", 0) / 100.0 + msg_dict["vz"] = msg_dict.get("vz", 0) / 100.0 + msg_dict["hdg"] = msg_dict.get("hdg", 0) / 100.0 # centidegrees to degrees + self._publish_odom() + + elif msg_type == "GPS_RAW_INT": + msg_dict["lat"] = msg_dict.get("lat", 0) / 1e7 + msg_dict["lon"] = msg_dict.get("lon", 0) / 1e7 + msg_dict["alt"] = msg_dict.get("alt", 0) / 1000.0 + msg_dict["vel"] = msg_dict.get("vel", 0) / 100.0 + msg_dict["cog"] = msg_dict.get("cog", 0) / 100.0 + + elif msg_type == "SYS_STATUS": + msg_dict["voltage_battery"] = msg_dict.get("voltage_battery", 0) / 1000.0 + msg_dict["current_battery"] = msg_dict.get("current_battery", 0) / 100.0 + self._publish_status() + + elif msg_type == "POWER_STATUS": + msg_dict["Vcc"] = msg_dict.get("Vcc", 0) / 1000.0 + msg_dict["Vservo"] = msg_dict.get("Vservo", 0) / 1000.0 + + elif msg_type == "HEARTBEAT": + # Extract armed status + base_mode = msg_dict.get("base_mode", 0) + msg_dict["armed"] = bool(base_mode & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED) + self._publish_status() + + elif msg_type == "ATTITUDE": + self._publish_odom() + + self.telemetry[msg_type] = msg_dict + + self._publish_telemetry() + + def _publish_odom(self) -> None: + """Publish odometry data - GPS for outdoor mode, velocity integration for indoor mode.""" + attitude = self.telemetry.get("ATTITUDE", {}) + roll = attitude.get("roll", 0) + pitch = attitude.get("pitch", 0) + yaw = attitude.get("yaw", 0) + + # Use heading from GLOBAL_POSITION_INT if no ATTITUDE data + if "roll" not in attitude and "GLOBAL_POSITION_INT" in self.telemetry: + import math + + heading = self.telemetry["GLOBAL_POSITION_INT"].get("hdg", 0) + yaw = math.radians(heading) + + if "roll" not in attitude and "GLOBAL_POSITION_INT" not in self.telemetry: + logger.debug("No attitude or position data available") + return + + # MAVLink --> ROS conversion + # MAVLink: positive pitch = nose up, positive yaw = clockwise + # ROS: positive pitch = nose down, positive yaw = counter-clockwise + quaternion = Quaternion.from_euler(Vector3(roll, -pitch, -yaw)) + + if not hasattr(self, "_position"): + self._position = {"x": 0.0, "y": 0.0, "z": 0.0} + self._last_update = time.time() + if self.outdoor: + self._gps_origin = None + + current_time = time.time() + dt = current_time - self._last_update + + # Get position data from GLOBAL_POSITION_INT + pos_data = self.telemetry.get("GLOBAL_POSITION_INT", {}) + + # Outdoor mode: Use GPS coordinates + if self.outdoor and pos_data: + lat = pos_data.get("lat", 0) # Already in degrees from update_telemetry + lon = pos_data.get("lon", 0) # Already in degrees from update_telemetry + + if lat != 0 and lon != 0: # Valid GPS fix + if self._gps_origin is None: + self._gps_origin = {"lat": lat, "lon": lon} + logger.debug(f"GPS origin set: lat={lat:.7f}, lon={lon:.7f}") + + # Convert GPS to local X/Y coordinates + import math + + R = 6371000 # Earth radius in meters + dlat = math.radians(lat - self._gps_origin["lat"]) + dlon = math.radians(lon - self._gps_origin["lon"]) + + # X = North, Y = West (ROS convention) + self._position["x"] = dlat * R + self._position["y"] = -dlon * R * math.cos(math.radians(self._gps_origin["lat"])) + + # Indoor mode: Use velocity integration (ORIGINAL CODE - UNCHANGED) + elif pos_data and dt > 0: + vx = pos_data.get("vx", 0) # North velocity in m/s (already converted) + vy = pos_data.get("vy", 0) # East velocity in m/s (already converted) + + # +vx is North, +vy is East in NED mavlink frame + # ROS/Foxglove: X=forward(North), Y=left(West), Z=up + self._position["x"] += vx * dt # North → X (forward) + self._position["y"] += -vy * dt # East → -Y (right in ROS, Y points left/West) + + # Altitude handling (same for both modes) + if "ALTITUDE" in self.telemetry: + self._position["z"] = self.telemetry["ALTITUDE"].get("altitude_relative", 0) + elif pos_data: + self._position["z"] = pos_data.get( + "relative_alt", 0 + ) # Already in m from update_telemetry + + self._last_update = current_time + + # Debug logging + mode = "GPS" if self.outdoor else "VELOCITY" + logger.debug( + f"[{mode}] Position: x={self._position['x']:.2f}m, y={self._position['y']:.2f}m, z={self._position['z']:.2f}m" + ) + + pose = PoseStamped( + position=Vector3(self._position["x"], self._position["y"], self._position["z"]), + orientation=quaternion, + frame_id="world", + ts=current_time, + ) + + self._odom_subject.on_next(pose) + + def _publish_status(self) -> None: + """Publish drone status with key telemetry.""" + heartbeat = self.telemetry.get("HEARTBEAT", {}) + sys_status = self.telemetry.get("SYS_STATUS", {}) + gps_raw = self.telemetry.get("GPS_RAW_INT", {}) + global_pos = self.telemetry.get("GLOBAL_POSITION_INT", {}) + altitude = self.telemetry.get("ALTITUDE", {}) + + status = { + "armed": heartbeat.get("armed", False), + "mode": heartbeat.get("custom_mode", -1), + "battery_voltage": sys_status.get("voltage_battery", 0), + "battery_current": sys_status.get("current_battery", 0), + "battery_remaining": sys_status.get("battery_remaining", 0), + "satellites": gps_raw.get("satellites_visible", 0), + "altitude": altitude.get("altitude_relative", global_pos.get("relative_alt", 0)), + "heading": global_pos.get("hdg", 0), + "vx": global_pos.get("vx", 0), + "vy": global_pos.get("vy", 0), + "vz": global_pos.get("vz", 0), + "lat": global_pos.get("lat", 0), + "lon": global_pos.get("lon", 0), + "ts": time.time(), + } + self._status_subject.on_next(status) + + def _publish_telemetry(self) -> None: + """Publish full telemetry data.""" + telemetry_with_ts = self.telemetry.copy() + telemetry_with_ts["timestamp"] = time.time() + self._telemetry_subject.on_next(telemetry_with_ts) + + def move(self, velocity: Vector3, duration: float = 0.0) -> bool: + """Send movement command to drone. + + Args: + velocity: Velocity vector [x, y, z] in m/s + duration: How long to move (0 = continuous) + + Returns: + True if command sent successfully + """ + if not self.connected: + return False + + # MAVLink body frame velocities + forward = velocity.y # Forward/backward + right = velocity.x # Left/right + down = velocity.z # Up/down (negative for DOWN, positive for UP) + + logger.debug(f"Moving: forward={forward}, right={right}, down={down}") + + if duration > 0: + # Send velocity for duration + end_time = time.time() + duration + while time.time() < end_time: + self.mavlink.mav.set_position_target_local_ned_send( + 0, # time_boot_ms + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_FRAME_BODY_NED, + 0b0000111111000111, # type_mask (only velocities) + 0, + 0, + 0, # positions + forward, + right, + down, # velocities + 0, + 0, + 0, # accelerations + 0, + 0, # yaw, yaw_rate + ) + time.sleep(0.1) + self.stop() + else: + # Single velocity command + self.mavlink.mav.set_position_target_local_ned_send( + 0, + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_FRAME_BODY_NED, + 0b0000111111000111, + 0, + 0, + 0, + forward, + right, + down, + 0, + 0, + 0, + 0, + 0, + ) + + return True + + def move_twist(self, twist: Twist, duration: float = 0.0, lock_altitude: bool = True) -> bool: + """Move using ROS-style Twist commands. + + Args: + twist: Twist message with linear velocities (angular.z ignored for now) + duration: How long to move (0 = single command) + lock_altitude: If True, ignore Z velocity and maintain current altitude + + Returns: + True if command sent successfully + """ + if not self.connected: + return False + + # Extract velocities + forward = twist.linear.x # m/s forward (body frame) + right = twist.linear.y # m/s right (body frame) + down = 0.0 if lock_altitude else -twist.linear.z # Lock altitude by default + + if duration > 0: + # Send velocity for duration + end_time = time.time() + duration + while time.time() < end_time: + self.mavlink.mav.set_position_target_local_ned_send( + 0, # time_boot_ms + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_FRAME_BODY_NED, # Body frame for strafing + 0b0000111111000111, # type_mask - velocities only, no rotation + 0, + 0, + 0, # positions (ignored) + forward, + right, + down, # velocities in m/s + 0, + 0, + 0, # accelerations (ignored) + 0, + 0, # yaw, yaw_rate (ignored) + ) + time.sleep(0.05) # 20Hz + # Send stop command + self.stop() + else: + # Send single command for continuous movement + self.mavlink.mav.set_position_target_local_ned_send( + 0, # time_boot_ms + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_FRAME_BODY_NED, # Body frame for strafing + 0b0000111111000111, # type_mask - velocities only, no rotation + 0, + 0, + 0, # positions (ignored) + forward, + right, + down, # velocities in m/s + 0, + 0, + 0, # accelerations (ignored) + 0, + 0, # yaw, yaw_rate (ignored) + ) + + return True + + def stop(self) -> bool: + """Stop all movement.""" + if not self.connected: + return False + + self.mavlink.mav.set_position_target_local_ned_send( + 0, + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_FRAME_BODY_NED, + 0b0000111111000111, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ) + return True + + def rotate_to(self, target_heading_deg: float, timeout: float = 60.0) -> bool: + """Rotate drone to face a specific heading. + + Args: + target_heading_deg: Target heading in degrees (0-360, 0=North, 90=East) + timeout: Maximum time to spend rotating in seconds + + Returns: + True if rotation completed successfully + """ + if not self.connected: + return False + + logger.info(f"Rotating to heading {target_heading_deg:.1f}°") + + import math + import time + + start_time = time.time() + loop_count = 0 + + while time.time() - start_time < timeout: + loop_count += 1 + + # Don't call update_telemetry - let background thread handle it + # Just read the current telemetry which should be continuously updated + + if "GLOBAL_POSITION_INT" not in self.telemetry: + logger.warning("No GLOBAL_POSITION_INT in telemetry dict") + time.sleep(0.1) + continue + + # Debug: Log what's in telemetry + gps_telem = self.telemetry["GLOBAL_POSITION_INT"] + + # Get current heading - check if already converted or still in centidegrees + raw_hdg = gps_telem.get("hdg", 0) + + # Debug logging to figure out the issue + if loop_count % 5 == 0: # Log every 5th iteration + logger.info(f"DEBUG TELEMETRY: raw hdg={raw_hdg}, type={type(raw_hdg)}") + logger.info(f"DEBUG TELEMETRY keys: {list(gps_telem.keys())[:5]}") # First 5 keys + + # Check if hdg is already converted (should be < 360 if in degrees, > 360 if in centidegrees) + if raw_hdg > 360: + logger.info(f"HDG appears to be in centidegrees: {raw_hdg}") + current_heading_deg = raw_hdg / 100.0 + else: + logger.info(f"HDG appears to be in degrees already: {raw_hdg}") + current_heading_deg = raw_hdg + else: + # Normal conversion + if raw_hdg > 360: + current_heading_deg = raw_hdg / 100.0 + else: + current_heading_deg = raw_hdg + + # Normalize to 0-360 + if current_heading_deg > 360: + current_heading_deg = current_heading_deg % 360 + + # Calculate heading error (shortest angular distance) + heading_error = target_heading_deg - current_heading_deg + if heading_error > 180: + heading_error -= 360 + elif heading_error < -180: + heading_error += 360 + + logger.info( + f"ROTATION: current={current_heading_deg:.1f}° → target={target_heading_deg:.1f}° (error={heading_error:.1f}°)" + ) + + # Check if we're close enough + if abs(heading_error) < 10: # Complete within 10 degrees + logger.info( + f"ROTATION COMPLETE: current={current_heading_deg:.1f}° ≈ target={target_heading_deg:.1f}° (within {abs(heading_error):.1f}°)" + ) + # Don't stop - let fly_to immediately transition to forward movement + return True + + # Calculate yaw rate with minimum speed to avoid slow approach + yaw_rate = heading_error * 0.3 # Higher gain for faster rotation + # Ensure minimum rotation speed of 15 deg/s to avoid crawling near target + if abs(yaw_rate) < 15.0: + yaw_rate = 15.0 if heading_error > 0 else -15.0 + yaw_rate = max(-60.0, min(60.0, yaw_rate)) # Cap at 60 deg/s max + yaw_rate_rad = math.radians(yaw_rate) + + logger.info( + f"ROTATING: yaw_rate={yaw_rate:.1f} deg/s to go from {current_heading_deg:.1f}° → {target_heading_deg:.1f}°" + ) + + # Send rotation command + self.mavlink.mav.set_position_target_local_ned_send( + 0, # time_boot_ms + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_FRAME_BODY_NED, # Body frame for rotation + 0b0000011111111111, # type_mask - ignore everything except yaw_rate + 0, + 0, + 0, # positions (ignored) + 0, + 0, + 0, # velocities (ignored) + 0, + 0, + 0, # accelerations (ignored) + 0, # yaw (ignored) + yaw_rate_rad, # yaw_rate in rad/s + ) + + time.sleep(0.1) # 10Hz control loop + + logger.warning("Rotation timeout") + self.stop() + return False + + def arm(self) -> bool: + """Arm the drone.""" + if not self.connected: + return False + + logger.info("Arming motors...") + self.update_telemetry() + + self.mavlink.mav.command_long_send( + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_CMD_COMPONENT_ARM_DISARM, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + ) + + # Wait for ACK + ack = self.mavlink.recv_match(type="COMMAND_ACK", blocking=True, timeout=5) + if ack and ack.command == mavutil.mavlink.MAV_CMD_COMPONENT_ARM_DISARM: + if ack.result == mavutil.mavlink.MAV_RESULT_ACCEPTED: + logger.info("Arm command accepted") + + # Verify armed status + for _i in range(10): + msg = self.mavlink.recv_match(type="HEARTBEAT", blocking=True, timeout=1) + if msg: + armed = msg.base_mode & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED + if armed: + logger.info("Motors ARMED successfully!") + return True + time.sleep(0.5) + else: + logger.error(f"Arm failed with result: {ack.result}") + + return False + + def disarm(self) -> bool: + """Disarm the drone.""" + if not self.connected: + return False + + logger.info("Disarming motors...") + + self.mavlink.mav.command_long_send( + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_CMD_COMPONENT_ARM_DISARM, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ) + + time.sleep(1) + return True + + def takeoff(self, altitude: float = 3.0) -> bool: + """Takeoff to specified altitude.""" + if not self.connected: + return False + + logger.info(f"Taking off to {altitude}m...") + + # Set GUIDED mode + if not self.set_mode("GUIDED"): + logger.error("Failed to set GUIDED mode for takeoff") + return False + + # Send takeoff command + self.mavlink.mav.command_long_send( + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_CMD_NAV_TAKEOFF, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + altitude, + ) + + logger.info(f"Takeoff command sent for {altitude}m altitude") + return True + + def land(self) -> bool: + """Land the drone at current position.""" + if not self.connected: + return False + + logger.info("Landing...") + + # Send initial land command + self.mavlink.mav.command_long_send( + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_CMD_NAV_LAND, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ) + + # Wait for disarm with confirmations + disarm_count = 0 + for _ in range(120): # 60 seconds max (120 * 0.5s) + # Keep sending land command + self.mavlink.mav.command_long_send( + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_CMD_NAV_LAND, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ) + + # Check armed status + msg = self.mavlink.recv_match(type="HEARTBEAT", blocking=True, timeout=0.5) + if msg: + msg_dict = msg.to_dict() + armed = bool( + msg_dict.get("base_mode", 0) & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED + ) + logger.debug(f"HEARTBEAT: {msg_dict} ARMED: {armed}") + + disarm_count = 0 if armed else disarm_count + 1 + + if disarm_count >= 5: # 2.5 seconds of continuous disarm + logger.info("Drone landed and disarmed") + return True + + time.sleep(0.5) + + logger.warning("Land timeout") + return self.set_mode("LAND") + + def fly_to(self, lat: float, lon: float, alt: float) -> str: + """Fly to GPS coordinates - sends commands continuously until reaching target. + + Args: + lat: Latitude in degrees + lon: Longitude in degrees + alt: Altitude in meters (relative to home) + + Returns: + String message indicating success or failure reason + """ + if not self.connected: + return "Failed: Not connected to drone" + + # Check if already flying to a target + if self.flying_to_target: + logger.warning( + "Already flying to target, ignoring new fly_to command. Wait until completed to send new fly_to command." + ) + return ( + "Already flying to target - wait for completion before sending new fly_to command" + ) + + self.flying_to_target = True + + # Ensure GUIDED mode for GPS navigation + if not self.set_mode("GUIDED"): + logger.error("Failed to set GUIDED mode for GPS navigation") + self.flying_to_target = False + return "Failed: Could not set GUIDED mode for GPS navigation" + + logger.info(f"Flying to GPS: lat={lat:.7f}, lon={lon:.7f}, alt={alt:.1f}m") + + # Reset velocity tracking for smooth start + self.prev_vx = 0.0 + self.prev_vy = 0.0 + self.prev_vz = 0.0 + + # Send velocity commands towards GPS target at 10Hz + acceptance_radius = 30.0 # meters + max_duration = 120 # seconds max flight time + start_time = time.time() + max_speed = self.max_velocity # m/s max speed + + import math + + loop_count = 0 + + try: + while time.time() - start_time < max_duration: + loop_start = time.time() + + # Don't update telemetry here - let background thread handle it + # self.update_telemetry(timeout=0.01) # Removed to prevent message conflicts + + # Check current position from telemetry + if "GLOBAL_POSITION_INT" in self.telemetry: + t1 = time.time() + + # Telemetry already has converted values (see update_telemetry lines 104-107) + current_lat = self.telemetry["GLOBAL_POSITION_INT"].get( + "lat", 0 + ) # Already in degrees + current_lon = self.telemetry["GLOBAL_POSITION_INT"].get( + "lon", 0 + ) # Already in degrees + current_alt = self.telemetry["GLOBAL_POSITION_INT"].get( + "relative_alt", 0 + ) # Already in meters + + t2 = time.time() + + logger.info( + f"DEBUG: Current GPS: lat={current_lat:.10f}, lon={current_lon:.10f}, alt={current_alt:.2f}m" + ) + logger.info( + f"DEBUG: Target GPS: lat={lat:.10f}, lon={lon:.10f}, alt={alt:.2f}m" + ) + + # Calculate vector to target with high precision + dlat = lat - current_lat + dlon = lon - current_lon + dalt = alt - current_alt + + logger.info( + f"DEBUG: Delta: dlat={dlat:.10f}, dlon={dlon:.10f}, dalt={dalt:.2f}m" + ) + + t3 = time.time() + + # Convert lat/lon difference to meters with high precision + # Using more accurate calculation + lat_rad = current_lat * math.pi / 180.0 + meters_per_degree_lat = ( + 111132.92 - 559.82 * math.cos(2 * lat_rad) + 1.175 * math.cos(4 * lat_rad) + ) + meters_per_degree_lon = 111412.84 * math.cos(lat_rad) - 93.5 * math.cos( + 3 * lat_rad + ) + + x_dist = dlat * meters_per_degree_lat # North distance in meters + y_dist = dlon * meters_per_degree_lon # East distance in meters + + logger.info( + f"DEBUG: Distance in meters: North={x_dist:.2f}m, East={y_dist:.2f}m, Up={dalt:.2f}m" + ) + + # Calculate total distance + distance = math.sqrt(x_dist**2 + y_dist**2 + dalt**2) + logger.info(f"DEBUG: Total distance to target: {distance:.2f}m") + + t4 = time.time() + + if distance < acceptance_radius: + logger.info(f"Reached GPS target (within {distance:.1f}m)") + self.stop() + # Return to manual control + self.set_mode("STABILIZE") + logger.info("Returned to STABILIZE mode for manual control") + self.flying_to_target = False + return f"Success: Reached target location (lat={lat:.7f}, lon={lon:.7f}, alt={alt:.1f}m)" + + # Only send velocity commands if we're far enough + if distance > 0.1: + # On first loop, rotate to face the target + if loop_count == 0: + # Calculate bearing to target + bearing_rad = math.atan2( + y_dist, x_dist + ) # East, North -> angle from North + target_heading_deg = math.degrees(bearing_rad) + if target_heading_deg < 0: + target_heading_deg += 360 + + logger.info( + f"Rotating to face target at heading {target_heading_deg:.1f}°" + ) + self.rotate_to(target_heading_deg, timeout=45.0) + logger.info("Rotation complete, starting movement") + + # Now just move towards target (no rotation) + t5 = time.time() + + # Calculate movement speed - maintain max speed until 20m from target + if distance > 20: + speed = max_speed # Full speed when far from target + else: + # Ramp down speed from 20m to target + speed = max( + 0.5, distance / 4.0 + ) # At 20m: 5m/s, at 10m: 2.5m/s, at 2m: 0.5m/s + + # Calculate target velocities + target_vx = (x_dist / distance) * speed # North velocity + target_vy = (y_dist / distance) * speed # East velocity + target_vz = (dalt / distance) * speed # Up velocity (positive = up) + + # Direct velocity assignment (no acceleration limiting) + vx = target_vx + vy = target_vy + vz = target_vz + + # Store for next iteration + self.prev_vx = vx + self.prev_vy = vy + self.prev_vz = vz + + logger.info( + f"MOVING: vx={vx:.3f} vy={vy:.3f} vz={vz:.3f} m/s, distance={distance:.1f}m" + ) + + # Send velocity command in LOCAL_NED frame + self.mavlink.mav.set_position_target_local_ned_send( + 0, # time_boot_ms + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_FRAME_LOCAL_NED, # Local NED for movement + 0b0000111111000111, # type_mask - use velocities only + 0, + 0, + 0, # positions (not used) + vx, + vy, + vz, # velocities in m/s + 0, + 0, + 0, # accelerations (not used) + 0, # yaw (not used) + 0, # yaw_rate (not used) + ) + + # Log if stuck + if loop_count > 20 and loop_count % 10 == 0: + logger.warning( + f"STUCK? Been sending commands for {loop_count} iterations but distance still {distance:.1f}m" + ) + + t6 = time.time() + + # Log timing every 10 loops + loop_count += 1 + if loop_count % 10 == 0: + logger.info( + f"TIMING: telemetry_read={t2 - t1:.4f}s, delta_calc={t3 - t2:.4f}s, " + f"distance_calc={t4 - t3:.4f}s, velocity_calc={t5 - t4:.4f}s, " + f"mavlink_send={t6 - t5:.4f}s, total_loop={t6 - loop_start:.4f}s" + ) + else: + logger.info("DEBUG: Too close to send velocity commands") + + else: + logger.warning("DEBUG: No GLOBAL_POSITION_INT in telemetry!") + + time.sleep(0.1) # Send at 10Hz + + except Exception as e: + logger.error(f"Error during fly_to: {e}") + self.flying_to_target = False # Clear flag immediately + raise # Re-raise the exception so caller sees the error + finally: + # Always clear the flag when exiting + if self.flying_to_target: + logger.info("Stopped sending GPS velocity commands (timeout)") + self.flying_to_target = False + self.set_mode("BRAKE") + time.sleep(0.5) + # Return to manual control + self.set_mode("STABILIZE") + logger.info("Returned to STABILIZE mode for manual control") + + return "Failed: Timeout - did not reach target within 120 seconds" + + def set_mode(self, mode: str) -> bool: + """Set flight mode.""" + if not self.connected: + return False + + mode_mapping = { + "STABILIZE": 0, + "GUIDED": 4, + "LOITER": 5, + "RTL": 6, + "LAND": 9, + "POSHOLD": 16, + "BRAKE": 17, + } + + if mode not in mode_mapping: + logger.error(f"Unknown mode: {mode}") + return False + + mode_id = mode_mapping[mode] + logger.info(f"Setting mode to {mode}") + + self.update_telemetry() + + self.mavlink.mav.command_long_send( + self.mavlink.target_system, + self.mavlink.target_component, + mavutil.mavlink.MAV_CMD_DO_SET_MODE, + 0, + mavutil.mavlink.MAV_MODE_FLAG_CUSTOM_MODE_ENABLED, + mode_id, + 0, + 0, + 0, + 0, + 0, + ) + + ack = self.mavlink.recv_match(type="COMMAND_ACK", blocking=True, timeout=3) + if ack and ack.result == mavutil.mavlink.MAV_RESULT_ACCEPTED: + logger.info(f"Mode changed to {mode}") + self.telemetry["mode"] = mode_id + return True + + return False + + @functools.cache + def odom_stream(self) -> Subject[PoseStamped]: + """Get odometry stream.""" + return self._odom_subject + + @functools.cache + def status_stream(self) -> Subject[dict[str, Any]]: + """Get status stream.""" + return self._status_subject + + @functools.cache + def telemetry_stream(self) -> Subject[dict[str, Any]]: + """Get full telemetry stream.""" + return self._telemetry_subject + + def get_telemetry(self) -> dict[str, Any]: + """Get current telemetry.""" + # Update telemetry multiple times to ensure we get data + for _ in range(5): + self.update_telemetry(timeout=0.2) + return self.telemetry.copy() + + def disconnect(self) -> None: + """Disconnect from drone.""" + if self.mavlink: + self.mavlink.close() + self.connected = False + logger.info("Disconnected") + + @property + def is_flying_to_target(self) -> bool: + """Check if drone is currently flying to a GPS target.""" + return self.flying_to_target + + def get_video_stream(self, fps: int = 30) -> None: + """Get video stream (to be implemented with GStreamer).""" + # Will be implemented in camera module + return None + + +class FakeMavlinkConnection(MavlinkConnection): + """Replay MAVLink for testing.""" + + def __init__(self, connection_string: str) -> None: + # Call parent init (which no longer calls connect()) + super().__init__(connection_string) + + # Create fake mavlink object + class FakeMavlink: + def __init__(self) -> None: + from dimos.utils.data import get_data + from dimos.utils.testing import TimedSensorReplay + + get_data("drone") + + self.replay: Any = TimedSensorReplay("drone/mavlink") + self.messages: list[dict[str, Any]] = [] + # The stream() method returns an Observable that emits messages with timing + self.replay.stream().subscribe(self.messages.append) + + # Properties that get accessed + self.target_system = 1 + self.target_component = 1 + self.mav = self # self.mavlink.mav is used in many places + + def recv_match( + self, blocking: bool = False, type: Any = None, timeout: Any = None + ) -> Any: + """Return next replay message as fake message object.""" + if not self.messages: + return None + + msg_dict = self.messages.pop(0) + + # Create message object with ALL attributes that might be accessed + class FakeMsg: + def __init__(self, d: dict[str, Any]) -> None: + self._dict = d + # Set any direct attributes that get accessed + self.base_mode = d.get("base_mode", 0) + self.command = d.get("command", 0) + self.result = d.get("result", 0) + + def get_type(self) -> Any: + return self._dict.get("mavpackettype", "") + + def to_dict(self) -> dict[str, Any]: + return self._dict + + # Filter by type if requested + if type and msg_dict.get("type") != type: + return None + + return FakeMsg(msg_dict) + + def wait_heartbeat(self, timeout: int = 30) -> None: + """Fake heartbeat received.""" + pass + + def close(self) -> None: + """Fake close.""" + pass + + # Command methods that get called but don't need to do anything in replay + def command_long_send(self, *args: Any, **kwargs: Any) -> None: + pass + + def set_position_target_local_ned_send(self, *args: Any, **kwargs: Any) -> None: + pass + + def set_position_target_global_int_send(self, *args: Any, **kwargs: Any) -> None: + pass + + # Set up fake mavlink + self.mavlink = FakeMavlink() + self.connected = True + + # Initialize position tracking (parent __init__ doesn't do this since connect wasn't called) + self._position = {"x": 0.0, "y": 0.0, "z": 0.0} + self._last_update = time.time() + + def takeoff(self, altitude: float = 3.0) -> bool: + """Fake takeoff - return immediately without blocking.""" + logger.info(f"[FAKE] Taking off to {altitude}m...") + return True + + def land(self) -> bool: + """Fake land - return immediately without blocking.""" + logger.info("[FAKE] Landing...") + return True diff --git a/dimos/robot/drone/test_drone.py b/dimos/robot/drone/test_drone.py new file mode 100644 index 0000000000..385aef3e0c --- /dev/null +++ b/dimos/robot/drone/test_drone.py @@ -0,0 +1,1038 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Core unit tests for drone module.""" + +import json +import os +import time +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.robot.drone.connection_module import DroneConnectionModule +from dimos.robot.drone.dji_video_stream import FakeDJIVideoStream +from dimos.robot.drone.drone import Drone +from dimos.robot.drone.mavlink_connection import FakeMavlinkConnection, MavlinkConnection + + +class TestMavlinkProcessing(unittest.TestCase): + """Test MAVLink message processing and coordinate conversions.""" + + def test_mavlink_message_processing(self) -> None: + """Test that MAVLink messages trigger correct odom/tf publishing.""" + conn = MavlinkConnection("udp:0.0.0.0:14550") + + # Mock the mavlink connection + conn.mavlink = MagicMock() + conn.connected = True + + # Track what gets published + published_odom = [] + conn._odom_subject.on_next = lambda x: published_odom.append(x) + + # Create ATTITUDE message and process it + attitude_msg = MagicMock() + attitude_msg.get_type.return_value = "ATTITUDE" + attitude_msg.to_dict.return_value = { + "mavpackettype": "ATTITUDE", + "roll": 0.1, + "pitch": 0.2, # Positive pitch = nose up in MAVLink + "yaw": 0.3, # Positive yaw = clockwise in MAVLink + } + + # Mock recv_match to return our message once then None + def recv_side_effect(*args, **kwargs): + if not hasattr(recv_side_effect, "called"): + recv_side_effect.called = True + return attitude_msg + return None + + conn.mavlink.recv_match = MagicMock(side_effect=recv_side_effect) + + # Process the message + conn.update_telemetry(timeout=0.01) + + # Check telemetry was updated + self.assertEqual(conn.telemetry["ATTITUDE"]["roll"], 0.1) + self.assertEqual(conn.telemetry["ATTITUDE"]["pitch"], 0.2) + self.assertEqual(conn.telemetry["ATTITUDE"]["yaw"], 0.3) + + # Check odom was published with correct coordinate conversion + self.assertEqual(len(published_odom), 1) + pose = published_odom[0] + + # Verify NED to ROS conversion happened + # ROS uses different conventions: positive pitch = nose down, positive yaw = counter-clockwise + # So we expect sign flips in the quaternion conversion + self.assertIsNotNone(pose.orientation) + + def test_position_integration(self) -> None: + """Test velocity integration for indoor flight positioning.""" + conn = MavlinkConnection("udp:0.0.0.0:14550") + conn.mavlink = MagicMock() + conn.connected = True + + # Initialize position tracking + conn._position = {"x": 0.0, "y": 0.0, "z": 0.0} + conn._last_update = time.time() + + # Create GLOBAL_POSITION_INT with velocities + pos_msg = MagicMock() + pos_msg.get_type.return_value = "GLOBAL_POSITION_INT" + pos_msg.to_dict.return_value = { + "mavpackettype": "GLOBAL_POSITION_INT", + "lat": 0, + "lon": 0, + "alt": 0, + "relative_alt": 1000, # 1m in mm + "vx": 100, # 1 m/s North in cm/s + "vy": 200, # 2 m/s East in cm/s + "vz": 0, + "hdg": 0, + } + + def recv_side_effect(*args, **kwargs): + if not hasattr(recv_side_effect, "called"): + recv_side_effect.called = True + return pos_msg + return None + + conn.mavlink.recv_match = MagicMock(side_effect=recv_side_effect) + + # Process with known dt + old_time = conn._last_update + conn.update_telemetry(timeout=0.01) + dt = conn._last_update - old_time + + # Check position was integrated from velocities + # vx=1m/s North → +X in ROS + # vy=2m/s East → -Y in ROS (Y points West) + expected_x = 1.0 * dt # North velocity + expected_y = -2.0 * dt # East velocity (negated for ROS) + + self.assertAlmostEqual(conn._position["x"], expected_x, places=2) + self.assertAlmostEqual(conn._position["y"], expected_y, places=2) + + def test_ned_to_ros_coordinate_conversion(self) -> None: + """Test NED to ROS coordinate system conversion for all axes.""" + conn = MavlinkConnection("udp:0.0.0.0:14550") + conn.mavlink = MagicMock() + conn.connected = True + + # Initialize position + conn._position = {"x": 0.0, "y": 0.0, "z": 0.0} + conn._last_update = time.time() + + # Test with velocities in all directions + # NED: North-East-Down + # ROS: X(forward/North), Y(left/West), Z(up) + pos_msg = MagicMock() + pos_msg.get_type.return_value = "GLOBAL_POSITION_INT" + pos_msg.to_dict.return_value = { + "mavpackettype": "GLOBAL_POSITION_INT", + "lat": 0, + "lon": 0, + "alt": 5000, # 5m altitude in mm + "relative_alt": 5000, + "vx": 300, # 3 m/s North (NED) + "vy": 400, # 4 m/s East (NED) + "vz": -100, # 1 m/s Up (negative in NED for up) + "hdg": 0, + } + + def recv_side_effect(*args, **kwargs): + if not hasattr(recv_side_effect, "called"): + recv_side_effect.called = True + return pos_msg + return None + + conn.mavlink.recv_match = MagicMock(side_effect=recv_side_effect) + + # Process message + old_time = conn._last_update + conn.update_telemetry(timeout=0.01) + dt = conn._last_update - old_time + + # Verify coordinate conversion: + # NED North (vx=3) → ROS +X + # NED East (vy=4) → ROS -Y (ROS Y points West/left) + # NED Down (vz=-1, up) → ROS +Z (ROS Z points up) + + # Position should integrate with converted velocities + self.assertGreater(conn._position["x"], 0) # North → positive X + self.assertLess(conn._position["y"], 0) # East → negative Y + self.assertEqual(conn._position["z"], 5.0) # Altitude from relative_alt (5000mm = 5m) + + # Check X,Y velocity integration (Z is set from altitude, not integrated) + self.assertAlmostEqual(conn._position["x"], 3.0 * dt, places=2) + self.assertAlmostEqual(conn._position["y"], -4.0 * dt, places=2) + + +class TestReplayMode(unittest.TestCase): + """Test replay mode functionality.""" + + def test_fake_mavlink_connection(self) -> None: + """Test FakeMavlinkConnection replays messages correctly.""" + with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + # Mock the replay stream + MagicMock() + mock_messages = [ + {"mavpackettype": "ATTITUDE", "roll": 0.1, "pitch": 0.2, "yaw": 0.3}, + {"mavpackettype": "HEARTBEAT", "type": 2, "base_mode": 193}, + ] + + # Make stream emit our messages + mock_replay.return_value.stream.return_value.subscribe = lambda callback: [ + callback(msg) for msg in mock_messages + ] + + conn = FakeMavlinkConnection("replay") + + # Check messages are available + msg1 = conn.mavlink.recv_match() + self.assertIsNotNone(msg1) + self.assertEqual(msg1.get_type(), "ATTITUDE") + + msg2 = conn.mavlink.recv_match() + self.assertIsNotNone(msg2) + self.assertEqual(msg2.get_type(), "HEARTBEAT") + + def test_fake_video_stream_no_throttling(self) -> None: + """Test FakeDJIVideoStream returns replay stream directly.""" + with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + mock_stream = MagicMock() + mock_replay.return_value.stream.return_value = mock_stream + + stream = FakeDJIVideoStream(port=5600) + result_stream = stream.get_stream() + + # Verify stream is returned directly without throttling + self.assertEqual(result_stream, mock_stream) + + def test_connection_module_replay_mode(self) -> None: + """Test connection module uses Fake classes in replay mode.""" + with patch("dimos.robot.drone.mavlink_connection.FakeMavlinkConnection") as mock_fake_conn: + with patch("dimos.robot.drone.dji_video_stream.FakeDJIVideoStream") as mock_fake_video: + # Mock the fake connection + mock_conn_instance = MagicMock() + mock_conn_instance.connected = True + mock_conn_instance.odom_stream.return_value.subscribe = MagicMock( + return_value=lambda: None + ) + mock_conn_instance.status_stream.return_value.subscribe = MagicMock( + return_value=lambda: None + ) + mock_conn_instance.telemetry_stream.return_value.subscribe = MagicMock( + return_value=lambda: None + ) + mock_conn_instance.disconnect = MagicMock() + mock_fake_conn.return_value = mock_conn_instance + + # Mock the fake video + mock_video_instance = MagicMock() + mock_video_instance.start.return_value = True + mock_video_instance.get_stream.return_value.subscribe = MagicMock( + return_value=lambda: None + ) + mock_video_instance.stop = MagicMock() + mock_fake_video.return_value = mock_video_instance + + # Create module with replay connection string + module = DroneConnectionModule(connection_string="replay") + module.video = MagicMock() + module.movecmd = MagicMock() + module.movecmd.subscribe = MagicMock(return_value=lambda: None) + module.tf = MagicMock() + + try: + # Start should use Fake classes + result = module.start() + + self.assertTrue(result) + mock_fake_conn.assert_called_once_with("replay") + mock_fake_video.assert_called_once() + finally: + # Always clean up + module.stop() + + def test_connection_module_replay_with_messages(self) -> None: + """Test connection module in replay mode receives and processes messages.""" + + os.environ["DRONE_CONNECTION"] = "replay" + + with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + # Set up MAVLink replay stream + mavlink_messages = [ + {"mavpackettype": "HEARTBEAT", "type": 2, "base_mode": 193}, + {"mavpackettype": "ATTITUDE", "roll": 0.1, "pitch": 0.2, "yaw": 0.3}, + { + "mavpackettype": "GLOBAL_POSITION_INT", + "lat": 377810501, + "lon": -1224069671, + "alt": 0, + "relative_alt": 1000, + "vx": 100, + "vy": 0, + "vz": 0, + "hdg": 0, + }, + ] + + # Set up video replay stream + video_frames = [ + np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8), + np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8), + ] + + def create_mavlink_stream(): + stream = MagicMock() + + def subscribe(callback) -> None: + print("\n[TEST] MAVLink replay stream subscribed") + for msg in mavlink_messages: + print(f"[TEST] Replaying MAVLink: {msg['mavpackettype']}") + callback(msg) + + stream.subscribe = subscribe + return stream + + def create_video_stream(): + stream = MagicMock() + + def subscribe(callback) -> None: + print("[TEST] Video replay stream subscribed") + for i, frame in enumerate(video_frames): + print( + f"[TEST] Replaying video frame {i + 1}/{len(video_frames)}, shape: {frame.shape}" + ) + callback(frame) + + stream.subscribe = subscribe + return stream + + # Configure mock replay to return appropriate streams + def replay_side_effect(store_name: str): + print(f"[TEST] TimedSensorReplay created for: {store_name}") + mock = MagicMock() + if "mavlink" in store_name: + mock.stream.return_value = create_mavlink_stream() + elif "video" in store_name: + mock.stream.return_value = create_video_stream() + return mock + + mock_replay.side_effect = replay_side_effect + + # Create and start connection module + module = DroneConnectionModule(connection_string="replay") + + # Mock publishers to track what gets published + published_odom = [] + published_video = [] + published_status = [] + + module.odom = MagicMock( + publish=lambda x: ( + published_odom.append(x), + print( + f"[TEST] Published odom: position=({x.position.x:.2f}, {x.position.y:.2f}, {x.position.z:.2f})" + ), + ) + ) + module.video = MagicMock( + publish=lambda x: ( + published_video.append(x), + print( + f"[TEST] Published video frame with shape: {x.data.shape if hasattr(x, 'data') else 'unknown'}" + ), + ) + ) + module.status = MagicMock( + publish=lambda x: ( + published_status.append(x), + print( + f"[TEST] Published status: {x.data[:50]}..." + if hasattr(x, "data") + else "[TEST] Published status" + ), + ) + ) + module.telemetry = MagicMock() + module.tf = MagicMock() + module.movecmd = MagicMock() + + try: + print("\n[TEST] Starting connection module in replay mode...") + result = module.start() + + # Give time for messages to process + import time + + time.sleep(0.1) + + print(f"\n[TEST] Module started: {result}") + print(f"[TEST] Total odom messages published: {len(published_odom)}") + print(f"[TEST] Total video frames published: {len(published_video)}") + print(f"[TEST] Total status messages published: {len(published_status)}") + + # Verify module started and is processing messages + self.assertTrue(result) + self.assertIsNotNone(module.connection) + self.assertIsNotNone(module.video_stream) + + # Should have published some messages + self.assertGreater( + len(published_odom) + len(published_video) + len(published_status), + 0, + "No messages were published in replay mode", + ) + finally: + # Clean up + module.stop() + + +class TestDroneFullIntegration(unittest.TestCase): + """Full integration test of Drone class with replay mode.""" + + def setUp(self) -> None: + """Set up test environment.""" + # Mock the DimOS core module + self.mock_dimos = MagicMock() + self.mock_dimos.deploy.return_value = MagicMock() + + # Mock pubsub.lcm.autoconf + self.pubsub_patch = patch("dimos.protocol.pubsub.lcm.autoconf") + self.pubsub_patch.start() + + # Mock FoxgloveBridge + self.foxglove_patch = patch("dimos.robot.drone.drone.FoxgloveBridge") + self.mock_foxglove = self.foxglove_patch.start() + + def tearDown(self) -> None: + """Clean up patches.""" + self.pubsub_patch.stop() + self.foxglove_patch.stop() + + @patch("dimos.robot.drone.drone.core.start") + @patch("dimos.utils.testing.TimedSensorReplay") + def test_full_system_with_replay(self, mock_replay, mock_core_start) -> None: + """Test full drone system initialization and operation with replay mode.""" + # Set up mock replay data + mavlink_messages = [ + {"mavpackettype": "HEARTBEAT", "type": 2, "base_mode": 193, "armed": True}, + {"mavpackettype": "ATTITUDE", "roll": 0.1, "pitch": 0.2, "yaw": 0.3}, + { + "mavpackettype": "GLOBAL_POSITION_INT", + "lat": 377810501, + "lon": -1224069671, + "alt": 5000, + "relative_alt": 5000, + "vx": 100, # 1 m/s North + "vy": 200, # 2 m/s East + "vz": -50, # 0.5 m/s Up + "hdg": 9000, # 90 degrees + }, + { + "mavpackettype": "BATTERY_STATUS", + "voltages": [3800, 3800, 3800, 3800], + "battery_remaining": 75, + }, + ] + + video_frames = [ + Image( + data=np.random.randint(0, 255, (360, 640, 3), dtype=np.uint8), + format=ImageFormat.BGR, + ) + ] + + def replay_side_effect(store_name: str): + mock = MagicMock() + if "mavlink" in store_name: + # Create stream that emits MAVLink messages + stream = MagicMock() + stream.subscribe = lambda callback: [callback(msg) for msg in mavlink_messages] + mock.stream.return_value = stream + elif "video" in store_name: + # Create stream that emits video frames + stream = MagicMock() + stream.subscribe = lambda callback: [callback(frame) for frame in video_frames] + mock.stream.return_value = stream + return mock + + mock_replay.side_effect = replay_side_effect + + # Mock DimOS core + mock_core_start.return_value = self.mock_dimos + + # Create drone in replay mode + drone = Drone(connection_string="replay", video_port=5600) + + # Mock the deployed modules + mock_connection = MagicMock() + mock_camera = MagicMock() + + # Set up return values for module methods + mock_connection.start.return_value = True + mock_connection.get_odom.return_value = PoseStamped( + position=Vector3(1.0, 2.0, 3.0), orientation=Quaternion(0, 0, 0, 1), frame_id="world" + ) + mock_connection.get_status.return_value = { + "armed": True, + "battery_voltage": 15.2, + "battery_remaining": 75, + "altitude": 5.0, + } + + mock_camera.start.return_value = True + + # Configure deploy to return our mocked modules + def deploy_side_effect(module_class, **kwargs): + if "DroneConnectionModule" in str(module_class): + return mock_connection + elif "DroneCameraModule" in str(module_class): + return mock_camera + return MagicMock() + + self.mock_dimos.deploy.side_effect = deploy_side_effect + + # Start the drone system + drone.start() + + # Verify modules were deployed + self.assertEqual(self.mock_dimos.deploy.call_count, 4) + + # Test get_odom + odom = drone.get_odom() + self.assertIsNotNone(odom) + self.assertEqual(odom.position.x, 1.0) + self.assertEqual(odom.position.y, 2.0) + self.assertEqual(odom.position.z, 3.0) + + # Test get_status + status = drone.get_status() + self.assertIsNotNone(status) + self.assertTrue(status["armed"]) + self.assertEqual(status["battery_remaining"], 75) + + # Test movement command + drone.move(Vector3(1.0, 0.0, 0.5), duration=2.0) + mock_connection.move.assert_called_once_with(Vector3(1.0, 0.0, 0.5), 2.0) + + # Test control commands + drone.arm() + mock_connection.arm.assert_called_once() + + drone.takeoff(altitude=10.0) + mock_connection.takeoff.assert_called_once_with(10.0) + + drone.land() + mock_connection.land.assert_called_once() + + drone.disarm() + mock_connection.disarm.assert_called_once() + + # Test mode setting + drone.set_mode("GUIDED") + mock_connection.set_mode.assert_called_once_with("GUIDED") + + # Clean up + drone.stop() + + # Verify cleanup was called + mock_connection.stop.assert_called_once() + mock_camera.stop.assert_called_once() + self.mock_dimos.close_all.assert_called_once() + + +class TestDroneControlCommands(unittest.TestCase): + """Test drone control commands with FakeMavlinkConnection.""" + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: + """Test arm and disarm commands work with fake connection.""" + # Set up mock replay + mock_stream = MagicMock() + mock_stream.subscribe = lambda callback: None + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + + # Test arm + result = conn.arm() + self.assertIsInstance(result, bool) # Should return bool without crashing + + # Test disarm + result = conn.disarm() + self.assertIsInstance(result, bool) # Should return bool without crashing + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: + """Test takeoff and land commands with fake connection.""" + mock_stream = MagicMock() + mock_stream.subscribe = lambda callback: None + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + + # Test takeoff + result = conn.takeoff(altitude=15.0) + # In fake mode, should accept but may return False if no ACK simulation + self.assertIsNotNone(result) + + # Test land + result = conn.land() + self.assertIsNotNone(result) + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_set_mode_command(self, mock_get_data, mock_replay) -> None: + """Test flight mode setting with fake connection.""" + mock_stream = MagicMock() + mock_stream.subscribe = lambda callback: None + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + + # Test various flight modes + modes = ["STABILIZE", "GUIDED", "LAND", "RTL", "LOITER"] + for mode in modes: + result = conn.set_mode(mode) + # Should return True or False but not crash + self.assertIsInstance(result, bool) + + +class TestDronePerception(unittest.TestCase): + """Test drone perception capabilities.""" + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_video_stream_replay(self, mock_get_data, mock_replay) -> None: + """Test video stream works with replay data.""" + # Set up video frames - create a test pattern instead of random noise + import cv2 + + # Create a test pattern image with some structure + test_frame = np.zeros((360, 640, 3), dtype=np.uint8) + # Add some colored rectangles to make it visually obvious + cv2.rectangle(test_frame, (50, 50), (200, 150), (255, 0, 0), -1) # Blue + cv2.rectangle(test_frame, (250, 50), (400, 150), (0, 255, 0), -1) # Green + cv2.rectangle(test_frame, (450, 50), (600, 150), (0, 0, 255), -1) # Red + cv2.putText( + test_frame, + "DRONE TEST FRAME", + (150, 250), + cv2.FONT_HERSHEY_SIMPLEX, + 1.5, + (255, 255, 255), + 2, + ) + + video_frames = [test_frame, test_frame.copy()] + + # Mock replay stream + mock_stream = MagicMock() + received_frames = [] + + def subscribe_side_effect(callback) -> None: + for frame in video_frames: + img = Image(data=frame, format=ImageFormat.BGR) + callback(img) + received_frames.append(img) + + mock_stream.subscribe = subscribe_side_effect + mock_replay.return_value.stream.return_value = mock_stream + + # Create fake video stream + video_stream = FakeDJIVideoStream(port=5600) + stream = video_stream.get_stream() + + # Subscribe to stream + captured_frames = [] + stream.subscribe(captured_frames.append) + + # Verify frames were captured + self.assertEqual(len(received_frames), 2) + for i, frame in enumerate(received_frames): + self.assertIsInstance(frame, Image) + self.assertEqual(frame.data.shape, (360, 640, 3)) + + # Save first frame to file for visual inspection + if i == 0: + import os + + output_path = "/tmp/drone_test_frame.png" + cv2.imwrite(output_path, frame.data) + print(f"\n[TEST] Saved test frame to {output_path} for visual inspection") + if os.path.exists(output_path): + print(f"[TEST] File size: {os.path.getsize(output_path)} bytes") + + +class TestDroneMovementAndOdometry(unittest.TestCase): + """Test drone movement commands and odometry.""" + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: + """Test movement commands are properly converted from ROS to NED.""" + mock_stream = MagicMock() + mock_stream.subscribe = lambda callback: None + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + + # Test movement in ROS frame + # ROS: X=forward, Y=left, Z=up + velocity_ros = Vector3(2.0, -1.0, 0.5) # Forward 2m/s, right 1m/s, up 0.5m/s + + result = conn.move(velocity_ros, duration=1.0) + self.assertTrue(result) + + # Movement should be converted to NED internally + # The fake connection doesn't actually send commands, but it should not crash + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_odometry_from_replay(self, mock_get_data, mock_replay) -> None: + """Test odometry is properly generated from replay messages.""" + # Set up replay messages + messages = [ + {"mavpackettype": "ATTITUDE", "roll": 0.1, "pitch": 0.2, "yaw": 0.3}, + { + "mavpackettype": "GLOBAL_POSITION_INT", + "lat": 377810501, + "lon": -1224069671, + "alt": 10000, + "relative_alt": 5000, + "vx": 200, # 2 m/s North + "vy": 100, # 1 m/s East + "vz": -50, # 0.5 m/s Up + "hdg": 18000, # 180 degrees + }, + ] + + def replay_stream_subscribe(callback) -> None: + for msg in messages: + callback(msg) + + mock_stream = MagicMock() + mock_stream.subscribe = replay_stream_subscribe + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + + # Collect published odometry + published_odom = [] + conn._odom_subject.subscribe(published_odom.append) + + # Process messages + for _ in range(5): + conn.update_telemetry(timeout=0.01) + + # Should have published odometry + self.assertGreater(len(published_odom), 0) + + # Check odometry message + odom = published_odom[0] + self.assertIsInstance(odom, PoseStamped) + self.assertIsNotNone(odom.orientation) + self.assertEqual(odom.frame_id, "world") + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_position_integration_indoor(self, mock_get_data, mock_replay) -> None: + """Test position integration for indoor flight without GPS.""" + messages = [ + {"mavpackettype": "ATTITUDE", "roll": 0, "pitch": 0, "yaw": 0}, + { + "mavpackettype": "GLOBAL_POSITION_INT", + "lat": 0, # Invalid GPS + "lon": 0, + "alt": 0, + "relative_alt": 2000, # 2m altitude + "vx": 100, # 1 m/s North + "vy": 0, + "vz": 0, + "hdg": 0, + }, + ] + + def replay_stream_subscribe(callback) -> None: + for msg in messages: + callback(msg) + + mock_stream = MagicMock() + mock_stream.subscribe = replay_stream_subscribe + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + + # Process messages multiple times to integrate position + initial_time = time.time() + conn._last_update = initial_time + + for _i in range(3): + conn.update_telemetry(timeout=0.01) + time.sleep(0.1) # Let some time pass for integration + + # Position should have been integrated + self.assertGreater(conn._position["x"], 0) # Moving North + self.assertEqual(conn._position["z"], 2.0) # Altitude from relative_alt + + +class TestDroneStatusAndTelemetry(unittest.TestCase): + """Test drone status and telemetry reporting.""" + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_status_extraction(self, mock_get_data, mock_replay) -> None: + """Test status is properly extracted from MAVLink messages.""" + messages = [ + {"mavpackettype": "HEARTBEAT", "type": 2, "base_mode": 193}, # Armed + { + "mavpackettype": "BATTERY_STATUS", + "voltages": [3700, 3700, 3700, 3700], + "current_battery": -1500, + "battery_remaining": 65, + }, + {"mavpackettype": "GPS_RAW_INT", "satellites_visible": 12, "fix_type": 3}, + {"mavpackettype": "GLOBAL_POSITION_INT", "relative_alt": 8000, "hdg": 27000}, + ] + + def replay_stream_subscribe(callback) -> None: + for msg in messages: + callback(msg) + + mock_stream = MagicMock() + mock_stream.subscribe = replay_stream_subscribe + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + + # Collect published status + published_status = [] + conn._status_subject.subscribe(published_status.append) + + # Process messages + for _ in range(5): + conn.update_telemetry(timeout=0.01) + + # Should have published status + self.assertGreater(len(published_status), 0) + + # Check status fields + status = published_status[-1] # Get latest + self.assertIn("armed", status) + self.assertIn("battery_remaining", status) + self.assertIn("satellites", status) + self.assertIn("altitude", status) + self.assertIn("heading", status) + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_telemetry_json_publishing(self, mock_get_data, mock_replay) -> None: + """Test full telemetry is published as JSON.""" + messages = [ + {"mavpackettype": "ATTITUDE", "roll": 0.1, "pitch": 0.2, "yaw": 0.3}, + {"mavpackettype": "GLOBAL_POSITION_INT", "lat": 377810501, "lon": -1224069671}, + ] + + def replay_stream_subscribe(callback) -> None: + for msg in messages: + callback(msg) + + mock_stream = MagicMock() + mock_stream.subscribe = replay_stream_subscribe + mock_replay.return_value.stream.return_value = mock_stream + + # Create connection module with replay + module = DroneConnectionModule(connection_string="replay") + + # Mock publishers + published_telemetry = [] + module.telemetry = MagicMock(publish=lambda x: published_telemetry.append(x)) + module.status = MagicMock() + module.odom = MagicMock() + module.tf = MagicMock() + module.video = MagicMock() + module.movecmd = MagicMock() + + # Start module + result = module.start() + self.assertTrue(result) + + # Give time for processing + time.sleep(0.2) + + # Stop module + module.stop() + + # Check telemetry was published + self.assertGreater(len(published_telemetry), 0) + + # Telemetry should be JSON string + telem_msg = published_telemetry[0] + self.assertIsNotNone(telem_msg) + + # If it's a String message, check the data + if hasattr(telem_msg, "data"): + telem_dict = json.loads(telem_msg.data) + self.assertIn("timestamp", telem_dict) + + +class TestFlyToErrorHandling(unittest.TestCase): + """Test fly_to() error handling paths.""" + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: + """flying_to_target=True rejects concurrent fly_to() calls.""" + mock_stream = MagicMock() + mock_stream.subscribe = lambda callback: None + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + conn.flying_to_target = True + + result = conn.fly_to(37.0, -122.0, 10.0) + self.assertIn("Already flying to target", result) + + @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.data.get_data") + def test_error_when_not_connected(self, mock_get_data, mock_replay) -> None: + """connected=False returns error immediately.""" + mock_stream = MagicMock() + mock_stream.subscribe = lambda callback: None + mock_replay.return_value.stream.return_value = mock_stream + + conn = FakeMavlinkConnection("replay") + conn.connected = False + + result = conn.fly_to(37.0, -122.0, 10.0) + self.assertIn("Not connected", result) + + +class TestVisualServoingEdgeCases(unittest.TestCase): + """Test DroneVisualServoingController edge cases.""" + + def test_output_clamping(self) -> None: + """Large errors are clamped to max_velocity.""" + from dimos.robot.drone.drone_visual_servoing_controller import ( + DroneVisualServoingController, + ) + + # PID params: (kp, ki, kd, output_limits, integral_limit, deadband) + max_vel = 2.0 + controller = DroneVisualServoingController( + x_pid_params=(1.0, 0.0, 0.0, (-max_vel, max_vel), None, 0), + y_pid_params=(1.0, 0.0, 0.0, (-max_vel, max_vel), None, 0), + ) + + # Large error should be clamped + vx, vy, vz = controller.compute_velocity_control( + target_x=1000, target_y=1000, center_x=0, center_y=0, dt=0.1 + ) + self.assertLessEqual(abs(vx), max_vel) + self.assertLessEqual(abs(vy), max_vel) + + def test_deadband_prevents_integral_windup(self) -> None: + """Deadband prevents integral accumulation for small errors.""" + from dimos.robot.drone.drone_visual_servoing_controller import ( + DroneVisualServoingController, + ) + + deadband = 10 # pixels + controller = DroneVisualServoingController( + x_pid_params=(0.0, 1.0, 0.0, (-2.0, 2.0), None, deadband), # integral only + y_pid_params=(0.0, 1.0, 0.0, (-2.0, 2.0), None, deadband), + ) + + # With error inside deadband, integral should stay at zero + for _ in range(10): + controller.compute_velocity_control( + target_x=5, target_y=5, center_x=0, center_y=0, dt=0.1 + ) + + # Integral should be zero since error < deadband + self.assertEqual(controller.x_pid.integral, 0.0) + self.assertEqual(controller.y_pid.integral, 0.0) + + def test_reset_clears_integral(self) -> None: + """reset() clears accumulated integral to prevent windup.""" + from dimos.robot.drone.drone_visual_servoing_controller import ( + DroneVisualServoingController, + ) + + controller = DroneVisualServoingController( + x_pid_params=(0.0, 1.0, 0.0, (-10.0, 10.0), None, 0), # Only integral + y_pid_params=(0.0, 1.0, 0.0, (-10.0, 10.0), None, 0), + ) + + # Accumulate integral by calling multiple times with error + for _ in range(10): + controller.compute_velocity_control( + target_x=100, target_y=100, center_x=0, center_y=0, dt=0.1 + ) + + # Integral should be non-zero + self.assertNotEqual(controller.x_pid.integral, 0.0) + + # Reset should clear it + controller.reset() + self.assertEqual(controller.x_pid.integral, 0.0) + self.assertEqual(controller.y_pid.integral, 0.0) + + +class TestVisualServoingVelocity(unittest.TestCase): + """Test visual servoing velocity calculations.""" + + def test_velocity_from_bbox_center_error(self) -> None: + """Bbox center offset produces proportional velocity command.""" + from dimos.robot.drone.drone_visual_servoing_controller import ( + DroneVisualServoingController, + ) + + controller = DroneVisualServoingController( + x_pid_params=(0.01, 0.0, 0.0, (-2.0, 2.0), None, 0), + y_pid_params=(0.01, 0.0, 0.0, (-2.0, 2.0), None, 0), + ) + + # Image center at (320, 180), bbox center at (400, 180) = 80px right + frame_center = (320, 180) + bbox_center = (400, 180) + + vx, vy, vz = controller.compute_velocity_control( + target_x=bbox_center[0], + target_y=bbox_center[1], + center_x=frame_center[0], + center_y=frame_center[1], + dt=0.1, + ) + + # Object to the right -> drone should strafe right (positive vy) + self.assertGreater(vy, 0) + # No vertical offset -> vx should be ~0 + self.assertAlmostEqual(vx, 0, places=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 5a054ed9f8..341c19a66a 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -93,10 +93,13 @@ def __init__(self, port: int = 7779, **kwargs) -> None: # type: ignore[no-untyp self.vis_state = {} # type: ignore[var-annotated] self.state_lock = threading.Lock() - self.costmap_encoder = OptimizedCostmapEncoder(chunk_size=64) - logger.info(f"WebSocket visualization module initialized on port {port}") + # Track GPS goal points for visualization + self.gps_goal_points: list[dict[str, float]] = [] + logger.info( + f"WebSocket visualization module initialized on port {port}, GPS goal tracking enabled" + ) def _start_broadcast_loop(self) -> None: def websocket_vis_loop() -> None: @@ -141,8 +144,11 @@ def start(self) -> None: except Exception: ... - unsub = self.global_costmap.subscribe(self._on_global_costmap) - self._disposables.add(Disposable(unsub)) + try: + unsub = self.global_costmap.subscribe(self._on_global_costmap) + self._disposables.add(Disposable(unsub)) + except Exception: + ... @rpc def stop(self) -> None: @@ -191,10 +197,17 @@ async def connect(sid, environ) -> None: # type: ignore[no-untyped-def] with self.state_lock: current_state = dict(self.vis_state) + # Include GPS goal points in the initial state + if self.gps_goal_points: + current_state["gps_travel_goal_points"] = self.gps_goal_points + # Force full costmap update on new connection self.costmap_encoder.last_full_grid = None await self.sio.emit("full_state", current_state, room=sid) # type: ignore[union-attr] + logger.info( + f"Client {sid} connected, sent state with {len(self.gps_goal_points)} GPS goal points" + ) @self.sio.event # type: ignore[misc, untyped-decorator] async def click(sid, position) -> None: # type: ignore[no-untyped-def] @@ -207,12 +220,25 @@ async def click(sid, position) -> None: # type: ignore[no-untyped-def] logger.info(f"Click goal published: ({goal.position.x:.2f}, {goal.position.y:.2f})") @self.sio.event # type: ignore[misc, untyped-decorator] - async def gps_goal(sid, goal) -> None: # type: ignore[no-untyped-def] - logger.info(f"Set GPS goal: {goal}") + async def gps_goal(sid: str, goal: dict[str, float]) -> None: + logger.info(f"Received GPS goal: {goal}") + + # Publish the goal to LCM self.gps_goal.publish(LatLon(lat=goal["lat"], lon=goal["lon"])) + # Add to goal points list for visualization + self.gps_goal_points.append(goal) + logger.info(f"Added GPS goal to list. Total goals: {len(self.gps_goal_points)}") + + # Emit updated goal points back to all connected clients + if self.sio is not None: + await self.sio.emit("gps_travel_goal_points", self.gps_goal_points) + logger.debug( + f"Emitted gps_travel_goal_points with {len(self.gps_goal_points)} points: {self.gps_goal_points}" + ) + @self.sio.event # type: ignore[misc, untyped-decorator] - async def start_explore(sid) -> None: # type: ignore[no-untyped-def] + async def start_explore(sid: str) -> None: logger.info("Starting exploration") self.explore_cmd.publish(Bool(data=True)) @@ -222,7 +248,15 @@ async def stop_explore(sid) -> None: # type: ignore[no-untyped-def] self.stop_explore_cmd.publish(Bool(data=True)) @self.sio.event # type: ignore[misc, untyped-decorator] - async def move_command(sid, data) -> None: # type: ignore[no-untyped-def] + async def clear_gps_goals(sid: str) -> None: + logger.info("Clearing all GPS goal points") + self.gps_goal_points.clear() + if self.sio is not None: + await self.sio.emit("gps_travel_goal_points", self.gps_goal_points) + logger.info("GPS goal points cleared and updated clients") + + @self.sio.event # type: ignore[misc, untyped-decorator] + async def move_command(sid: str, data: dict[str, Any]) -> None: # Publish Twist if transport is configured if self.cmd_vel and self.cmd_vel.transport: twist = Twist( diff --git a/docker/python/Dockerfile b/docker/python/Dockerfile index c12f7ea5d9..50f021a9a1 100644 --- a/docker/python/Dockerfile +++ b/docker/python/Dockerfile @@ -49,4 +49,4 @@ COPY . /app/ # Install dependencies with UV (10-100x faster than pip) RUN uv pip install --upgrade 'pip>=24' 'setuptools>=70' 'wheel' 'packaging>=24' && \ - uv pip install '.[cpu,sim]' + uv pip install '.[cpu,sim,drone]' diff --git a/pyproject.toml b/pyproject.toml index cbd8c46c67..7ed6b7502a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -236,6 +236,10 @@ jetson-jp6-cuda126 = [ "xformers @ https://pypi.jetson-ai-lab.io/jp6/cu126/+f/731/15133b0ebb2b3/xformers-0.0.33+ac00641.d20250830-cp39-abi3-linux_aarch64.whl", ] +drone = [ + "pymavlink" +] + [tool.ruff] line-length = 100 exclude = [