Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions dimos/perception/detection2d/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate
from dimos_lcm.visualization_msgs.MarkerArray import MarkerArray

from dimos.core import LCMTransport
from dimos.msgs.geometry_msgs import Transform
from dimos.msgs.sensor_msgs import CameraInfo
from dimos.msgs.sensor_msgs.Image import Image
from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2
from dimos.msgs.vision_msgs import Detection2DArray
from dimos.perception.detection2d.module2D import Detection2DModule
from dimos.perception.detection2d.module3D import Detection3DModule
from dimos.perception.detection2d.moduleDB import ObjectDBModule
from dimos.perception.detection2d.type import (
Detection2D,
Detection3D,
Detection3DPC,
ImageDetections2D,
ImageDetections3D,
ImageDetections3DPC,
)
from dimos.protocol.tf import TF
from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule
Expand All @@ -47,7 +50,7 @@ class Moment(TypedDict, total=False):
transforms: list[Transform]
tf: TF
annotations: Optional[ImageAnnotations]
detections: Optional[ImageDetections3D]
detections: Optional[ImageDetections3DPC]
markers: Optional[MarkerArray]
scene_update: Optional[SceneUpdate]

Expand All @@ -57,7 +60,7 @@ class Moment2D(Moment):


class Moment3D(Moment):
detections3d: ImageDetections3D
detections3dpc: ImageDetections3D


@pytest.fixture
Expand Down Expand Up @@ -102,6 +105,47 @@ def moment_provider(**kwargs) -> Moment:
return moment_provider


@pytest.fixture
def publish_moment():
def publisher(moment: Moment | Moment2D | Moment3D):
if moment.get("detections2d"):
# 2d annotations
annotations = LCMTransport("/annotations", ImageAnnotations)
annotations.publish(moment.get("detections2d").to_foxglove_annotations())

detections = LCMTransport("/detections", Detection2DArray)
detections.publish(moment.get("detections2d").to_ros_detection2d_array())

annotations.lcm.stop()
detections.lcm.stop()

if moment.get("detections3dpc"):
scene_update = LCMTransport("/scene_update", SceneUpdate)
# 3d scene update
scene_update.publish(moment.get("detections3dpc").to_foxglove_scene_update())
scene_update.lcm.stop()

lidar = LCMTransport("/lidar", PointCloud2)
lidar.publish(moment.get("lidar_frame"))
lidar.lcm.stop()

image = LCMTransport("/image", Image)
image.publish(moment.get("image_frame"))
image.lcm.stop()

camera_info = LCMTransport("/camera_info", CameraInfo)
camera_info.publish(moment.get("camera_info"))
camera_info.lcm.stop()

tf = moment.get("tf")
tf.publish(*moment.get("transforms"))

# moduleDB.scene_update.transport = LCMTransport("/scene_update", SceneUpdate)
# moduleDB.target.transport = LCMTransport("/target", PoseStamped)

return publisher


@pytest.fixture
def detection2d(get_moment_2d) -> Detection2D:
moment = get_moment_2d(seek=10.0)
Expand All @@ -110,11 +154,11 @@ def detection2d(get_moment_2d) -> Detection2D:


@pytest.fixture
def detection3d(get_moment_3d) -> Detection3D:
def detection3dpc(get_moment_3d) -> Detection3DPC:
moment = get_moment_3d(seek=10.0)
assert len(moment["detections3d"]) > 0, "No detections found in the moment"
print(moment["detections3d"])
return moment["detections3d"][0]
assert len(moment["detections3dpc"]) > 0, "No detections found in the moment"
print(moment["detections3dpc"])
return moment["detections3dpc"][0]


@pytest.fixture
Expand Down Expand Up @@ -150,7 +194,7 @@ def moment_provider(**kwargs) -> Moment2D:

return {
**moment,
"detections3d": module.process_frame(
"detections3dpc": module.process_frame(
moment["detections2d"], moment["lidar_frame"], camera_transform
),
}
Expand Down
11 changes: 5 additions & 6 deletions dimos/perception/detection2d/module3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from dimos.perception.detection2d.type import (
ImageDetections2D,
ImageDetections3D,
ImageDetections3DPC,
)
from dimos.perception.detection2d.type.detection3d import Detection3D
from dimos.perception.detection2d.type.detection3dpc import Detection3DPC
from dimos.types.timestamped import align_timestamped
from dimos.utils.reactive import backpressure

Expand All @@ -40,7 +41,7 @@ class Detection3DModule(Detection2DModule):
detected_pointcloud_1: Out[PointCloud2] = None # type: ignore
detected_pointcloud_2: Out[PointCloud2] = None # type: ignore

detection_3d_stream: Observable[ImageDetections3D] = None
detection_3d_stream: Observable[ImageDetections3DPC] = None

def __init__(self, camera_info: CameraInfo, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -55,10 +56,9 @@ def process_frame(
if not transform:
return ImageDetections3D(detections.image, [])

print("3d projection", detections, pointcloud, transform)
detection3d_list = []
for detection in detections:
detection3d = Detection3D.from_2d(
detection3d = Detection3DPC.from_2d(
detection,
world_pointcloud=pointcloud,
camera_info=self.camera_info,
Expand All @@ -67,8 +67,7 @@ def process_frame(
if detection3d is not None:
detection3d_list.append(detection3d)

ret = ImageDetections3D(detections.image, detection3d_list)
print("3d projection finished", ret)
ret = ImageDetections3DPC(detections.image, detection3d_list)
return ret

@rpc
Expand Down
9 changes: 8 additions & 1 deletion dimos/perception/detection2d/type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,12 @@
ImageDetections2D,
InconvinientDetectionFormat,
)
from dimos.perception.detection2d.type.detection3d import Detection3D, ImageDetections3D
from dimos.perception.detection2d.type.detection3d import (
Detection3D,
ImageDetections3D,
)
from dimos.perception.detection2d.type.detection3dpc import (
Detection3DPC,
ImageDetections3DPC,
)
from dimos.perception.detection2d.type.imageDetections import ImageDetections, TableStr
180 changes: 3 additions & 177 deletions dimos/perception/detection2d/type/detection3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,195 +32,21 @@
from dimos.perception.detection2d.type.imageDetections import ImageDetections
from dimos.types.timestamped import to_ros_stamp

Detection3DFilter = Callable[
[Detection2D, PointCloud2, CameraInfo, Transform], Optional["Detection3D"]
]


def height_filter(height=0.1) -> Detection3DFilter:
return lambda det, pc, ci, tf: pc.filter_by_height(height)


def statistical(nb_neighbors=40, std_ratio=0.5) -> Detection3DFilter:
def filter_func(
det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform
) -> Optional[PointCloud2]:
try:
statistical, removed = pc.pointcloud.remove_statistical_outlier(
nb_neighbors=nb_neighbors, std_ratio=std_ratio
)
return PointCloud2(statistical, pc.frame_id, pc.ts)
except Exception as e:
# print("statistical filter failed:", e)
return None

return filter_func


def raycast() -> Detection3DFilter:
def filter_func(
det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform
) -> Optional[PointCloud2]:
try:
camera_pos = tf.inverse().translation
camera_pos_np = camera_pos.to_numpy()
_, visible_indices = pc.pointcloud.hidden_point_removal(camera_pos_np, radius=100.0)
visible_pcd = pc.pointcloud.select_by_index(visible_indices)
return PointCloud2(visible_pcd, pc.frame_id, pc.ts)
except Exception as e:
# print("raycast filter failed:", e)
return None

return filter_func


def radius_outlier(min_neighbors: int = 20, radius: float = 0.3) -> Detection3DFilter:
"""
Remove isolated points: keep only points that have at least `min_neighbors`
neighbors within `radius` meters (same units as your point cloud).
"""

def filter_func(
det: Detection2D, pc: PointCloud2, ci: CameraInfo, tf: Transform
) -> Optional[PointCloud2]:
filtered_pcd, removed = pc.pointcloud.remove_radius_outlier(
nb_points=min_neighbors, radius=radius
)
return PointCloud2(filtered_pcd, pc.frame_id, pc.ts)

return filter_func


@dataclass
class Detection3D(Detection2D):
pointcloud: PointCloud2
transform: Transform
frame_id: str = "unknown"
frame_id: str

@classmethod
def from_2d(
cls,
det: Detection2D,
world_pointcloud: PointCloud2,
distance: float,
camera_info: CameraInfo,
world_to_optical_transform: Transform,
# filters are to be adjusted based on the sensor noise characteristics if feeding
# sensor data directly
filters: list[Callable[[PointCloud2], PointCloud2]] = [
# height_filter(0.1),
raycast(),
radius_outlier(),
statistical(),
],
) -> Optional["Detection3D"]:
"""Create a Detection3D from a 2D detection by projecting world pointcloud.

This method handles:
1. Projecting world pointcloud to camera frame
2. Filtering points within the 2D detection bounding box
3. Cleaning up the pointcloud (height filter, outlier removal)
4. Hidden point removal from camera perspective

Args:
det: The 2D detection
world_pointcloud: Full pointcloud in world frame
camera_info: Camera calibration info
world_to_camerlka_transform: Transform from world to camera frame
filters: List of functions to apply to the pointcloud for filtering
Returns:
Detection3D with filtered pointcloud, or None if no valid points
"""
# Extract camera parameters
fx, fy = camera_info.K[0], camera_info.K[4]
cx, cy = camera_info.K[2], camera_info.K[5]
image_width = camera_info.width
image_height = camera_info.height

camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])

# Convert pointcloud to numpy array
world_points = world_pointcloud.as_numpy()

# Project points to camera frame
points_homogeneous = np.hstack([world_points, np.ones((world_points.shape[0], 1))])
extrinsics_matrix = world_to_optical_transform.to_matrix()
points_camera = (extrinsics_matrix @ points_homogeneous.T).T

# Filter out points behind the camera
valid_mask = points_camera[:, 2] > 0
points_camera = points_camera[valid_mask]
world_points = world_points[valid_mask]

if len(world_points) == 0:
return None

# Project to 2D
points_2d_homogeneous = (camera_matrix @ points_camera[:, :3].T).T
points_2d = points_2d_homogeneous[:, :2] / points_2d_homogeneous[:, 2:3]

# Filter points within image bounds
in_image_mask = (
(points_2d[:, 0] >= 0)
& (points_2d[:, 0] < image_width)
& (points_2d[:, 1] >= 0)
& (points_2d[:, 1] < image_height)
)
points_2d = points_2d[in_image_mask]
world_points = world_points[in_image_mask]

if len(world_points) == 0:
return None

# Extract bbox from Detection2D
x_min, y_min, x_max, y_max = det.bbox

# Find points within this detection box (with small margin)
margin = 5 # pixels
in_box_mask = (
(points_2d[:, 0] >= x_min - margin)
& (points_2d[:, 0] <= x_max + margin)
& (points_2d[:, 1] >= y_min - margin)
& (points_2d[:, 1] <= y_max + margin)
)

detection_points = world_points[in_box_mask]

if detection_points.shape[0] == 0:
# print(f"No points found in detection bbox after projection. {det.name}")
return None

# Create initial pointcloud for this detection
initial_pc = PointCloud2.from_numpy(
detection_points,
frame_id=world_pointcloud.frame_id,
timestamp=world_pointcloud.ts,
)

# Apply filters - each filter needs all 4 arguments
detection_pc = initial_pc
for filter_func in filters:
result = filter_func(det, detection_pc, camera_info, world_to_optical_transform)
if result is None:
return None
detection_pc = result

# Final check for empty pointcloud
if len(detection_pc.pointcloud.points) == 0:
return None

# Create Detection3D with filtered pointcloud
return Detection3D(
image=det.image,
bbox=det.bbox,
track_id=det.track_id,
class_id=det.class_id,
confidence=det.confidence,
name=det.name,
ts=det.ts,
pointcloud=detection_pc,
transform=world_to_optical_transform,
frame_id=world_pointcloud.frame_id,
)
raise NotImplementedError()

@functools.cached_property
def center(self) -> Vector3:
Expand Down
Loading
Loading