Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions data/.lfs/person.tar.gz
Git LFS file not shown
248 changes: 248 additions & 0 deletions dimos/agents/skills/person_follow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright 2025-2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import Event, RLock
import time
from typing import TYPE_CHECKING

import numpy as np
from reactivex.disposable import Disposable

from dimos.core.core import rpc
from dimos.core.global_config import GlobalConfig
from dimos.core.skill_module import SkillModule
from dimos.core.stream import In, Out
from dimos.models.qwen.video_query import BBox
from dimos.models.segmentation.edge_tam import EdgeTAMProcessor
from dimos.models.vl.qwen import QwenVlModel
from dimos.msgs.geometry_msgs import Twist
from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2
from dimos.navigation.visual.query import get_object_bbox_from_image
from dimos.navigation.visual_servoing.detection_navigation import DetectionNavigation
from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D
from dimos.protocol.skill.skill import skill
from dimos.utils.logging_config import setup_logger

if TYPE_CHECKING:
from dimos.models.vl.base import VlModel

logger = setup_logger()


class PersonFollowSkillContainer(SkillModule):
"""Skill container for following a person.

This skill uses:
- A VL model (QwenVlModel) to initially detect a person from a text description.
- EdgeTAM for continuous tracking across frames.
- Visual servoing OR 3D navigation to control robot movement towards the person.
- Does not do obstacle avoidance; assumes a clear path.
"""

color_image: In[Image]
global_map: In[PointCloud2]
cmd_vel: Out[Twist]

_frequency: float = 20.0 # Hz - control loop frequency
_max_lost_frames: int = 15 # number of frames to wait before declaring person lost

def __init__(
self,
camera_info: CameraInfo,
global_config: GlobalConfig,
use_3d_navigation: bool = False,
) -> None:
super().__init__()
self._global_config: GlobalConfig = global_config
self._use_3d_navigation: bool = use_3d_navigation
self._latest_image: Image | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have this fancy system of hot/colr getters

https://github.com/dimensionalOS/dimos/blob/dev/docs/api/sensor_streams/advanced_streams.md#getter_hot---background-subscription-instant-reads

so you can do getter = geter_hot(self.image().observable())

idk if you want this, just letting you know it exists. also getter is not disposable, should make it one

self._latest_pointcloud: PointCloud2 | None = None
self._vl_model: VlModel = QwenVlModel()
self._tracker: EdgeTAMProcessor | None = None
self._should_stop: Event = Event()
self._lock = RLock()

# Use MuJoCo camera intrinsics in simulation mode
if self._global_config.simulation:
from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection

camera_info = MujocoConnection.camera_info_static

self._camera_info = camera_info
self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation)
self._detection_navigation = DetectionNavigation(self.tf, camera_info)

@rpc
def start(self) -> None:
super().start()
self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image)))
if self._use_3d_navigation:
self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud)))

@rpc
def stop(self) -> None:
self._stop_following()

with self._lock:
if self._tracker is not None:
self._tracker.stop()
self._tracker = None

self._vl_model.stop()
super().stop()

@skill()
def follow_person(self, query: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally I'd much prefer this as a reactive pipeline in which you

self.image_frame.observable().pipe(
  ops.map(track)
  ops.map(detection_to_twist)
).subscribe(self.twist.publish)

so you don't need to do this self._image thing, thread management etc etc. much cleaner, just letting you know, this is an ok first test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But doing it that way ties the frequency of the cmd_vel to the frequency of the camera, right? And you still need some state like counting how many frames are without a detection, or knowing when to stop.

Copy link
Contributor

@leshy leshy Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are nice tools for all of this - like sample() maybe? samples latest value at fixed hz, so you can emit twist at a faster fixed rate then input image

https://github.com/dimensionalOS/dimos/blob/dev/docs/api/sensor_streams/reactivex.md#sampleinterval---emit-latest-value-every-n-seconds

you still need some state like counting how many frames are without a detection, or knowing when to stop.

yeah you either start propagading more complex types (worst case but still nice)

{
  Image :image
  Detection2d: ...
  frames: n
}

and this gives you full power, and clear typing also, since you can transform between those complex types at each step

def bla(in: ComplexType1) -> ComplexType2

ops.map(bla)

but as an also easier shortcut before going for complex types you can have a map that has a sideeffect that counts empty Detection2Ds and starts returning false/exiting early.

detection_to_twist in above example can have a sideeffect of counting empty detections it got, and starting to output twist(0,0,0) - after initial few twist(0,0,0) you can also choose to terminate tracking loop as well, observable just finishes

to be clear Im not asking for a rewrite, just mentioning how this can be structured in a way in which it's easy and fun to quickly iterate (add twist easing etc), rotation towards a person and walking towards a person can be separate maps() that layer onto the same Twist, etc

"""Follow a person matching the given description using visual servoing.

The robot will continuously track and follow the person, while keeping
them centered in the camera view.

Args:
query: Description of the person to follow (e.g., "man with blue shirt")

Returns:
Status message indicating the result of the following action.

Example:
follow_person("man with blue shirt")
follow_person("person in the doorway")
"""

self._stop_following()

self._should_stop.clear()

with self._lock:
latest_image = self._latest_image

if latest_image is None:
return "No image available to detect person."

initial_bbox = get_object_bbox_from_image(
self._vl_model,
latest_image,
query,
)

if initial_bbox is None:
return f"Could not find '{query}' in the current view."

return self._follow_loop(query, initial_bbox)

@skill()
def stop_following(self) -> str:
"""Stop following the current person.

Returns:
Confirmation message.
"""
self._stop_following()

self.cmd_vel.publish(Twist.zero())

return "Stopped following."

def _on_color_image(self, image: Image) -> None:
with self._lock:
self._latest_image = image

def _on_pointcloud(self, pointcloud: PointCloud2) -> None:
with self._lock:
self._latest_pointcloud = pointcloud

def _follow_loop(self, query: str, initial_bbox: BBox) -> str:
x1, y1, x2, y2 = initial_bbox
box = np.array([x1, y1, x2, y2], dtype=np.float32)

with self._lock:
if self._tracker is None:
self._tracker = EdgeTAMProcessor()
tracker = self._tracker
latest_image = self._latest_image
if latest_image is None:
return "No image available to start tracking."

initial_detections = tracker.init_track(
image=latest_image,
box=box,
obj_id=1,
)

if len(initial_detections) == 0:
self.cmd_vel.publish(Twist.zero())
return f"EdgeTAM failed to segment '{query}'."

logger.info(f"EdgeTAM initialized with {len(initial_detections)} detections")

lost_count = 0
period = 1.0 / self._frequency
next_time = time.monotonic()

while not self._should_stop.is_set():
next_time += period

with self._lock:
latest_image = self._latest_image
assert latest_image is not None

detections = tracker.process_image(latest_image)

if len(detections) == 0:
self.cmd_vel.publish(Twist.zero())

lost_count += 1
if lost_count > self._max_lost_frames:
self.cmd_vel.publish(Twist.zero())
return f"Lost track of '{query}'. Stopping."
else:
lost_count = 0
best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Calling max(detections.detections, ...) without checking if the list is empty will raise ValueError. Even though len(detections) == 0 is checked, the len() behavior on ImageDetections2D may differ from checking detections.detections length. Add explicit check: if detections.detections: before calling max().


if self._use_3d_navigation:
with self._lock:
pointcloud = self._latest_pointcloud
if pointcloud is None:
self.cmd_vel.publish(Twist.zero())
return "No pointcloud available for 3D navigation. Stopping."
twist = self._detection_navigation.compute_twist_for_detection_3d(
pointcloud,
best_detection,
latest_image,
)
if twist is None:
self.cmd_vel.publish(Twist.zero())
return f"3D navigation failed for '{query}'. Stopping."
else:
twist = self._visual_servo.compute_twist(
best_detection.bbox,
latest_image.width,
)
self.cmd_vel.publish(twist)

now = time.monotonic()
sleep_duration = next_time - now
if sleep_duration > 0:
time.sleep(sleep_duration)

self.cmd_vel.publish(Twist.zero())
return "Stopped following as requested."

def _stop_following(self) -> None:
self._should_stop.set()


person_follow_skill = PersonFollowSkillContainer.blueprint

__all__ = ["PersonFollowSkillContainer", "person_follow_skill"]
4 changes: 2 additions & 2 deletions dimos/e2e_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def fun(*, points: list[tuple[float, float, float]], fail_message: str) -> None:
def start_blueprint() -> Iterator[Callable[[str], DimosCliCall]]:
dimos_robot_call = DimosCliCall()

def set_name_and_start(demo_name: str) -> DimosCliCall:
dimos_robot_call.demo_name = demo_name
def set_name_and_start(*demo_args: str) -> DimosCliCall:
dimos_robot_call.demo_args = list(demo_args)
dimos_robot_call.start()
return dimos_robot_call

Expand Down
10 changes: 4 additions & 6 deletions dimos/e2e_tests/dimos_cli_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,16 @@

class DimosCliCall:
process: subprocess.Popen[bytes] | None
demo_name: str | None = None
demo_args: list[str] | None = None

def __init__(self) -> None:
self.process = None

def start(self) -> None:
if self.demo_name is None:
raise ValueError("Demo name must be set before starting the process.")
if self.demo_args is None:
raise ValueError("Demo args must be set before starting the process.")

self.process = subprocess.Popen(
["dimos", "--simulation", "run", self.demo_name],
)
self.process = subprocess.Popen(["dimos", "--simulation", *self.demo_args])

def stop(self) -> None:
if self.process is None:
Expand Down
2 changes: 1 addition & 1 deletion dimos/e2e_tests/test_dimos_cli_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_dimos_skills(lcm_spy, start_blueprint, human_input) -> None:
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/sum_numbers/req")
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/sum_numbers/res")

start_blueprint("demo-skill")
start_blueprint("run", "demo-skill")

lcm_spy.wait_for_saved_topic("/rpc/DemoCalculatorSkill/set_AgentSpec_register_skills/res")
lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res")
Expand Down
85 changes: 85 additions & 0 deletions dimos/e2e_tests/test_person_follow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2025-2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable, Generator
import os
import threading
import time

import pytest

from dimos.e2e_tests.dimos_cli_call import DimosCliCall
from dimos.e2e_tests.lcm_spy import LcmSpy
from dimos.simulation.mujoco.person_on_track import PersonTrackPublisher

StartPersonTrack = Callable[[list[tuple[float, float]]], None]


@pytest.fixture
def start_person_track() -> Generator[StartPersonTrack, None, None]:
publisher: PersonTrackPublisher | None = None
stop_event = threading.Event()
thread: threading.Thread | None = None

def start(track: list[tuple[float, float]]) -> None:
nonlocal publisher, thread
publisher = PersonTrackPublisher(track)

def run_person_track() -> None:
while not stop_event.is_set():
publisher.tick()
time.sleep(1 / 60)

thread = threading.Thread(target=run_person_track, daemon=True)
thread.start()

yield start

stop_event.set()
if thread is not None:
thread.join(timeout=1.0)
if publisher is not None:
publisher.stop()


@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.")
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.")
@pytest.mark.e2e
def test_person_follow(
lcm_spy: LcmSpy,
start_blueprint: Callable[[str], DimosCliCall],
human_input: Callable[[str], None],
start_person_track: StartPersonTrack,
) -> None:
start_blueprint("--mujoco-start-pos", "-6.18 0.96", "run", "unitree-go2-agentic")

lcm_spy.save_topic("/rpc/HumanInput/start/res")
lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res", timeout=120.0)
lcm_spy.save_topic("/agent")
lcm_spy.wait_for_saved_topic_content("/agent", b"AIMessage", timeout=120.0)

time.sleep(5)

start_person_track(
[
(-2.60, 1.28),
(4.80, 0.21),
(4.14, -6.0),
(0.59, -3.79),
(-3.35, -0.51),
]
)
human_input("follow the person in beige pants")

lcm_spy.wait_until_odom_position(4.2, -3, threshold=1.5)
2 changes: 1 addition & 1 deletion dimos/e2e_tests/test_spatial_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_spatial_memory_navigation(
human_input: Callable[[str], None],
follow_points: Callable[..., None],
) -> None:
start_blueprint("unitree-go2-agentic")
start_blueprint("run", "unitree-go2-agentic")

lcm_spy.save_topic("/rpc/HumanInput/start/res")
lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res", timeout=120.0)
Expand Down
Loading
Loading