-
Notifications
You must be signed in to change notification settings - Fork 160
Person follow skill with EdgeTAM #1042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
745dc0f
b9d3971
c4702b1
f2301e0
2a247ab
5d02818
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,248 @@ | ||
| # Copyright 2025-2026 Dimensional Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from threading import Event, RLock | ||
| import time | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
| from reactivex.disposable import Disposable | ||
|
|
||
| from dimos.core.core import rpc | ||
| from dimos.core.global_config import GlobalConfig | ||
| from dimos.core.skill_module import SkillModule | ||
| from dimos.core.stream import In, Out | ||
| from dimos.models.qwen.video_query import BBox | ||
| from dimos.models.segmentation.edge_tam import EdgeTAMProcessor | ||
| from dimos.models.vl.qwen import QwenVlModel | ||
| from dimos.msgs.geometry_msgs import Twist | ||
| from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 | ||
| from dimos.navigation.visual.query import get_object_bbox_from_image | ||
| from dimos.navigation.visual_servoing.detection_navigation import DetectionNavigation | ||
| from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D | ||
| from dimos.protocol.skill.skill import skill | ||
| from dimos.utils.logging_config import setup_logger | ||
|
|
||
| if TYPE_CHECKING: | ||
| from dimos.models.vl.base import VlModel | ||
|
|
||
| logger = setup_logger() | ||
|
|
||
|
|
||
| class PersonFollowSkillContainer(SkillModule): | ||
| """Skill container for following a person. | ||
|
|
||
| This skill uses: | ||
| - A VL model (QwenVlModel) to initially detect a person from a text description. | ||
| - EdgeTAM for continuous tracking across frames. | ||
| - Visual servoing OR 3D navigation to control robot movement towards the person. | ||
| - Does not do obstacle avoidance; assumes a clear path. | ||
| """ | ||
|
|
||
| color_image: In[Image] | ||
| global_map: In[PointCloud2] | ||
| cmd_vel: Out[Twist] | ||
|
|
||
| _frequency: float = 20.0 # Hz - control loop frequency | ||
| _max_lost_frames: int = 15 # number of frames to wait before declaring person lost | ||
|
|
||
| def __init__( | ||
| self, | ||
| camera_info: CameraInfo, | ||
| global_config: GlobalConfig, | ||
| use_3d_navigation: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| self._global_config: GlobalConfig = global_config | ||
| self._use_3d_navigation: bool = use_3d_navigation | ||
| self._latest_image: Image | None = None | ||
| self._latest_pointcloud: PointCloud2 | None = None | ||
| self._vl_model: VlModel = QwenVlModel() | ||
| self._tracker: EdgeTAMProcessor | None = None | ||
| self._should_stop: Event = Event() | ||
| self._lock = RLock() | ||
|
|
||
| # Use MuJoCo camera intrinsics in simulation mode | ||
| if self._global_config.simulation: | ||
| from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection | ||
|
|
||
| camera_info = MujocoConnection.camera_info_static | ||
|
|
||
| self._camera_info = camera_info | ||
| self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation) | ||
| self._detection_navigation = DetectionNavigation(self.tf, camera_info) | ||
|
|
||
| @rpc | ||
| def start(self) -> None: | ||
| super().start() | ||
| self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) | ||
| if self._use_3d_navigation: | ||
| self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud))) | ||
|
|
||
| @rpc | ||
| def stop(self) -> None: | ||
| self._stop_following() | ||
|
|
||
| with self._lock: | ||
| if self._tracker is not None: | ||
| self._tracker.stop() | ||
| self._tracker = None | ||
|
|
||
| self._vl_model.stop() | ||
| super().stop() | ||
|
|
||
| @skill() | ||
| def follow_person(self, query: str) -> str: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. generally I'd much prefer this as a reactive pipeline in which you self.image_frame.observable().pipe(
ops.map(track)
ops.map(detection_to_twist)
).subscribe(self.twist.publish)so you don't need to do this self._image thing, thread management etc etc. much cleaner, just letting you know, this is an ok first test
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But doing it that way ties the frequency of the cmd_vel to the frequency of the camera, right? And you still need some state like counting how many frames are without a detection, or knowing when to stop.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are nice tools for all of this - like
yeah you either start propagading more complex types (worst case but still nice) {
Image :image
Detection2d: ...
frames: n
}and this gives you full power, and clear typing also, since you can transform between those complex types at each step
but as an also easier shortcut before going for complex types you can have a map that has a sideeffect that counts empty Detection2Ds and starts returning false/exiting early.
to be clear Im not asking for a rewrite, just mentioning how this can be structured in a way in which it's easy and fun to quickly iterate (add twist easing etc), rotation towards a person and walking towards a person can be separate maps() that layer onto the same Twist, etc |
||
| """Follow a person matching the given description using visual servoing. | ||
|
|
||
| The robot will continuously track and follow the person, while keeping | ||
| them centered in the camera view. | ||
|
|
||
| Args: | ||
| query: Description of the person to follow (e.g., "man with blue shirt") | ||
|
|
||
| Returns: | ||
| Status message indicating the result of the following action. | ||
|
|
||
| Example: | ||
| follow_person("man with blue shirt") | ||
| follow_person("person in the doorway") | ||
| """ | ||
|
|
||
| self._stop_following() | ||
|
|
||
| self._should_stop.clear() | ||
|
|
||
| with self._lock: | ||
| latest_image = self._latest_image | ||
|
|
||
| if latest_image is None: | ||
| return "No image available to detect person." | ||
|
|
||
| initial_bbox = get_object_bbox_from_image( | ||
| self._vl_model, | ||
| latest_image, | ||
| query, | ||
| ) | ||
|
|
||
| if initial_bbox is None: | ||
| return f"Could not find '{query}' in the current view." | ||
|
|
||
| return self._follow_loop(query, initial_bbox) | ||
|
|
||
| @skill() | ||
| def stop_following(self) -> str: | ||
| """Stop following the current person. | ||
|
|
||
| Returns: | ||
| Confirmation message. | ||
| """ | ||
| self._stop_following() | ||
|
|
||
| self.cmd_vel.publish(Twist.zero()) | ||
|
|
||
| return "Stopped following." | ||
|
|
||
| def _on_color_image(self, image: Image) -> None: | ||
| with self._lock: | ||
| self._latest_image = image | ||
|
|
||
| def _on_pointcloud(self, pointcloud: PointCloud2) -> None: | ||
| with self._lock: | ||
| self._latest_pointcloud = pointcloud | ||
|
|
||
| def _follow_loop(self, query: str, initial_bbox: BBox) -> str: | ||
| x1, y1, x2, y2 = initial_bbox | ||
| box = np.array([x1, y1, x2, y2], dtype=np.float32) | ||
|
|
||
| with self._lock: | ||
| if self._tracker is None: | ||
| self._tracker = EdgeTAMProcessor() | ||
| tracker = self._tracker | ||
| latest_image = self._latest_image | ||
| if latest_image is None: | ||
| return "No image available to start tracking." | ||
|
|
||
| initial_detections = tracker.init_track( | ||
| image=latest_image, | ||
| box=box, | ||
| obj_id=1, | ||
| ) | ||
|
|
||
| if len(initial_detections) == 0: | ||
| self.cmd_vel.publish(Twist.zero()) | ||
| return f"EdgeTAM failed to segment '{query}'." | ||
|
|
||
| logger.info(f"EdgeTAM initialized with {len(initial_detections)} detections") | ||
|
|
||
| lost_count = 0 | ||
| period = 1.0 / self._frequency | ||
| next_time = time.monotonic() | ||
|
|
||
| while not self._should_stop.is_set(): | ||
| next_time += period | ||
|
|
||
| with self._lock: | ||
| latest_image = self._latest_image | ||
| assert latest_image is not None | ||
paul-nechifor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| detections = tracker.process_image(latest_image) | ||
|
|
||
| if len(detections) == 0: | ||
| self.cmd_vel.publish(Twist.zero()) | ||
|
|
||
| lost_count += 1 | ||
| if lost_count > self._max_lost_frames: | ||
| self.cmd_vel.publish(Twist.zero()) | ||
| return f"Lost track of '{query}'. Stopping." | ||
| else: | ||
| lost_count = 0 | ||
| best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Calling |
||
|
|
||
| if self._use_3d_navigation: | ||
| with self._lock: | ||
| pointcloud = self._latest_pointcloud | ||
| if pointcloud is None: | ||
| self.cmd_vel.publish(Twist.zero()) | ||
| return "No pointcloud available for 3D navigation. Stopping." | ||
| twist = self._detection_navigation.compute_twist_for_detection_3d( | ||
| pointcloud, | ||
| best_detection, | ||
| latest_image, | ||
| ) | ||
| if twist is None: | ||
| self.cmd_vel.publish(Twist.zero()) | ||
| return f"3D navigation failed for '{query}'. Stopping." | ||
| else: | ||
| twist = self._visual_servo.compute_twist( | ||
| best_detection.bbox, | ||
| latest_image.width, | ||
| ) | ||
| self.cmd_vel.publish(twist) | ||
|
|
||
| now = time.monotonic() | ||
| sleep_duration = next_time - now | ||
| if sleep_duration > 0: | ||
| time.sleep(sleep_duration) | ||
|
|
||
| self.cmd_vel.publish(Twist.zero()) | ||
| return "Stopped following as requested." | ||
|
|
||
| def _stop_following(self) -> None: | ||
| self._should_stop.set() | ||
|
|
||
|
|
||
| person_follow_skill = PersonFollowSkillContainer.blueprint | ||
|
|
||
| __all__ = ["PersonFollowSkillContainer", "person_follow_skill"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # Copyright 2025-2026 Dimensional Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from collections.abc import Callable, Generator | ||
| import os | ||
| import threading | ||
| import time | ||
|
|
||
| import pytest | ||
|
|
||
| from dimos.e2e_tests.dimos_cli_call import DimosCliCall | ||
| from dimos.e2e_tests.lcm_spy import LcmSpy | ||
| from dimos.simulation.mujoco.person_on_track import PersonTrackPublisher | ||
|
|
||
| StartPersonTrack = Callable[[list[tuple[float, float]]], None] | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def start_person_track() -> Generator[StartPersonTrack, None, None]: | ||
| publisher: PersonTrackPublisher | None = None | ||
| stop_event = threading.Event() | ||
| thread: threading.Thread | None = None | ||
|
|
||
| def start(track: list[tuple[float, float]]) -> None: | ||
| nonlocal publisher, thread | ||
| publisher = PersonTrackPublisher(track) | ||
|
|
||
| def run_person_track() -> None: | ||
| while not stop_event.is_set(): | ||
| publisher.tick() | ||
| time.sleep(1 / 60) | ||
|
|
||
| thread = threading.Thread(target=run_person_track, daemon=True) | ||
| thread.start() | ||
|
|
||
| yield start | ||
|
|
||
| stop_event.set() | ||
| if thread is not None: | ||
| thread.join(timeout=1.0) | ||
| if publisher is not None: | ||
| publisher.stop() | ||
|
|
||
|
|
||
| @pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.") | ||
| @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") | ||
| @pytest.mark.e2e | ||
| def test_person_follow( | ||
| lcm_spy: LcmSpy, | ||
| start_blueprint: Callable[[str], DimosCliCall], | ||
| human_input: Callable[[str], None], | ||
| start_person_track: StartPersonTrack, | ||
| ) -> None: | ||
| start_blueprint("--mujoco-start-pos", "-6.18 0.96", "run", "unitree-go2-agentic") | ||
|
|
||
| lcm_spy.save_topic("/rpc/HumanInput/start/res") | ||
| lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res", timeout=120.0) | ||
| lcm_spy.save_topic("/agent") | ||
| lcm_spy.wait_for_saved_topic_content("/agent", b"AIMessage", timeout=120.0) | ||
|
|
||
| time.sleep(5) | ||
|
|
||
| start_person_track( | ||
| [ | ||
| (-2.60, 1.28), | ||
| (4.80, 0.21), | ||
| (4.14, -6.0), | ||
| (0.59, -3.79), | ||
| (-3.35, -0.51), | ||
| ] | ||
| ) | ||
| human_input("follow the person in beige pants") | ||
|
|
||
| lcm_spy.wait_until_odom_position(4.2, -3, threshold=1.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have this fancy system of hot/colr getters
https://github.com/dimensionalOS/dimos/blob/dev/docs/api/sensor_streams/advanced_streams.md#getter_hot---background-subscription-instant-reads
so you can do
getter = geter_hot(self.image().observable())idk if you want this, just letting you know it exists. also getter is not disposable, should make it one