Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions dimos/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,19 @@ class RemoteOut(RemoteStream[T]):
def connect(self, other: RemoteIn[T]):
return other.connect(self)

def observable(self):
"""Create an Observable stream from this remote output."""
from reactivex import create

def subscribe(observer, scheduler=None):
def on_msg(msg):
observer.on_next(msg)

self._transport.subscribe(self, on_msg)
return lambda: None

return create(subscribe)


class In(Stream[T]):
connection: Optional[RemoteOut[T]] = None
Expand Down
389 changes: 241 additions & 148 deletions dimos/perception/object_tracker.py

Large diffs are not rendered by default.

262 changes: 183 additions & 79 deletions dimos/perception/person_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,28 @@
from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector
from dimos.perception.detection2d.utils import filter_detections
from dimos.perception.common.ibvs import PersonDistanceEstimator
from reactivex import Observable
from reactivex import Observable, interval
from reactivex import operators as ops
import numpy as np
import cv2
from typing import Dict, Optional

from dimos.core import In, Out, Module, rpc
from dimos.msgs.sensor_msgs import Image
from dimos.utils.logging_config import setup_logger

logger = setup_logger("dimos.perception.person_tracker")


class PersonTrackingStream(Module):
"""Module for person tracking with LCM input/output."""

# LCM inputs
video: In[Image] = None

# LCM outputs
tracking_data: Out[Dict] = None

class PersonTrackingStream:
def __init__(
self,
camera_intrinsics=None,
Expand All @@ -40,6 +55,13 @@ def __init__(
camera_pitch: Camera pitch angle in radians (positive is up)
camera_height: Height of the camera from the ground in meters
"""
# Call parent Module init
super().__init__()

self.camera_intrinsics = camera_intrinsics
self.camera_pitch = camera_pitch
self.camera_height = camera_height

self.detector = Yolo2DDetector()

# Initialize distance estimator
Expand All @@ -61,6 +83,161 @@ def __init__(
K=K, camera_pitch=camera_pitch, camera_height=camera_height
)

# For tracking latest frame data
self._latest_frame: Optional[np.ndarray] = None
self._process_interval = 0.1 # Process at 10Hz

# Tracking state - starts disabled
self._tracking_enabled = False

@rpc
def start(self):
"""Start the person tracking module and subscribe to LCM streams."""

# Subscribe to video stream
def set_video(image_msg: Image):
if hasattr(image_msg, "data"):
self._latest_frame = image_msg.data
else:
logger.warning("Received image message without data attribute")

self.video.subscribe(set_video)

# Start periodic processing
interval(self._process_interval).subscribe(lambda _: self._process_frame())

logger.info("PersonTracking module started and subscribed to LCM streams")

def _process_frame(self):
"""Process the latest frame if available."""
if self._latest_frame is None:
return

# Only process and publish if tracking is enabled
if not self._tracking_enabled:
return

# Process frame through tracking pipeline
result = self._process_tracking(self._latest_frame)

# Publish result to LCM
if result:
self.tracking_data.publish(result)

def _process_tracking(self, frame):
"""Process a single frame for person tracking."""
# Detect people in the frame
bboxes, track_ids, class_ids, confidences, names = self.detector.process_image(frame)

# Filter to keep only person detections using filter_detections
(
filtered_bboxes,
filtered_track_ids,
filtered_class_ids,
filtered_confidences,
filtered_names,
) = filter_detections(
bboxes,
track_ids,
class_ids,
confidences,
names,
class_filter=[0], # 0 is the class_id for person
name_filter=["person"],
)

# Create visualization
viz_frame = self.detector.visualize_results(
frame,
filtered_bboxes,
filtered_track_ids,
filtered_class_ids,
filtered_confidences,
filtered_names,
)

# Calculate distance and angle for each person
targets = []
for i, bbox in enumerate(filtered_bboxes):
target_data = {
"target_id": filtered_track_ids[i] if i < len(filtered_track_ids) else -1,
"bbox": bbox,
"confidence": filtered_confidences[i] if i < len(filtered_confidences) else None,
}

distance, angle = self.distance_estimator.estimate_distance_angle(bbox)
target_data["distance"] = distance
target_data["angle"] = angle

# Add text to visualization
x1, y1, x2, y2 = map(int, bbox)
dist_text = f"{distance:.2f}m, {np.rad2deg(angle):.1f} deg"

# Add black background for better visibility
text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
# Position at top-right corner
cv2.rectangle(
viz_frame, (x2 - text_size[0], y1 - text_size[1] - 5), (x2, y1), (0, 0, 0), -1
)

# Draw text in white at top-right
cv2.putText(
viz_frame,
dist_text,
(x2 - text_size[0], y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
2,
)

targets.append(target_data)

# Create the result dictionary
return {"frame": frame, "viz_frame": viz_frame, "targets": targets}

@rpc
def enable_tracking(self) -> bool:
"""Enable person tracking.

Returns:
bool: True if tracking was enabled successfully
"""
self._tracking_enabled = True
logger.info("Person tracking enabled")
return True

@rpc
def disable_tracking(self) -> bool:
"""Disable person tracking.

Returns:
bool: True if tracking was disabled successfully
"""
self._tracking_enabled = False
logger.info("Person tracking disabled")
return True

@rpc
def is_tracking_enabled(self) -> bool:
"""Check if tracking is currently enabled.

Returns:
bool: True if tracking is enabled
"""
return self._tracking_enabled

@rpc
def get_tracking_data(self) -> Dict:
"""Get the latest tracking data.

Returns:
Dictionary containing tracking results
"""
if self._latest_frame is not None:
return self._process_tracking(self._latest_frame)
return {"frame": None, "viz_frame": None, "targets": []}

def create_stream(self, video_stream: Observable) -> Observable:
"""
Create an Observable stream of person tracking results from a video stream.
Expand All @@ -72,83 +249,10 @@ def create_stream(self, video_stream: Observable) -> Observable:
Observable that emits dictionaries containing tracking results and visualizations
"""

def process_frame(frame):
# Detect people in the frame
bboxes, track_ids, class_ids, confidences, names = self.detector.process_image(frame)

# Filter to keep only person detections using filter_detections
(
filtered_bboxes,
filtered_track_ids,
filtered_class_ids,
filtered_confidences,
filtered_names,
) = filter_detections(
bboxes,
track_ids,
class_ids,
confidences,
names,
class_filter=[0], # 0 is the class_id for person
name_filter=["person"],
)

# Create visualization
viz_frame = self.detector.visualize_results(
frame,
filtered_bboxes,
filtered_track_ids,
filtered_class_ids,
filtered_confidences,
filtered_names,
)

# Calculate distance and angle for each person
targets = []
for i, bbox in enumerate(filtered_bboxes):
target_data = {
"target_id": filtered_track_ids[i] if i < len(filtered_track_ids) else -1,
"bbox": bbox,
"confidence": filtered_confidences[i]
if i < len(filtered_confidences)
else None,
}

distance, angle = self.distance_estimator.estimate_distance_angle(bbox)
target_data["distance"] = distance
target_data["angle"] = angle

# Add text to visualization
x1, y1, x2, y2 = map(int, bbox)
dist_text = f"{distance:.2f}m, {np.rad2deg(angle):.1f} deg"

# Add black background for better visibility
text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
# Position at top-right corner
cv2.rectangle(
viz_frame, (x2 - text_size[0], y1 - text_size[1] - 5), (x2, y1), (0, 0, 0), -1
)

# Draw text in white at top-right
cv2.putText(
viz_frame,
dist_text,
(x2 - text_size[0], y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
2,
)

targets.append(target_data)

# Create the result dictionary
result = {"frame": frame, "viz_frame": viz_frame, "targets": targets}

return result

return video_stream.pipe(ops.map(process_frame))
return video_stream.pipe(ops.map(self._process_tracking))

@rpc
def cleanup(self):
"""Clean up resources."""
pass # No specific cleanup needed for now
# CUDA cleanup is now handled by WorkerPlugin in dimos.core
pass
Loading
Loading