Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions dimos/agents/skills/person_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import Event, RLock
from threading import Event, RLock, Thread
import time
from typing import TYPE_CHECKING

from langchain_core.messages import HumanMessage
import numpy as np
from reactivex.disposable import Disposable

from dimos.agents.agent import AgentSpec
from dimos.agents.annotation import skill
from dimos.core.core import rpc
from dimos.core.global_config import GlobalConfig
Expand Down Expand Up @@ -54,6 +56,7 @@ class PersonFollowSkillContainer(Module):
global_map: In[PointCloud2]
cmd_vel: Out[Twist]

_agent_spec: AgentSpec
_frequency: float = 20.0 # Hz - control loop frequency
_max_lost_frames: int = 15 # number of frames to wait before declaring person lost

Expand All @@ -70,6 +73,7 @@ def __init__(
self._latest_pointcloud: PointCloud2 | None = None
self._vl_model: VlModel = QwenVlModel()
self._tracker: EdgeTAMProcessor | None = None
self._thread: Thread | None = None
self._should_stop: Event = Event()
self._lock = RLock()

Expand Down Expand Up @@ -139,7 +143,7 @@ def follow_person(self, query: str) -> str:
if initial_bbox is None:
return f"Could not find '{query}' in the current view."

return self._follow_loop(query, initial_bbox)
return self._follow_person(query, initial_bbox)

@skill
def stop_following(self) -> str:
Expand All @@ -152,6 +156,10 @@ def stop_following(self) -> str:

self.cmd_vel.publish(Twist.zero())

if self._thread is not None:
self._thread.join(timeout=2)
self._thread = None

return "Stopped following."

def _on_color_image(self, image: Image) -> None:
Expand All @@ -162,7 +170,7 @@ def _on_pointcloud(self, pointcloud: PointCloud2) -> None:
with self._lock:
self._latest_pointcloud = pointcloud

def _follow_loop(self, query: str, initial_bbox: BBox) -> str:
def _follow_person(self, query: str, initial_bbox: BBox) -> str:
x1, y1, x2, y2 = initial_bbox
box = np.array([x1, y1, x2, y2], dtype=np.float32)

Expand All @@ -186,6 +194,15 @@ def _follow_loop(self, query: str, initial_bbox: BBox) -> str:

logger.info(f"EdgeTAM initialized with {len(initial_detections)} detections")

self._thread = Thread(target=self._follow_loop, args=(tracker, query))
self._thread.start()

return (
"Found the person. Starting to follow. You can stop following by calling "
"the 'stop_following' tool."
)

def _follow_loop(self, tracker: EdgeTAMProcessor, query: str) -> None:
lost_count = 0
period = 1.0 / self._frequency
next_time = time.monotonic()
Expand All @@ -204,8 +221,8 @@ def _follow_loop(self, query: str, initial_bbox: BBox) -> str:

lost_count += 1
if lost_count > self._max_lost_frames:
self.cmd_vel.publish(Twist.zero())
return f"Lost track of '{query}'. Stopping."
self._send_stop_reason(query, "lost track of the person")
return
else:
lost_count = 0
best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume())
Expand All @@ -214,16 +231,16 @@ def _follow_loop(self, query: str, initial_bbox: BBox) -> str:
with self._lock:
pointcloud = self._latest_pointcloud
if pointcloud is None:
self.cmd_vel.publish(Twist.zero())
return "No pointcloud available for 3D navigation. Stopping."
self._send_stop_reason(query, "no pointcloud available for 3D navigation")
return
twist = self._detection_navigation.compute_twist_for_detection_3d(
pointcloud,
best_detection,
latest_image,
)
if twist is None:
self.cmd_vel.publish(Twist.zero())
return f"3D navigation failed for '{query}'. Stopping."
self._send_stop_reason(query, "3D navigation failed")
return
else:
twist = self._visual_servo.compute_twist(
best_detection.bbox,
Expand All @@ -236,12 +253,17 @@ def _follow_loop(self, query: str, initial_bbox: BBox) -> str:
if sleep_duration > 0:
time.sleep(sleep_duration)

self.cmd_vel.publish(Twist.zero())
return "Stopped following as requested."
self._send_stop_reason(query, "it was requested to stop following")

def _stop_following(self) -> None:
self._should_stop.set()

def _send_stop_reason(self, query: str, reason: str) -> None:
self.cmd_vel.publish(Twist.zero())
message = f"Person follow stopped for '{query}'. Reason: {reason}."
self._agent_spec.add_message(HumanMessage(message))
logger.info("Person follow stopped", query=query, reason=reason)


person_follow_skill = PersonFollowSkillContainer.blueprint

Expand Down