diff --git a/data/.lfs/unitree_go2_lidar_corrected.tar.gz b/data/.lfs/unitree_go2_lidar_corrected.tar.gz new file mode 100644 index 0000000000..013f6b3fe1 --- /dev/null +++ b/data/.lfs/unitree_go2_lidar_corrected.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51a817f2b5664c9e2f2856293db242e030f0edce276e21da0edc2821d947aad2 +size 1212727745 diff --git a/data/.lfs/unitree_go2_office_walk2.tar.gz b/data/.lfs/unitree_go2_office_walk2.tar.gz new file mode 100644 index 0000000000..ea392c4b4c --- /dev/null +++ b/data/.lfs/unitree_go2_office_walk2.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d208cdf537ad01eed2068a4665e454ed30b30894bd9b35c14b4056712faeef5d +size 1693876005 diff --git a/dimos/agents2/test_mock_agent.py b/dimos/agents2/test_mock_agent.py index 298e1c968b..4331b48c30 100644 --- a/dimos/agents2/test_mock_agent.py +++ b/dimos/agents2/test_mock_agent.py @@ -27,8 +27,6 @@ from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import Image -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection2d import Detect2DModule from dimos.protocol.skill.test_coordinator import SkillContainerTest from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -167,15 +165,8 @@ def test_tool_call_implicit_detections(): robot_connection.camera_info.transport = LCMTransport("/camera_info", CameraInfo) robot_connection.start() - detect2d = dimos.deploy(Detect2DModule) - detect2d.detections.transport = LCMTransport("/detections", Detection2DArray) - detect2d.annotations.transport = LCMTransport("/annotations", ImageAnnotations) - detect2d.image.connect(robot_connection.video) - detect2d.start() - test_skill_module = dimos.deploy(SkillContainerTest) - agent.register_skills(detect2d) agent.register_skills(test_skill_module) agent.start() @@ -208,5 +199,4 @@ def test_tool_call_implicit_detections(): agent.stop() test_skill_module.stop() robot_connection.stop() - detect2d.stop() dimos.stop() diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 5da06e787d..12043300ae 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -8,13 +8,13 @@ import dimos.core.colors as colors from dimos.core.core import rpc -from dimos.core.module import Module, ModuleBase +from dimos.core.module import Module, ModuleBase, ModuleConfig from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.core.transport import ( LCMTransport, + SHMTransport, ZenohTransport, pLCMTransport, - SHMTransport, pSHMTransport, ) from dimos.protocol.rpc.lcmrpc import LCMRPC @@ -99,7 +99,10 @@ def rpc_call(*args, **kwargs): return self.actor_instance.__getattr__(name) -def patchdask(dask_client: Client, local_cluster: LocalCluster) -> Client: +DimosCluster = Client + + +def patchdask(dask_client: Client, local_cluster: LocalCluster) -> DimosCluster: def deploy( actor_class, *args, diff --git a/dimos/core/stream.py b/dimos/core/stream.py index 9d9852d400..fab190e07f 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -79,12 +79,12 @@ class Transport(ObservableMixin[T]): # used by local Output def broadcast(self, selfstream: Out[T], value: T): ... + def publish(self, msg: T): + self.broadcast(None, msg) + # used by local Input def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... - def publish(self, *args, **kwargs): - return self.broadcast(*args, **kwargs) - class Stream(Generic[T]): _transport: Optional[Transport] diff --git a/dimos/msgs/foxglove_msgs/ImageAnnotations.py b/dimos/msgs/foxglove_msgs/ImageAnnotations.py index 6a1c668a4b..1f58b09d73 100644 --- a/dimos/msgs/foxglove_msgs/ImageAnnotations.py +++ b/dimos/msgs/foxglove_msgs/ImageAnnotations.py @@ -16,6 +16,17 @@ class ImageAnnotations(FoxgloveImageAnnotations): + def __add__(self, other: "ImageAnnotations") -> "ImageAnnotations": + points = self.points + other.points + texts = self.texts + other.texts + + return ImageAnnotations( + texts=texts, + texts_length=len(texts), + points=points, + points_length=len(points), + ) + def agent_encode(self) -> str: if len(self.texts) == 0: return None diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index a47c58337c..60695f702d 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -192,6 +192,37 @@ def to_pose(self) -> "PoseStamped": frame_id=self.frame_id, ) + def to_matrix(self) -> "np.ndarray": + """Convert Transform to a 4x4 transformation matrix. + + Returns a homogeneous transformation matrix that represents both + the rotation and translation of this transform. + + Returns: + np.ndarray: A 4x4 homogeneous transformation matrix + """ + import numpy as np + + # Extract quaternion components + x, y, z, w = self.rotation.x, self.rotation.y, self.rotation.z, self.rotation.w + + # Build rotation matrix from quaternion using standard formula + # This avoids numerical issues compared to converting to axis-angle first + rotation_matrix = np.array( + [ + [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], + [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], + [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], + ] + ) + + # Build 4x4 homogeneous transformation matrix + matrix = np.eye(4) + matrix[:3, :3] = rotation_matrix + matrix[:3, 3] = [self.translation.x, self.translation.y, self.translation.z] + + return matrix + def lcm_encode(self) -> bytes: # we get a circular import otherwise from dimos.msgs.tf2_msgs.TFMessage import TFMessage diff --git a/dimos/msgs/geometry_msgs/Vector3.py b/dimos/msgs/geometry_msgs/Vector3.py index 4b255d082c..757fab4d6e 100644 --- a/dimos/msgs/geometry_msgs/Vector3.py +++ b/dimos/msgs/geometry_msgs/Vector3.py @@ -203,6 +203,10 @@ def cross(self, other: VectorConvertable | Vector3) -> Vector3: self.x * other_vector.y - self.y * other_vector.x, ) + def magnitude(self) -> float: + """Alias for length().""" + return self.length() + def length(self) -> float: """Compute the Euclidean length (magnitude) of the vector.""" return float(np.sqrt(self.x * self.x + self.y * self.y + self.z * self.z)) diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index e6d6ae2d40..b79e07639f 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -30,7 +30,7 @@ from reactivex.observable import Observable from reactivex.scheduler import ThreadPoolScheduler -from dimos.types.timestamped import Timestamped, TimestampedBufferCollection +from dimos.types.timestamped import Timestamped, TimestampedBufferCollection, to_human_readable class ImageFormat(Enum): @@ -65,6 +65,9 @@ class Image(Timestamped): frame_id: str = field(default="") ts: float = field(default_factory=time.time) + def __str__(self): + return f"Image(shape={self.shape}, format={self.format}, dtype={self.dtype}, ts={to_human_readable(self.ts)})" + def __post_init__(self): """Validate image data and format.""" if self.data is None: @@ -527,8 +530,10 @@ def sharpness_window(target_frequency: float, source: Observable[Image]) -> Obse def find_best(*argv): if not window._items: return None - return max(window._items, key=lambda x: x.sharpness()) + + found = max(window._items, key=lambda x: x.sharpness()) + return found return rx.interval(1.0 / target_frequency).pipe( - ops.observe_on(thread_scheduler), ops.map(find_best) + ops.observe_on(thread_scheduler), ops.map(find_best), ops.filter(lambda x: x is not None) ) diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py index 2238b31025..da270699f0 100644 --- a/dimos/msgs/sensor_msgs/PointCloud2.py +++ b/dimos/msgs/sensor_msgs/PointCloud2.py @@ -38,13 +38,34 @@ class PointCloud2(Timestamped): def __init__( self, pointcloud: o3d.geometry.PointCloud = None, - frame_id: str = "", + frame_id: str = "world", ts: Optional[float] = None, ): - self.ts = ts if ts is not None else time.time() + self.ts = ts self.pointcloud = pointcloud if pointcloud is not None else o3d.geometry.PointCloud() self.frame_id = frame_id + @classmethod + def from_numpy( + cls, points: np.ndarray, frame_id: str = "world", timestamp: Optional[float] = None + ) -> PointCloud2: + """Create PointCloud2 from numpy array of shape (N, 3). + + Args: + points: Nx3 numpy array of 3D points + frame_id: Frame ID for the point cloud + timestamp: Timestamp for the point cloud (defaults to current time) + + Returns: + PointCloud2 instance + """ + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + return cls(pointcloud=pcd, ts=timestamp, frame_id=frame_id) + + def points(self): + return self.pointcloud.points + # TODO what's the usual storage here? is it already numpy? def as_numpy(self) -> np.ndarray: """Get points as numpy array.""" @@ -208,6 +229,78 @@ def __len__(self) -> int: """Return number of points.""" return len(self.pointcloud.points) + def filter_by_height( + self, + min_height: Optional[float] = None, + max_height: Optional[float] = None, + ) -> "PointCloud2": + """Filter points based on their height (z-coordinate). + + This method creates a new PointCloud2 containing only points within the specified + height range. All metadata (frame_id, timestamp) is preserved. + + Args: + min_height: Optional minimum height threshold. Points with z < min_height are filtered out. + If None, no lower limit is applied. + max_height: Optional maximum height threshold. Points with z > max_height are filtered out. + If None, no upper limit is applied. + + Returns: + New PointCloud2 instance containing only the filtered points. + + Raises: + ValueError: If both min_height and max_height are None (no filtering would occur). + + Example: + # Remove ground points below 0.1m height + filtered_pc = pointcloud.filter_by_height(min_height=0.1) + + # Keep only points between ground level and 2m height + filtered_pc = pointcloud.filter_by_height(min_height=0.0, max_height=2.0) + + # Remove points above 1.5m (e.g., ceiling) + filtered_pc = pointcloud.filter_by_height(max_height=1.5) + """ + # Validate that at least one threshold is provided + if min_height is None and max_height is None: + raise ValueError("At least one of min_height or max_height must be specified") + + # Get points as numpy array + points = self.as_numpy() + + if len(points) == 0: + # Empty pointcloud - return a copy + return PointCloud2( + pointcloud=o3d.geometry.PointCloud(), + frame_id=self.frame_id, + ts=self.ts, + ) + + # Extract z-coordinates (height values) - column index 2 + heights = points[:, 2] + + # Create boolean mask for filtering based on height thresholds + # Start with all True values + mask = np.ones(len(points), dtype=bool) + + # Apply minimum height filter if specified + if min_height is not None: + mask &= heights >= min_height + + # Apply maximum height filter if specified + if max_height is not None: + mask &= heights <= max_height + + # Apply mask to filter points + filtered_points = points[mask] + + # Create new PointCloud2 with filtered points + return PointCloud2.from_numpy( + points=filtered_points, + frame_id=self.frame_id, + timestamp=self.ts, + ) + def __repr__(self) -> str: """String representation.""" return f"PointCloud(points={len(self)}, frame_id='{self.frame_id}', ts={self.ts})" diff --git a/dimos/perception/detection2d/__init__.py b/dimos/perception/detection2d/__init__.py index b64461b493..bdcf9ca827 100644 --- a/dimos/perception/detection2d/__init__.py +++ b/dimos/perception/detection2d/__init__.py @@ -1,3 +1,8 @@ -from dimos.perception.detection2d.module import Detect2DModule +from dimos.perception.detection2d.module2D import ( + Detection2DModule, +) +from dimos.perception.detection2d.module3D import ( + Detection3DModule, +) from dimos.perception.detection2d.utils import * from dimos.perception.detection2d.yolo_2d_det import * diff --git a/dimos/perception/detection2d/conftest.py b/dimos/perception/detection2d/conftest.py new file mode 100644 index 0000000000..93d771b373 --- /dev/null +++ b/dimos/perception/detection2d/conftest.py @@ -0,0 +1,142 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Optional, TypedDict + +import pytest +from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos_lcm.sensor_msgs import CameraInfo, PointCloud2 + +from dimos.core import start +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection2d.module2D import Detection2DModule +from dimos.perception.detection2d.module3D import Detection3DModule +from dimos.perception.detection2d.type import ImageDetections3D +from dimos.protocol.service import lcmservice as lcm +from dimos.protocol.tf import TF +from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.data import get_data +from dimos.utils.testing import TimedSensorReplay + + +class Moment(TypedDict, total=False): + odom_frame: Odometry + lidar_frame: LidarMessage + image_frame: Image + camera_info: CameraInfo + transforms: list[Transform] + tf: TF + annotations: Optional[ImageAnnotations] + detections: Optional[ImageDetections3D] + + +@pytest.fixture +def dimos_cluster(): + dimos = start(5) + yield dimos + dimos.stop() + + +@pytest.fixture(scope="session") +def moment(): + data_dir = "unitree_go2_lidar_corrected" + get_data(data_dir) + + seek = 10 + + lidar_frame = TimedSensorReplay(f"{data_dir}/lidar").find_closest_seek(seek) + + image_frame = TimedSensorReplay( + f"{data_dir}/video", + ).find_closest(lidar_frame.ts) + + image_frame.frame_id = "camera_optical" + + odom_frame = TimedSensorReplay(f"{data_dir}/odom", autocast=Odometry.from_msg).find_closest( + lidar_frame.ts + ) + + transforms = ConnectionModule._odom_to_tf(odom_frame) + + tf = TF() + tf.publish(*transforms) + + return { + "odom_frame": odom_frame, + "lidar_frame": lidar_frame, + "image_frame": image_frame, + "camera_info": ConnectionModule._camera_info(), + "transforms": transforms, + "tf": tf, + } + + +@pytest.fixture(scope="session") +def publish_lcm(): + def publish(moment: Moment): + lcm.autoconf() + + lidar_frame_transport: LCMTransport = LCMTransport("/lidar", LidarMessage) + lidar_frame_transport.publish(moment.get("lidar_frame")) + + image_frame_transport: LCMTransport = LCMTransport("/image", Image) + image_frame_transport.publish(moment.get("image_frame")) + + odom_frame_transport: LCMTransport = LCMTransport("/odom", Odometry) + odom_frame_transport.publish(moment.get("odom_frame")) + + camera_info_transport: LCMTransport = LCMTransport("/camera_info", CameraInfo) + camera_info_transport.publish(moment.get("camera_info")) + + annotations = moment.get("annotations") + if annotations: + annotations_transport: LCMTransport = LCMTransport("/annotations", ImageAnnotations) + annotations_transport.publish(annotations) + + detections = moment.get("detections") + if detections: + for i, detection in enumerate(detections): + detections_transport: LCMTransport = LCMTransport( + f"/detected/pointcloud/{i}", PointCloud2 + ) + detections_transport.publish(detection.pointcloud) + + detections_image_transport: LCMTransport = LCMTransport( + f"/detected/image/{i}", Image + ) + detections_image_transport.publish(detection.cropped_image()) + + return publish + + +@pytest.fixture(scope="session") +def detections2d(moment: Moment): + return Detection2DModule().process_image_frame(moment["image_frame"]) + + +@pytest.fixture(scope="session") +def detections3d(moment: Moment): + detections2d = Detection2DModule().process_image_frame(moment["image_frame"]) + camera_transform = moment["tf"].get("camera_optical", "world") + if camera_transform is None: + raise ValueError("No camera_optical transform in tf") + + return Detection3DModule(camera_info=moment["camera_info"]).process_frame( + detections2d, moment["lidar_frame"], camera_transform + ) diff --git a/dimos/perception/detection2d/detic_2d_det.py b/dimos/perception/detection2d/detic_2d_det.py index 8bc4f9c4b0..44b77cb397 100644 --- a/dimos/perception/detection2d/detic_2d_det.py +++ b/dimos/perception/detection2d/detic_2d_det.py @@ -27,7 +27,7 @@ import PIL.Image if not hasattr(PIL.Image, "LINEAR") and hasattr(PIL.Image, "BILINEAR"): - PIL.Image.LINEAR = PIL.Image.BILINEAR + PIL.Image.LINEAR = PIL.Image.BILINEAR # type: ignore[attr-defined] # Detectron2 imports from detectron2.config import get_cfg diff --git a/dimos/perception/detection2d/module.py b/dimos/perception/detection2d/module.py deleted file mode 100644 index f8bd3a340a..0000000000 --- a/dimos/perception/detection2d/module.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import queue -from typing import Any, Callable, Generator, List, Optional, Tuple - -from dimos_lcm.foxglove_msgs import ( - PointsAnnotation, - TextAnnotation, -) -from dimos_lcm.foxglove_msgs.Color import Color -from dimos_lcm.foxglove_msgs.Point2 import Point2 -from dimos.msgs.vision_msgs import Detection2DArray -from dimos_lcm.vision_msgs import ( - BoundingBox2D, - Detection2D, - ObjectHypothesis, - ObjectHypothesisWithPose, - Point2D, - Pose2D, -) -from reactivex import operators as ops -from reactivex.observable import Observable - -from dimos.core import In, Module, Out, rpc -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header -from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector -from dimos.protocol.skill.skill import skill -from dimos.protocol.skill.type import Output, Reducer, Stream -from dimos.types.timestamped import to_ros_stamp - - -Bbox = Tuple[float, float, float, float] -CenteredBbox = Tuple[float, float, float, float] -# yolo and detic have bad output formats -InconvinientDetectionFormat = Tuple[List[Bbox], List[int], List[int], List[float], List[List[str]]] - - -Detection = Tuple[Bbox, int, int, float, List[str]] -Detections = List[Detection] -ImageDetections = Tuple[Image, Detections] -ImageDetection = Tuple[Image, Detection] - - -def get_bbox_center(bbox: Bbox) -> CenteredBbox: - x1, y1, x2, y2 = bbox - center_x = (x1 + x2) / 2.0 - center_y = (y1 + y2) / 2.0 - width = float(x2 - x1) - height = float(y2 - y1) - return [center_x, center_y, width, height] - - -def build_bbox(bbox: Bbox) -> BoundingBox2D: - center_x, center_y, width, height = get_bbox_center(bbox) - - return BoundingBox2D( - center=Pose2D( - position=Point2D(x=center_x, y=center_y), - theta=0.0, - ), - size_x=width, - size_y=height, - ) - - -def build_detection2d(image, detection) -> Detection2D: - [bbox, track_id, class_id, confidence, name] = detection - - return Detection2D( - header=Header(image.ts, "camera_link"), - bbox=build_bbox(bbox), - results=[ - ObjectHypothesisWithPose( - ObjectHypothesis( - class_id=class_id, - score=1.0, - ) - ) - ], - ) - - -def build_detection2d_array(imageDetections: ImageDetections) -> Detection2DArray: - [image, detections] = imageDetections - return Detection2DArray( - detections_length=len(detections), - header=Header(image.ts, "camera_link"), - detections=list( - map( - functools.partial(build_detection2d, image), - detections, - ) - ), - ) - - -# yolo and detic have bad formats this translates into list of detections -def better_detection_format(inconvinient_detections: InconvinientDetectionFormat) -> Detections: - bboxes, track_ids, class_ids, confidences, names = inconvinient_detections - return [ - [bbox, track_id, class_id, confidence, name] - for bbox, track_id, class_id, confidence, name in zip( - bboxes, track_ids, class_ids, confidences, names - ) - ] - - -def build_imageannotation_text(image: Image, detection: Detection) -> ImageAnnotations: - [bbox, track_id, class_id, confidence, name] = detection - - x1, y1, x2, y2 = bbox - - font_size = int(image.height / 35) - return [ - TextAnnotation( - timestamp=to_ros_stamp(image.ts), - position=Point2(x=x1, y=y2 + font_size), - text=f"confidence: {confidence:.3f}", - font_size=font_size, - text_color=Color(r=1.0, g=1.0, b=1.0, a=1), - background_color=Color(r=0, g=0, b=0, a=1), - ), - TextAnnotation( - timestamp=to_ros_stamp(image.ts), - position=Point2(x=x1, y=y1), - text=f"{name}_{class_id}_{track_id}", - font_size=font_size, - text_color=Color(r=1.0, g=1.0, b=1.0, a=1), - background_color=Color(r=0, g=0, b=0, a=1), - ), - ] - - -def build_imageannotation_box(image: Image, detection: Detection) -> ImageAnnotations: - [bbox, track_id, class_id, confidence, name] = detection - - x1, y1, x2, y2 = bbox - - thickness = image.height / 720 - - return PointsAnnotation( - timestamp=to_ros_stamp(image.ts), - outline_color=Color(r=0.0, g=0.0, b=0.0, a=1.0), - fill_color=Color(r=1.0, g=0.0, b=0.0, a=0.15), - thickness=thickness, - points_length=4, - points=[ - Point2(x1, y1), - Point2(x1, y2), - Point2(x2, y2), - Point2(x2, y1), - ], - type=PointsAnnotation.LINE_LOOP, - ) - - -def build_imageannotations(image_detections: [Image, Detections]) -> ImageAnnotations: - [image, detections] = image_detections - - def flatten(xss): - return [x for xs in xss for x in xs] - - points = list(map(functools.partial(build_imageannotation_box, image), detections)) - texts = list(flatten(map(functools.partial(build_imageannotation_text, image), detections))) - - return ImageAnnotations( - texts=texts, - texts_length=len(texts), - points=points, - points_length=len(points), - ) - - -class Detect2DModule(Module): - image: In[Image] = None - detections: Out[Detection2DArray] = None - annotations: Out[ImageAnnotations] = None - - # _initDetector = Detic2DDetector - _initDetector = Yolo2DDetector - - def __init__(self, *args, detector=Optional[Callable[[Any], Any]], **kwargs): - if detector: - self._detectorClass = detector - super().__init__(*args, **kwargs) - - def detect(self, image: Image) -> Detections: - return [image, better_detection_format(self.detector.process_image(image.to_opencv()))] - - @rpc - def start(self): - self.detector = self._initDetector() - self.detection2d_stream().subscribe(self.detections.publish) - self.annotation_stream().subscribe(self.annotations.publish) - - @functools.cache - def detection2d_stream(self) -> Observable[Detection2DArray]: - return self.image.observable().pipe(ops.map(self.detect), ops.map(build_detection2d_array)) - - @functools.cache - def annotation_stream(self) -> Observable[ImageAnnotations]: - return self.image.observable().pipe(ops.map(self.detect), ops.map(build_imageannotations)) - - @functools.cache - def detection_stream(self) -> Observable[ImageDetections]: - return self.image.observable().pipe(ops.map(self.detect)) - - @skill(stream=Stream.passive, reducer=Reducer.accumulate_dict) - def get_detections(self) -> Generator[ImageAnnotations, None, None]: - """Provides latest image detections""" - - blocking_queue = queue.Queue() - self.detection_stream().subscribe(blocking_queue.put) - - while True: - # dealing with a dumb format from detic and yolo - # probably needs to be abstracted earlier in the pipeline so it's more convinient to use - [image, detections] = blocking_queue.get() - - detection_dict = {} - for detection in detections: - [bbox, track_id, class_id, confidence, name] = detection - detection_dict[name] = f"{confidence:.3f}" - - yield detection_dict diff --git a/dimos/perception/detection2d/module2D.py b/dimos/perception/detection2d/module2D.py new file mode 100644 index 0000000000..92d1c0612b --- /dev/null +++ b/dimos/perception/detection2d/module2D.py @@ -0,0 +1,66 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Any, Callable, List, Optional + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( + ImageAnnotations, +) +from reactivex import operators as ops +from reactivex.observable import Observable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection2d.type import Detection2D, ImageDetections2D +from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector +from dimos.utils.reactive import backpressure + + +class Detection2DModule(Module): + image: In[Image] = None # type: ignore + + detections: Out[Detection2DArray] = None # type: ignore + annotations: Out[ImageAnnotations] = None # type: ignore + + _initDetector = Yolo2DDetector + + def __init__(self, *args, detector=Optional[Callable[[Any], Any]], **kwargs): + super().__init__(*args, **kwargs) + if detector: + self._detectorClass = detector + self.detector = self._initDetector() + + def process_image_frame(self, image: Image) -> ImageDetections2D: + detections = ImageDetections2D.from_detector( + image, self.detector.process_image(image.to_opencv()) + ) + return detections + + @functools.cache + def detection_stream(self) -> Observable[ImageDetections2D]: + return backpressure(self.image.observable().pipe(ops.map(self.process_image_frame))) + + @rpc + def start(self): + self.detection_stream().subscribe( + lambda det: self.detections.publish(det.to_ros_detection2d_array()) + ) + + self.detection_stream().subscribe( + lambda det: self.annotations.publish(det.to_foxglove_annotations()) + ) + + @rpc + def stop(self): ... diff --git a/dimos/perception/detection2d/module3D.py b/dimos/perception/detection2d/module3D.py new file mode 100644 index 0000000000..98a1dc14af --- /dev/null +++ b/dimos/perception/detection2d/module3D.py @@ -0,0 +1,266 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import List, Optional, Tuple + +import numpy as np +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( + ImageAnnotations, +) +from dimos_lcm.sensor_msgs import CameraInfo +from reactivex import operators as ops + +from dimos.core import In, Out, rpc +from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection2d.module2D import Detection2DModule + +# from dimos.perception.detection2d.detic import Detic2DDetector +from dimos.perception.detection2d.type import ( + Detection2D, + ImageDetections2D, + ImageDetections3D, +) + +# Type aliases for clarity +ImageDetections = Tuple[Image, List[Detection2D]] +ImageDetection = Tuple[Image, Detection2D] + + +class Detection3DModule(Detection2DModule): + camera_info: CameraInfo + + image: In[Image] = None # type: ignore + pointcloud: In[PointCloud2] = None # type: ignore + # type: ignore + detections: Out[Detection2DArray] = None # type: ignore + annotations: Out[ImageAnnotations] = None # type: ignore + + detected_pointcloud_0: Out[PointCloud2] = None # type: ignore + detected_pointcloud_1: Out[PointCloud2] = None # type: ignore + detected_pointcloud_2: Out[PointCloud2] = None # type: ignore + detected_image_0: Out[Image] = None # type: ignore + detected_image_1: Out[Image] = None # type: ignore + detected_image_2: Out[Image] = None # type: ignore + + def __init__(self, camera_info: CameraInfo, *args, **kwargs): + self.camera_info = camera_info + super().__init__(*args, **kwargs) + + def detect(self, image: Image) -> ImageDetections: + detections = Detection2D.from_detector( + self.detector.process_image(image.to_opencv()), ts=image.ts + ) + return (image, detections) + + def project_points_to_camera( + self, + points_3d: np.ndarray, + camera_matrix: np.ndarray, + extrinsics: Transform, + ) -> Tuple[np.ndarray, np.ndarray]: + """Project 3D points to 2D camera coordinates.""" + # Transform points from world to camera_optical frame + points_homogeneous = np.hstack([points_3d, np.ones((points_3d.shape[0], 1))]) + extrinsics_matrix = extrinsics.to_matrix() + points_camera = (extrinsics_matrix @ points_homogeneous.T).T + + # Filter out points behind the camera + valid_mask = points_camera[:, 2] > 0 + points_camera = points_camera[valid_mask] + + # Project to 2D + points_2d_homogeneous = (camera_matrix @ points_camera[:, :3].T).T + points_2d = points_2d_homogeneous[:, :2] / points_2d_homogeneous[:, 2:3] + + return points_2d, valid_mask + + def filter_points_in_detections( + self, + detections: ImageDetections2D, + pointcloud: PointCloud2, + world_to_camera_transform: Transform, + ) -> List[Optional[PointCloud2]]: + """Filter lidar points that fall within detection bounding boxes.""" + # Extract camera parameters + camera_info = self.camera_info + fx, fy, cx = camera_info.K[0], camera_info.K[4], camera_info.K[2] + cy = camera_info.K[5] + image_width = camera_info.width + image_height = camera_info.height + + camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + + # Convert pointcloud to numpy array + lidar_points = pointcloud.as_numpy() + + # Project all points to camera frame + points_2d_all, valid_mask = self.project_points_to_camera( + lidar_points, camera_matrix, world_to_camera_transform + ) + valid_3d_points = lidar_points[valid_mask] + points_2d = points_2d_all.copy() + + # Filter points within image bounds + in_image_mask = ( + (points_2d[:, 0] >= 0) + & (points_2d[:, 0] < image_width) + & (points_2d[:, 1] >= 0) + & (points_2d[:, 1] < image_height) + ) + points_2d = points_2d[in_image_mask] + valid_3d_points = valid_3d_points[in_image_mask] + + filtered_pointclouds: List[Optional[PointCloud2]] = [] + + for detection in detections: + # Extract bbox from Detection2D object + bbox = detection.bbox + x_min, y_min, x_max, y_max = bbox + + # Find points within this detection box (with small margin) + margin = 5 # pixels + in_box_mask = ( + (points_2d[:, 0] >= x_min - margin) + & (points_2d[:, 0] <= x_max + margin) + & (points_2d[:, 1] >= y_min - margin) + & (points_2d[:, 1] <= y_max + margin) + ) + + detection_points = valid_3d_points[in_box_mask] + + # Create PointCloud2 message for this detection + if detection_points.shape[0] > 0: + detection_pointcloud = PointCloud2.from_numpy( + detection_points, + frame_id=pointcloud.frame_id, + timestamp=pointcloud.ts, + ) + filtered_pointclouds.append(detection_pointcloud) + else: + filtered_pointclouds.append(None) + + return filtered_pointclouds + + def combine_pointclouds(self, pointcloud_list: List[PointCloud2]) -> PointCloud2: + """Combine multiple pointclouds into a single one.""" + # Filter out None values + valid_pointclouds = [pc for pc in pointcloud_list if pc is not None] + + if not valid_pointclouds: + # Return empty pointcloud if no valid pointclouds + return PointCloud2.from_numpy( + np.array([]).reshape(0, 3), frame_id="world", timestamp=time.time() + ) + + # Combine all point arrays + all_points = np.vstack([pc.as_numpy() for pc in valid_pointclouds]) + + # Use frame_id and timestamp from first pointcloud + combined_pointcloud = PointCloud2.from_numpy( + all_points, + frame_id=valid_pointclouds[0].frame_id, + timestamp=valid_pointclouds[0].ts, + ) + + return combined_pointcloud + + def hidden_point_removal( + self, camera_transform: Transform, pc: PointCloud2, radius: float = 100.0 + ): + camera_position = camera_transform.inverse().translation + camera_pos_np = camera_position.to_numpy().reshape(3, 1) + + pcd = pc.pointcloud + try: + _, visible_indices = pcd.hidden_point_removal(camera_pos_np, radius) + visible_pcd = pcd.select_by_index(visible_indices) + + return PointCloud2(visible_pcd, frame_id=pc.frame_id, ts=pc.ts) + except Exception as e: + return pc + + def cleanup_pointcloud(self, pc: PointCloud2) -> PointCloud2: + height = pc.filter_by_height(-0.05) + statistical, _ = height.pointcloud.remove_statistical_outlier( + nb_neighbors=20, std_ratio=2.0 + ) + return PointCloud2(statistical, pc.frame_id, pc.ts) + + def process_frame( + self, + # these have to be timestamp aligned + detections: ImageDetections2D, + pointcloud: PointCloud2, + transform: Transform, + ) -> ImageDetections3D: + if not transform: + return ImageDetections3D(detections.image, []) + + pointcloud_list = self.filter_points_in_detections(detections, pointcloud, transform) + + detection3d_list = [] + for detection, pc in zip(detections, pointcloud_list): + if pc is None: + continue + pc = self.hidden_point_removal(transform, self.cleanup_pointcloud(pc)) + if pc is None: + continue + + detection3d_list.append(detection.to_3d(pointcloud=pc, transform=transform)) + + return ImageDetections3D(detections.image, detection3d_list) + + @rpc + def start(self): + time_tolerance = 5.0 # seconds + + def detection2d_to_3d(args): + detections, pc = args + transform = self.tf.get("camera_optical", "world", detections.image.ts, time_tolerance) + return self.process_frame(detections, pc, transform) + + combined_stream = self.detection_stream().pipe( + ops.with_latest_from(self.pointcloud.observable()), ops.map(detection2d_to_3d) + ) + + self.detection_stream().subscribe( + lambda det: self.detections.publish(det.to_ros_detection2d_array()) + ) + + self.detection_stream().subscribe( + lambda det: self.annotations.publish(det.to_image_annotations()) + ) + + combined_stream.subscribe(self._handle_combined_detections) + + def _handle_combined_detections(self, detections: ImageDetections3D): + if not detections: + return + + print(detections) + + if len(detections) > 0: + self.detected_pointcloud_0.publish(detections[0].pointcloud) + self.detected_image_0.publish(detections[0].cropped_image()) + + if len(detections) > 1: + self.detected_pointcloud_1.publish(detections[1].pointcloud) + self.detected_image_1.publish(detections[1].cropped_image()) + + if len(detections) > 3: + self.detected_pointcloud_2.publish(detections[2].pointcloud) + self.detected_image_2.publish(detections[2].cropped_image()) diff --git a/dimos/perception/detection2d/moduleDB.py b/dimos/perception/detection2d/moduleDB.py new file mode 100644 index 0000000000..8e9205c4f3 --- /dev/null +++ b/dimos/perception/detection2d/moduleDB.py @@ -0,0 +1,56 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import time +from typing import List, Optional, Tuple + +import numpy as np +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( + ImageAnnotations, +) +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.vision_msgs import Detection2D as ROSDetection2D +from reactivex import operators as ops + +from dimos.core import In, Out, rpc +from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.perception.detection2d.module2D import Detection2DModule +from dimos.perception.detection2d.module3D import Detection3DModule +from dimos.perception.detection2d.type import ( + Detection2D, + Detection3D, + ImageDetections2D, + ImageDetections3D, +) +from dimos.protocol.skill import skill + + +class DetectionDBModule(Detection3DModule): + @rpc + def start(self): + super().start() + self.pointcloud_stream().subscribe(self.add_detections) + + def add_detections(self, detections: List[Detection3DArray]): + for det in detections: + self.add_detection(det) + + def add_detection(self, detection: Detection3D): + print("DETECTION", detection) + + def lookup(self, label: str) -> List[Detection3D]: + """Look up a detection by label.""" + return [] diff --git a/dimos/perception/detection2d/test_module.py b/dimos/perception/detection2d/test_module.py new file mode 100644 index 0000000000..4e938a6fa6 --- /dev/null +++ b/dimos/perception/detection2d/test_module.py @@ -0,0 +1,168 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( + ImageAnnotations, +) +from dimos_lcm.sensor_msgs import Image, PointCloud2 + +from dimos.core import LCMTransport +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.sensor_msgs import PointCloud2 as PointCloud2Msg +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection2d.conftest import Moment +from dimos.perception.detection2d.module2D import Detection2DModule +from dimos.perception.detection2d.module3D import Detection3DModule +from dimos.perception.detection2d.type import ( + Detection2D, + Detection3D, + ImageDetections2D, + ImageDetections3D, +) +from dimos.robot.unitree_webrtc.modular import deploy_connection, deploy_navigation +from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map + + +def test_module2d(moment: Moment, publish_lcm): + detections2d = Detection2DModule().process_image_frame(moment["image_frame"]) + + print(detections2d) + + # Assertions for test_module2d + assert isinstance(detections2d, ImageDetections2D) + assert len(detections2d) == 1 + assert detections2d.image.ts == 1757960670.490248 + assert detections2d.image.shape == (720, 1280, 3) + assert detections2d.image.frame_id == "camera_optical" + + # Check first detection + det = detections2d.detections[0] + assert isinstance(det, Detection2D) + assert det.name == "suitcase" + assert det.class_id == 28 + assert det.track_id == 1 + assert det.confidence == 0.8145349025726318 + + # Check bbox values + assert det.bbox == [503.437255859375, 249.89385986328125, 655.950439453125, 469.82879638671875] + + annotations = detections2d.to_image_annotations() + publish_lcm({"annotations": annotations, **moment}) + + +def test_module3d(moment: Moment, publish_lcm): + detections2d = Detection2DModule().process_image_frame(moment["image_frame"]) + pointcloud = moment["lidar_frame"] + camera_transform = moment["tf"].get("camera_optical", "world") + if camera_transform is None: + raise ValueError("No camera_optical transform in tf") + annotations = detections2d.to_image_annotations() + + detections3d = Detection3DModule(camera_info=moment["camera_info"]).process_frame( + detections2d, pointcloud, camera_transform + ) + + publish_lcm( + { + **moment, + "annotations": annotations, + "detections": detections3d, + } + ) + + print(detections3d) + + # Assertions for test_module3d + assert isinstance(detections3d, ImageDetections3D) + assert len(detections3d) == 1 + assert detections3d.image.ts == 1757960670.490248 + assert detections3d.image.shape == (720, 1280, 3) + assert detections3d.image.frame_id == "camera_optical" + + # Check first 3D detection + det = detections3d.detections[0] + assert isinstance(det, Detection3D) + assert det.name == "suitcase" + assert det.class_id == 28 + assert det.track_id == 1 + assert det.confidence == 0.8145349025726318 + + # Check bbox values (should match 2D) + assert det.bbox == [503.437255859375, 249.89385986328125, 655.950439453125, 469.82879638671875] + + # 3D-specific assertions + assert isinstance(det.pointcloud, PointCloud2Msg) + assert len(det.pointcloud) == 81 + assert det.pointcloud.frame_id == "world" + assert isinstance(det.transform, Transform) + + # Check center + center = det.center + assert isinstance(center, Vector3) + # Values from output: Vector([ -3.3565 -0.26265 0.18549]) + assert abs(center.x - (-3.3565)) < 1e-4 + assert abs(center.y - (-0.26265)) < 1e-4 + assert abs(center.z - 0.18549) < 1e-4 + + # Check pose + pose = det.pose + assert isinstance(pose, PoseStamped) + assert pose.frame_id == "world" + assert pose.ts == det.ts + + # Check repr dict values + repr_dict = det.to_repr_dict() + assert repr_dict["dist"] == "0.88m" + assert repr_dict["points"] == "81" + + +@pytest.mark.tool +def test_module3d_replay(dimos_cluster): + connection = deploy_connection(dimos_cluster, loop=False, speed=1.0) + # mapper = deploy_navigation(dimos_cluster, connection) + mapper = dimos_cluster.deploy( + Map, voxel_size=0.5, cost_resolution=0.05, global_publish_interval=1.0 + ) + mapper.lidar.connect(connection.lidar) + mapper.global_map.transport = LCMTransport("/global_map", LidarMessage) + mapper.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) + mapper.local_costmap.transport = LCMTransport("/local_costmap", OccupancyGrid) + + mapper.start() + + module3D = dimos_cluster.deploy(Detection3DModule, camera_info=ConnectionModule._camera_info()) + + module3D.image.connect(connection.video) + module3D.pointcloud.connect(connection.lidar) + + module3D.annotations.transport = LCMTransport("/annotations", ImageAnnotations) + module3D.detections.transport = LCMTransport("/detections", Detection2DArray) + + module3D.detected_pointcloud_0.transport = LCMTransport("/detected/pointcloud/0", PointCloud2) + module3D.detected_pointcloud_1.transport = LCMTransport("/detected/pointcloud/1", PointCloud2) + module3D.detected_pointcloud_2.transport = LCMTransport("/detected/pointcloud/2", PointCloud2) + + module3D.detected_image_0.transport = LCMTransport("/detected/image/0", Image) + module3D.detected_image_1.transport = LCMTransport("/detected/image/1", Image) + module3D.detected_image_2.transport = LCMTransport("/detected/image/2", Image) + + module3D.start() + connection.start() + import time + + while True: + time.sleep(1) diff --git a/dimos/perception/detection2d/test_type.py b/dimos/perception/detection2d/test_type.py new file mode 100644 index 0000000000..25264ef727 --- /dev/null +++ b/dimos/perception/detection2d/test_type.py @@ -0,0 +1,239 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.perception.detection2d.conftest import detections2d, detections3d +from dimos.perception.detection2d.type import ( + Detection2D, + Detection3D, + ImageDetections2D, + ImageDetections3D, +) +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs import Transform, Vector3, PoseStamped + + +def test_detections2d(detections2d): + print(f"\n=== ImageDetections2D Test ===") + print(f"Type: {type(detections2d)}") + print(f"Number of detections: {len(detections2d)}") + print(f"Image timestamp: {detections2d.image.ts}") + print(f"Image shape: {detections2d.image.shape}") + print(f"Image frame_id: {detections2d.image.frame_id}") + + print(f"\nFull detections object:") + print(detections2d) + + # Basic type assertions + assert isinstance(detections2d, ImageDetections2D) + assert hasattr(detections2d, "image") + assert hasattr(detections2d, "detections") + + # Image assertions + assert detections2d.image is not None + assert detections2d.image.ts == 1757960670.490248 + assert detections2d.image.shape == (720, 1280, 3) + assert detections2d.image.frame_id == "camera_optical" + + # Detection count assertions + assert len(detections2d) == 1 + assert isinstance(detections2d.detections, list) + assert len(detections2d.detections) == 1 + + # Test first detection with literal checks + det = detections2d.detections[0] + print(f"\n--- Detection 0 (literal checks) ---") + print(f"Type: {type(det)}") + print(f"Name: {det.name}") + print(f"Class ID: {det.class_id}") + print(f"Track ID: {det.track_id}") + print(f"Confidence: {det.confidence}") + print(f"Bbox: {det.bbox}") + print(f"Timestamp: {det.ts}") + + # Detection type assertions + assert isinstance(det, Detection2D) + + # Literal value assertions + assert det.name == "suitcase" + assert det.class_id == 28 # COCO class 28 is suitcase + assert det.track_id == 1 + assert 0.814 < det.confidence < 0.815 # Allow small floating point variance + + # Data type assertions + assert isinstance(det.name, str) + assert isinstance(det.class_id, int) + assert isinstance(det.track_id, int) + assert isinstance(det.confidence, float) + assert isinstance(det.bbox, (tuple, list)) and len(det.bbox) == 4 + assert isinstance(det.ts, float) + + # Bbox literal checks (with tolerance for float precision) + x1, y1, x2, y2 = det.bbox + assert 503.4 < x1 < 503.5 + assert 249.8 < y1 < 250.0 + assert 655.9 < x2 < 656.0 + assert 469.8 < y2 < 470.0 + + # Bbox format assertions + assert all(isinstance(coord, (int, float)) for coord in det.bbox) + assert x2 > x1, f"x2 ({x2}) should be greater than x1 ({x1})" + assert y2 > y1, f"y2 ({y2}) should be greater than y1 ({y1})" + assert x1 >= 0 and y1 >= 0, "Bbox coordinates should be non-negative" + + # Bbox width/height checks + width = x2 - x1 + height = y2 - y1 + assert 152.0 < width < 153.0 # Expected width ~152.5 + assert 219.0 < height < 221.0 # Expected height ~219.9 + + # Confidence assertions + assert 0.0 <= det.confidence <= 1.0, ( + f"Confidence should be between 0 and 1, got {det.confidence}" + ) + + # Image reference assertion + assert det.image is detections2d.image, "Detection should reference the same image" + + # Timestamp consistency + assert det.ts == detections2d.image.ts + assert det.ts == 1757960670.490248 + + +def test_detections3d(detections3d): + print(f"\n=== ImageDetections3D Test ===") + print(f"Type: {type(detections3d)}") + print(f"Number of detections: {len(detections3d)}") + print(f"Image timestamp: {detections3d.image.ts}") + print(f"Image shape: {detections3d.image.shape}") + print(f"Image frame_id: {detections3d.image.frame_id}") + + print(f"\nFull detections object:") + print(detections3d) + + # Basic type assertions + assert isinstance(detections3d, ImageDetections3D) + assert hasattr(detections3d, "image") + assert hasattr(detections3d, "detections") + + # Image assertions + assert detections3d.image is not None + assert detections3d.image.ts == 1757960670.490248 + assert detections3d.image.shape == (720, 1280, 3) + assert detections3d.image.frame_id == "camera_optical" + + # Detection count assertions + assert len(detections3d) == 1 + assert isinstance(detections3d.detections, list) + assert len(detections3d.detections) == 1 + + # Test first 3D detection with literal checks + det = detections3d.detections[0] + print(f"\n--- Detection3D 0 (literal checks) ---") + print(f"Type: {type(det)}") + print(f"Name: {det.name}") + print(f"Class ID: {det.class_id}") + print(f"Track ID: {det.track_id}") + print(f"Confidence: {det.confidence}") + print(f"Bbox: {det.bbox}") + print(f"Timestamp: {det.ts}") + print(f"Has pointcloud: {hasattr(det, 'pointcloud')}") + print(f"Has transform: {hasattr(det, 'transform')}") + if hasattr(det, "pointcloud"): + print(f"Pointcloud points: {len(det.pointcloud)}") + print(f"Pointcloud frame_id: {det.pointcloud.frame_id}") + + # Detection type assertions + assert isinstance(det, Detection3D) + + # Detection3D should have all Detection2D fields plus pointcloud and transform + assert hasattr(det, "bbox") + assert hasattr(det, "track_id") + assert hasattr(det, "class_id") + assert hasattr(det, "confidence") + assert hasattr(det, "name") + assert hasattr(det, "ts") + assert hasattr(det, "image") + assert hasattr(det, "pointcloud") + assert hasattr(det, "transform") + + # Literal value assertions (should match Detection2D) + assert det.name == "suitcase" + assert det.class_id == 28 # COCO class 28 is suitcase + assert det.track_id == 1 + assert 0.814 < det.confidence < 0.815 # Allow small floating point variance + + # Data type assertions + assert isinstance(det.name, str) + assert isinstance(det.class_id, int) + assert isinstance(det.track_id, int) + assert isinstance(det.confidence, float) + assert isinstance(det.bbox, (tuple, list)) and len(det.bbox) == 4 + assert isinstance(det.ts, float) + + # Bbox literal checks (should match Detection2D) + x1, y1, x2, y2 = det.bbox + assert 503.4 < x1 < 503.5 + assert 249.8 < y1 < 250.0 + assert 655.9 < x2 < 656.0 + assert 469.8 < y2 < 470.0 + + # 3D-specific assertions + assert isinstance(det.pointcloud, PointCloud2) + assert isinstance(det.transform, Transform) + + # Pointcloud assertions + assert len(det.pointcloud) == 81 # Based on the output we saw + assert det.pointcloud.frame_id == "world" # Pointcloud should be in world frame + + # Test center calculation + center = det.center + print(f"\nDetection center: {center}") + assert isinstance(center, Vector3) + assert hasattr(center, "x") + assert hasattr(center, "y") + assert hasattr(center, "z") + + # Test pose property + pose = det.pose + print(f"Detection pose: {pose}") + assert isinstance(pose, PoseStamped) + assert pose.frame_id == "world" + assert pose.ts == det.ts + assert pose.position == center # Pose position should match center + + # Check distance calculation (from to_repr_dict) + repr_dict = det.to_repr_dict() + print(f"\nRepr dict: {repr_dict}") + assert "dist" in repr_dict + assert repr_dict["dist"] == "0.88m" # Based on the output + assert repr_dict["points"] == "81" + assert repr_dict["name"] == "suitcase" + assert repr_dict["class"] == "28" + assert repr_dict["track"] == "1" + + # Image reference assertion + assert det.image is detections3d.image, "Detection should reference the same image" + + # Timestamp consistency + assert det.ts == detections3d.image.ts + assert det.ts == 1757960670.490248 + + +def test_detection3d_to_pose(detections3d): + det = detections3d[0] + pose = det.pose + + # Check that pose is valid + assert pose.frame_id == "world" + assert pose.ts == det.ts diff --git a/dimos/perception/detection2d/type.py b/dimos/perception/detection2d/type.py new file mode 100644 index 0000000000..b37828e8e8 --- /dev/null +++ b/dimos/perception/detection2d/type.py @@ -0,0 +1,491 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools +import hashlib +from dataclasses import dataclass +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar + +import numpy as np +from dimos_lcm.foxglove_msgs.Color import Color +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( + PointsAnnotation, + TextAnnotation, +) +from dimos_lcm.foxglove_msgs.Point2 import Point2 +from dimos_lcm.vision_msgs import ( + BoundingBox2D, + ObjectHypothesis, + ObjectHypothesisWithPose, + Point2D, + Pose2D, +) +from dimos_lcm.vision_msgs import ( + Detection2D as ROSDetection2D, +) +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.types.timestamped import Timestamped, to_ros_stamp, to_timestamp + +Bbox = Tuple[float, float, float, float] +CenteredBbox = Tuple[float, float, float, float] +# yolo and detic have bad output formats +InconvinientDetectionFormat = Tuple[List[Bbox], List[int], List[int], List[float], List[str]] + +Detection = Tuple[Bbox, int, int, float, str] +Detections = List[Detection] + + +# yolo and detic have bad formats this translates into list of detections +def better_detection_format(inconvinient_detections: InconvinientDetectionFormat) -> Detections: + bboxes, track_ids, class_ids, confidences, names = inconvinient_detections + return [ + (bbox, track_id, class_id, confidence, name if name else "") + for bbox, track_id, class_id, confidence, name in zip( + bboxes, track_ids, class_ids, confidences, names + ) + ] + + +@dataclass +class Detection2D(Timestamped): + bbox: Bbox + track_id: int + class_id: int + confidence: float + name: str + ts: float + image: Image + + def to_repr_dict(self) -> Dict[str, Any]: + """Return a dictionary representation of the detection for display purposes.""" + x1, y1, x2, y2 = self.bbox + return { + "name": self.name, + "class": str(self.class_id), + "track": str(self.track_id), + "conf": self.confidence, + "bbox": f"[{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}]", + } + + def to_image(self) -> Image: + return self.image + + # return focused image, only on the bbox + def cropped_image(self, padding: int = 20) -> Image: + """Return a cropped version of the image focused on the bounding box. + + Args: + padding: Pixels to add around the bounding box (default: 20) + + Returns: + Cropped Image containing only the detection area plus padding + """ + x1, y1, x2, y2 = map(int, self.bbox) + return self.image.crop( + x1 - padding, y1 - padding, x2 - x1 + 2 * padding, y2 - y1 + 2 * padding + ) + + def __str__(self): + console = Console(force_terminal=True, legacy_windows=False) + d = self.to_repr_dict() + + # Create confidence text with color based on value + conf_color = "green" if d["conf"] > 0.8 else "yellow" if d["conf"] > 0.5 else "red" + conf_text = Text(f"{d['conf']:.1%}", style=conf_color) + + # Build the string representation + parts = [ + Text(f"{self.__class__.__name__}("), + Text(d["name"], style="bold cyan"), + Text(f" cls={d['class']} trk={d['track']} "), + conf_text, + Text(f" {d['bbox']}"), + ] + + # Add any extra fields (e.g., points for Detection3D) + extra_keys = [k for k in d.keys() if k not in ["name", "class", "track", "conf", "bbox"]] + for key in extra_keys: + if d[key] == "None": + parts.append(Text(f" {key}={d[key]}", style="dim")) + else: + parts.append(Text(f" {key}={d[key]}", style="blue")) + + parts.append(Text(")")) + + # Render to string + with console.capture() as capture: + console.print(*parts, end="") + return capture.get().strip() + + @classmethod + def from_detector( + cls, raw_detections: InconvinientDetectionFormat, **kwargs + ) -> List["Detection2D"]: + return [ + cls.from_detection(raw, **kwargs) for raw in better_detection_format(raw_detections) + ] + + @classmethod + def from_detection(cls, raw_detection: Detection, **kwargs) -> "Detection2D": + bbox, track_id, class_id, confidence, name = raw_detection + + return cls( + bbox=bbox, + track_id=track_id, + class_id=class_id, + confidence=confidence, + name=name, + **kwargs, + ) + + def get_bbox_center(self) -> CenteredBbox: + x1, y1, x2, y2 = self.bbox + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + width = float(x2 - x1) + height = float(y2 - y1) + return (center_x, center_y, width, height) + + def to_ros_bbox(self) -> BoundingBox2D: + center_x, center_y, width, height = self.get_bbox_center() + return BoundingBox2D( + center=Pose2D( + position=Point2D(x=center_x, y=center_y), + theta=0.0, + ), + size_x=width, + size_y=height, + ) + + def lcm_encode(self): + return self.to_imageannotations().lcm_encode() + + def to_text_annotation(self) -> List[TextAnnotation]: + x1, y1, x2, y2 = self.bbox + + font_size = 20 + + return [ + TextAnnotation( + timestamp=to_ros_stamp(self.ts), + position=Point2(x=x1, y=y2 + font_size), + text=f"confidence: {self.confidence:.3f}", + font_size=font_size, + text_color=Color(r=1.0, g=1.0, b=1.0, a=1), + background_color=Color(r=0, g=0, b=0, a=1), + ), + TextAnnotation( + timestamp=to_ros_stamp(self.ts), + position=Point2(x=x1, y=y1), + text=f"{self.name}_{self.class_id}_{self.track_id}", + font_size=font_size, + text_color=Color(r=1.0, g=1.0, b=1.0, a=1), + background_color=Color(r=0, g=0, b=0, a=1), + ), + ] + + def to_points_annotation(self) -> List[PointsAnnotation]: + x1, y1, x2, y2 = self.bbox + + thickness = 1 + + return [ + PointsAnnotation( + timestamp=to_ros_stamp(self.ts), + outline_color=Color(r=0.0, g=0.0, b=0.0, a=1.0), + fill_color=Color(r=1.0, g=0.0, b=0.0, a=0.15), + thickness=thickness, + points_length=4, + points=[ + Point2(x1, y1), + Point2(x1, y2), + Point2(x2, y2), + Point2(x2, y1), + ], + type=PointsAnnotation.LINE_LOOP, + ) + ] + + # this is almost never called directly since this is a single detection + # and ImageAnnotations message normally contains multiple detections annotations + # so ImageDetections2D and ImageDetections3D normally implements this for whole image + def to_annotations(self) -> ImageAnnotations: + points = self.to_points_annotation() + texts = self.to_text_annotation() + + return ImageAnnotations( + texts=texts, + texts_length=len(texts), + points=points, + points_length=len(points), + ) + + @classmethod + def from_ros_detection2d(cls, ros_det: ROSDetection2D, **kwargs) -> "Detection2D": + """Convert from ROS Detection2D message to Detection2D object.""" + # Extract bbox from ROS format + center_x = ros_det.bbox.center.position.x + center_y = ros_det.bbox.center.position.y + width = ros_det.bbox.size_x + height = ros_det.bbox.size_y + + # Convert centered bbox to corner format + x1 = center_x - width / 2.0 + y1 = center_y - height / 2.0 + x2 = center_x + width / 2.0 + y2 = center_y + height / 2.0 + bbox = (x1, y1, x2, y2) + + # Extract hypothesis info + class_id = 0 + confidence = 0.0 + if ros_det.results: + hypothesis = ros_det.results[0].hypothesis + class_id = hypothesis.class_id + confidence = hypothesis.score + + # Extract track_id + track_id = int(ros_det.id) if ros_det.id.isdigit() else 0 + + # Extract timestamp + ts = to_timestamp(ros_det.header.stamp) + + # Name is not stored in ROS Detection2D, so we'll use a placeholder + # Remove 'name' from kwargs if present to avoid duplicate + name = kwargs.pop("name", f"class_{class_id}") + + return cls( + bbox=bbox, + track_id=track_id, + class_id=class_id, + confidence=confidence, + name=name, + ts=ts, + **kwargs, + ) + + def to_ros_detection2d(self) -> ROSDetection2D: + return ROSDetection2D( + header=Header(self.ts, "camera_link"), + bbox=self.to_ros_bbox(), + results=[ + ObjectHypothesisWithPose( + ObjectHypothesis( + class_id=self.class_id, + score=self.confidence, + ) + ) + ], + id=str(self.track_id), + ) + + def to_3d(self, **kwargs) -> "Detection3D": + return Detection3D( + image=self.image, + bbox=self.bbox, + track_id=self.track_id, + class_id=self.class_id, + confidence=self.confidence, + name=self.name, + ts=self.ts, + **kwargs, + ) + + +@dataclass +class Detection3D(Detection2D): + pointcloud: PointCloud2 + transform: Transform + + def localize(self, pointcloud: PointCloud2) -> Detection3D: + self.pointcloud = pointcloud + return self + + @functools.cached_property + def center(self) -> Vector3: + """Calculate the center of the pointcloud in world frame.""" + points = np.asarray(self.pointcloud.pointcloud.points) + center = points.mean(axis=0) + return Vector3(*center) + + @functools.cached_property + def pose(self) -> PoseStamped: + """Convert detection to a PoseStamped using pointcloud center. + + Returns pose in world frame with identity rotation. + The pointcloud is already in world frame. + """ + return PoseStamped( + ts=self.ts, + frame_id="world", + position=self.center, + orientation=(0.0, 0.0, 0.0, 1.0), # Identity quaternion + ) + + def to_repr_dict(self) -> Dict[str, Any]: + d = super().to_repr_dict() + + # Add pointcloud info + d["points"] = str(len(self.pointcloud)) + + # Calculate distance from camera + # The pointcloud is in world frame, and transform gives camera position in world + center_world = self.center + # Camera position in world frame is the translation part of the transform + camera_pos = self.transform.translation + # Use Vector3 subtraction and magnitude + distance = (center_world - camera_pos).magnitude() + d["dist"] = f"{distance:.2f}m" + + return d + + +T = TypeVar("T", bound="Detection2D") + + +def _hash_to_color(name: str) -> str: + """Generate a consistent color for a given name using hash.""" + # List of rich colors to choose from + colors = [ + "cyan", + "magenta", + "yellow", + "blue", + "green", + "red", + "bright_cyan", + "bright_magenta", + "bright_yellow", + "bright_blue", + "bright_green", + "bright_red", + "purple", + "white", + "pink", + ] + + # Hash the name and pick a color + hash_value = hashlib.md5(name.encode()).digest()[0] + return colors[hash_value % len(colors)] + + +class ImageDetections(Generic[T]): + image: Image + detections: List[T] + + def __init__(self, image: Image, detections: List[T]): + self.image = image + self.detections = detections + for det in self.detections: + if not det.ts: + det.ts = image.ts + + def __str__(self): + console = Console(force_terminal=True, legacy_windows=False) + + # Dynamically build columns based on the first detection's dict keys + if not self.detections: + return "Empty ImageDetections" + + # Create a table for detections + table = Table( + title=f"{self.__class__.__name__} [{len(self.detections)} detections @ {to_timestamp(self.image.ts):.3f}]", + show_header=True, + show_edge=True, + ) + + # Cache all repr_dicts to avoid double computation + detection_dicts = [det.to_repr_dict() for det in self.detections] + + first_dict = detection_dicts[0] + table.add_column("#", style="dim") + for col in first_dict.keys(): + color = _hash_to_color(col) + table.add_column(col.title(), style=color) + + # Add each detection to the table + for i, d in enumerate(detection_dicts): + row = [str(i)] + + for key in first_dict.keys(): + if key == "conf": + # Color-code confidence + conf_color = "green" if d[key] > 0.8 else "yellow" if d[key] > 0.5 else "red" + row.append(Text(f"{d[key]:.1%}", style=conf_color)) + elif key == "points" and d.get(key) == "None": + row.append(Text(d.get(key, ""), style="dim")) + else: + row.append(str(d.get(key, ""))) + table.add_row(*row) + + with console.capture() as capture: + console.print(table) + return capture.get().strip() + + def __len__(self): + return len(self.detections) + + def __iter__(self): + return iter(self.detections) + + def __getitem__(self, index): + return self.detections[index] + + def to_ros_detection2d_array(self) -> Detection2DArray: + return Detection2DArray( + detections_length=len(self.detections), + header=Header(self.image.ts, "camera_optical"), + detections=[det.to_ros_detection2d() for det in self.detections], + ) + + def to_image_annotations(self) -> ImageAnnotations: + def flatten(xss): + return [x for xs in xss for x in xs] + + texts = flatten(det.to_text_annotation() for det in self.detections) + points = flatten(det.to_points_annotation() for det in self.detections) + + return ImageAnnotations( + texts=texts, + texts_length=len(texts), + points=points, + points_length=len(points), + ) + + +class ImageDetections2D(ImageDetections[Detection2D]): + @classmethod + def from_detector( + cls, image: Image, raw_detections: InconvinientDetectionFormat, **kwargs + ) -> "ImageDetections2D": + return cls( + image=image, + detections=Detection2D.from_detector(raw_detections, image=image, ts=image.ts), + ) + + +class ImageDetections3D(ImageDetections[Detection3D]): + """Specialized class for 3D detections in an image.""" + + ... diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 1c7bb32101..9d1db5ed16 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -80,13 +80,12 @@ def __init__(self, buffer_size: float = 10.0): def add(self, transform: Transform) -> None: super().add(transform) - self._prune_old_transforms() + self._prune_old_transforms(transform.ts) - def _prune_old_transforms(self) -> None: + def _prune_old_transforms(self, current_time) -> None: if not self._items: return - current_time = time.time() cutoff_time = current_time - self.buffer_size while self._items and self._items[0].ts < cutoff_time: @@ -113,8 +112,10 @@ def __str__(self) -> str: time_range = self.time_range() if time_range: - start_time = time.strftime("%H:%M:%S", time.localtime(time_range[0])) - end_time = time.strftime("%H:%M:%S", time.localtime(time_range[1])) + from dimos.types.timestamped import to_human_readable + + start_time = to_human_readable(time_range[0]) + end_time = to_human_readable(time_range[1]) duration = time_range[1] - time_range[0] frame_str = ( diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index faad61833d..b57fed8e92 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -33,6 +33,7 @@ from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import Pose, Transform, Twist, Vector3 from dimos.msgs.sensor_msgs import Image +from dimos.robot.connection_interface import ConnectionInterface from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg from dimos.robot.unitree_webrtc.type.odometry import Odometry @@ -197,7 +198,7 @@ def raw_odom_stream(self) -> Subject[Pose]: def lidar_stream(self) -> Subject[LidarMessage]: return backpressure( self.raw_lidar_stream().pipe( - ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame)) + ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) ) ) @@ -210,6 +211,20 @@ def tf_stream(self) -> Subject[Transform]: def odom_stream(self) -> Subject[Pose]: return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) + @functools.cache + def video_stream(self) -> Observable[Image]: + return backpressure( + self.raw_video_stream().pipe( + ops.filter(lambda frame: frame is not None), + ops.map( + lambda frame: Image.from_numpy( + frame.to_ndarray(format="rgb24"), + frame_id="camera_optical", + ) + ), + ) + ) + @functools.cache def lowstate_stream(self) -> Subject[LowStateMsg]: return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) @@ -253,8 +268,8 @@ def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: }, ) - @functools.cache - def raw_video_stream(self) -> Subject[VideoMessage]: + @functools.lru_cache(maxsize=None) + def raw_video_stream(self) -> Observable[VideoMessage]: subject: Subject[VideoMessage] = Subject() stop_event = threading.Event() @@ -286,14 +301,6 @@ def switch_video_channel_off(): return subject.pipe(ops.finally_action(stop)) - @functools.cache - def video_stream(self) -> Observable[VideoMessage]: - return backpressure( - self.raw_video_stream().pipe( - ops.map(lambda frame: Image.from_numpy(frame.to_ndarray(format="rgb24"))) - ) - ) - def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: """Get the video stream from the robot's camera. diff --git a/dimos/robot/unitree_webrtc/modular/__init__.py b/dimos/robot/unitree_webrtc/modular/__init__.py new file mode 100644 index 0000000000..d823cd796e --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/__init__.py @@ -0,0 +1,2 @@ +from dimos.robot.unitree_webrtc.modular.connection_module import deploy_connection +from dimos.robot.unitree_webrtc.modular.navigation import deploy_navigation diff --git a/dimos/robot/unitree_webrtc/modular/connection_module.py b/dimos/robot/unitree_webrtc/modular/connection_module.py index 289cc622e0..c6214c4f2c 100644 --- a/dimos/robot/unitree_webrtc/modular/connection_module.py +++ b/dimos/robot/unitree_webrtc/modular/connection_module.py @@ -15,27 +15,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import functools import logging +import os import time import warnings +from dataclasses import dataclass +from typing import List, Optional import reactivex as rx from dimos_lcm.sensor_msgs import CameraInfo from reactivex import operators as ops from reactivex.observable import Observable -from dimos.core import In, Module, Out, rpc +from dimos.core import In, LCMTransport, Module, ModuleConfig, Out, rpc, DimosCluster +from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 from dimos.msgs.sensor_msgs.Image import Image, sharpness_window from dimos.msgs.std_msgs import Header +from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.reactive import backpressure +from dimos.utils.testing import TimedSensorReplay, TimedSensorStorage logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) @@ -47,19 +51,29 @@ logging.getLogger("asyncio").setLevel(logging.ERROR) logging.getLogger("root").setLevel(logging.WARNING) + # Suppress warnings warnings.filterwarnings("ignore", message="coroutine.*was never awaited") warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") image_resize_factor = 1 originalwidth, originalheight = (1280, 720) -get_data("unitree_raw_webrtc_replay") class FakeRTC(UnitreeWebRTCConnection): + dir_name = "unitree_go2_office_walk2" + # we don't want UnitreeWebRTCConnection to init - def __init__(self, *args, **kwargs): - pass + def __init__( + self, + **kwargs, + ): + get_data(self.dir_name) + self.replay_config = { + "loop": kwargs.get("loop"), + "seek": kwargs.get("seek"), + "duration": kwargs.get("duration"), + } def connect(self): pass @@ -74,30 +88,26 @@ def liedown(self): print("liedown suppressed") @functools.cache - def raw_lidar_stream(self): + def lidar_stream(self): print("lidar stream start") - lidar_store = TimedSensorReplay("unitree_raw_webrtc_replay/lidar") - return lidar_store.stream() + lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") + return lidar_store.stream(**self.replay_config) @functools.cache - def raw_odom_stream(self): + def odom_stream(self): print("odom stream start") - odom_store = TimedSensorReplay("unitree_raw_webrtc_replay/odom") - return odom_store.stream() + odom_store = TimedSensorReplay(f"{self.dir_name}/odom") + return odom_store.stream(**self.replay_config) # we don't have raw video stream in the data set @functools.cache - def raw_video_stream(self): + def video_stream(self): print("video stream start") video_store = TimedSensorReplay( - "unitree_raw_webrtc_replay/video", - autocast=lambda f: Image.from_numpy(f.to_ndarray() if hasattr(f, "to_ndarray") else f), + f"{self.dir_name}/video", ) - return video_store.stream() - @functools.cache - def video_stream(self): - return self.raw_video_stream() + return video_store.stream(**self.replay_config) def move(self, vector: Twist, duration: float = 0.0): pass @@ -107,82 +117,88 @@ def publish_request(self, topic: str, data: dict): return {"status": "ok", "message": "Fake publish"} +@dataclass +class ConnectionModuleConfig(ModuleConfig): + ip: Optional[str] = None + connection_type: str = "fake" # or "fake" or "mujoco" + loop: bool = False # For fake connection + speed: float = 1.0 # For fake connection + + class ConnectionModule(Module): - ip: str - connection_type: str = "webrtc" camera_info: Out[CameraInfo] = None odom: Out[PoseStamped] = None lidar: Out[LidarMessage] = None video: Out[Image] = None movecmd: In[Twist] = None - def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwargs): - self.ip = ip + connection = None + + default_config = ConnectionModuleConfig + + def __init__(self, connection_type: str = "webrtc", *args, **kwargs): + self.connection_config = kwargs self.connection_type = connection_type - self.connection = None Module.__init__(self, *args, **kwargs) @rpc def record(self, recording_name: str): - from dimos.utils.testing import TimedSensorStorage - - lidar_store = TimedSensorStorage(f"{recording_name}/lidar") - lidar_store.save_stream(self.connection.raw_lidar_stream()).subscribe(lambda x: x) + lidar_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/lidar") + lidar_store.save_stream(self.connection.lidar_stream()).subscribe(lambda x: x) - odom_store = TimedSensorStorage(f"{recording_name}/odom") - odom_store.save_stream(self.connection.raw_odom_stream()).subscribe(lambda x: x) + odom_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/odom") + odom_store.save_stream(self.connection.odom_stream()).subscribe(lambda x: x) - video_store = TimedSensorStorage(f"{recording_name}/video") - video_store.save_stream(self.connection.raw_video_stream()).subscribe(lambda x: x) + video_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/video") + video_store.save_stream(self.connection.video_stream()).subscribe(lambda x: x) @rpc def start(self): """Start the connection and subscribe to sensor streams.""" match self.connection_type: case "webrtc": - self.connection = UnitreeWebRTCConnection(self.ip) + self.connection = UnitreeWebRTCConnection(**self.connection_config) case "fake": - self.connection = FakeRTC() + self.connection = FakeRTC(**self.connection_config) case "mujoco": from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection - self.connection = MujocoConnection() + self.connection = MujocoConnection(**self.connection_config) self.connection.start() case _: raise ValueError(f"Unknown connection type: {self.connection_type}") - - def image_pub(img): - self.video.publish(img) - - # Connect sensor streams to outputs - self.connection.lidar_stream().subscribe(self.lidar.publish) self.connection.odom_stream().subscribe( lambda odom: self._publish_tf(odom) and self.odom.publish(odom) ) - def attach_frame_id(image: Image) -> Image: - image.frame_id = "camera_optical" + # Connect sensor streams to outputs + self.connection.lidar_stream().subscribe(self.lidar.publish) + + # self.connection.lidar_stream().subscribe(lambda lidar: print("LIDAR", lidar.ts)) + # self.connection.video_stream().subscribe(lambda video: print("IMAGE", video.ts)) + # self.connection.odom_stream().subscribe(lambda odom: print("ODOM", odom.ts)) + def resize(image: Image) -> Image: return image.resize( int(originalwidth / image_resize_factor), int(originalheight / image_resize_factor) ) - # sharpness_window( - # 10, self.connection.video_stream().pipe(ops.map(attach_frame_id)) - # ).subscribe(image_pub) - self.connection.video_stream().pipe(ops.map(attach_frame_id)).subscribe(image_pub) + sharpness = sharpness_window(10, self.connection.video_stream()) + sharpness.subscribe(self.video.publish) + # self.connection.video_stream().subscribe(self.video.publish) + + # self.connection.video_stream().pipe(ops.map(resize)).subscribe(self.video.publish) self.camera_info_stream().subscribe(self.camera_info.publish) self.movecmd.subscribe(self.connection.move) - def _publish_tf(self, msg): - self.odom.publish(msg) - + @classmethod + def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: camera_link = Transform( translation=Vector3(0.3, 0.0, 0.0), rotation=Quaternion(0.0, 0.0, 0.0, 1.0), frame_id="base_link", child_frame_id="camera_link", - ts=time.time(), + ts=odom.ts, ) camera_optical = Transform( @@ -190,14 +206,18 @@ def _publish_tf(self, msg): rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), frame_id="camera_link", child_frame_id="camera_optical", - ts=camera_link.ts, + ts=odom.ts, ) - self.tf.publish( - Transform.from_pose("base_link", msg), + return [ + Transform.from_pose("base_link", odom), camera_link, camera_optical, - ) + ] + + def _publish_tf(self, msg): + self.odom.publish(msg) + self.tf.publish(*self._odom_to_tf(msg)) @rpc def publish_request(self, topic: str, data: dict): @@ -210,8 +230,8 @@ def publish_request(self, topic: str, data: dict): """ return self.connection.publish_request(topic, data) - @functools.cache - def camera_info_stream(self) -> Observable[CameraInfo]: + @classmethod + def _camera_info(self) -> Out[CameraInfo]: fx, fy, cx, cy = list( map( lambda x: int(x / image_resize_factor), @@ -250,11 +270,28 @@ def camera_info_stream(self) -> Observable[CameraInfo]: "binning_y": 0, } - return rx.interval(1).pipe( - ops.map( - lambda x: CameraInfo( - **base_msg, - header=Header("camera_optical"), - ) - ) - ) + return CameraInfo(**base_msg, header=Header("camera_optical")) + + @functools.cache + def camera_info_stream(self) -> Observable[CameraInfo]: + return rx.interval(1).pipe(ops.map(lambda _: self._camera_info())) + + +def deploy_connection(dimos: DimosCluster, **kwargs): + # foxglove_bridge = dimos.deploy(FoxgloveBridge) + # foxglove_bridge.start() + + connection = dimos.deploy( + ConnectionModule, + ip=os.getenv("ROBOT_IP"), + connection_type=os.getenv("CONNECTION_TYPE", "fake"), + **kwargs, + ) + + connection.lidar.transport = LCMTransport("/lidar", LidarMessage) + connection.odom.transport = LCMTransport("/odom", PoseStamped) + connection.video.transport = LCMTransport("/image", Image) + connection.movecmd.transport = LCMTransport("/cmd_vel", Vector3) + connection.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + + return connection diff --git a/dimos/robot/unitree_webrtc/modular/detect.py b/dimos/robot/unitree_webrtc/modular/detect.py new file mode 100644 index 0000000000..7d0ded7ac8 --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/detect.py @@ -0,0 +1,180 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle + +from dimos_lcm.sensor_msgs import CameraInfo + +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry + +image_resize_factor = 1 +originalwidth, originalheight = (1280, 720) + + +def camera_info() -> CameraInfo: + fx, fy, cx, cy = list( + map( + lambda x: int(x / image_resize_factor), + [819.553492, 820.646595, 625.284099, 336.808987], + ) + ) + width, height = tuple( + map( + lambda x: int(x / image_resize_factor), + [originalwidth, originalheight], + ) + ) + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion coefficients for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + base_msg = { + "D_length": len(D), + "height": height, + "width": width, + "distortion_model": "plumb_bob", + "D": D, + "K": K, + "R": R, + "P": P, + "binning_x": 0, + "binning_y": 0, + } + + return CameraInfo( + **base_msg, + header=Header("camera_optical"), + ) + + +def transform_chain(odom_frame: Odometry) -> list: + from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 + from dimos.protocol.tf import TF + + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=odom_frame.ts, + ) + + camera_optical = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ts=camera_link.ts, + ) + + tf = TF() + tf.publish( + Transform.from_pose("base_link", odom_frame), + camera_link, + camera_optical, + ) + + return tf + + +def broadcast( + timestamp: float, + lidar_frame: LidarMessage, + video_frame: Image, + odom_frame: Odometry, + detections, + annotations, +): + from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations + + from dimos.core import LCMTransport + from dimos.msgs.geometry_msgs import PoseStamped + + lidar_transport = LCMTransport("/lidar", LidarMessage) + odom_transport = LCMTransport("/odom", PoseStamped) + video_transport = LCMTransport("/image", Image) + camera_info_transport = LCMTransport("/camera_info", CameraInfo) + + lidar_transport.broadcast(None, lidar_frame) + video_transport.broadcast(None, video_frame) + odom_transport.broadcast(None, odom_frame) + camera_info_transport.broadcast(None, camera_info()) + + transform_chain(odom_frame) + + print(lidar_frame) + print(video_frame) + print(odom_frame) + video_transport = LCMTransport("/image", Image) + annotations_transport = LCMTransport("/annotations", ImageAnnotations) + annotations_transport.broadcast(None, annotations) + + +def process_data(): + from dimos.msgs.sensor_msgs import Image + from dimos.perception.detection2d.module import Detect2DModule, build_imageannotations + from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + from dimos.robot.unitree_webrtc.type.odometry import Odometry + from dimos.utils.data import get_data + from dimos.utils.testing import TimedSensorReplay + + get_data("unitree_office_walk") + target = 1751591272.9654856 + lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) + video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) + odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + def attach_frame_id(image: Image) -> Image: + image.frame_id = "camera_optical" + return image + + lidar_frame = lidar_store.find_closest(target, tolerance=1) + video_frame = attach_frame_id(video_store.find_closest(target, tolerance=1)) + odom_frame = odom_store.find_closest(target, tolerance=1) + + detector = Detect2DModule() + detections = detector.detect(video_frame) + annotations = build_imageannotations(detections) + + data = (target, lidar_frame, video_frame, odom_frame, detections, annotations) + + with open("filename.pkl", "wb") as file: + pickle.dump(data, file) + + return data + + +def main(): + try: + with open("filename.pkl", "rb") as file: + data = pickle.load(file) + except FileNotFoundError: + print("Processing data and creating pickle file...") + data = process_data() + broadcast(*data) + + +main() diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py index b96f1e7af5..5893248530 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py @@ -13,150 +13,65 @@ # limitations under the License. import logging -import os import time -from typing import Optional from dimos_lcm.sensor_msgs import CameraInfo -from dimos_lcm.std_msgs import Bool, String from dimos.core import LCMTransport, start + +# from dimos.msgs.detection2d import Detection2DArray from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid, Path -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState -from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer -from dimos.navigation.global_planner import AstarPlanner -from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner -from dimos.perception.detection2d import Detect2DModule +from dimos.perception.detection2d import Detection3DModule from dimos.protocol.pubsub import lcm -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.robot.unitree_webrtc.modular import deploy_connection, deploy_navigation from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map from dimos.utils.logging_config import setup_logger -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) -def deploy_foxglove(dimos, connection, mapper, global_planner): - """Deploy and configure visualization modules.""" - websocket_vis = dimos.deploy(WebsocketVisModule, port=7779) - websocket_vis.click_goal.transport = LCMTransport("/goal_request", PoseStamped) - websocket_vis.explore_cmd.transport = LCMTransport("/explore_cmd", Bool) - websocket_vis.stop_explore_cmd.transport = LCMTransport("/stop_explore_cmd", Bool) - websocket_vis.movecmd.transport = LCMTransport("/cmd_vel", Twist) - - websocket_vis.robot_pose.connect(connection.odom) - websocket_vis.path.connect(global_planner.path) - websocket_vis.global_costmap.connect(mapper.global_costmap) - - connection.movecmd.connect(websocket_vis.movecmd) - foxglove_bridge = FoxgloveBridge() - - websocket_vis.start() - foxglove_bridge.start() - return websocket_vis, foxglove_bridge - - -def deploy_navigation(dimos, connection): - mapper = dimos.deploy(Map, voxel_size=0.5, cost_resolution=0.05, global_publish_interval=1.0) - mapper.lidar.connect(connection.lidar) - mapper.global_map.transport = LCMTransport("/global_map", LidarMessage) - mapper.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) - mapper.local_costmap.transport = LCMTransport("/local_costmap", OccupancyGrid) - - """Deploy and configure navigation modules.""" - global_planner = dimos.deploy(AstarPlanner) - local_planner = dimos.deploy(HolonomicLocalPlanner) - navigator = dimos.deploy( - BehaviorTreeNavigator, - reset_local_planner=local_planner.reset, - check_goal_reached=local_planner.is_goal_reached, - ) - frontier_explorer = dimos.deploy(WavefrontFrontierExplorer) - - navigator.goal.transport = LCMTransport("/navigation_goal", PoseStamped) - navigator.goal_request.transport = LCMTransport("/goal_request", PoseStamped) - navigator.goal_reached.transport = LCMTransport("/goal_reached", Bool) - navigator.navigation_state.transport = LCMTransport("/navigation_state", String) - navigator.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) - global_planner.path.transport = LCMTransport("/global_path", Path) - local_planner.cmd_vel.transport = LCMTransport("/cmd_vel", Twist) - frontier_explorer.goal_request.transport = LCMTransport("/goal_request", PoseStamped) - frontier_explorer.goal_reached.transport = LCMTransport("/goal_reached", Bool) - frontier_explorer.explore_cmd.transport = LCMTransport("/explore_cmd", Bool) - frontier_explorer.stop_explore_cmd.transport = LCMTransport("/stop_explore_cmd", Bool) - - global_planner.target.connect(navigator.goal) - - global_planner.global_costmap.connect(mapper.global_costmap) - global_planner.odom.connect(connection.odom) - - local_planner.path.connect(global_planner.path) - local_planner.local_costmap.connect(mapper.local_costmap) - local_planner.odom.connect(connection.odom) - - connection.movecmd.connect(local_planner.cmd_vel) - - navigator.odom.connect(connection.odom) - - frontier_explorer.costmap.connect(mapper.global_costmap) - frontier_explorer.odometry.connect(connection.odom) - mapper.start() - global_planner.start() - local_planner.start() - navigator.start() - - return mapper, global_planner - - -class UnitreeGo2: - def __init__( - self, - ip: str, - connection_type: Optional[str] = "webrtc", - ): - dimos = start(3) - - connection = dimos.deploy(ConnectionModule, ip, connection_type) - connection.lidar.transport = LCMTransport("/lidar", LidarMessage) - connection.odom.transport = LCMTransport("/odom", PoseStamped) - connection.video.transport = LCMTransport("/image", Image) - connection.movecmd.transport = LCMTransport("/cmd_vel", Twist) - connection.camera_info.transport = LCMTransport("/camera_info", CameraInfo) - connection.start() - - # connection.record("unitree_raw_webrtc_replay") - - detection = dimos.deploy(Detect2DModule) - detection.image.connect(connection.video) - detection.detections.transport = LCMTransport("/detections", Detection2DArray) - detection.annotations.transport = LCMTransport("/annotations", ImageAnnotations) - detection.start() - - mapper, global_planner = deploy_navigation(dimos, connection) - deploy_foxglove(dimos, connection, mapper, global_planner) - - def stop(): ... +def detection_unitree(): + dimos = start(6) + connection = deploy_connection(dimos) + connection.start() + # connection.record("unitree_go2_office_walk2") + # mapper = deploy_navigation(dimos, connection) -def main(): - lcm.autoconf() - robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), connection_type=os.getenv("CONNECTION_TYPE", "fake") - ) + module3D = dimos.deploy(Detection3DModule, camera_info=ConnectionModule._camera_info()) + + module3D.image.connect(connection.video) + module3D.pointcloud.connect(connection.lidar) + + module3D.annotations.transport = LCMTransport("/annotations", ImageAnnotations) + module3D.detections.transport = LCMTransport("/detections", Detection2DArray) + + module3D.detected_pointcloud_0.transport = LCMTransport("/detected/pointcloud/0", PointCloud2) + module3D.detected_pointcloud_1.transport = LCMTransport("/detected/pointcloud/1", PointCloud2) + module3D.detected_pointcloud_2.transport = LCMTransport("/detected/pointcloud/2", PointCloud2) + + module3D.detected_image_0.transport = LCMTransport("/detected/image/0", Image) + module3D.detected_image_1.transport = LCMTransport("/detected/image/1", Image) + module3D.detected_image_2.transport = LCMTransport("/detected/image/2", Image) + module3D.start() + # detection.start() try: while True: time.sleep(1) except KeyboardInterrupt: - robot.stop() + connection.stop() + # mapper.stop() + # detection.stop() logger.info("Shutting down...") +def main(): + lcm.autoconf() + detection_unitree() + + if __name__ == "__main__": main() diff --git a/dimos/robot/unitree_webrtc/modular/navigation.py b/dimos/robot/unitree_webrtc/modular/navigation.py new file mode 100644 index 0000000000..8ceaf0e195 --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/navigation.py @@ -0,0 +1,86 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.std_msgs import Bool, String + +from dimos.core import LCMTransport +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator +from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer +from dimos.navigation.global_planner import AstarPlanner +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule + + +def deploy_navigation(dimos, connection): + mapper = dimos.deploy(Map, voxel_size=0.5, cost_resolution=0.05, global_publish_interval=1.0) + mapper.lidar.connect(connection.lidar) + mapper.global_map.transport = LCMTransport("/global_map", LidarMessage) + mapper.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) + mapper.local_costmap.transport = LCMTransport("/local_costmap", OccupancyGrid) + + """Deploy and configure navigation modules.""" + global_planner = dimos.deploy(AstarPlanner) + local_planner = dimos.deploy(HolonomicLocalPlanner) + navigator = dimos.deploy( + BehaviorTreeNavigator, + reset_local_planner=local_planner.reset, + check_goal_reached=local_planner.is_goal_reached, + ) + frontier_explorer = dimos.deploy(WavefrontFrontierExplorer) + + navigator.goal.transport = LCMTransport("/navigation_goal", PoseStamped) + navigator.goal_request.transport = LCMTransport("/goal_request", PoseStamped) + navigator.goal_reached.transport = LCMTransport("/goal_reached", Bool) + navigator.navigation_state.transport = LCMTransport("/navigation_state", String) + navigator.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) + global_planner.path.transport = LCMTransport("/global_path", Path) + local_planner.cmd_vel.transport = LCMTransport("/cmd_vel", Vector3) + frontier_explorer.goal_request.transport = LCMTransport("/goal_request", PoseStamped) + frontier_explorer.goal_reached.transport = LCMTransport("/goal_reached", Bool) + frontier_explorer.explore_cmd.transport = LCMTransport("/explore_cmd", Bool) + frontier_explorer.stop_explore_cmd.transport = LCMTransport("/stop_explore_cmd", Bool) + + global_planner.target.connect(navigator.goal) + + global_planner.global_costmap.connect(mapper.global_costmap) + global_planner.odom.connect(connection.odom) + + local_planner.path.connect(global_planner.path) + local_planner.local_costmap.connect(mapper.local_costmap) + local_planner.odom.connect(connection.odom) + + connection.movecmd.connect(local_planner.cmd_vel) + + navigator.odom.connect(connection.odom) + + frontier_explorer.costmap.connect(mapper.global_costmap) + frontier_explorer.odometry.connect(connection.odom) + websocket_vis = dimos.deploy(WebsocketVisModule, port=7779) + websocket_vis.click_goal.transport = LCMTransport("/goal_request", PoseStamped) + + websocket_vis.robot_pose.connect(connection.odom) + websocket_vis.path.connect(global_planner.path) + websocket_vis.global_costmap.connect(mapper.global_costmap) + + mapper.start() + global_planner.start() + local_planner.start() + navigator.start() + websocket_vis.start() + + return mapper diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index fec56f9f44..aefd9654e1 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time from copy import copy from typing import List, Optional, TypedDict @@ -20,7 +21,7 @@ from dimos.msgs.geometry_msgs import Vector3 from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.robot.unitree_webrtc.type.timeseries import to_human_readable +from dimos.types.timestamped import to_human_readable class RawLidarPoints(TypedDict): @@ -64,7 +65,7 @@ def __init__(self, **kwargs): self.resolution = kwargs.get("resolution", 0.05) @classmethod - def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg) -> "LidarMessage": + def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": data = raw_message["data"] points = data["data"]["points"] pointcloud = o3d.geometry.PointCloud() @@ -75,14 +76,16 @@ def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg) -> "LidarMessage": # to shift the pointcloud by it's origin # # pointcloud.translate((origin / 2).to_tuple()) - - return cls( - origin=origin, - resolution=data["resolution"], - pointcloud=pointcloud, - ts=data["stamp"], - raw_msg=raw_message, - ) + cls_data = { + "origin": origin, + "resolution": data["resolution"], + "pointcloud": pointcloud, + # - this is broken in unitree webrtc api "stamp":1.758148e+09 + "ts": time.time(), # data["stamp"], + "raw_msg": raw_message, + **kwargs, + } + return cls(**cls_data) def __repr__(self): return f"LidarMessage(ts={to_human_readable(self.ts)}, origin={self.origin}, resolution={self.resolution}, {self.pointcloud})" diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py index 27d59f2cb8..c307929a00 100644 --- a/dimos/robot/unitree_webrtc/type/odometry.py +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -20,6 +20,7 @@ from dimos.robot.unitree_webrtc.type.timeseries import ( Timestamped, ) +from dimos.types.timestamped import to_human_readable, to_timestamp raw_odometry_msg_sample = { "type": "msg", @@ -97,6 +98,9 @@ def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": pose["orientation"].get("w"), ) + # ts = to_timestamp(msg["data"]["header"]["stamp"]) + # lidar / video timestamps are not available from the robot + # so we are deferring to local time for everything ts = time.time() return Odometry(position=pos, orientation=rot, ts=ts, frame_id="world") diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index 39b3f606b5..36f86b2ebb 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -42,6 +42,13 @@ def to_timestamp(ts: TimeLike) -> float: return float(ts) if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: return ts["sec"] + ts["nanosec"] / 1e9 + # Check for ROS Time-like objects by attributes + if hasattr(ts, "sec") and (hasattr(ts, "nanosec") or hasattr(ts, "nsec")): + # Handle both std_msgs.Time (nsec) and builtin_interfaces.Time (nanosec) + if hasattr(ts, "nanosec"): + return ts.sec + ts.nanosec / 1e9 + else: # has nsec + return ts.sec + ts.nsec / 1e9 raise TypeError("unsupported timestamp type") @@ -56,6 +63,13 @@ def to_ros_stamp(ts: TimeLike) -> ROSTime: return ROSTime(sec=sec, nanosec=nanosec) +def to_human_readable(ts: float) -> str: + """Convert timestamp to human-readable format with date and time.""" + import time + + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(ts)) + + def to_datetime(ts: TimeLike, tz=None) -> datetime: if isinstance(ts, datetime): if ts.tzinfo is None: @@ -85,8 +99,8 @@ def __init__(self, ts: float): def dt(self) -> datetime: return datetime.fromtimestamp(self.ts, tz=timezone.utc).astimezone() - def ros_timestamp(self) -> dict[str, int]: - """Convert timestamp to ROS-style dictionary.""" + def ros_timestamp(self) -> list[int]: + """Convert timestamp to ROS-style list [sec, nanosec].""" sec = int(self.ts) nanosec = int((self.ts - sec) * 1_000_000_000) return [sec, nanosec] @@ -265,3 +279,62 @@ def dispose(): return dispose return create(subscribe) + + +def align_timestamped_multiple( + primary_observable: Observable[PRIMARY], + *secondary_observables: Observable[SECONDARY], + buffer_size: float = 1.0, # seconds + match_tolerance: float = 0.05, # seconds +) -> Observable[Tuple[PRIMARY, ...]]: + """Align a primary observable with multiple secondary observables. + + Args: + primary_observable: The primary stream to align against + *secondary_observables: Secondary streams to align + buffer_size: Time window to keep secondary messages in seconds + match_tolerance: Maximum time difference for matching in seconds + + Returns: + Observable that emits tuples of (primary_item, secondary1, secondary2, ...) + where each secondary item is the closest match from the corresponding + secondary observable, or None if no match within tolerance. + """ + from reactivex import create + + def subscribe(observer, scheduler=None): + # Create a buffer collection for each secondary observable + secondary_collections: list[TimestampedBufferCollection[SECONDARY]] = [ + TimestampedBufferCollection(buffer_size) for _ in secondary_observables + ] + + # Subscribe to all secondary observables + secondary_subs = [] + for i, secondary_obs in enumerate(secondary_observables): + sub = secondary_obs.subscribe(secondary_collections[i].add) + secondary_subs.append(sub) + + def on_primary(primary_item: PRIMARY): + # Find closest match from each secondary collection + secondary_items = [] + for collection in secondary_collections: + secondary_item = collection.find_closest(primary_item.ts, tolerance=match_tolerance) + secondary_items.append(secondary_item) + + # Emit the aligned tuple (flatten into single tuple) + observer.on_next((primary_item, *secondary_items)) + + # Subscribe to primary and emit aligned tuples + primary_sub = primary_observable.subscribe( + on_next=on_primary, on_error=observer.on_error, on_completed=observer.on_completed + ) + + # Return cleanup function + def dispose(): + for sub in secondary_subs: + sub.dispose() + primary_sub.dispose() + + return dispose + + return create(subscribe) diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index bf69209617..017b267c1b 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -14,10 +14,12 @@ import hashlib import os +import re import subprocess -from reactivex import operators as ops import reactivex as rx +from reactivex import operators as ops + from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils import testing @@ -43,7 +45,7 @@ def test_sensor_replay_cast(): def test_timed_sensor_replay(): - data = get_data("unitree_office_walk") + get_data("unitree_office_walk") odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) itermsgs = [] @@ -66,3 +68,216 @@ def test_timed_sensor_replay(): for i in range(10): print(itermsgs[i], timed_msgs[i]) assert itermsgs[i] == timed_msgs[i] + + +def test_iterate_ts_no_seek(): + """Test iterate_ts without seek (start_timestamp=None)""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Test without seek + ts_msgs = [] + for ts, msg in odom_store.iterate_ts(): + ts_msgs.append((ts, msg)) + if len(ts_msgs) >= 5: + break + + assert len(ts_msgs) == 5 + # Check that we get tuples of (timestamp, data) + for ts, msg in ts_msgs: + assert isinstance(ts, float) + assert isinstance(msg, Odometry) + + +def test_iterate_ts_with_from_timestamp(): + """Test iterate_ts with from_timestamp (absolute timestamp)""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # First get all messages to find a good seek point + all_msgs = [] + for ts, msg in odom_store.iterate_ts(): + all_msgs.append((ts, msg)) + if len(all_msgs) >= 10: + break + + # Seek to timestamp of 5th message + seek_timestamp = all_msgs[4][0] + + # Test with from_timestamp + seeked_msgs = [] + for ts, msg in odom_store.iterate_ts(from_timestamp=seek_timestamp): + seeked_msgs.append((ts, msg)) + if len(seeked_msgs) >= 5: + break + + assert len(seeked_msgs) == 5 + # First message should be at or after seek timestamp + assert seeked_msgs[0][0] >= seek_timestamp + # Should match the data from position 5 onward + assert seeked_msgs[0][1] == all_msgs[4][1] + + +def test_iterate_ts_with_relative_seek(): + """Test iterate_ts with seek (relative seconds after first timestamp)""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Get first few messages to understand timing + all_msgs = [] + for ts, msg in odom_store.iterate_ts(): + all_msgs.append((ts, msg)) + if len(all_msgs) >= 10: + break + + # Calculate relative seek time (e.g., 0.5 seconds after start) + first_ts = all_msgs[0][0] + seek_seconds = 0.5 + expected_start_ts = first_ts + seek_seconds + + # Test with relative seek + seeked_msgs = [] + for ts, msg in odom_store.iterate_ts(seek=seek_seconds): + seeked_msgs.append((ts, msg)) + if len(seeked_msgs) >= 5: + break + + # First message should be at or after expected timestamp + assert seeked_msgs[0][0] >= expected_start_ts + # Make sure we're actually skipping some messages + assert seeked_msgs[0][0] > first_ts + + +def test_stream_with_seek(): + """Test stream method with seek parameters""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Test stream with relative seek + msgs_with_seek = [] + for msg in odom_store.stream(seek=0.2).pipe(ops.take(5), ops.to_list()).run(): + msgs_with_seek.append(msg) + + assert len(msgs_with_seek) == 5 + + # Test stream with from_timestamp + # First get a reference timestamp + first_msgs = [] + for msg in odom_store.stream().pipe(ops.take(3), ops.to_list()).run(): + first_msgs.append(msg) + + # Now test from_timestamp (would need actual timestamps from iterate_ts to properly test) + # This is a basic test to ensure the parameter is accepted + msgs_with_timestamp = [] + for msg in ( + odom_store.stream(from_timestamp=1000000000.0).pipe(ops.take(3), ops.to_list()).run() + ): + msgs_with_timestamp.append(msg) + + +def test_duration_with_loop(): + """Test duration parameter with looping in TimedSensorReplay""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Collect timestamps from a small duration window + collected_ts = [] + duration = 0.3 # 300ms window + + # First pass: collect timestamps in the duration window + for ts, msg in odom_store.iterate_ts(duration=duration): + collected_ts.append(ts) + if len(collected_ts) >= 100: # Safety limit + break + + # Should have some messages but not too many + assert len(collected_ts) > 0 + assert len(collected_ts) < 20 # Assuming ~30Hz data + + # Test looping with duration - should repeat the same window + loop_count = 0 + prev_ts = None + + for ts, msg in odom_store.iterate_ts(duration=duration, loop=True): + if prev_ts is not None and ts < prev_ts: + # We've looped back to the beginning + loop_count += 1 + if loop_count >= 2: # Stop after 2 full loops + break + prev_ts = ts + + assert loop_count >= 2 # Verify we actually looped + + +def test_first_methods(): + """Test first() and first_timestamp() methods""" + + # Test SensorReplay.first() + lidar_replay = testing.SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + print("first file", lidar_replay.files[0]) + # Verify the first file ends with 000.pickle using regex + assert re.search(r"000\.pickle$", str(lidar_replay.files[0])), ( + f"Expected first file to end with 000.pickle, got {lidar_replay.files[0]}" + ) + + first_msg = lidar_replay.first() + assert first_msg is not None + assert isinstance(first_msg, LidarMessage) + + # Verify it's the same type as first item from iterate() + first_from_iterate = next(lidar_replay.iterate()) + print("DONE") + assert type(first_msg) is type(first_from_iterate) + # Since LidarMessage.from_msg uses time.time(), timestamps will be slightly different + assert abs(first_msg.ts - first_from_iterate.ts) < 1.0 # Within 1 second tolerance + + # Test TimedSensorReplay.first_timestamp() + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + first_ts = odom_store.first_timestamp() + assert first_ts is not None + assert isinstance(first_ts, float) + + # Verify it matches the timestamp from iterate_ts + ts_from_iterate, _ = next(odom_store.iterate_ts()) + assert first_ts == ts_from_iterate + + # Test that first() returns just the data + first_data = odom_store.first() + assert first_data is not None + assert isinstance(first_data, Odometry) + + +def test_find_closest(): + """Test find_closest method in TimedSensorReplay""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Get some reference timestamps + timestamps = [] + for ts, msg in odom_store.iterate_ts(): + timestamps.append(ts) + if len(timestamps) >= 10: + break + + # Test exact match + target_ts = timestamps[5] + result = odom_store.find_closest(target_ts) + assert result is not None + assert isinstance(result, Odometry) + + # Test between timestamps + mid_ts = (timestamps[3] + timestamps[4]) / 2 + result = odom_store.find_closest(mid_ts) + assert result is not None + + # Test with tolerance + far_future = timestamps[-1] + 100.0 + result = odom_store.find_closest(far_future, tolerance=1.0) + assert result is None # Too far away + + result = odom_store.find_closest(timestamps[0] - 0.001, tolerance=0.01) + assert result is not None # Within tolerance + + # Test find_closest_seek + result = odom_store.find_closest_seek(0.5) # 0.5 seconds from start + assert result is not None + assert isinstance(result, Odometry) + + # Test with negative seek (before start) + result = odom_store.find_closest_seek(-1.0) + assert result is not None # Should still return closest (first frame) diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index 26a6517fff..8930b2f0e9 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -11,26 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import functools import glob +import logging import os import pickle +import re import time from pathlib import Path from typing import Any, Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union from reactivex import ( - concat, - concat_with_iterable, - empty, from_iterable, interval, - just, - merge, - timer, ) from reactivex import operators as ops -from reactivex import timer as rx_timer from reactivex.observable import Observable from reactivex.scheduler import TimeoutScheduler @@ -71,23 +66,41 @@ def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: return self.autocast(data) return data - def iterate(self) -> Iterator[Union[T, Any]]: - pattern = os.path.join(self.root_dir, "*") - for file_path in sorted( - glob.glob(pattern), - key=lambda x: int(os.path.basename(x).split(".")[0]) - if os.path.basename(x).split(".")[0].isdigit() - else 0, - ): - yield self.load_one(Path(file_path)) - - def stream(self, rate_hz: Optional[float] = None) -> Observable[Union[T, Any]]: + def first(self) -> Optional[Union[T, Any]]: + try: + return next(self.iterate()) + except StopIteration: + return None + + @functools.cached_property + def files(self) -> list[Path]: + def extract_number(filepath): + """Extract last digits before .pickle extension""" + basename = os.path.basename(filepath) + match = re.search(r"(\d+)\.pickle$", basename) + return int(match.group(1)) if match else 0 + + return sorted( + glob.glob(os.path.join(self.root_dir, "*")), + key=extract_number, + ) + + def iterate(self, loop: bool = False) -> Iterator[Union[T, Any]]: + while True: + for file_path in self.files: + yield self.load_one(Path(file_path)) + if not loop: + break + + def stream( + self, rate_hz: Optional[float] = None, loop: bool = False + ) -> Observable[Union[T, Any]]: if rate_hz is None: - return from_iterable(self.iterate()) + return from_iterable(self.iterate(loop=loop)) sleep_time = 1.0 / rate_hz - return from_iterable(self.iterate()).pipe( + return from_iterable(self.iterate(loop=loop)).pipe( ops.zip(interval(sleep_time)), ops.map(lambda x: x[0] if isinstance(x, tuple) else x), ) @@ -176,19 +189,118 @@ def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: return (data[0], self.autocast(data[1])) return data - def iterate(self) -> Iterator[Union[T, Any]]: - return (x[1] for x in super().iterate()) - - def iterate_ts(self) -> Iterator[Union[Tuple[float, T], Any]]: - return super().iterate() - - def stream(self, speed=1.0) -> Observable[Union[T, Any]]: + def find_closest( + self, timestamp: float, tolerance: Optional[float] = None + ) -> Optional[Union[T, Any]]: + """Find the frame closest to the given timestamp. + + Args: + timestamp: The target timestamp to search for + tolerance: Optional maximum time difference allowed + + Returns: + The data frame closest to the timestamp, or None if no match within tolerance + """ + closest_data = None + closest_diff = float("inf") + + # Check frames before and after the timestamp + for ts, data in self.iterate_ts(): + diff = abs(ts - timestamp) + + if diff < closest_diff: + closest_diff = diff + closest_data = data + elif diff > closest_diff: + # We're moving away from the target, can stop + break + + if tolerance is not None and closest_diff > tolerance: + return None + + return closest_data + + def find_closest_seek( + self, relative_seconds: float, tolerance: Optional[float] = None + ) -> Optional[Union[T, Any]]: + """Find the frame closest to a time relative to the start. + + Args: + relative_seconds: Seconds from the start of the dataset + tolerance: Optional maximum time difference allowed + + Returns: + The data frame closest to the relative timestamp, or None if no match within tolerance + """ + # Get the first timestamp + first_ts = self.first_timestamp() + if first_ts is None: + return None + + # Calculate absolute timestamp and use find_closest + target_timestamp = first_ts + relative_seconds + return self.find_closest(target_timestamp, tolerance) + + def first_timestamp(self) -> Optional[float]: + """Get the timestamp of the first item in the dataset. + + Returns: + The first timestamp, or None if dataset is empty + """ + try: + ts, _ = next(self.iterate_ts()) + return ts + except StopIteration: + return None + + def iterate(self, loop: bool = False) -> Iterator[Union[T, Any]]: + return (x[1] for x in super().iterate(loop=loop)) + + def iterate_ts( + self, + seek: Optional[float] = None, + duration: Optional[float] = None, + from_timestamp: Optional[float] = None, + loop: bool = False, + ) -> Iterator[Union[Tuple[float, T], Any]]: + first_ts = None + if (seek is not None) or (duration is not None): + first_ts = self.first_timestamp() + if first_ts is None: + return + + if seek is not None: + from_timestamp = first_ts + seek + + end_timestamp = None + if duration is not None: + end_timestamp = (from_timestamp if from_timestamp else first_ts) + duration + + while True: + for ts, data in super().iterate(): + if from_timestamp is None or ts >= from_timestamp: + if end_timestamp is not None and ts >= end_timestamp: + break + yield (ts, data) + if not loop: + break + + def stream( + self, + speed=1.0, + seek: Optional[float] = None, + duration: Optional[float] = None, + from_timestamp: Optional[float] = None, + loop: bool = False, + ) -> Observable[Union[T, Any]]: def _subscribe(observer, scheduler=None): from reactivex.disposable import CompositeDisposable, Disposable scheduler = scheduler or TimeoutScheduler() # default thread-based - iterator = self.iterate_ts() + iterator = self.iterate_ts( + seek=seek, duration=duration, from_timestamp=from_timestamp, loop=loop + ) try: prev_ts, first_data = next(iterator) diff --git a/flake.nix b/flake.nix index 7101de506f..0061153089 100644 --- a/flake.nix +++ b/flake.nix @@ -36,6 +36,11 @@ ### GTK / OpenCV helpers glib gtk3 gdk-pixbuf gobject-introspection + + ### GStreamer + gst_all_1.gstreamer gst_all_1.gst-plugins-base gst_all_1.gst-plugins-good + gst_all_1.gst-plugins-bad gst_all_1.gst-plugins-ugly + python312Packages.gst-python ### Open3D & build-time eigen cmake ninja jsoncpp libjpeg libpng @@ -57,9 +62,11 @@ pkgs.xorg.libXrender pkgs.xorg.libXdamage pkgs.xorg.libXcomposite pkgs.xorg.libxcb pkgs.xorg.libXScrnSaver pkgs.xorg.libXxf86vm pkgs.udev pkgs.portaudio pkgs.SDL2.dev pkgs.zlib pkgs.glib pkgs.gtk3 - pkgs.gdk-pixbuf pkgs.gobject-introspection pkgs.lcm pkgs.pcre2]}:$LD_LIBRARY_PATH" + pkgs.gdk-pixbuf pkgs.gobject-introspection pkgs.lcm pkgs.pcre2 + pkgs.gst_all_1.gstreamer pkgs.gst_all_1.gst-plugins-base]}:$LD_LIBRARY_PATH" export DISPLAY=:0 + export GI_TYPELIB_PATH="${pkgs.gst_all_1.gstreamer}/lib/girepository-1.0:${pkgs.gst_all_1.gst-plugins-base}/lib/girepository-1.0:$GI_TYPELIB_PATH" PROJECT_ROOT=$(git rev-parse --show-toplevel 2>/dev/null || echo "$PWD") if [ -f "$PROJECT_ROOT/env/bin/activate" ]; then