diff --git a/.gitmodules b/.gitmodules index 85734c1fff..ae48e66391 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,13 +1,9 @@ [submodule "dimos/external/openMVS"] path = dimos/external/openMVS url = https://github.com/cdcseacave/openMVS.git - [submodule "dimos/external/vcpkg"] path = dimos/external/vcpkg url = https://github.com/microsoft/vcpkg.git -[submodule "dimos/robot/unitree/external/go2_webrtc_connect"] - path = dimos/robot/unitree/external/go2_webrtc_connect - url = https://github.com/spomichter/go2_webrtc_connect [submodule "dimos/robot/unitree/external/go2_ros2_sdk"] path = dimos/robot/unitree/external/go2_ros2_sdk url = https://github.com/dimensionalOS/go2_ros2_sdk diff --git a/dimos/robot/abstract_robot.py b/dimos/robot/abstract_robot.py new file mode 100644 index 0000000000..4c6345f1b0 --- /dev/null +++ b/dimos/robot/abstract_robot.py @@ -0,0 +1,66 @@ +"""Abstract base class for all DIMOS robot implementations. + +This module defines the AbstractRobot class which serves as the foundation for +all robot implementations in DIMOS, establishing a common interface regardless +of the underlying hardware or communication protocol (ROS, WebRTC, etc). +""" + +from abc import ABC, abstractmethod +from typing import Any, Union, Optional +from reactivex.observable import Observable +import numpy as np + + +class AbstractRobot(ABC): + """Abstract base class for all robot implementations. + + This class defines the minimal interface that all robot implementations + must provide, regardless of whether they use ROS, WebRTC, or other + communication protocols. + """ + + @abstractmethod + def connect(self) -> bool: + """Establish a connection to the robot. + + This method should handle all necessary setup to establish + communication with the robot hardware. + + Returns: + bool: True if connection was successful, False otherwise. + """ + pass + + @abstractmethod + def move(self, *args, **kwargs) -> bool: + """Move the robot. + + This is a generic movement interface that should be implemented + by all robot classes. The exact parameters will depend on the + specific robot implementation. + + Returns: + bool: True if movement command was successfully sent. + """ + pass + + @abstractmethod + def get_video_stream(self, fps: int = 30) -> Observable: + """Get a video stream from the robot's camera. + + Args: + fps: Frames per second for the video stream. Defaults to 30. + + Returns: + Observable: An observable stream of video frames. + """ + pass + + @abstractmethod + def stop(self) -> None: + """Clean up resources and stop the robot. + + This method should handle all necessary cleanup when shutting down + the robot connection, including stopping any ongoing movements. + """ + pass diff --git a/dimos/robot/global_planner/planner.py b/dimos/robot/global_planner/planner.py index 96de4cca4c..0a59428917 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/robot/global_planner/planner.py @@ -37,11 +37,11 @@ def plan(self, goal: VectorLike) -> Path: ... def set_goal( self, goal: VectorLike, goal_theta: Optional[float] = None, stop_event: Optional[threading.Event] = None ): - goal = to_vector(goal).to_2d() path = self.plan(goal) if not path: logger.warning("No path found to the goal.") return False + print("pathing success", path) return self.set_local_nav(path, stop_event=stop_event, goal_theta=goal_theta) @@ -53,12 +53,14 @@ class AstarPlanner(Planner): conservativism: int = 8 def plan(self, goal: VectorLike) -> Path: - pos = self.get_robot_pos() + goal = to_vector(goal).to_2d() + pos = self.get_robot_pos().to_2d() costmap = self.get_costmap().smudge(preserve_unknown=False) - self.vis("planner_costmap", costmap) + # self.vis("costmap", costmap) self.vis("target", goal) + print("ASTAR ", costmap, goal, pos) path = astar(costmap, goal, pos) if path: diff --git a/dimos/robot/unitree/unitree_go2.py b/dimos/robot/unitree/unitree_go2.py index cb97b82e52..54ba4c3327 100644 --- a/dimos/robot/unitree/unitree_go2.py +++ b/dimos/robot/unitree/unitree_go2.py @@ -30,8 +30,7 @@ from dimos.utils.logging_config import setup_logger from dimos.perception.person_tracker import PersonTrackingStream from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.robot.local_planner.local_planner import navigate_path_local +from dimos.robot.local_planner import VFHPurePursuitPlanner, navigate_path_local from dimos.robot.global_planner.planner import AstarPlanner from dimos.types.path import Path from dimos.types.costmap import Costmap @@ -170,7 +169,7 @@ def __init__( robot_width=0.36, # Unitree Go2 width in meters robot_length=0.6, # Unitree Go2 length in meters max_linear_vel=0.5, - lookahead_distance=1.0, + lookahead_distance=2.0, visualization_size=500, # 500x500 pixel visualization ) diff --git a/dimos/robot/unitree/unitree_skills.py b/dimos/robot/unitree/unitree_skills.py index 0449e16afa..9c2e406c97 100644 --- a/dimos/robot/unitree/unitree_skills.py +++ b/dimos/robot/unitree/unitree_skills.py @@ -237,7 +237,10 @@ def __call__(self): # region Class-based Skills class Move(AbstractRobotSkill): - """Move the robot using direct velocity commands.""" + """Move the robot using direct velocity commands. + + This skill works with both ROS and WebRTC robot implementations. + """ x: float = Field(..., description="Forward velocity (m/s).") y: float = Field(default=0.0, description="Left/right velocity (m/s)") @@ -246,10 +249,52 @@ class Move(AbstractRobotSkill): def __call__(self): super().__call__() - return self._robot.move_vel(x=self.x, y=self.y, yaw=self.yaw, duration=self.duration) + + from dimos.types.vector import Vector + vector = Vector(self.x, self.y, self.yaw) + + # Handle duration for continuous movement + if self.duration > 0: + import time + import threading + import asyncio + + # Create a stop event + stop_event = threading.Event() + + # Function to continuously send movement commands + async def continuous_move(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + start_time = time.time() + try: + while not stop_event.is_set() and (time.time() - start_time) < self.duration: + self._robot.move(vector) + await asyncio.sleep(0.001) # Send commands at 1000Hz + # Always stop at the end + self._robot.move(Vector(0, 0, 0)) + finally: + loop.close() + + # Run movement in a separate thread with asyncio event loop + move_thread = threading.Thread(target=lambda: asyncio.run(continuous_move())) + move_thread.daemon = True + move_thread.start() + + # Wait for the full duration + time.sleep(self.duration) + stop_event.set() + move_thread.join(timeout=0.5) # Wait for thread to finish with timeout + else: + # Just execute the move command once for continuous movement + self._robot.move(vector) + return True class Reverse(AbstractRobotSkill): - """Reverse the robot using direct velocity commands.""" + """Reverse the robot using direct velocity commands. + + This skill works with both ROS and WebRTC robot implementations. + """ x: float = Field(..., description="Backward velocity (m/s). Positive values move backward.") y: float = Field(default=0.0, description="Left/right velocity (m/s)") @@ -258,8 +303,46 @@ class Reverse(AbstractRobotSkill): def __call__(self): super().__call__() - # Use move_vel with negative x for backward movement - return self._robot.move_vel(x=-self.x, y=self.y, yaw=self.yaw, duration=self.duration) + from dimos.types.vector import Vector + # Use negative x for backward movement + vector = Vector(-self.x, self.y, self.yaw) + + # Handle duration for continuous movement + if self.duration > 0: + import time + import threading + import asyncio + + # Create a stop event + stop_event = threading.Event() + + # Function to continuously send movement commands + async def continuous_move(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + start_time = time.time() + try: + while not stop_event.is_set() and (time.time() - start_time) < self.duration: + self._robot.move(vector) + await asyncio.sleep(0.001) # Send commands at 1000Hz + # Always stop at the end + self._robot.move(Vector(0, 0, 0)) + finally: + loop.close() + + # Run movement in a separate thread with asyncio event loop + move_thread = threading.Thread(target=lambda: asyncio.run(continuous_move())) + move_thread.daemon = True + move_thread.start() + + # Wait for the full duration + time.sleep(self.duration) + stop_event.set() + move_thread.join(timeout=0.5) # Wait for thread to finish with timeout + else: + # Just execute the move command once for continuous movement + self._robot.move(vector) + return True class SpinLeft(AbstractRobotSkill): """Spin the robot left using degree commands.""" diff --git a/dimos/robot/unitree_webrtc/__init__.py b/dimos/robot/unitree_webrtc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py new file mode 100644 index 0000000000..48ea276883 --- /dev/null +++ b/dimos/robot/unitree_webrtc/connection.py @@ -0,0 +1,198 @@ +import functools +import asyncio +import threading +from typing import TypeAlias, Literal +from dimos.utils.reactive import backpressure, callback_to_observable +from dimos.types.vector import Vector +from dimos.types.position import Position +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod # type: ignore[import-not-found] +from go2_webrtc_driver.constants import RTC_TOPIC, VUI_COLOR, SPORT_CMD +from reactivex.subject import Subject +from reactivex.observable import Observable +import numpy as np +from reactivex import operators as ops +from aiortc import MediaStreamTrack +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.abstract_robot import AbstractRobot + + +VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] + + +class WebRTCRobot(AbstractRobot): + def __init__(self, ip: str, mode: str = "ai"): + self.ip = ip + self.mode = mode + self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) + self.connect() + + def connect(self): + self.loop = asyncio.new_event_loop() + self.task = None + self.connected_event = asyncio.Event() + self.connection_ready = threading.Event() + + async def async_connect(): + await self.conn.connect() + await self.conn.datachannel.disableTrafficSaving(True) + + self.conn.datachannel.set_decoder(decoder_type="native") + + await self.conn.datachannel.pub_sub.publish_request_new( + RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} + ) + + self.connected_event.set() + self.connection_ready.set() + + while True: + await asyncio.sleep(1) + + def start_background_loop(): + asyncio.set_event_loop(self.loop) + self.task = self.loop.create_task(async_connect()) + self.loop.run_forever() + + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=start_background_loop, daemon=True) + self.thread.start() + self.connection_ready.wait() + + def move(self, vector: Vector): + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": vector.x, "ly": vector.y, "rx": vector.z, "ry": 0}, + ) + + # Generic conversion of unitree subscription to Subject (used for all subs) + def unitree_sub_stream(self, topic_name: str): + return callback_to_observable( + start=lambda cb: self.conn.datachannel.pub_sub.subscribe(topic_name, cb), + stop=lambda: self.conn.datachannel.pub_sub.unsubscribe(topic_name), + ) + + # Generic sync API call (we jump into the client thread) + def publish_request(self, topic: str, data: dict): + future = asyncio.run_coroutine_threadsafe( + self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop + ) + return future.result() + + @functools.cache + def raw_lidar_stream(self) -> Subject[LidarMessage]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) + + @functools.cache + def raw_odom_stream(self) -> Subject[Position]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) + + @functools.cache + def lidar_stream(self) -> Subject[LidarMessage]: + return backpressure(self.raw_lidar_stream().pipe(ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame)))) + + @functools.cache + def odom_stream(self) -> Subject[Position]: + return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) + + @functools.cache + def lowstate_stream(self) -> Subject[LowStateMsg]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) + + def standup_ai(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) + + def standup_normal(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + + def standup(self): + if self.mode == "ai": + return self.standup_ai() + else: + return self.standup_normal() + + def liedown(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + + async def handstand(self): + return self.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, + ) + + def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: + return self.publish_request( + RTC_TOPIC["VUI"], + { + "api_id": 1001, + "parameter": { + "color": color, + "time": colortime, + }, + }, + ) + + @functools.lru_cache(maxsize=None) + def video_stream(self) -> Observable[VideoMessage]: + subject: Subject[VideoMessage] = Subject() + stop_event = threading.Event() + + async def accept_track(track: MediaStreamTrack) -> VideoMessage: + while True: + if stop_event.is_set(): + return + frame = await track.recv() + subject.on_next(frame.to_ndarray(format="bgr24")) + + self.conn.video.add_track_callback(accept_track) + self.conn.video.switchVideoChannel(True) + + def stop(cb): + stop_event.set() # Signal the loop to stop + self.conn.video.track_callbacks.remove(accept_track) + self.conn.video.switchVideoChannel(False) + + return backpressure(subject.pipe(ops.finally_action(stop))) + + def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: + """Get the video stream from the robot's camera. + + Implements the AbstractRobot interface method. + + Args: + fps: Frames per second. This parameter is included for API compatibility, + but doesn't affect the actual frame rate which is determined by the camera. + + Returns: + Observable: An observable stream of video frames or None if video is not available. + """ + try: + print("Starting WebRTC video stream...") + stream = self.video_stream() + if stream is None: + print("Warning: Video stream is not available") + return stream + except Exception as e: + print(f"Error getting video stream: {e}") + return None + + def stop(self): + if hasattr(self, "task") and self.task: + self.task.cancel() + if hasattr(self, "conn"): + + async def disconnect(): + try: + await self.conn.disconnect() + except: + pass + + if hasattr(self, "loop") and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(disconnect(), self.loop) + + if hasattr(self, "loop") and self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + + if hasattr(self, "thread") and self.thread.is_alive(): + self.thread.join(timeout=2.0) diff --git a/dimos/robot/unitree_webrtc/test_tooling.py b/dimos/robot/unitree_webrtc/test_tooling.py new file mode 100644 index 0000000000..917cca69a0 --- /dev/null +++ b/dimos/robot/unitree_webrtc/test_tooling.py @@ -0,0 +1,71 @@ +import os +import sys +import time + +import pytest +from dotenv import load_dotenv +import reactivex.operators as ops + +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree_webrtc.testing.multimock import Multimock +from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import position_from_odom + + +@pytest.mark.tool +def test_record_lidar(): + load_dotenv() + robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") + + print("Robot is standing up...") + robot.standup() + + lidar_store = Multimock("athens_lidar") + odom_store = Multimock("athens_odom") + lidar_store.consume(robot.raw_lidar_stream()).subscribe(print) + odom_store.consume(robot.raw_odom_stream()).subscribe(print) + + print("Recording, CTRL+C to kill") + + try: + while True: + time.sleep(0.1) + + except KeyboardInterrupt: + print("Robot is lying down...") + robot.liedown() + print("Exit") + sys.exit(0) + + +@pytest.mark.tool +def test_replay_recording(): + odom_stream = Multimock("athens_odom").stream().pipe(ops.map(position_from_odom)) + odom_stream.subscribe(lambda x: print(x)) + + map = Map() + + def lidarmsg(msg): + frame = LidarMessage.from_msg(msg) + map.add_frame(frame) + return [map, map.costmap.smudge()] + + global_map_stream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) + show3d_stream(global_map_stream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() + + +@pytest.mark.tool +def compare_events(): + odom_events = Multimock("athens_odom").list() + + map = Map() + + def lidarmsg(msg): + frame = LidarMessage.from_msg(msg) + map.add_frame(frame) + return [map, map.costmap.smudge()] + + global_map_stream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) + show3d_stream(global_map_stream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() diff --git a/dimos/robot/unitree_webrtc/testing/__init__.py b/dimos/robot/unitree_webrtc/testing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/robot/unitree_webrtc/testing/helpers.py b/dimos/robot/unitree_webrtc/testing/helpers.py new file mode 100644 index 0000000000..6f815abd56 --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/helpers.py @@ -0,0 +1,149 @@ +import time +import open3d as o3d +from typing import Callable, Union, Any, Protocol, Iterable +from reactivex.observable import Observable + +color1 = [1, 0.706, 0] +color2 = [0, 0.651, 0.929] +color3 = [0.8, 0.196, 0.6] +color4 = [0.235, 0.702, 0.443] +color = [color1, color2, color3, color4] + + +# benchmarking function can return int, which will be applied to the time. +# +# (in case there is some preparation within the fuction and this time needs to be subtracted +# from the benchmark target) +def benchmark(calls: int, targetf: Callable[[], Union[int, None]]) -> float: + start = time.time() + timemod = 0 + for _ in range(calls): + res = targetf() + if res is not None: + timemod += res + end = time.time() + return (end - start + timemod) * 1000 / calls + + +O3dDrawable = o3d.geometry.Geometry | o3d.geometry.LineSet | o3d.geometry.TriangleMesh | o3d.geometry.PointCloud + + +class ReturnsDrawable(Protocol): + def o3d_geometry(self) -> O3dDrawable: ... + + +Drawable = O3dDrawable | ReturnsDrawable + + +def show3d(*components: Iterable[Drawable], title: str = "open3d") -> o3d.visualization.Visualizer: + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=title) + for component in components: + # our custom drawable components should return an open3d geometry + if hasattr(component, "o3d_geometry"): + vis.add_geometry(component.o3d_geometry) + else: + vis.add_geometry(component) + + opt = vis.get_render_option() + opt.background_color = [0, 0, 0] + opt.point_size = 10 + vis.poll_events() + vis.update_renderer() + return vis + + +def multivis(*vis: o3d.visualization.Visualizer) -> None: + while True: + for v in vis: + v.poll_events() + v.update_renderer() + + +def show3d_stream( + geometry_observable: Observable[Any], + clearframe: bool = False, + title: str = "open3d", +) -> o3d.visualization.Visualizer: + """ + Visualize a stream of geometries using Open3D. The first geometry initializes the visualizer. + Subsequent geometries update the visualizer. If no new geometry, just poll events. + geometry_observable: Observable of objects with .o3d_geometry or Open3D geometry + """ + import threading + import queue + import time + from typing import Any + + q: queue.Queue[Any] = queue.Queue() + stop_flag = threading.Event() + + def on_next(geometry: O3dDrawable) -> None: + q.put(geometry) + + def on_error(e: Exception) -> None: + print(f"Visualization error: {e}") + stop_flag.set() + + def on_completed() -> None: + print("Geometry stream completed") + stop_flag.set() + + subscription = geometry_observable.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + ) + + def geom(geometry: Drawable) -> O3dDrawable: + """Extracts the Open3D geometry from the given object.""" + return geometry.o3d_geometry if hasattr(geometry, "o3d_geometry") else geometry + + # Wait for the first geometry + first_geometry = None + while first_geometry is None and not stop_flag.is_set(): + try: + first_geometry = q.get(timeout=100) + except queue.Empty: + print("No geometry received to visualize.") + return + + scene_geometries = [] + first_geom_obj = geom(first_geometry) + + scene_geometries.append(first_geom_obj) + + vis = show3d(first_geom_obj, title=title) + + try: + while not stop_flag.is_set(): + try: + geometry = q.get_nowait() + geom_obj = geom(geometry) + if clearframe: + scene_geometries = [] + vis.clear_geometries() + + vis.add_geometry(geom_obj) + scene_geometries.append(geom_obj) + else: + if geom_obj in scene_geometries: + print("updating existing geometry") + vis.update_geometry(geom_obj) + else: + print("new geometry") + vis.add_geometry(geom_obj) + scene_geometries.append(geom_obj) + except queue.Empty: + pass + vis.poll_events() + vis.update_renderer() + time.sleep(0.1) + + except KeyboardInterrupt: + print("closing visualizer...") + stop_flag.set() + vis.destroy_window() + subscription.dispose() + + return vis diff --git a/dimos/robot/unitree_webrtc/testing/mock.py b/dimos/robot/unitree_webrtc/testing/mock.py new file mode 100644 index 0000000000..ab28fcce02 --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/mock.py @@ -0,0 +1,77 @@ +import os +import pickle +import glob +from typing import Union, Iterator, cast, overload +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage, RawLidarMsg + +from reactivex import operators as ops +from reactivex import interval, from_iterable +from reactivex.observable import Observable + + +class Mock: + def __init__(self, root="office", autocast: bool = True): + current_dir = os.path.dirname(os.path.abspath(__file__)) + self.root = os.path.join(current_dir, f"mockdata/{root}") + self.autocast = autocast + self.cnt = 0 + + @overload + def load(self, name: Union[int, str], /) -> LidarMessage: ... + @overload + def load(self, *names: Union[int, str]) -> list[LidarMessage]: ... + + def load(self, *names: Union[int, str]) -> Union[LidarMessage, list[LidarMessage]]: + if len(names) == 1: + return self.load_one(names[0]) + return list(map(lambda name: self.load_one(name), names)) + + def load_one(self, name: Union[int, str]) -> LidarMessage: + if isinstance(name, int): + file_name = f"/lidar_data_{name:03d}.pickle" + else: + file_name = f"/{name}.pickle" + + full_path = self.root + file_name + with open(full_path, "rb") as f: + return LidarMessage.from_msg(cast(RawLidarMsg, pickle.load(f))) + + def iterate(self) -> Iterator[LidarMessage]: + pattern = os.path.join(self.root, "lidar_data_*.pickle") + print("loading data", pattern) + for file_path in sorted(glob.glob(pattern)): + basename = os.path.basename(file_path) + filename = os.path.splitext(basename)[0] + yield self.load_one(filename) + + def stream(self, rate_hz=10.0): + sleep_time = 1.0 / rate_hz + + return from_iterable(self.iterate()).pipe( + ops.zip(interval(sleep_time)), + ops.map(lambda x: x[0] if isinstance(x, tuple) else x), + ) + + def save_stream(self, observable: Observable[LidarMessage]): + return observable.pipe(ops.map(lambda frame: self.save_one(frame))) + + def save(self, *frames): + [self.save_one(frame) for frame in frames] + return self.cnt + + def save_one(self, frame): + file_name = f"/lidar_data_{self.cnt:03d}.pickle" + full_path = self.root + file_name + + self.cnt += 1 + + if os.path.isfile(full_path): + raise Exception(f"file {full_path} exists") + + if frame.__class__ == LidarMessage: + frame = frame.raw_msg + + with open(full_path, "wb") as f: + pickle.dump(frame, f) + + return self.cnt diff --git a/dimos/robot/unitree_webrtc/testing/multimock.py b/dimos/robot/unitree_webrtc/testing/multimock.py new file mode 100644 index 0000000000..5049e669da --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/multimock.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""Multimock – lightweight persistence & replay helper built on RxPy. + +A directory of pickle files acts as a tiny append-only log of (timestamp, data) +pairs. You can: + • save() / consume(): append new frames + • iterate(): read them back lazily + • interval_stream(): emit at a fixed cadence + • stream(): replay with original timing (optionally scaled) + +The implementation keeps memory usage constant by relying on reactive +operators instead of pre-materialising lists. Timing is reproduced via +`rx.timer`, and drift is avoided with `concat_map`. +""" + +from __future__ import annotations + +import glob +import os +import pickle +import time +from typing import Any, Generic, Iterator, List, Tuple, TypeVar, Union, Optional +from reactivex.scheduler import ThreadPoolScheduler + +from reactivex import from_iterable, interval, operators as ops +from reactivex.observable import Observable +from dimos.utils.threadpool import get_scheduler +from dimos.robot.unitree_webrtc.type.timeseries import TEvent, Timeseries + +T = TypeVar("T") + + +class Multimock(Generic[T], Timeseries[TEvent[T]]): + """Persist frames as pickle files and replay them with RxPy.""" + + def __init__(self, root: str = "office", file_prefix: str = "msg") -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + self.root = os.path.join(current_dir, f"multimockdata/{root}") + self.file_prefix = file_prefix + + os.makedirs(self.root, exist_ok=True) + self.cnt: int = 0 + + def save(self, *frames: Any) -> int: + """Persist one or more frames; returns the new counter value.""" + for frame in frames: + self.save_one(frame) + return self.cnt + + def save_one(self, frame: Any) -> int: + """Persist a single frame and return the running count.""" + file_name = f"/{self.file_prefix}_{self.cnt:03d}.pickle" + full_path = os.path.join(self.root, file_name.lstrip("/")) + self.cnt += 1 + + if os.path.isfile(full_path): + raise FileExistsError(f"file {full_path} exists") + + # Optional convinience magic to extract raw messages from advanced types + # trying to deprecate for now + # if hasattr(frame, "raw_msg"): + # frame = frame.raw_msg # type: ignore[attr-defined] + + with open(full_path, "wb") as f: + pickle.dump([time.time(), frame], f) + + return self.cnt + + def load(self, *names: Union[int, str]) -> List[Tuple[float, T]]: + """Load multiple items by name or index.""" + return list(map(self.load_one, names)) + + def load_one(self, name: Union[int, str]) -> TEvent[T]: + """Load a single item by name or index.""" + if isinstance(name, int): + file_name = f"/{self.file_prefix}_{name:03d}.pickle" + else: + file_name = f"/{name}.pickle" + + full_path = os.path.join(self.root, file_name.lstrip("/")) + + with open(full_path, "rb") as f: + timestamp, data = pickle.load(f) + + return TEvent(timestamp, data) + + def iterate(self) -> Iterator[TEvent[T]]: + """Yield all persisted TEvent(timestamp, data) pairs lazily in order.""" + pattern = os.path.join(self.root, f"{self.file_prefix}_*.pickle") + for file_path in sorted(glob.glob(pattern)): + with open(file_path, "rb") as f: + timestamp, data = pickle.load(f) + yield TEvent(timestamp, data) + + def list(self) -> List[TEvent[T]]: + return list(self.iterate()) + + def interval_stream(self, rate_hz: float = 10.0) -> Observable[T]: + """Emit frames at a fixed rate, ignoring recorded timing.""" + sleep_time = 1.0 / rate_hz + return from_iterable(self.iterate()).pipe( + ops.zip(interval(sleep_time)), + ops.map(lambda pair: pair[1]), # keep only the frame + ) + + def stream( + self, + replay_speed: float = 1.0, + scheduler: Optional[ThreadPoolScheduler] = None, + ) -> Observable[T]: + def _generator(): + prev_ts: float | None = None + for event in self.iterate(): + if prev_ts is not None: + delay = (event.ts - prev_ts).total_seconds() / replay_speed + time.sleep(delay) + prev_ts = event.ts + yield event.data + + return from_iterable(_generator(), scheduler=scheduler or get_scheduler()) + + def consume(self, observable: Observable[Any]) -> Observable[int]: + """Side-effect: save every frame that passes through.""" + return observable.pipe(ops.map(self.save_one)) + + def __iter__(self) -> Iterator[TEvent[T]]: + """Allow iteration over the Multimock instance to yield TEvent(timestamp, data) pairs.""" + return self.iterate() diff --git a/dimos/robot/unitree_webrtc/testing/test_mock.py b/dimos/robot/unitree_webrtc/testing/test_mock.py new file mode 100644 index 0000000000..fce99e6b77 --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/test_mock.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +import time +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.testing.mock import Mock + + +def test_mock_load_cast(): + mock = Mock("test") + + # Load a frame with type casting + frame = mock.load("a") + + # Verify it's a LidarMessage object + assert frame.__class__.__name__ == "LidarMessage" + assert hasattr(frame, "timestamp") + assert hasattr(frame, "origin") + assert hasattr(frame, "resolution") + assert hasattr(frame, "pointcloud") + + # Verify pointcloud has points + assert frame.pointcloud.has_points() + assert len(frame.pointcloud.points) > 0 + + +def test_mock_iterate(): + """Test the iterate method of the Mock class.""" + mock = Mock("office") + + # Test iterate method + frames = list(mock.iterate()) + assert len(frames) > 0 + for frame in frames: + assert isinstance(frame, LidarMessage) + assert frame.pointcloud.has_points() + + +def test_mock_stream(): + frames = [] + sub1 = Mock("office").stream(rate_hz=30.0).subscribe(on_next=frames.append) + time.sleep(0.1) + sub1.dispose() + + assert len(frames) >= 2 + assert isinstance(frames[0], LidarMessage) diff --git a/dimos/robot/unitree_webrtc/testing/test_multimock.py b/dimos/robot/unitree_webrtc/testing/test_multimock.py new file mode 100644 index 0000000000..230d960c58 --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/test_multimock.py @@ -0,0 +1,84 @@ +import time +import pytest + +from reactivex import operators as ops + +from dimos.utils.reactive import backpressure +from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream +from dimos.web.websocket_vis.server import WebsocketVis +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.robot.unitree_webrtc.type.timeseries import to_datetime +from dimos.robot.unitree_webrtc.testing.multimock import Multimock + + +@pytest.mark.vis +def test_multimock_stream(): + backpressure(Multimock("athens_odom").stream().pipe(ops.map(Odometry.from_msg))).subscribe(lambda x: print(x)) + map = Map() + + def lidarmsg(msg): + frame = LidarMessage.from_msg(msg) + map.add_frame(frame) + return [map, map.costmap.smudge()] + + mapstream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) + show3d_stream(mapstream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() + time.sleep(5) + + +def test_clock_mismatch(): + for odometry_raw in Multimock("athens_odom").iterate(): + print( + odometry_raw.ts - to_datetime(odometry_raw.data["data"]["header"]["stamp"]), + odometry_raw.data["data"]["header"]["stamp"], + ) + + +def test_odom_stream(): + for odometry_raw in Multimock("athens_odom").iterate(): + print(Odometry.from_msg(odometry_raw.data)) + + +def test_lidar_stream(): + for lidar_raw in Multimock("athens_lidar").iterate(): + lidarmsg = LidarMessage.from_msg(lidar_raw.data) + print(lidarmsg) + print(lidar_raw) + + +def test_multimock_timeseries(): + odom = Odometry.from_msg(Multimock("athens_odom").load_one(1).data) + lidar_raw = Multimock("athens_lidar").load_one(1).data + lidar = LidarMessage.from_msg(lidar_raw) + map = Map() + map.add_frame(lidar) + print(odom) + print(lidar) + print(lidar_raw) + print(map.costmap) + + +def test_origin_changes(): + for lidar_raw in Multimock("athens_lidar").iterate(): + print(LidarMessage.from_msg(lidar_raw.data).origin) + + +@pytest.mark.vis +def test_webui_multistream(): + websocket_vis = WebsocketVis() + websocket_vis.start() + + odom_stream = Multimock("athens_odom").stream().pipe(ops.map(Odometry.from_msg)) + lidar_stream = backpressure(Multimock("athens_lidar").stream().pipe(ops.map(LidarMessage.from_msg))) + + map = Map() + map_stream = map.consume(lidar_stream) + + costmap_stream = map_stream.pipe(ops.map(lambda x: ["costmap", map.costmap.smudge(preserve_unknown=False)])) + + websocket_vis.connect(costmap_stream) + websocket_vis.connect(odom_stream.pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) + + show3d_stream(lidar_stream, clearframe=True).run() diff --git a/dimos/robot/unitree_webrtc/type/__init__.py b/dimos/robot/unitree_webrtc/type/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/robot/unitree_webrtc/type/costmap.py b/dimos/robot/unitree_webrtc/type/costmap.py new file mode 100644 index 0000000000..49e600ab46 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/costmap.py @@ -0,0 +1,330 @@ +import base64 +import pickle +import numpy as np +from typing import Optional +from scipy import ndimage +from dimos.types.vector import Vector, VectorLike, x, y, to_vector +import open3d as o3d +from matplotlib import cm # any matplotlib colormap + +DTYPE2STR = { + np.float32: "f32", + np.float64: "f64", + np.int32: "i32", + np.int8: "i8", +} + +STR2DTYPE = {v: k for k, v in DTYPE2STR.items()} + + +def encode_ndarray(arr: np.ndarray, compress: bool = False): + arr_c = np.ascontiguousarray(arr) + payload = arr_c.tobytes() + b64 = base64.b64encode(payload).decode("ascii") + + return { + "type": "grid", + "shape": arr_c.shape, + "dtype": DTYPE2STR[arr_c.dtype.type], + "data": b64, + } + + +class Costmap: + """Class to hold ROS OccupancyGrid data.""" + + def __init__( + self, + grid: np.ndarray, + origin: VectorLike, + resolution: float = 0.05, + ): + """Initialize Costmap with its core attributes.""" + self.grid = grid + self.resolution = resolution + self.origin = to_vector(origin).to_2d() + self.width = self.grid.shape[1] + self.height = self.grid.shape[0] + + def serialize(self) -> dict: + """Serialize the Costmap instance to a dictionary.""" + return { + "type": "costmap", + "grid": encode_ndarray(self.grid), + "origin": self.origin.serialize(), + "resolution": self.resolution, + } + + def save_pickle(self, pickle_path: str): + """Save costmap to a pickle file. + + Args: + pickle_path: Path to save the pickle file + """ + with open(pickle_path, "wb") as f: + pickle.dump(self, f) + + @classmethod + def create_empty(cls, width: int = 100, height: int = 100, resolution: float = 0.1) -> "Costmap": + """Create an empty costmap with specified dimensions.""" + return cls( + grid=np.zeros((height, width), dtype=np.int8), + resolution=resolution, + origin=(0.0, 0.0), + ) + + def world_to_grid(self, point: VectorLike) -> Vector: + """Convert world coordinates to grid coordinates. + + Args: + point: A vector-like object containing X,Y coordinates + + Returns: + Vector containing grid_x and grid_y coordinates + """ + return (to_vector(point) - self.origin) / self.resolution + + def grid_to_world(self, grid_point: VectorLike) -> Vector: + return to_vector(grid_point) * self.resolution + self.origin + + def is_occupied(self, point: VectorLike, threshold: int = 50) -> bool: + """Check if a position in world coordinates is occupied. + + Args: + point: Vector-like object containing X,Y coordinates + threshold: Cost threshold above which a cell is considered occupied (0-100) + + Returns: + True if position is occupied or out of bounds, False otherwise + """ + grid_pos = self.world_to_grid(point) + + if 0 <= grid_pos.x < self.width and 0 <= grid_pos.y < self.height: + # Consider unknown (-1) as unoccupied for navigation purposes + # Convert to int coordinates for grid indexing + grid_y, grid_x = int(grid_pos.y), int(grid_pos.x) + value = self.grid[grid_y, grid_x] + return bool(value > 0 and value >= threshold) + return True # Consider out-of-bounds as occupied + + def get_value(self, point: VectorLike) -> Optional[int]: + grid_pos = self.world_to_grid(point) + + if 0 <= grid_pos.x < self.width and 0 <= grid_pos.y < self.height: + grid_y, grid_x = int(grid_pos.y), int(grid_pos.x) + return int(self.grid[grid_y, grid_x]) + return None + + def set_value(self, point: VectorLike, value: int = 0) -> bool: + grid_pos = self.world_to_grid(point) + + if 0 <= grid_pos.x < self.width and 0 <= grid_pos.y < self.height: + grid_y, grid_x = int(grid_pos.y), int(grid_pos.x) + self.grid[grid_y, grid_x] = value + return True + return False + + def smudge( + self, + kernel_size: int = 3, + iterations: int = 20, + decay_factor: float = 0.9, + threshold: int = 90, + preserve_unknown: bool = False, + ) -> "Costmap": + """ + Creates a new costmap with expanded obstacles (smudged). + + Args: + kernel_size: Size of the convolution kernel for dilation (must be odd) + iterations: Number of dilation iterations + decay_factor: Factor to reduce cost as distance increases (0.0-1.0) + threshold: Minimum cost value to consider as an obstacle for expansion + preserve_unknown: Whether to keep unknown (-1) cells as unknown + + Returns: + A new Costmap instance with expanded obstacles + """ + # Make sure kernel size is odd + if kernel_size % 2 == 0: + kernel_size += 1 + + # Create a copy of the grid for processing + grid_copy = self.grid.copy() + + # Create a mask of unknown cells if needed + unknown_mask = None + if preserve_unknown: + unknown_mask = grid_copy == -1 + # Temporarily replace unknown cells with 0 for processing + # This allows smudging to go over unknown areas + grid_copy[unknown_mask] = 0 + + # Create a mask of cells that are above the threshold + obstacle_mask = grid_copy >= threshold + + # Create a binary map of obstacles + binary_map = obstacle_mask.astype(np.uint8) * 100 + + # Create a circular kernel for dilation (instead of square) + y, x = np.ogrid[ + -kernel_size // 2 : kernel_size // 2 + 1, + -kernel_size // 2 : kernel_size // 2 + 1, + ] + kernel = (x * x + y * y <= (kernel_size // 2) * (kernel_size // 2)).astype(np.uint8) + + # Create distance map using dilation + # Each iteration adds one 'ring' of cells around obstacles + dilated_map = binary_map.copy() + + # Store each layer of dilation with decreasing values + layers = [] + + # First layer is the original obstacle cells + layers.append(binary_map.copy()) + + for i in range(iterations): + # Dilate the binary map + dilated = ndimage.binary_dilation(dilated_map > 0, structure=kernel, iterations=1).astype(np.uint8) + + # Calculate the new layer (cells that were just added in this iteration) + new_layer = (dilated - (dilated_map > 0).astype(np.uint8)) * 100 + + # Apply decay factor based on distance from obstacle + new_layer = new_layer * (decay_factor ** (i + 1)) + + # Add to layers list + layers.append(new_layer) + + # Update dilated map for next iteration + dilated_map = dilated * 100 + + # Combine all layers to create a distance-based cost map + smudged_map = np.zeros_like(grid_copy) + for layer in layers: + # For each cell, keep the maximum value across all layers + smudged_map = np.maximum(smudged_map, layer) + + # Preserve original obstacles + smudged_map[obstacle_mask] = grid_copy[obstacle_mask] + + # When preserve_unknown is true, restore all original unknown cells + # This overlays unknown cells on top of the smudged map + if preserve_unknown and unknown_mask is not None: + smudged_map[unknown_mask] = -1 + + # Ensure cost values are in valid range (0-100) except for unknown (-1) + if preserve_unknown and unknown_mask is not None: + valid_cells = ~unknown_mask + smudged_map[valid_cells] = np.clip(smudged_map[valid_cells], 0, 100) + else: + smudged_map = np.clip(smudged_map, 0, 100) + + # Create a new costmap with the smudged grid + return Costmap( + grid=smudged_map.astype(np.int8), + resolution=self.resolution, + origin=self.origin, + ) + + def __str__(self) -> str: + """ + Create a string representation of the Costmap. + + Returns: + A formatted string with key costmap information + """ + # Calculate occupancy statistics + total_cells = self.width * self.height + occupied_cells = np.sum(self.grid >= 0.1) + unknown_cells = np.sum(self.grid == -1) + free_cells = total_cells - occupied_cells - unknown_cells + + # Calculate percentages + occupied_percent = (occupied_cells / total_cells) * 100 + unknown_percent = (unknown_cells / total_cells) * 100 + free_percent = (free_cells / total_cells) * 100 + + cell_info = [ + "▦ Costmap", + f"{self.width}x{self.height}", + f"({self.width * self.resolution:.1f}x{self.height * self.resolution:.1f}m @", + f"{1 / self.resolution:.0f}cm res)", + f"Origin: ({x(self.origin):.2f}, {y(self.origin):.2f})", + f"▣ {occupied_percent:.1f}%", + f"□ {free_percent:.1f}%", + f"◌ {unknown_percent:.1f}%", + ] + + return " ".join(cell_info) + + @property + def o3d_geometry(self): + return self.pointcloud + + @property + def pointcloud(self, *, res: float = 0.25, origin=(0.0, 0.0), show_unknown: bool = False): + """ + Visualise a 2-D costmap (int8, −1…100) as an Open3D PointCloud. + + • −1 → ‘unknown’ (optionally drawn as mid-grey, or skipped) + • 0 → free + • 1-99→ graduated cost (turbo colour-ramp) + • 100 → lethal / obstacle (red end of ramp) + + Parameters + ---------- + res : float + Cell size in metres. + origin : (float, float) + World-space coord of costmap [row0,col0] centre. + show_unknown : bool + If true, draw unknown cells in grey; otherwise omit them. + """ + cost = np.asarray(self.grid, dtype=np.int16) + if cost.ndim != 2: + raise ValueError("cost map must be 2-D (H×W)") + + H, W = cost.shape + ys, xs = np.mgrid[0:H, 0:W] + + # ---------- flatten & mask -------------------------------------------------- + xs = xs.ravel() + ys = ys.ravel() + vals = cost.ravel() + + unknown_mask = vals == -1 + if not show_unknown: + keep = ~unknown_mask + xs, ys, vals = xs[keep], ys[keep], vals[keep] + + # ---------- 3-D points ------------------------------------------------------ + xyz = np.column_stack( + ( + (xs + 0.5) * res + origin[0], # X + (ys + 0.5) * res + origin[1], # Y + np.zeros_like(xs, dtype=np.float32), # Z = 0 + ) + ) + + # ---------- colours --------------------------------------------------------- + rgb = np.empty((len(vals), 3), dtype=np.float32) + + if show_unknown: + # mid-grey for unknown + rgb[unknown_mask[~unknown_mask if not show_unknown else slice(None)]] = ( + 0.4, + 0.4, + 0.4, + ) + + # normalise valid costs: 0…100 → 0…1 + norm = np.clip(vals.astype(np.float32), 0, 100) / 100.0 + rgb_valid = cm.turbo(norm)[:, :3] # type: ignore[attr-defined] # strip alpha + rgb[:] = rgb_valid # unknown already set if needed + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(xyz) + pcd.colors = o3d.utility.Vector3dVector(rgb) + + return pcd diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py new file mode 100644 index 0000000000..37a51b702a --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -0,0 +1,139 @@ +from dimos.robot.unitree_webrtc.testing.helpers import color +from datetime import datetime +from dimos.robot.unitree_webrtc.type.timeseries import Timestamped, to_datetime, to_human_readable +from dimos.types.vector import Vector +from dataclasses import dataclass, field +from typing import List, TypedDict +import numpy as np +import open3d as o3d +from copy import copy + + +class RawLidarPoints(TypedDict): + points: np.ndarray # Shape (N, 3) array of 3D points [x, y, z] + + +class RawLidarData(TypedDict): + """Data portion of the LIDAR message""" + + frame_id: str + origin: List[float] + resolution: float + src_size: int + stamp: float + width: List[int] + data: RawLidarPoints + + +class RawLidarMsg(TypedDict): + """Static type definition for raw LIDAR message""" + + type: str + topic: str + data: RawLidarData + + +@dataclass +class LidarMessage(Timestamped): + ts: datetime + origin: Vector + resolution: float + pointcloud: o3d.geometry.PointCloud + raw_msg: RawLidarMsg = field(repr=False, default=None) + + @classmethod + def from_msg(cls, raw_message: RawLidarMsg) -> "LidarMessage": + data = raw_message["data"] + points = data["data"]["points"] + point_cloud = o3d.geometry.PointCloud() + point_cloud.points = o3d.utility.Vector3dVector(points) + return cls( + ts=to_datetime(data["stamp"]), + origin=Vector(data["origin"]), + resolution=data["resolution"], + pointcloud=point_cloud, + raw_msg=raw_message, + ) + + def __repr__(self): + return f"LidarMessage(ts={to_human_readable(self.ts)}, origin={self.origin}, resolution={self.resolution}, {self.pointcloud})" + + def __iadd__(self, other: "LidarMessage") -> "LidarMessage": + self.pointcloud += other.pointcloud + return self + + def __add__(self, other: "LidarMessage") -> "LidarMessage": + # Create a new point cloud combining both + + # Determine which message is more recent + if self.timestamp >= other.timestamp: + timestamp = self.timestamp + origin = self.origin + resolution = self.resolution + else: + timestamp = other.timestamp + origin = other.origin + resolution = other.resolution + + # Return a new LidarMessage with combined data + return LidarMessage( + timestamp=timestamp, + origin=origin, + resolution=resolution, + pointcloud=self.pointcloud + other.pointcloud, + ).estimate_normals() + + @property + def o3d_geometry(self): + return self.pointcloud + + def icp(self, other: "LidarMessage") -> o3d.pipelines.registration.RegistrationResult: + self.estimate_normals() + other.estimate_normals() + + reg_p2l = o3d.pipelines.registration.registration_icp( + self.pointcloud, + other.pointcloud, + 0.1, + np.identity(4), + o3d.pipelines.registration.TransformationEstimationPointToPlane(), + o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=100), + ) + + return reg_p2l + + def transform(self, transform) -> "LidarMessage": + self.pointcloud.transform(transform) + return self + + def clone(self) -> "LidarMessage": + return self.copy() + + def copy(self) -> "LidarMessage": + return LidarMessage( + ts=self.ts, + origin=copy(self.origin), + resolution=self.resolution, + # TODO: seems to work, but will it cause issues because of the shallow copy? + pointcloud=copy(self.pointcloud), + ) + + def icptransform(self, other): + return self.transform(self.icp(other).transformation) + + def estimate_normals(self) -> "LidarMessage": + # Check if normals already exist by testing if the normals attribute has data + if not self.pointcloud.has_normals() or len(self.pointcloud.normals) == 0: + self.pointcloud.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30)) + return self + + def color(self, color_choice) -> "LidarMessage": + def get_color(color_choice): + if isinstance(color_choice, int): + return color[color_choice] + return color_choice + + self.pointcloud.paint_uniform_color(get_color(color_choice)) + # Looks like we'll be displaying so might as well? + self.estimate_normals() + return self diff --git a/dimos/robot/unitree_webrtc/type/lowstate.py b/dimos/robot/unitree_webrtc/type/lowstate.py new file mode 100644 index 0000000000..48c0d23a5f --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/lowstate.py @@ -0,0 +1,79 @@ +from typing import TypedDict, List, Literal + +raw_odom_msg_sample = { + "type": "msg", + "topic": "rt/lf/lowstate", + "data": { + "imu_state": {"rpy": [0.008086, -0.007515, 2.981771]}, + "motor_state": [ + {"q": 0.098092, "temperature": 40, "lost": 0, "reserve": [0, 674]}, + {"q": 0.757921, "temperature": 32, "lost": 0, "reserve": [0, 674]}, + {"q": -1.490911, "temperature": 38, "lost": 6, "reserve": [0, 674]}, + {"q": -0.072477, "temperature": 42, "lost": 0, "reserve": [0, 674]}, + {"q": 1.020276, "temperature": 32, "lost": 5, "reserve": [0, 674]}, + {"q": -2.007172, "temperature": 38, "lost": 5, "reserve": [0, 674]}, + {"q": 0.071382, "temperature": 50, "lost": 5, "reserve": [0, 674]}, + {"q": 0.963379, "temperature": 36, "lost": 6, "reserve": [0, 674]}, + {"q": -1.978311, "temperature": 40, "lost": 5, "reserve": [0, 674]}, + {"q": -0.051066, "temperature": 48, "lost": 0, "reserve": [0, 674]}, + {"q": 0.73103, "temperature": 34, "lost": 10, "reserve": [0, 674]}, + {"q": -1.466473, "temperature": 38, "lost": 6, "reserve": [0, 674]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + ], + "bms_state": { + "version_high": 1, + "version_low": 18, + "soc": 55, + "current": -2481, + "cycle": 56, + "bq_ntc": [30, 29], + "mcu_ntc": [33, 32], + }, + "foot_force": [97, 84, 81, 81], + "temperature_ntc1": 48, + "power_v": 28.331045, + }, +} + + +class MotorState(TypedDict): + q: float + temperature: int + lost: int + reserve: List[int] + + +class ImuState(TypedDict): + rpy: List[float] + + +class BmsState(TypedDict): + version_high: int + version_low: int + soc: int + current: int + cycle: int + bq_ntc: List[int] + mcu_ntc: List[int] + + +class LowStateData(TypedDict): + imu_state: ImuState + motor_state: List[MotorState] + bms_state: BmsState + foot_force: List[int] + temperature_ntc1: int + power_v: float + + +class LowStateMsg(TypedDict): + type: Literal["msg"] + topic: str + data: LowStateData diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py new file mode 100644 index 0000000000..eef15bdeef --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -0,0 +1,168 @@ +import open3d as o3d +import numpy as np +from dataclasses import dataclass +from typing import Tuple, Optional + +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.costmap import Costmap + +from reactivex.observable import Observable +import reactivex.operators as ops + + +@dataclass +class Map: + pointcloud: o3d.geometry.PointCloud = o3d.geometry.PointCloud() + voxel_size: float = 0.05 + cost_resolution: float = 0.05 + + def add_frame(self, frame: LidarMessage) -> "Map": + """Voxelise *frame* and splice it into the running map.""" + new_pct = frame.pointcloud.voxel_down_sample(voxel_size=self.voxel_size) + self.pointcloud = splice_cylinder(self.pointcloud, new_pct, shrink=0.5) + return self + + def consume(self, observable: Observable[LidarMessage]) -> Observable["Map"]: + """Reactive operator that folds a stream of `LidarMessage` into the map.""" + return observable.pipe(ops.map(self.add_frame)) + + @property + def o3d_geometry(self) -> o3d.geometry.PointCloud: + return self.pointcloud + + @property + def costmap(self) -> Costmap: + """Return a fully inflated cost-map in a `Costmap` wrapper.""" + inflate_radius_m = 0.5 * self.voxel_size if self.voxel_size > self.cost_resolution else 0.0 + grid, origin_xy = pointcloud_to_costmap( + self.pointcloud, + resolution=self.cost_resolution, + inflate_radius_m=inflate_radius_m, + ) + return Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.cost_resolution) + + +def splice_sphere( + map_pcd: o3d.geometry.PointCloud, + patch_pcd: o3d.geometry.PointCloud, + shrink: float = 0.95, +) -> o3d.geometry.PointCloud: + center = patch_pcd.get_center() + radius = np.linalg.norm(np.asarray(patch_pcd.points) - center, axis=1).max() * shrink + dists = np.linalg.norm(np.asarray(map_pcd.points) - center, axis=1) + victims = np.nonzero(dists < radius)[0] + survivors = map_pcd.select_by_index(victims, invert=True) + return survivors + patch_pcd + + +def splice_cylinder( + map_pcd: o3d.geometry.PointCloud, + patch_pcd: o3d.geometry.PointCloud, + axis: int = 2, + shrink: float = 0.95, +) -> o3d.geometry.PointCloud: + center = patch_pcd.get_center() + patch_pts = np.asarray(patch_pcd.points) + + # Axes perpendicular to cylinder + axes = [0, 1, 2] + axes.remove(axis) + + planar_dists = np.linalg.norm(patch_pts[:, axes] - center[axes], axis=1) + radius = planar_dists.max() * shrink + + axis_min = (patch_pts[:, axis].min() - center[axis]) * shrink + center[axis] + axis_max = (patch_pts[:, axis].max() - center[axis]) * shrink + center[axis] + + map_pts = np.asarray(map_pcd.points) + planar_dists_map = np.linalg.norm(map_pts[:, axes] - center[axes], axis=1) + + victims = np.nonzero((planar_dists_map < radius) & (map_pts[:, axis] >= axis_min) & (map_pts[:, axis] <= axis_max))[ + 0 + ] + + survivors = map_pcd.select_by_index(victims, invert=True) + return survivors + patch_pcd + + +def _inflate_lethal(costmap: np.ndarray, radius: int, lethal_val: int = 100) -> np.ndarray: + """Return *costmap* with lethal cells dilated by *radius* grid steps (circular).""" + if radius <= 0 or not np.any(costmap == lethal_val): + return costmap + + mask = costmap == lethal_val + dilated = mask.copy() + for dy in range(-radius, radius + 1): + for dx in range(-radius, radius + 1): + if dx * dx + dy * dy > radius * radius or (dx == 0 and dy == 0): + continue + dilated |= np.roll(mask, shift=(dy, dx), axis=(0, 1)) + + out = costmap.copy() + out[dilated] = lethal_val + return out + + +def pointcloud_to_costmap( + pcd: o3d.geometry.PointCloud, + *, + resolution: float = 0.05, + ground_z: float = 0.0, + obs_min_height: float = 0.15, + max_height: Optional[float] = 0.5, + inflate_radius_m: Optional[float] = None, + default_unknown: int = -1, + cost_free: int = 0, + cost_lethal: int = 100, +) -> Tuple[np.ndarray, np.ndarray]: + """Rasterise *pcd* into a 2-D int8 cost-map with optional obstacle inflation. + + Grid origin is **aligned** to the `resolution` lattice so that when + `resolution == voxel_size` every voxel centroid lands squarely inside a cell + (no alternating blank lines). + """ + + pts = np.asarray(pcd.points, dtype=np.float32) + if pts.size == 0: + return np.full((1, 1), default_unknown, np.int8), np.zeros(2, np.float32) + + # 0. Ceiling filter -------------------------------------------------------- + if max_height is not None: + pts = pts[pts[:, 2] <= max_height] + if pts.size == 0: + return np.full((1, 1), default_unknown, np.int8), np.zeros(2, np.float32) + + # 1. Bounding box & aligned origin --------------------------------------- + xy_min = pts[:, :2].min(axis=0) + xy_max = pts[:, :2].max(axis=0) + + # Align origin to the resolution grid (anchor = 0,0) + origin = np.floor(xy_min / resolution) * resolution + + # Grid dimensions (inclusive) ------------------------------------------- + Nx, Ny = (np.ceil((xy_max - origin) / resolution).astype(int) + 1).tolist() + + # 2. Bin points ------------------------------------------------------------ + idx_xy = np.floor((pts[:, :2] - origin) / resolution).astype(np.int32) + np.clip(idx_xy[:, 0], 0, Nx - 1, out=idx_xy[:, 0]) + np.clip(idx_xy[:, 1], 0, Ny - 1, out=idx_xy[:, 1]) + + lin = idx_xy[:, 1] * Nx + idx_xy[:, 0] + z_max = np.full(Nx * Ny, -np.inf, np.float32) + np.maximum.at(z_max, lin, pts[:, 2]) + z_max = z_max.reshape(Ny, Nx) + + # 3. Cost rules ----------------------------------------------------------- + costmap = np.full_like(z_max, default_unknown, np.int8) + known = z_max != -np.inf + costmap[known] = cost_free + + lethal = z_max >= (ground_z + obs_min_height) + costmap[lethal] = cost_lethal + + # 4. Optional inflation ---------------------------------------------------- + if inflate_radius_m and inflate_radius_m > 0: + cells = int(np.ceil(inflate_radius_m / resolution)) + costmap = _inflate_lethal(costmap, cells, lethal_val=cost_lethal) + + return costmap, origin.astype(np.float32) diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py new file mode 100644 index 0000000000..df10bd8d54 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -0,0 +1,74 @@ +from typing import TypedDict, Literal +from datetime import datetime +from dataclasses import dataclass +from dimos.types.vector import Vector +from dimos.types.position import Position +from dimos.robot.unitree_webrtc.type.timeseries import Timestamped, to_human_readable + +raw_odometry_msg_sample = { + "type": "msg", + "topic": "rt/utlidar/robot_pose", + "data": { + "header": {"stamp": {"sec": 1746565669, "nanosec": 448350564}, "frame_id": "odom"}, + "pose": { + "position": {"x": 5.961965, "y": -2.916958, "z": 0.319509}, + "orientation": {"x": 0.002787, "y": -0.000902, "z": -0.970244, "w": -0.242112}, + }, + }, +} + + +class TimeStamp(TypedDict): + sec: int + nanosec: int + + +class Header(TypedDict): + stamp: TimeStamp + frame_id: str + + +class RawPosition(TypedDict): + x: float + y: float + z: float + + +class Orientation(TypedDict): + x: float + y: float + z: float + w: float + + +class Pose(TypedDict): + position: RawPosition + orientation: Orientation + + +class OdometryData(TypedDict): + header: Header + pose: Pose + + +class RawOdometryMessage(TypedDict): + type: Literal["msg"] + topic: str + data: OdometryData + + +@dataclass +class Odometry(Timestamped, Position): + ts: datetime + + @classmethod + def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": + pose = msg["data"]["pose"] + orientation = pose["orientation"] + position = pose["position"] + pos = Vector(position.get("x"), position.get("y"), position.get("z")) + rot = Vector(orientation.get("x"), orientation.get("y"), orientation.get("z")) + return cls(pos=pos, rot=rot, ts=msg["data"]["header"]["stamp"]) + + def __repr__(self) -> str: + return f"Odom ts({to_human_readable(self.ts)}) pos({self.pos}), rot({self.rot})" diff --git a/dimos/robot/unitree_webrtc/type/test_lidar.py b/dimos/robot/unitree_webrtc/type/test_lidar.py new file mode 100644 index 0000000000..2c80f9013a --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/test_lidar.py @@ -0,0 +1,120 @@ +import pytest +import time +import open3d as o3d + +from dimos.types.vector import Vector +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + +from dimos.robot.unitree_webrtc.testing.mock import Mock +from dimos.robot.unitree_webrtc.testing.helpers import show3d, multivis, benchmark + + +def test_load(): + mock = Mock("test") + frame = mock.load("a") + + # Validate the result + assert isinstance(frame, LidarMessage) + assert isinstance(frame.timestamp, float) + assert isinstance(frame.origin, Vector) + assert isinstance(frame.resolution, float) + assert isinstance(frame.pointcloud, o3d.geometry.PointCloud) + assert len(frame.pointcloud.points) > 0 + + +def test_add(): + mock = Mock("test") + [frame_a, frame_b] = mock.load("a", "b") + + # Get original point counts + points_a = len(frame_a.pointcloud.points) + points_b = len(frame_b.pointcloud.points) + + # Add the frames + combined = frame_a + frame_b + + assert isinstance(combined, LidarMessage) + assert len(combined.pointcloud.points) == points_a + points_b + + # Check metadata is from the most recent message + if frame_a.timestamp >= frame_b.timestamp: + assert combined.timestamp == frame_a.timestamp + assert combined.origin == frame_a.origin + assert combined.resolution == frame_a.resolution + else: + assert combined.timestamp == frame_b.timestamp + assert combined.origin == frame_b.origin + assert combined.resolution == frame_b.resolution + + +@pytest.mark.vis +def test_icp_vis(): + mock = Mock("test") + [framea, frameb] = mock.load("a", "b") + + # framea.pointcloud = framea.pointcloud.voxel_down_sample(voxel_size=0.1) + # frameb.pointcloud = frameb.pointcloud.voxel_down_sample(voxel_size=0.1) + + framea.color(0) + frameb.color(1) + + # Normally this is a mutating operation (for efficiency) + # but here we need an original frame A for the visualizer + framea_icp = framea.copy().icptransform(frameb) + + multivis( + show3d(framea, title="frame a"), + show3d(frameb, title="frame b"), + show3d((framea + frameb), title="union"), + show3d((framea_icp + frameb), title="ICP"), + ) + + +@pytest.mark.benchmark +def test_benchmark_icp(): + frames = Mock("dynamic_house").iterate() + + prev_frame = None + + def icptest(): + nonlocal prev_frame + start = time.time() + + current_frame = frames.__next__() + if not prev_frame: + prev_frame = frames.__next__() + end = time.time() + + current_frame.icptransform(prev_frame) + # for subtracting the time of the function exec + return (end - start) * -1 + + ms = benchmark(100, icptest) + assert ms < 20, "ICP took too long" + + print(f"ICP takes {ms:.2f} ms") + + +@pytest.mark.vis +def test_downsample(): + mock = Mock("test") + [framea, frameb] = mock.load("a", "b") + + # framea.pointcloud = framea.pointcloud.voxel_down_sample(voxel_size=0.1) + # frameb.pointcloud = frameb.pointcloud.voxel_down_sample(voxel_size=0.1) + + # framea.color(0) + # frameb.color(1) + + # Normally this is a mutating operation (for efficiency) + # but here we need an original frame A for the visualizer + # framea_icp = framea.copy().icptransform(frameb) + pcd = framea.copy().pointcloud + newpcd, _, _ = pcd.voxel_down_sample_and_trace( + voxel_size=0.25, min_bound=pcd.get_min_bound(), max_bound=pcd.get_max_bound(), approximate_class=False + ) + + multivis( + show3d(framea, title="frame a"), + show3d(newpcd, title="frame a downsample"), + ) diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py new file mode 100644 index 0000000000..8533371a45 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/test_map.py @@ -0,0 +1,39 @@ +import pytest +from dimos.robot.unitree_webrtc.testing.mock import Mock +from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream, show3d +from dimos.robot.unitree_webrtc.utils.reactive import backpressure +from dimos.robot.unitree_webrtc.type.map import splice_sphere, Map +from dimos.robot.unitree_webrtc.lidar import lidar + + +@pytest.mark.vis +def test_costmap_vis(): + map = Map() + for frame in Mock("office").iterate(): + print(frame) + map.add_frame(frame) + costmap = map.costmap + print(costmap) + show3d(costmap.smudge().pointcloud, title="Costmap").run() + + +@pytest.mark.vis +def test_reconstruction_with_realtime_vis(): + show3d_stream(Map().consume(Mock("office").stream(rate_hz=60.0)), clearframe=True).run() + + +@pytest.mark.vis +def test_splice_vis(): + mock = Mock("test") + target = mock.load("a") + insert = mock.load("b") + show3d(splice_sphere(target.pointcloud, insert.pointcloud, shrink=0.7)).run() + + +@pytest.mark.vis +def test_robot_vis(): + show3d_stream( + Map().consume(backpressure(lidar())), + clearframe=True, + title="gloal dynamic map test", + ) diff --git a/dimos/robot/unitree_webrtc/type/test_odometry.py b/dimos/robot/unitree_webrtc/type/test_odometry.py new file mode 100644 index 0000000000..3061eeb92e --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/test_odometry.py @@ -0,0 +1,8 @@ +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.robot.unitree_webrtc.testing.multimock import Multimock + + +def test_odometry_time(): + (timestamp, odom_raw) = Multimock("athens_odom").load_one(33) + print("RAW MSG", odom_raw) + print(Odometry.from_msg(odom_raw)) diff --git a/dimos/robot/unitree_webrtc/type/test_timeseries.py b/dimos/robot/unitree_webrtc/type/test_timeseries.py new file mode 100644 index 0000000000..00f29c3202 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/test_timeseries.py @@ -0,0 +1,31 @@ +from datetime import timedelta, datetime +from typing import TypeVar +from dimos.robot.unitree_webrtc.type.timeseries import TEvent, TList + + +fixed_date = datetime(2025, 5, 13, 15, 2, 5).astimezone() +start_event = TEvent(fixed_date, 1) +end_event = TEvent(fixed_date + timedelta(seconds=10), 9) + +sample_list = TList([start_event, TEvent(fixed_date + timedelta(seconds=2), 5), end_event]) + + +def test_repr(): + assert ( + str(sample_list) + == "Timeseries(date=2025-05-13, start=15:02:05, end=15:02:15, duration=0:00:10, events=3, freq=0.30Hz)" + ) + + +def test_equals(): + assert start_event == TEvent(start_event.ts, 1) + assert start_event != TEvent(start_event.ts, 2) + assert start_event != TEvent(start_event.ts + timedelta(seconds=1), 1) + + +def test_range(): + assert sample_list.time_range() == (start_event.ts, end_event.ts) + + +def test_duration(): + assert sample_list.duration() == timedelta(seconds=10) diff --git a/dimos/robot/unitree_webrtc/type/timeseries.py b/dimos/robot/unitree_webrtc/type/timeseries.py new file mode 100644 index 0000000000..84d2910622 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/timeseries.py @@ -0,0 +1,130 @@ +from __future__ import annotations +from datetime import datetime, timedelta, timezone +from typing import Iterable, TypeVar, Generic, Tuple, Union, TypedDict +from abc import ABC, abstractmethod + + +PAYLOAD = TypeVar("PAYLOAD") + + +class RosStamp(TypedDict): + sec: int + nanosec: int + + +EpochLike = Union[int, float, datetime, RosStamp] + + +def from_ros_stamp(stamp: dict[str, int], tz: timezone = None) -> datetime: + """Convert ROS-style timestamp {'sec': int, 'nanosec': int} to datetime.""" + return datetime.fromtimestamp(stamp["sec"] + stamp["nanosec"] / 1e9, tz=tz) + + +def to_human_readable(ts: EpochLike) -> str: + dt = to_datetime(ts) + return dt.strftime("%Y-%m-%d %H:%M:%S") + + +def to_datetime(ts: EpochLike, tz: timezone = None) -> datetime: + if isinstance(ts, datetime): + # if ts.tzinfo is None: + # ts = ts.astimezone(tz) + return ts + if isinstance(ts, (int, float)): + return datetime.fromtimestamp(ts, tz=tz) + if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: + return datetime.fromtimestamp(ts["sec"] + ts["nanosec"] / 1e9, tz=tz) + raise TypeError("unsupported timestamp type") + + +class Timestamped(ABC): + """Abstract class for an event with a timestamp.""" + + def __init__(self, timestamp: EpochLike): + self.ts = to_datetime(timestamp) + + +class TEvent(Timestamped, Generic[PAYLOAD]): + """Concrete class for an event with a timestamp and data.""" + + def __init__(self, timestamp: EpochLike, data: PAYLOAD): + super().__init__(timestamp) + self.data = data + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TEvent): + return NotImplemented + return self.ts == other.ts and self.data == other.data + + def __repr__(self) -> str: + return f"TEvent(ts={self.ts}, data={self.data})" + + +EVENT = TypeVar("EVENT", bound=Timestamped) # any object that is a subclass of Timestamped + + +class Timeseries(ABC, Generic[EVENT]): + """Abstract class for an iterable of events with timestamps.""" + + @abstractmethod + def __iter__(self) -> Iterable[EVENT]: ... + + @property + def start_time(self) -> datetime: + """Return the timestamp of the earliest event, assuming the data is sorted.""" + return next(iter(self)).ts + + @property + def end_time(self) -> datetime: + """Return the timestamp of the latest event, assuming the data is sorted.""" + return next(reversed(list(self))).ts + + @property + def frequency(self) -> float: + """Calculate the frequency of events in Hz.""" + return len(list(self)) / (self.duration().total_seconds() or 1) + + def time_range(self) -> Tuple[datetime, datetime]: + """Return (earliest_ts, latest_ts). Empty input ⇒ ValueError.""" + return self.start_time, self.end_time + + def duration(self) -> timedelta: + """Total time spanned by the iterable (Δ = last - first).""" + return self.end_time - self.start_time + + def closest_to(self, timestamp: EpochLike) -> EVENT: + """Return the event closest to the given timestamp. Assumes timeseries is sorted.""" + print("closest to", timestamp) + target = to_datetime(timestamp) + print("converted to", target) + target_ts = target.timestamp() + + closest = None + min_dist = float("inf") + + for event in self: + dist = abs(event.ts.timestamp() - target_ts) + if dist > min_dist: + break + + min_dist = dist + closest = event + + print(f"closest: {closest}") + return closest + + def __repr__(self) -> str: + """Return a string representation of the Timeseries.""" + return f"Timeseries(date={self.start_time.strftime('%Y-%m-%d')}, start={self.start_time.strftime('%H:%M:%S')}, end={self.end_time.strftime('%H:%M:%S')}, duration={self.duration()}, events={len(list(self))}, freq={self.frequency:.2f}Hz)" + + def __str__(self) -> str: + """Return a string representation of the Timeseries.""" + return self.__repr__() + + +class TList(list[EVENT], Timeseries[EVENT]): + """A test class that inherits from both list and Timeseries.""" + + def __repr__(self) -> str: + """Return a string representation of the TList using Timeseries repr method.""" + return Timeseries.__repr__(self) diff --git a/dimos/robot/unitree_webrtc/type/vector.py b/dimos/robot/unitree_webrtc/type/vector.py new file mode 100644 index 0000000000..368867dd4d --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/vector.py @@ -0,0 +1,584 @@ +import numpy as np +from typing import ( + Tuple, + List, + TypeVar, + Protocol, + runtime_checkable, + Any, + Iterable, + Union, +) +from numpy.typing import NDArray + +T = TypeVar("T", bound="Vector") + + +class Vector: + """A wrapper around numpy arrays for vector operations with intuitive syntax.""" + + def __init__(self, *args: Any) -> None: + """Initialize a vector from components or another iterable. + + Examples: + Vector(1, 2) # 2D vector + Vector(1, 2, 3) # 3D vector + Vector([1, 2, 3]) # From list + Vector(np.array([1, 2, 3])) # From numpy array + """ + if len(args) == 1 and hasattr(args[0], "__iter__"): + self._data = np.array(args[0], dtype=float) + elif len(args) == 1: + self._data = np.array([args[0].x, args[0].y, args[0].z], dtype=float) + + else: + self._data = np.array(args, dtype=float) + + @property + def yaw(self) -> float: + return self.x + + @property + def tuple(self) -> Tuple[float, ...]: + """Tuple representation of the vector.""" + return tuple(self._data) + + @property + def x(self) -> float: + """X component of the vector.""" + return self._data[0] if len(self._data) > 0 else 0.0 + + @property + def y(self) -> float: + """Y component of the vector.""" + return self._data[1] if len(self._data) > 1 else 0.0 + + @property + def z(self) -> float: + """Z component of the vector.""" + return self._data[2] if len(self._data) > 2 else 0.0 + + @property + def dim(self) -> int: + """Dimensionality of the vector.""" + return len(self._data) + + @property + def data(self) -> NDArray[np.float64]: + """Get the underlying numpy array.""" + return self._data + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, idx: int) -> float: + return float(self._data[idx]) + + def __iter__(self) -> Iterable[float]: + return iter(self._data) + + def __repr__(self) -> str: + components = ",".join(f"{x:.6g}" for x in self._data) + return f"({components})" + + def __str__(self) -> str: + if self.dim < 2: + return self.__repr__() + + def getArrow() -> str: + repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] + + if self.y == 0 and self.x == 0: + return "·" + + # Calculate angle in radians and convert to directional index + angle = np.arctan2(self.y, self.x) + # Map angle to 0-7 index (8 directions) with proper orientation + dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) + # Get directional arrow symbol + return repr[dir_index] + + return f"{getArrow()} Vector {self.__repr__()}" + + def serialize(self) -> dict: + """Serialize the vector to a dictionary.""" + return {"type": "vector", "c": self._data.tolist()} + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Vector): + return np.array_equal(self._data, other._data) + return np.array_equal(self._data, np.array(other, dtype=float)) + + def __add__(self: T, other: Union["Vector", Iterable[float]]) -> T: + if isinstance(other, Vector): + return self.__class__(self._data + other._data) + return self.__class__(self._data + np.array(other, dtype=float)) + + def __sub__(self: T, other: Union["Vector", Iterable[float]]) -> T: + if isinstance(other, Vector): + print(self, other) + return self.__class__(self._data - other._data) + return self.__class__(self._data - np.array(other, dtype=float)) + + def __mul__(self: T, scalar: float) -> T: + return self.__class__(self._data * scalar) + + def __rmul__(self: T, scalar: float) -> T: + return self.__mul__(scalar) + + def __truediv__(self: T, scalar: float) -> T: + return self.__class__(self._data / scalar) + + def __neg__(self: T) -> T: + return self.__class__(-self._data) + + def dot(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute dot product.""" + if isinstance(other, Vector): + return float(np.dot(self._data, other._data)) + return float(np.dot(self._data, np.array(other, dtype=float))) + + def cross(self: T, other: Union["Vector", Iterable[float]]) -> T: + """Compute cross product (3D vectors only).""" + if self.dim != 3: + raise ValueError("Cross product is only defined for 3D vectors") + + if isinstance(other, Vector): + other_data = other._data + else: + other_data = np.array(other, dtype=float) + + if len(other_data) != 3: + raise ValueError("Cross product requires two 3D vectors") + + return self.__class__(np.cross(self._data, other_data)) + + def length(self) -> float: + """Compute the Euclidean length (magnitude) of the vector.""" + return float(np.linalg.norm(self._data)) + + def length_squared(self) -> float: + """Compute the squared length of the vector (faster than length()).""" + return float(np.sum(self._data * self._data)) + + def normalize(self: T) -> T: + """Return a normalized unit vector in the same direction.""" + length = self.length() + if length < 1e-10: # Avoid division by near-zero + return self.__class__(np.zeros_like(self._data)) + return self.__class__(self._data / length) + + def to_2d(self: T) -> T: + """Convert a vector to a 2D vector by taking only the x and y components.""" + return self.__class__(self._data[:2]) + + def distance(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute Euclidean distance to another vector.""" + if isinstance(other, Vector): + return float(np.linalg.norm(self._data - other._data)) + return float(np.linalg.norm(self._data - np.array(other, dtype=float))) + + def distance_squared(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute squared Euclidean distance to another vector (faster than distance()).""" + if isinstance(other, Vector): + diff = self._data - other._data + else: + diff = self._data - np.array(other, dtype=float) + return float(np.sum(diff * diff)) + + def angle(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute the angle (in radians) between this vector and another.""" + if self.length() < 1e-10 or ( + isinstance(other, Vector) and other.length() < 1e-10 + ): + return 0.0 + + if isinstance(other, Vector): + other_data = other._data + else: + other_data = np.array(other, dtype=float) + + cos_angle = np.clip( + np.dot(self._data, other_data) + / (np.linalg.norm(self._data) * np.linalg.norm(other_data)), + -1.0, + 1.0, + ) + return float(np.arccos(cos_angle)) + + def project(self: T, onto: Union["Vector", Iterable[float]]) -> T: + """Project this vector onto another vector.""" + if isinstance(onto, Vector): + onto_data = onto._data + else: + onto_data = np.array(onto, dtype=float) + + onto_length_sq = np.sum(onto_data * onto_data) + if onto_length_sq < 1e-10: + return self.__class__(np.zeros_like(self._data)) + + scalar_projection = np.dot(self._data, onto_data) / onto_length_sq + return self.__class__(scalar_projection * onto_data) + + # this is here to test ros_observable_topic + # doesn't happen irl afaik that we want a vector from ros message + @classmethod + def from_msg(cls: type[T], msg: Any) -> T: + return cls(*msg) + + @classmethod + def zeros(cls: type[T], dim: int) -> T: + """Create a zero vector of given dimension.""" + return cls(np.zeros(dim)) + + @classmethod + def ones(cls: type[T], dim: int) -> T: + """Create a vector of ones with given dimension.""" + return cls(np.ones(dim)) + + @classmethod + def unit_x(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the x direction.""" + v = np.zeros(dim) + v[0] = 1.0 + return cls(v) + + @classmethod + def unit_y(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the y direction.""" + v = np.zeros(dim) + v[1] = 1.0 + return cls(v) + + @classmethod + def unit_z(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the z direction.""" + v = np.zeros(dim) + if dim > 2: + v[2] = 1.0 + return cls(v) + + def to_list(self) -> List[float]: + """Convert the vector to a list.""" + return [float(x) for x in self._data] + + def to_tuple(self) -> Tuple[float, ...]: + """Convert the vector to a tuple.""" + return tuple(self._data) + + def to_numpy(self) -> NDArray[np.float64]: + """Convert the vector to a numpy array.""" + return self._data + + +# Protocol approach for static type checking +@runtime_checkable +class VectorLike(Protocol): + """Protocol for types that can be treated as vectors.""" + + def __getitem__(self, key: int) -> float: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterable[float]: ... + + +def to_numpy(value: VectorLike) -> NDArray[np.float64]: + """Convert a vector-compatible value to a numpy array. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Numpy array representation + """ + if isinstance(value, Vector): + return value.data + elif isinstance(value, np.ndarray): + return value + else: + return np.array(value, dtype=float) + + +def to_vector(value: VectorLike) -> Vector: + """Convert a vector-compatible value to a Vector object. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Vector object + """ + if isinstance(value, Vector): + return value + else: + return Vector(value) + + +def to_tuple(value: VectorLike) -> Tuple[float, ...]: + """Convert a vector-compatible value to a tuple. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Tuple of floats + """ + if isinstance(value, Vector): + return tuple(float(x) for x in value.data) + elif isinstance(value, np.ndarray): + return tuple(float(x) for x in value) + elif isinstance(value, tuple): + return tuple(float(x) for x in value) + else: + # Convert to list first to ensure we have an indexable sequence + data = [value[i] for i in range(len(value))] + return tuple(float(x) for x in data) + + +def to_list(value: VectorLike) -> List[float]: + """Convert a vector-compatible value to a list. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + List of floats + """ + if isinstance(value, Vector): + return [float(x) for x in value.data] + elif isinstance(value, np.ndarray): + return [float(x) for x in value] + elif isinstance(value, list): + return [float(x) for x in value] + else: + # Convert to list using indexing + return [float(value[i]) for i in range(len(value))] + + +# Helper functions to check dimensionality +def is_2d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 2D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 2D + """ + if isinstance(value, Vector): + return len(value) == 2 + elif isinstance(value, np.ndarray): + return value.shape[-1] == 2 or value.size == 2 + else: + return len(value) == 2 + + +def is_3d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 3D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 3D + """ + if isinstance(value, Vector): + return len(value) == 3 + elif isinstance(value, np.ndarray): + return value.shape[-1] == 3 or value.size == 3 + else: + return len(value) == 3 + + +# Extraction functions for XYZ components +def x(value: VectorLike) -> float: + """Get the X component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + X component as a float + """ + if isinstance(value, Vector): + return value.x + else: + return float(to_numpy(value)[0]) + + +def y(value: VectorLike) -> float: + """Get the Y component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Y component as a float + """ + if isinstance(value, Vector): + return value.y + else: + arr = to_numpy(value) + return float(arr[1]) if len(arr) > 1 else 0.0 + + +def z(value: VectorLike) -> float: + """Get the Z component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Z component as a float + """ + if isinstance(value, Vector): + return value.z + else: + arr = to_numpy(value) + return float(arr[2]) if len(arr) > 2 else 0.0 + + +if __name__ == "__main__": + # Test vectors in various directions + test_vectors = [ + Vector(1, 0), # Right + Vector(1, 1), # Up-Right + Vector(0, 1), # Up + Vector(-1, 1), # Up-Left + Vector(-1, 0), # Left + Vector(-1, -1), # Down-Left + Vector(0, -1), # Down + Vector(1, -1), # Down-Right + Vector(0.5, 0.5), # Up-Right (shorter) + Vector(-3, 4), # Up-Left (longer) + ] + + for v in test_vectors: + print(str(v)) + + # Test the vector compatibility functions + print("Testing vectortypes.py conversion functions\n") + + # Create test vectors in different formats + vector_obj = Vector(1.0, 2.0, 3.0) + numpy_arr = np.array([4.0, 5.0, 6.0]) + tuple_vec = (7.0, 8.0, 9.0) + list_vec = [10.0, 11.0, 12.0] + + print("Original values:") + print(f"Vector: {vector_obj}") + print(f"NumPy: {numpy_arr}") + print(f"Tuple: {tuple_vec}") + print(f"List: {list_vec}") + print() + + # Test to_numpy + print("to_numpy() conversions:") + print(f"Vector → NumPy: {to_numpy(vector_obj)}") + print(f"NumPy → NumPy: {to_numpy(numpy_arr)}") + print(f"Tuple → NumPy: {to_numpy(tuple_vec)}") + print(f"List → NumPy: {to_numpy(list_vec)}") + print() + + # Test to_vector + print("to_vector() conversions:") + print(f"Vector → Vector: {to_vector(vector_obj)}") + print(f"NumPy → Vector: {to_vector(numpy_arr)}") + print(f"Tuple → Vector: {to_vector(tuple_vec)}") + print(f"List → Vector: {to_vector(list_vec)}") + print() + + # Test to_tuple + print("to_tuple() conversions:") + print(f"Vector → Tuple: {to_tuple(vector_obj)}") + print(f"NumPy → Tuple: {to_tuple(numpy_arr)}") + print(f"Tuple → Tuple: {to_tuple(tuple_vec)}") + print(f"List → Tuple: {to_tuple(list_vec)}") + print() + + # Test to_list + print("to_list() conversions:") + print(f"Vector → List: {to_list(vector_obj)}") + print(f"NumPy → List: {to_list(numpy_arr)}") + print(f"Tuple → List: {to_list(tuple_vec)}") + print(f"List → List: {to_list(list_vec)}") + print() + + # Test component extraction + print("Component extraction:") + print("x() function:") + print(f"x(Vector): {x(vector_obj)}") + print(f"x(NumPy): {x(numpy_arr)}") + print(f"x(Tuple): {x(tuple_vec)}") + print(f"x(List): {x(list_vec)}") + print() + + print("y() function:") + print(f"y(Vector): {y(vector_obj)}") + print(f"y(NumPy): {y(numpy_arr)}") + print(f"y(Tuple): {y(tuple_vec)}") + print(f"y(List): {y(list_vec)}") + print() + + print("z() function:") + print(f"z(Vector): {z(vector_obj)}") + print(f"z(NumPy): {z(numpy_arr)}") + print(f"z(Tuple): {z(tuple_vec)}") + print(f"z(List): {z(list_vec)}") + print() + + # Test dimension checking + print("Dimension checking:") + vec2d = Vector(1.0, 2.0) + vec3d = Vector(1.0, 2.0, 3.0) + arr2d = np.array([1.0, 2.0]) + arr3d = np.array([1.0, 2.0, 3.0]) + + print(f"is_2d(Vector(1,2)): {is_2d(vec2d)}") + print(f"is_2d(Vector(1,2,3)): {is_2d(vec3d)}") + print(f"is_2d(np.array([1,2])): {is_2d(arr2d)}") + print(f"is_2d(np.array([1,2,3])): {is_2d(arr3d)}") + print(f"is_2d((1,2)): {is_2d((1.0, 2.0))}") + print(f"is_2d((1,2,3)): {is_2d((1.0, 2.0, 3.0))}") + print() + + print(f"is_3d(Vector(1,2)): {is_3d(vec2d)}") + print(f"is_3d(Vector(1,2,3)): {is_3d(vec3d)}") + print(f"is_3d(np.array([1,2])): {is_3d(arr2d)}") + print(f"is_3d(np.array([1,2,3])): {is_3d(arr3d)}") + print(f"is_3d((1,2)): {is_3d((1.0, 2.0))}") + print(f"is_3d((1,2,3)): {is_3d((1.0, 2.0, 3.0))}") + print() + + # Test the Protocol interface + print("Testing VectorLike Protocol:") + print(f"isinstance(Vector(1,2), VectorLike): {isinstance(vec2d, VectorLike)}") + print(f"isinstance(np.array([1,2]), VectorLike): {isinstance(arr2d, VectorLike)}") + print( + f"isinstance((1,2), VectorLike): {isinstance((1.0, 2.0), VectorLike)}" + ) + print( + f"isinstance([1,2], VectorLike): {isinstance([1.0, 2.0], VectorLike)}" + ) + print() + + # Test mixed operations using different vector types + # These functions aren't defined in vectortypes, but demonstrate the concept + def distance(a: VectorLike, b: VectorLike) -> float: + a_np = to_numpy(a) + b_np = to_numpy(b) + diff = a_np - b_np + return float(np.sqrt(np.sum(diff * diff))) + + def midpoint(a: VectorLike, b: VectorLike) -> NDArray[np.float64]: + a_np = to_numpy(a) + b_np = to_numpy(b) + return (a_np + b_np) / 2 + + print("Mixed operations between different vector types:") + print( + f"distance(Vector(1,2,3), [4,5,6]): {distance(vec3d, [4.0, 5.0, 6.0])}" + ) + print( + f"distance(np.array([1,2,3]), (4,5,6)): {distance(arr3d, (4.0, 5.0, 6.0))}" + ) + print(f"midpoint(Vector(1,2,3), np.array([4,5,6])): {midpoint(vec3d, numpy_arr)}") diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py new file mode 100644 index 0000000000..f6e4c1b47c --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass +from dimos.types.path import Path +from dimos.types.vector import Vector +from typing import Union, Optional +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.connection import WebRTCRobot +from dimos.robot.global_planner.planner import AstarPlanner +from dimos.utils.reactive import backpressure +from dimos.utils.reactive import getter_streaming +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary +import os +from go2_webrtc_driver.constants import VUI_COLOR +from dimos.robot.local_planner import VFHPurePursuitPlanner, navigate_path_local + + +class Color(VUI_COLOR): ... + + +class UnitreeGo2(WebRTCRobot): + def __init__( + self, + ip: str, + mode: str = "ai", + skills: Optional[Union[MyUnitreeSkills, AbstractSkill]] = None, + skill_library: SkillLibrary = None, + output_dir: str = os.path.join(os.getcwd(), "assets", "output"), + ): + super().__init__(ip=ip, mode=mode) + + self.odom = getter_streaming(self.odom_stream()) + self.map = Map() + self.map_stream = self.map.consume(self.lidar_stream()) + + self.global_planner = AstarPlanner( + set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( + self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event + ), + get_costmap=lambda: self.map.costmap, + get_robot_pos=lambda: self.odom().pos, + ) + + # # Initialize skills + # if skills is None: + # skills = MyUnitreeSkills(robot=self) + + # self.skill_library = skills if skills else SkillLibrary() + + # if self.skill_library is not None: + # for skill in self.skill_library: + # if isinstance(skill, AbstractRobotSkill): + # self.skill_library.create_instance(skill.__name__, robot=self) + # if isinstance(self.skill_library, MyUnitreeSkills): + # self.skill_library._robot = self + # self.skill_library.init() + # self.skill_library.initialize_skills() + + # # Camera stuff + # self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] + # self.camera_pitch = np.deg2rad(0) # negative for downward pitch + # self.camera_height = 0.44 # meters + + # os.makedirs(self.output_dir, exist_ok=True) + + # # Initialize visual servoing if enabled + # if self.get_video_stream() is not None: + # self.person_tracker = PersonTrackingStream( + # camera_intrinsics=self.camera_intrinsics, + # camera_pitch=self.camera_pitch, + # camera_height=self.camera_height, + # ) + # self.object_tracker = ObjectTrackingStream( + # camera_intrinsics=self.camera_intrinsics, + # camera_pitch=self.camera_pitch, + # camera_height=self.camera_height, + # ) + # person_tracking_stream = self.person_tracker.create_stream(self.get_video_stream()) + # object_tracking_stream = self.object_tracker.create_stream(self.get_video_stream()) + + # self.person_tracking_stream = person_tracking_stream + # self.object_tracking_stream = object_tracking_stream + + # Initialize the local planner and create BEV visualization stream + # self.local_planner = VFHPurePursuitPlanner( + # robot=self, + # robot_width=0.36, # Unitree Go2 width in meters + # robot_length=0.6, # Unitree Go2 length in meters + # max_linear_vel=0.5, + # lookahead_distance=0.6, + # visualization_size=500, # 500x500 pixel visualization + # ) + + # Create the visualization stream at 5Hz + # self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) + + def move(self, vector: Vector): + super().move(vector) + + def get_skills(self) -> Optional[SkillLibrary]: + return self.skill_library + + @property + def costmap(self): + return self.map.costmap diff --git a/dimos/types/position.py b/dimos/types/position.py new file mode 100644 index 0000000000..d32820b92c --- /dev/null +++ b/dimos/types/position.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from dimos.types.vector import Vector + + +@dataclass +class Position: + pos: Vector + rot: Vector + + def __repr__(self) -> str: + return f"pos({self.pos}), rot({self.rot})" + + def __str__(self) -> str: + return self.__repr__() diff --git a/dimos/utils/logging_config.py b/dimos/utils/logging_config.py index fcc1287536..66877658ef 100644 --- a/dimos/utils/logging_config.py +++ b/dimos/utils/logging_config.py @@ -22,12 +22,10 @@ import colorlog from typing import Optional +logging.basicConfig(format="%(name)s - %(levelname)s - %(message)s") -def setup_logger( - name: str, - level: Optional[int] = None, - log_format: Optional[str] = None -) -> logging.Logger: + +def setup_logger(name: str, level: Optional[int] = None, log_format: Optional[str] = None) -> logging.Logger: """Set up a logger with color output. Args: @@ -41,26 +39,26 @@ def setup_logger( """ if level is None: # Get level from environment variable or default to INFO - level_name = os.getenv('DIMOS_LOG_LEVEL', 'INFO') + level_name = os.getenv("DIMOS_LOG_LEVEL", "INFO") level = getattr(logging, level_name) - + if log_format is None: log_format = "%(log_color)s%(asctime)s - %(name)s - %(levelname)s - %(message)s" try: # Get or create logger logger = logging.getLogger(name) - + # Remove any existing handlers to avoid duplicates if logger.hasHandlers(): logger.handlers.clear() - + # Set logger level first logger.setLevel(level) - + # Ensure we're not blocked by parent loggers logger.propagate = False - + # Create and configure handler handler = colorlog.StreamHandler() handler.setLevel(level) # Explicitly set handler level diff --git a/dimos/utils/reactive.py b/dimos/utils/reactive.py index 3a06311f6d..0a609dd23e 100644 --- a/dimos/utils/reactive.py +++ b/dimos/utils/reactive.py @@ -1,15 +1,17 @@ import threading -from typing import Optional, TypeVar, Generic +from typing import Optional, TypeVar, Generic, Any, Callable import reactivex as rx from reactivex import operators as ops from reactivex.scheduler import ThreadPoolScheduler +from reactivex.disposable import Disposable from reactivex.observable import Observable from rxpy_backpressure import BackPressure from dimos.utils.threadpool import get_scheduler -T = TypeVar('T') +T = TypeVar("T") + # Observable ─► ReplaySubject─► observe_on(pool) ─► backpressure.latest ─► sub1 (fast) # ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) @@ -64,22 +66,19 @@ def dispose(self) -> None: if self._connection: self._connection.dispose() -def getter_ondemand( - observable: Observable[T], - timeout: Optional[float] = 30.0 -) -> T: + +def getter_ondemand(observable: Observable[T], timeout: Optional[float] = 30.0) -> T: def getter(): try: # Wait for first value with optional timeout - value = observable.pipe( - ops.first(), - *([ops.timeout(timeout)] if timeout is not None else []) - ).run() + value = observable.pipe(ops.first(), *([ops.timeout(timeout)] if timeout is not None else [])).run() return value except Exception as e: raise Exception(f"No value received after {timeout} seconds") from e + return getter + T = TypeVar("T") @@ -91,7 +90,7 @@ def getter_streaming( ) -> LatestReader[T]: shared = source.pipe( ops.replay(buffer_size=1), - ops.ref_count(), # auto-connect & auto-disconnect + ops.ref_count(), # auto-connect & auto-disconnect ) _val_lock = threading.Lock() @@ -112,19 +111,37 @@ def _update(v: T) -> None: sub.dispose() raise TimeoutError(f"No value received after {timeout} s") else: - _ready.wait() # wait indefinitely if timeout is None + _ready.wait() # wait indefinitely if timeout is None def reader() -> T: - if not _ready.is_set(): # first call in non-blocking mode + if not _ready.is_set(): # first call in non-blocking mode if timeout is not None and not _ready.wait(timeout): raise TimeoutError(f"No value received after {timeout} s") else: _ready.wait() with _val_lock: - return _val # type: ignore[return-value] + return _val # type: ignore[return-value] def _dispose() -> None: sub.dispose() - reader.dispose = _dispose # type: ignore[attr-defined] + reader.dispose = _dispose # type: ignore[attr-defined] return reader + + +T = TypeVar("T") +CB = Callable[[T], Any] + + +def callback_to_observable( + start: Callable[[CB[T]], Any], + stop: Callable[[CB[T]], Any], +) -> Observable[T]: + def _subscribe(observer, _scheduler=None): + def _on_msg(value: T): + observer.on_next(value) + + start(_on_msg) + return Disposable(lambda: stop(_on_msg)) + + return rx.create(_subscribe) diff --git a/dimos/utils/test_reactive.py b/dimos/utils/test_reactive.py index dc9af9e448..977863826a 100644 --- a/dimos/utils/test_reactive.py +++ b/dimos/utils/test_reactive.py @@ -4,8 +4,8 @@ import reactivex as rx from reactivex import operators as ops from typing import Callable, TypeVar, Any -from dimos.utils.reactive import backpressure, getter_streaming, getter_ondemand from reactivex.disposable import Disposable +from dimos.utils.reactive import backpressure, getter_streaming, getter_ondemand, callback_to_observable def measure_time(func: Callable[[], Any], iterations: int = 1) -> float: @@ -15,18 +15,23 @@ def measure_time(func: Callable[[], Any], iterations: int = 1) -> float: total_time = end_time - start_time return result, total_time + def assert_time(func: Callable[[], Any], assertion: Callable[[int], bool], assert_fail_msg=None) -> None: - [result, total_time ] = measure_time(func) + [result, total_time] = measure_time(func) assert assertion(total_time), assert_fail_msg + f", took {round(total_time, 2)}s" return result + def min_time(func: Callable[[], Any], min_t: int, assert_fail_msg="Function returned too fast"): return assert_time(func, (lambda t: t > min_t), assert_fail_msg + f", min: {min_t} seconds") + def max_time(func: Callable[[], Any], max_t: int, assert_fail_msg="Function took too long"): return assert_time(func, (lambda t: t < max_t), assert_fail_msg + f", max: {max_t} seconds") -T = TypeVar('T') + +T = TypeVar("T") + def dispose_spy(source: rx.Observable[T]) -> rx.Observable[T]: state = {"active": 0} @@ -34,9 +39,11 @@ def dispose_spy(source: rx.Observable[T]) -> rx.Observable[T]: def factory(observer, scheduler=None): state["active"] += 1 upstream = source.subscribe(observer, scheduler=scheduler) + def _dispose(): upstream.dispose() state["active"] -= 1 + return Disposable(_dispose) proxy = rx.create(factory) @@ -45,16 +52,11 @@ def _dispose(): return proxy - - def test_backpressure_handling(): received_fast = [] received_slow = [] # Create an observable that emits numpy arrays instead of integers - source = dispose_spy(rx.interval(0.1).pipe( - ops.map(lambda i: np.array([i, i+1, i+2])), - ops.take(50) - )) + source = dispose_spy(rx.interval(0.1).pipe(ops.map(lambda i: np.array([i, i + 1, i + 2])), ops.take(50))) # Wrap with backpressure handling safe_source = backpressure(source) @@ -64,9 +66,9 @@ def test_backpressure_handling(): # Slow sub (shouldn't block above) subscription2 = safe_source.subscribe(lambda x: (time.sleep(0.25), received_slow.append(x))) - + time.sleep(2.5) - + subscription1.dispose() assert not source.is_disposed(), "Observable should not be disposed yet" subscription2.dispose() @@ -76,30 +78,27 @@ def test_backpressure_handling(): # Check results print("Fast observer received:", len(received_fast), [arr[0] for arr in received_fast]) print("Slow observer received:", len(received_slow), [arr[0] for arr in received_slow]) - + # Fast observer should get all or nearly all items assert len(received_fast) > 15, f"Expected fast observer to receive most items, got {len(received_fast)}" - + # Slow observer should get fewer items due to backpressure handling assert len(received_slow) < len(received_fast), "Slow observer should receive fewer items than fast observer" # Specifically, processing at 0.25s means ~4 items per second, so expect 8-10 items assert 7 <= len(received_slow) <= 11, f"Expected 7-11 items, got {len(received_slow)}" - + # The slow observer should skip items (not process them in sequence) # We test this by checking that the difference between consecutive arrays is sometimes > 1 has_skips = False for i in range(1, len(received_slow)): - if received_slow[i][0] - received_slow[i-1][0] > 1: + if received_slow[i][0] - received_slow[i - 1][0] > 1: has_skips = True break assert has_skips, "Slow observer should skip items due to backpressure" def test_getter_streaming_blocking(): - source = dispose_spy(rx.interval(0.2).pipe( - ops.map(lambda i: np.array([i, i+1, i+2])), - ops.take(50) - )) + source = dispose_spy(rx.interval(0.2).pipe(ops.map(lambda i: np.array([i, i + 1, i + 2])), ops.take(50))) assert source.is_disposed() getter = min_time(lambda: getter_streaming(source), 0.2, "Latest getter needs to block until first msg is ready") @@ -113,6 +112,7 @@ def test_getter_streaming_blocking(): getter.dispose() assert source.is_disposed(), "Observable should be disposed" + def test_getter_streaming_blocking_timeout(): source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) with pytest.raises(Exception): @@ -120,10 +120,13 @@ def test_getter_streaming_blocking_timeout(): getter.dispose() assert source.is_disposed() + def test_getter_streaming_nonblocking(): source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) - getter = max_time(lambda: getter_streaming(source, nonblocking=True), 0.1, "nonblocking getter init shouldn't block") + getter = max_time( + lambda: getter_streaming(source, nonblocking=True), 0.1, "nonblocking getter init shouldn't block" + ) min_time(getter, 0.2, "Expected for first value call to block if cache is empty") assert getter() == 0 @@ -136,10 +139,10 @@ def test_getter_streaming_nonblocking(): time.sleep(0.5) assert getter() >= 4, f"Expected value >= 4, got {getter()}" - getter.dispose() assert source.is_disposed(), "Observable should be disposed" + def test_getter_streaming_nonblocking_timeout(): source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) getter = getter_streaming(source, timeout=0.1, nonblocking=True) @@ -148,6 +151,7 @@ def test_getter_streaming_nonblocking_timeout(): assert not source.is_disposed(), "is not disposed, this is a job of the caller" + def test_getter_ondemand(): source = dispose_spy(rx.interval(0.1).pipe(ops.take(50))) getter = getter_ondemand(source) @@ -157,9 +161,50 @@ def test_getter_ondemand(): assert getter() == 0, f"Expected to get the first value of 0, got {getter()}" assert source.is_disposed(), "Observable should be disposed" + def test_getter_ondemand_timeout(): source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) getter = getter_ondemand(source, timeout=0.1) with pytest.raises(Exception): getter() assert source.is_disposed(), "Observable should be disposed" + + +def test_callback_to_observable(): + # Test converting a callback-based API to an Observable + received = [] + callback = None + + # Mock start function that captures the callback + def start_fn(cb): + nonlocal callback + callback = cb + return "start_result" + + # Mock stop function + stop_called = False + + def stop_fn(cb): + nonlocal stop_called + stop_called = True + + # Create observable from callback + observable = callback_to_observable(start_fn, stop_fn) + + # Subscribe to the observable + subscription = observable.subscribe(lambda x: received.append(x)) + + # Verify start was called and we have access to the callback + assert callback is not None + + # Simulate callback being triggered with different messages + callback("message1") + callback(42) + callback({"key": "value"}) + + # Check that all messages were received + assert received == ["message1", 42, {"key": "value"}] + + # Dispose subscription and check that stop was called + subscription.dispose() + assert stop_called, "Stop function should be called on dispose" diff --git a/dimos/web/websocket_vis/clientside/decoder.ts b/dimos/web/websocket_vis/clientside/decoder.ts index 6eb61d6087..ff4439d799 100644 --- a/dimos/web/websocket_vis/clientside/decoder.ts +++ b/dimos/web/websocket_vis/clientside/decoder.ts @@ -1,19 +1,19 @@ -import { Costmap, EncodedSomething, Grid, Path, Vector } from "./types.ts" +import { Costmap, EncodedSomething, Grid, Path, Vector } from "./types.ts"; export function decode(data: EncodedSomething) { - console.log("decoding", data) + // console.log("decoding", data) if (data.type == "costmap") { - return Costmap.decode(data) + return Costmap.decode(data); } if (data.type == "vector") { - return Vector.decode(data) + return Vector.decode(data); } if (data.type == "grid") { - return Grid.decode(data) + return Grid.decode(data); } if (data.type == "path") { - return Path.decode(data) + return Path.decode(data); } - return "UNKNOWN" + return "UNKNOWN"; } diff --git a/dimos/web/websocket_vis/clientside/init.ts b/dimos/web/websocket_vis/clientside/init.ts index baeb3ae71b..4367d89819 100644 --- a/dimos/web/websocket_vis/clientside/init.ts +++ b/dimos/web/websocket_vis/clientside/init.ts @@ -1,7 +1,7 @@ -import { io } from "npm:socket.io-client" -import { decode } from "./decoder.ts" -import { Drawable, EncodedSomething } from "./types.ts" -import { Visualizer as ReactVisualizer } from "./vis2.tsx" +import { io } from "npm:socket.io-client"; +import { decode } from "./decoder.ts"; +import { Drawable, EncodedSomething } from "./types.ts"; +import { Visualizer as ReactVisualizer } from "./vis2.tsx"; // Store server state locally let serverState = { @@ -9,25 +9,25 @@ let serverState = { connected_clients: 0, data: {}, draw: {}, -} +}; -let reactVisualizer: ReactVisualizer | null = null +let reactVisualizer: ReactVisualizer | null = null; -const socket = io() +const socket = io(); socket.on("connect", () => { - console.log("Connected to server") - serverState.status = "connected" -}) + console.log("Connected to server"); + serverState.status = "connected"; +}); socket.on("disconnect", () => { - console.log("Disconnected from server") - serverState.status = "disconnected" -}) + console.log("Disconnected from server"); + serverState.status = "disconnected"; +}); socket.on("message", (data) => { - console.log("Received message:", data) -}) + //console.log("Received message:", data) +}); // Deep merge function for client-side state updates function deepMerge(source: any, destination: any): any { @@ -42,83 +42,83 @@ function deepMerge(source: any, destination: any): any { !Array.isArray(source[key]) && !Array.isArray(destination[key]) ) { - deepMerge(source[key], destination[key]) + deepMerge(source[key], destination[key]); } else { // Otherwise, just copy the value - destination[key] = source[key] + destination[key] = source[key]; } } - return destination + return destination; } -type DrawConfig = { [key: string]: any } +type DrawConfig = { [key: string]: any }; -type EncodedDrawable = EncodedSomething +type EncodedDrawable = EncodedSomething; type EncodedDrawables = { - [key: string]: EncodedDrawable -} + [key: string]: EncodedDrawable; +}; type Drawables = { - [key: string]: Drawable -} + [key: string]: Drawable; +}; function decodeDrawables(encoded: EncodedDrawables): Drawables { - const drawables: Drawables = {} + const drawables: Drawables = {}; for (const [key, value] of Object.entries(encoded)) { // @ts-ignore - drawables[key] = decode(value) + drawables[key] = decode(value); } - return drawables + return drawables; } function state_update(state: { [key: string]: any }) { - console.log("Received state update:", state) + //console.log("Received state update:", state) // Use deep merge to update nested properties if (state.draw) { - state.draw = decodeDrawables(state.draw) + state.draw = decodeDrawables(state.draw); } - console.log("Decoded state update:", state) + // console.log("Decoded state update:", state); // Create a fresh copy of the server state to trigger rerenders properly - serverState = { ...deepMerge(state, { ...serverState }) } + serverState = { ...deepMerge(state, { ...serverState }) }; - updateUI() + updateUI(); } -socket.on("state_update", state_update) -socket.on("full_state", state_update) +socket.on("state_update", state_update); +socket.on("full_state", state_update); // Function to send data to server function emitMessage(data: any) { - socket.emit("message", data) + socket.emit("message", data); } // Function to update UI based on state function updateUI() { - console.log("Current state:", serverState) + // console.log("Current state:", serverState); // Update both visualizers if they exist and there's data to display if (serverState.draw && Object.keys(serverState.draw).length > 0) { if (reactVisualizer) { - reactVisualizer.visualizeState(serverState.draw) + reactVisualizer.visualizeState(serverState.draw); } } } // Initialize the application function initializeApp() { - console.log("DOM loaded, initializing UI") - reactVisualizer = new ReactVisualizer("#vis") + console.log("DOM loaded, initializing UI"); + reactVisualizer = new ReactVisualizer("#vis"); // Set up click handler to convert clicks to world coordinates and send to server reactVisualizer.onWorldClick((worldX, worldY) => { - emitMessage({ type: "click", position: [worldX, worldY] }) - }) + emitMessage({ type: "click", position: [worldX, worldY] }); + }); - updateUI() + updateUI(); } -console.log("Socket.IO client initialized") +console.log("Socket.IO client initialized"); // Call initialization once when the DOM is loaded -document.addEventListener("DOMContentLoaded", initializeApp) +document.addEventListener("DOMContentLoaded", initializeApp); diff --git a/dimos/web/websocket_vis/clientside/vis2.tsx b/dimos/web/websocket_vis/clientside/vis2.tsx index 82a31ceaa4..e0052a210d 100644 --- a/dimos/web/websocket_vis/clientside/vis2.tsx +++ b/dimos/web/websocket_vis/clientside/vis2.tsx @@ -1,7 +1,7 @@ -import * as d3 from "npm:d3" -import * as React from "npm:react" -import * as ReactDOMClient from "npm:react-dom/client" -import { Costmap, Drawable, Path, Vector } from "./types.ts" +import * as d3 from "npm:d3"; +import * as React from "npm:react"; +import * as ReactDOMClient from "npm:react-dom/client"; +import { Costmap, Drawable, Path, Vector } from "./types.ts"; // ─────────────────────────────────────────────────────────────────────────────── // React component @@ -9,70 +9,70 @@ import { Costmap, Drawable, Path, Vector } from "./types.ts" const VisualizerComponent: React.FC<{ state: Record }> = ({ state, }) => { - const svgRef = React.useRef(null) + const svgRef = React.useRef(null); const [dimensions, setDimensions] = React.useState({ width: 800, height: 600, - }) - const { width, height } = dimensions + }); + const { width, height } = dimensions; // Update dimensions when container size changes React.useEffect(() => { - if (!svgRef.current) return + if (!svgRef.current) return; const updateDimensions = () => { - const rect = svgRef.current?.parentElement?.getBoundingClientRect() + const rect = svgRef.current?.parentElement?.getBoundingClientRect(); if (rect) { - setDimensions({ width: rect.width, height: rect.height }) + setDimensions({ width: rect.width, height: rect.height }); } - } + }; // Initial update - updateDimensions() + updateDimensions(); // Create resize observer - const observer = new ResizeObserver(updateDimensions) - observer.observe(svgRef.current.parentElement as Element) + const observer = new ResizeObserver(updateDimensions); + observer.observe(svgRef.current.parentElement as Element); - return () => observer.disconnect() - }, []) + return () => observer.disconnect(); + }, []); /** Build a world→pixel transformer from the *first* cost‑map we see. */ const { worldToPx, pxToWorld } = React.useMemo(() => { const ref = Object.values(state).find( (d): d is Costmap => d instanceof Costmap, - ) - if (!ref) return { worldToPx: undefined, pxToWorld: undefined } + ); + if (!ref) return { worldToPx: undefined, pxToWorld: undefined }; const { grid: { shape }, origin, resolution, - } = ref - const [rows, cols] = shape + } = ref; + const [rows, cols] = shape; // Same sizing/centering logic used in visualiseCostmap - const cell = Math.min(width / cols, height / rows) - const gridW = cols * cell - const gridH = rows * cell - const offsetX = (width - gridW) / 2 - const offsetY = (height - gridH) / 2 + const cell = Math.min(width / cols, height / rows); + const gridW = cols * cell; + const gridH = rows * cell; + const offsetX = (width - gridW) / 2; + const offsetY = (height - gridH) / 2; const xScale = d3 .scaleLinear() .domain([origin.coords[0], origin.coords[0] + cols * resolution]) - .range([offsetX, offsetX + gridW]) + .range([offsetX, offsetX + gridW]); const yScale = d3 .scaleLinear() .domain([origin.coords[1], origin.coords[1] + rows * resolution]) - .range([offsetY + gridH, offsetY]) // invert y (world ↑ => svg ↑) + .range([offsetY + gridH, offsetY]); // invert y (world ↑ => svg ↑) // World coordinates to pixel coordinates const worldToPxFn = ( x: number, y: number, - ): [number, number] => [xScale(x), yScale(y)] + ): [number, number] => [xScale(x), yScale(y)]; // Pixel coordinates to world coordinates (inverse transform) const pxToWorldFn = ( @@ -81,76 +81,42 @@ const VisualizerComponent: React.FC<{ state: Record }> = ({ ): [number, number] => [ xScale.invert(x), yScale.invert(y), - ] + ]; return { worldToPx: worldToPxFn, pxToWorld: pxToWorldFn, - } - }, [state]) - - // ── main draw effect ──────────────────────────────────────────────────────── - const handleClick = React.useCallback((event: MouseEvent) => { - if (!svgRef.current || !pxToWorld) return - - // Get SVG element position and dimensions - const svgRect = svgRef.current.getBoundingClientRect() - - // Calculate click position relative to SVG viewport - const viewportX = event.clientX - svgRect.left - const viewportY = event.clientY - svgRect.top - - // Convert to SVG coordinate space (accounting for viewBox) - const svgPoint = new DOMPoint(viewportX, viewportY) - const transformedPoint = svgPoint.matrixTransform( - svgRef.current.getScreenCTM()?.inverse(), - ) - - // Convert to world coordinates - const [worldX, worldY] = pxToWorld( - transformedPoint.x, - transformedPoint.y, - ) + }; + }, [state]); - console.log( - "Click at world coordinates:", - worldX.toFixed(2), - worldY.toFixed(2), - ) - }, [pxToWorld]) + // Removed component-level click handler as we're using the global one in Visualizer class React.useEffect(() => { - if (!svgRef.current) return - const svg = d3.select(svgRef.current) - svg.selectAll("*").remove() + if (!svgRef.current) return; + const svg = d3.select(svgRef.current); + svg.selectAll("*").remove(); // 1. maps (bottom layer) Object.values(state).forEach((d) => { - if (d instanceof Costmap) visualiseCostmap(svg, d, width, height) - }) + if (d instanceof Costmap) visualiseCostmap(svg, d, width, height); + }); // 2. paths (middle layer) Object.entries(state).forEach(([key, d]) => { if (d instanceof Path) { - visualisePath(svg, d, key, worldToPx, width, height) + visualisePath(svg, d, key, worldToPx, width, height); } - }) + }); // 3. vectors (top layer) Object.entries(state).forEach(([key, d]) => { if (d instanceof Vector) { - visualiseVector(svg, d, key, worldToPx, width, height) + visualiseVector(svg, d, key, worldToPx, width, height); } - }) - - // Add click handler - const svgElement = svgRef.current - svgElement.addEventListener("click", handleClick) - - return () => { - svgElement.removeEventListener("click", handleClick) - } - }, [state, worldToPx, handleClick]) + }); + + // Removed click handler as we're using the global one in Visualizer class + }, [state, worldToPx]); return (
}> = ({ backgroundColor: "black", borderRadius: "8px", boxShadow: "0 4px 12px rgba(0, 0, 0, 0.15)", + pointerEvents: "none", // Make SVG transparent to pointer events }} />
- ) -} + ); +}; // ─────────────────────────────────────────────────────────────────────────────── // Helper: costmap @@ -182,53 +149,53 @@ function visualiseCostmap( width: number, height: number, ): void { - const { grid, origin, resolution } = costmap - const [rows, cols] = grid.shape + const { grid, origin, resolution } = costmap; + const [rows, cols] = grid.shape; - const cell = Math.min(width / cols, height / rows) - const gridW = cols * cell - const gridH = rows * cell + const cell = Math.min(width / cols, height / rows); + const gridW = cols * cell; + const gridH = rows * cell; const group = svg .append("g") .attr( "transform", `translate(${(width - gridW) / 2}, ${(height - gridH) / 2})`, - ) + ); // Custom color interpolation function that maps 0 to white and other values to Inferno scale const customColorScale = (t: number) => { // If value is 0 (or very close to it), return dark bg color // bluest #2d2136 - if (t == 0) return "white" - if (t < 0) return "#2d2136" - if (t > 0.95) return "#000000" - - const color = d3.interpolateTurbo((t * 2) - 1) - const hsl = d3.hsl(color) - hsl.s *= 0.75 - return hsl.toString() - } + if (t == 0) return "white"; + if (t < 0) return "#2d2136"; + if (t > 0.95) return "#000000"; + + const color = d3.interpolateTurbo((t * 2) - 1); + const hsl = d3.hsl(color); + hsl.s *= 0.75; + return hsl.toString(); + }; const colour = d3.scaleSequential(customColorScale).domain([ -1, 100, - ]) + ]); const fo = group.append("foreignObject").attr("width", gridW).attr( "height", gridH, - ) + ); - const canvas = document.createElement("canvas") - canvas.width = cols - canvas.height = rows + const canvas = document.createElement("canvas"); + canvas.width = cols; + canvas.height = rows; Object.assign(canvas.style, { width: "100%", height: "100%", objectFit: "contain", backgroundColor: "black", - }) + }); fo.append("xhtml:div") .style("width", "100%") @@ -237,35 +204,35 @@ function visualiseCostmap( .style("alignItems", "center") .style("justifyContent", "center") .node() - ?.appendChild(canvas) + ?.appendChild(canvas); - const ctx = canvas.getContext("2d") + const ctx = canvas.getContext("2d"); if (ctx) { - const img = ctx.createImageData(cols, rows) - const data = grid.data // row‑major, (0,0) = world south‑west + const img = ctx.createImageData(cols, rows); + const data = grid.data; // row‑major, (0,0) = world south‑west // Flip vertically so world north appears at top of SVG for (let i = 0; i < data.length; i++) { - const row = Math.floor(i / cols) - const col = i % cols + const row = Math.floor(i / cols); + const col = i % cols; // Flip Y coordinate (invert row) to put origin at bottom-left - const invertedRow = rows - 1 - row - const srcIdx = invertedRow * cols + col - - const value = data[i] // Get value from original index - const c = d3.color(colour(value)) - if (!c) continue - const o = srcIdx * 4 // Write to flipped position - img.data[o] = c.r ?? 0 - img.data[o + 1] = c.g ?? 0 - img.data[o + 2] = c.b ?? 0 - img.data[o + 3] = 255 + const invertedRow = rows - 1 - row; + const srcIdx = invertedRow * cols + col; + + const value = data[i]; // Get value from original index + const c = d3.color(colour(value)); + if (!c) continue; + const o = srcIdx * 4; // Write to flipped position + img.data[o] = c.r ?? 0; + img.data[o + 1] = c.g ?? 0; + img.data[o + 2] = c.b ?? 0; + img.data[o + 3] = 255; } - ctx.putImageData(img, 0, 0) + ctx.putImageData(img, 0, 0); } - addCoordinateSystem(group, gridW, gridH, origin, resolution) + addCoordinateSystem(group, gridW, gridH, origin, resolution); } // ─────────────────────────────────────────────────────────────────────────────── @@ -278,25 +245,25 @@ function addCoordinateSystem( origin: Vector, resolution: number, ): void { - const minX = origin.coords[0] - const minY = origin.coords[1] + const minX = origin.coords[0]; + const minY = origin.coords[1]; - const maxX = minX + (width * resolution) - const maxY = minY + (height * resolution) - console.log(group, width, origin, maxX) + const maxX = minX + (width * resolution); + const maxY = minY + (height * resolution); + //console.log(group, width, origin, maxX); const xScale = d3.scaleLinear().domain([ minX, maxX, - ]).range([0, width]) + ]).range([0, width]); const yScale = d3.scaleLinear().domain([ minY, maxY, - ]).range([height, 0]) + ]).range([height, 0]); - const gridSize = 1.0 - const gridColour = "#000" - const gridGroup = group.append("g").attr("class", "grid") + const gridSize = 1 / resolution; + const gridColour = "#000"; + const gridGroup = group.append("g").attr("class", "grid"); for ( const x of d3.range( @@ -313,7 +280,7 @@ function addCoordinateSystem( .attr("y2", height) .attr("stroke", gridColour) .attr("stroke-width", 0.5) - .attr("opacity", 0.25) + .attr("opacity", 0.25); } for ( @@ -331,7 +298,7 @@ function addCoordinateSystem( .attr("y2", yScale(y)) .attr("stroke", gridColour) .attr("stroke-width", 0.5) - .attr("opacity", 0.25) + .attr("opacity", 0.25); } const stylise = ( @@ -339,22 +306,22 @@ function addCoordinateSystem( ) => { sel.selectAll("line,path") .attr("stroke", "#ffffff") - .attr("stroke-width", 1) + .attr("stroke-width", 1); sel.selectAll("text") - .attr("fill", "#ffffff") // Change the color here - } + .attr("fill", "#ffffff"); // Change the color here + }; group .append("g") .attr("transform", `translate(0, ${height})`) .call(d3.axisBottom(xScale).ticks(7)) - .call(stylise) - group.append("g").call(d3.axisLeft(yScale).ticks(7)).call(stylise) + .call(stylise); + group.append("g").call(d3.axisLeft(yScale).ticks(7)).call(stylise); if (minX <= 0 && 0 <= maxX && minY <= 0 && 0 <= maxY) { const originPoint = group.append("g") .attr("class", "origin-marker") - .attr("transform", `translate(${xScale(0)}, ${yScale(0)})`) + .attr("transform", `translate(${xScale(0)}, ${yScale(0)})`); // Add outer ring originPoint.append("circle") @@ -362,7 +329,7 @@ function addCoordinateSystem( .attr("fill", "none") .attr("stroke", "#00e676") .attr("stroke-width", 1) - .attr("opacity", 0.5) + .attr("opacity", 0.5); // Add center point originPoint.append("circle") @@ -370,7 +337,7 @@ function addCoordinateSystem( .attr("fill", "#00e676") .attr("opacity", 0.9) .append("title") - .text("World Origin (0,0)") + .text("World Origin (0,0)"); } } @@ -385,17 +352,17 @@ function visualisePath( width: number, height: number, ): void { - if (path.coords.length < 2) return + if (path.coords.length < 2) return; const points = path.coords.map(([x, y]) => { - return wp ? wp(x, y) : [width / 2 + x, height / 2 - y] - }) + return wp ? wp(x, y) : [width / 2 + x, height / 2 - y]; + }); // Create a path line - const line = d3.line() + const line = d3.line(); // Create a gradient for the path - const pathId = `path-gradient-${label.replace(/\s+/g, "-")}` + const pathId = `path-gradient-${label.replace(/\s+/g, "-")}`; svg.append("defs") .append("linearGradient") @@ -414,7 +381,7 @@ function visualisePath( ]) .enter().append("stop") .attr("offset", (d) => d.offset) - .attr("stop-color", (d) => d.color) + .attr("stop-color", (d) => d.color); // Create the path with gradient and animation svg.append("path") @@ -425,7 +392,7 @@ function visualisePath( .attr("stroke-linecap", "round") .attr("filter", "url(#glow)") .attr("opacity", 0.9) - .attr("d", line) + .attr("d", line); } // ─────────────────────────────────────────────────────────────────────────────── @@ -441,12 +408,12 @@ function visualiseVector( ): void { const [cx, cy] = wp ? wp(vector.coords[0], vector.coords[1]) - : [width / 2 + vector.coords[0], height / 2 - vector.coords[1]] + : [width / 2 + vector.coords[0], height / 2 - vector.coords[1]]; // Create a vector marker group const vectorGroup = svg.append("g") .attr("class", "vector-marker") - .attr("transform", `translate(${cx}, ${cy})`) + .attr("transform", `translate(${cx}, ${cy})`); // Add a glowing outer ring vectorGroup.append("circle") @@ -455,21 +422,21 @@ function visualiseVector( // .attr("stroke", "#4fc3f7") .attr("stroke", "red") .attr("stroke-width", "1") - .attr("opacity", 0.9) + .attr("opacity", 0.9); // Add inner dot vectorGroup.append("circle") .attr("r", ".4em") // .attr("fill", "#4fc3f7") - .attr("fill", "red") + .attr("fill", "red"); // Add text with background const text = `${label} (${vector.coords[0].toFixed(2)}, ${ vector.coords[1].toFixed(2) - })` + })`; // Create a group for the text and background - const textGroup = svg.append("g") + const textGroup = svg.append("g"); // Add text element const textElement = textGroup @@ -478,10 +445,10 @@ function visualiseVector( .attr("y", cy + 25) .attr("font-size", "1em") .attr("fill", "white") - .text(text) + .text(text); // Add background rect - const bbox = textElement.node()?.getBBox() + const bbox = textElement.node()?.getBBox(); if (bbox) { textGroup .insert("rect", "text") @@ -491,7 +458,7 @@ function visualiseVector( .attr("height", bbox.height + 2) .attr("fill", "black") .attr("stroke", "black") - .attr("opacity", 0.75) + .attr("opacity", 0.75); } } @@ -499,131 +466,160 @@ function visualiseVector( // Wrapper class // ─────────────────────────────────────────────────────────────────────────────── export class Visualizer { - private container: HTMLElement | null - private state: Record = {} - private resizeObserver: ResizeObserver | null = null - private root: ReactDOMClient.Root - private onClickCallback: ((worldX: number, worldY: number) => void) | null = - null + private container: HTMLElement | null; + private state: Record = {}; + private resizeObserver: ResizeObserver | null = null; + private root: ReactDOMClient.Root; + private onClickCallback: ((worldX: number, worldY: number) => void) | null = null; + private lastClickTime: number = 0; + private clickThrottleMs: number = 150; // Minimum ms between processed clicks constructor(selector: string) { - this.container = document.querySelector(selector) - if (!this.container) throw new Error(`Container not found: ${selector}`) - this.root = ReactDOMClient.createRoot(this.container) + this.container = document.querySelector(selector); + if (!this.container) { + throw new Error(`Container not found: ${selector}`); + } + this.root = ReactDOMClient.createRoot(this.container); // First paint - this.render() + this.render(); // Keep canvas responsive if (window.ResizeObserver) { - this.resizeObserver = new ResizeObserver(() => this.render()) - this.resizeObserver.observe(this.container) + this.resizeObserver = new ResizeObserver(() => this.render()); + this.resizeObserver.observe(this.container); } - // Set up global click handler to capture world coordinates - document.addEventListener("click", this.handleGlobalClick.bind(this)) + // Bind the click handler once to preserve reference for cleanup + this.handleGlobalClick = this.handleGlobalClick.bind(this); + + // Set up click handler directly on the container with capture phase + // This ensures we get the event before any SVG elements + if (this.container) { + this.container.addEventListener("click", this.handleGlobalClick, true); + } } /** Register a callback for when user clicks on the visualization */ public onWorldClick( callback: (worldX: number, worldY: number) => void, ): void { - this.onClickCallback = callback + this.onClickCallback = callback; } /** Handle global click events, filtering for clicks within our SVG */ private handleGlobalClick(event: MouseEvent): void { - if (!this.onClickCallback || !this.container) return - - // Check if click was inside our container - const containerRect = this.container.getBoundingClientRect() - if ( - event.clientX < containerRect.left || - event.clientX > containerRect.right || - event.clientY < containerRect.top || - event.clientY > containerRect.bottom - ) { - return // Click was outside our container + if (!this.onClickCallback || !this.container) return; + + // Stop propagation to prevent other handlers from interfering + event.stopPropagation(); + + // Throttle clicks to prevent issues with high refresh rates + const now = Date.now(); + if (now - this.lastClickTime < this.clickThrottleMs) { + console.log("Click throttled"); + return; } - + this.lastClickTime = now; + + // We don't need to check if click was inside container since we're attaching + // the event listener directly to the container + + console.log("Processing click at", event.clientX, event.clientY); + // Find our SVG element - const svgElement = this.container.querySelector("svg") - if (!svgElement) return + const svgElement = this.container.querySelector("svg"); + if (!svgElement) return; // Calculate click position relative to SVG viewport - const svgRect = svgElement.getBoundingClientRect() - const viewportX = event.clientX - svgRect.left - const viewportY = event.clientY - svgRect.top + const svgRect = svgElement.getBoundingClientRect(); + const viewportX = event.clientX - svgRect.left; + const viewportY = event.clientY - svgRect.top; // Convert to SVG coordinate space (accounting for viewBox) - const svgPoint = new DOMPoint(viewportX, viewportY) + const svgPoint = new DOMPoint(viewportX, viewportY); const transformedPoint = svgPoint.matrixTransform( svgElement.getScreenCTM()?.inverse() || new DOMMatrix(), - ) + ); // Find a costmap to use for coordinate conversion const costmap = Object.values(this.state).find( (d): d is Costmap => d instanceof Costmap, - ) + ); - if (!costmap) return + if (!costmap) return; const { grid: { shape }, origin, resolution, - } = costmap - const [rows, cols] = shape + } = costmap; + const [rows, cols] = shape; // Use the current SVG dimensions instead of hardcoded values - const width = svgRect.width - const height = svgRect.height + const width = svgRect.width; + const height = svgRect.height; // Calculate scales (same logic as in the component) - const cell = Math.min(width / cols, height / rows) - const gridW = cols * cell - const gridH = rows * cell - const offsetX = (width - gridW) / 2 - const offsetY = (height - gridH) / 2 + const cell = Math.min(width / cols, height / rows); + const gridW = cols * cell; + const gridH = rows * cell; + const offsetX = (width - gridW) / 2; + const offsetY = (height - gridH) / 2; const xScale = d3 .scaleLinear() .domain([offsetX, offsetX + gridW]) - .range([origin.coords[0], origin.coords[0] + cols * resolution]) + .range([origin.coords[0], origin.coords[0] + cols * resolution]); const yScale = d3 .scaleLinear() .domain([offsetY + gridH, offsetY]) - .range([origin.coords[1], origin.coords[1] + rows * resolution]) + .range([origin.coords[1], origin.coords[1] + rows * resolution]); // Convert to world coordinates - const worldX = xScale(transformedPoint.x) - const worldY = yScale(transformedPoint.y) + const worldX = xScale(transformedPoint.x); + const worldY = yScale(transformedPoint.y); + console.log("Calling callback with world coords:", worldX.toFixed(2), worldY.toFixed(2)); + // Call the callback with the world coordinates - this.onClickCallback(worldX, worldY) + this.onClickCallback(worldX, worldY); } /** Push a new application‑state snapshot to the visualiser */ public visualizeState(state: Record): void { - this.state = { ...state } - this.render() + // Store reference to current state before updating + const prevState = this.state; + this.state = { ...state }; + + // Don't re-render if we're currently processing a click + const timeSinceLastClick = Date.now() - this.lastClickTime; + if (timeSinceLastClick < this.clickThrottleMs) { + console.log("Skipping render during click processing"); + return; + } + + this.render(); } /** React‑render the component tree */ private render(): void { - this.root.render() + this.root.render(); } /** Tear down listeners and free resources */ public cleanup(): void { if (this.resizeObserver && this.container) { - this.resizeObserver.unobserve(this.container) - this.resizeObserver.disconnect() + this.resizeObserver.unobserve(this.container); + this.resizeObserver.disconnect(); + } + + if (this.container) { + this.container.removeEventListener("click", this.handleGlobalClick, true); } - document.removeEventListener("click", this.handleGlobalClick.bind(this)) } } // Convenience factory ---------------------------------------------------------- export function createReactVis(selector: string): Visualizer { - return new Visualizer(selector) + return new Visualizer(selector); } diff --git a/dimos/web/websocket_vis/server.py b/dimos/web/websocket_vis/server.py index 9155319ea1..ea9fbb00d5 100644 --- a/dimos/web/websocket_vis/server.py +++ b/dimos/web/websocket_vis/server.py @@ -148,8 +148,6 @@ def process_drawable(self, drawable: Drawable): def connect(self, obs: Observable[Tuple[str, Drawable]], window_name: str = "main"): """Connect to an Observable stream and update state on new data""" - # Subscribe to the stream - print("Subing to", obs) def new_update(data): [name, drawable] = data diff --git a/dimos/web/websocket_vis/static/js/clientside.js b/dimos/web/websocket_vis/static/js/clientside.js index fdfc2f57da..6aaaa7089c 100644 --- a/dimos/web/websocket_vis/static/js/clientside.js +++ b/dimos/web/websocket_vis/static/js/clientside.js @@ -16240,7 +16240,6 @@ var Grid = class _Grid { // clientside/decoder.ts function decode3(data) { - console.log("decoding", data); if (data.type == "costmap") { return Costmap.decode(data); } @@ -19636,25 +19635,6 @@ var VisualizerComponent = ({ pxToWorld: pxToWorldFn }; }, [state]); - const handleClick = React.useCallback((event) => { - if (!svgRef.current || !pxToWorld) return; - const svgRect = svgRef.current.getBoundingClientRect(); - const viewportX = event.clientX - svgRect.left; - const viewportY = event.clientY - svgRect.top; - const svgPoint = new DOMPoint(viewportX, viewportY); - const transformedPoint = svgPoint.matrixTransform( - svgRef.current.getScreenCTM()?.inverse() - ); - const [worldX, worldY] = pxToWorld( - transformedPoint.x, - transformedPoint.y - ); - console.log( - "Click at world coordinates:", - worldX.toFixed(2), - worldY.toFixed(2) - ); - }, [pxToWorld]); React.useEffect(() => { if (!svgRef.current) return; const svg = select_default2(svgRef.current); @@ -19672,12 +19652,7 @@ var VisualizerComponent = ({ visualiseVector(svg, d, key, worldToPx, width, height); } }); - const svgElement = svgRef.current; - svgElement.addEventListener("click", handleClick); - return () => { - svgElement.removeEventListener("click", handleClick); - }; - }, [state, worldToPx, handleClick]); + }, [state, worldToPx]); return /* @__PURE__ */ React.createElement( "div", { @@ -19695,7 +19670,9 @@ var VisualizerComponent = ({ style: { backgroundColor: "black", borderRadius: "8px", - boxShadow: "0 4px 12px rgba(0, 0, 0, 0.15)" + boxShadow: "0 4px 12px rgba(0, 0, 0, 0.15)", + pointerEvents: "none" + // Make SVG transparent to pointer events } } ) @@ -19765,7 +19742,6 @@ function addCoordinateSystem(group, width, height, origin, resolution) { const minY = origin.coords[1]; const maxX = minX + width * resolution; const maxY = minY + height * resolution; - console.log(group, width, origin, maxX); const xScale = linear2().domain([ minX, maxX @@ -19774,7 +19750,7 @@ function addCoordinateSystem(group, width, height, origin, resolution) { minY, maxY ]).range([height, 0]); - const gridSize = 1; + const gridSize = 1 / resolution; const gridColour = "#000"; const gridGroup = group.append("g").attr("class", "grid"); for (const x2 of range( @@ -19832,21 +19808,29 @@ function visualiseVector(svg, vector, label, wp, width, height) { } } var Visualizer = class { + // Minimum ms between processed clicks constructor(selector) { __publicField(this, "container"); __publicField(this, "state", {}); __publicField(this, "resizeObserver", null); __publicField(this, "root"); __publicField(this, "onClickCallback", null); + __publicField(this, "lastClickTime", 0); + __publicField(this, "clickThrottleMs", 150); this.container = document.querySelector(selector); - if (!this.container) throw new Error(`Container not found: ${selector}`); + if (!this.container) { + throw new Error(`Container not found: ${selector}`); + } this.root = ReactDOMClient.createRoot(this.container); this.render(); if (window.ResizeObserver) { this.resizeObserver = new ResizeObserver(() => this.render()); this.resizeObserver.observe(this.container); } - document.addEventListener("click", this.handleGlobalClick.bind(this)); + this.handleGlobalClick = this.handleGlobalClick.bind(this); + if (this.container) { + this.container.addEventListener("click", this.handleGlobalClick, true); + } } /** Register a callback for when user clicks on the visualization */ onWorldClick(callback) { @@ -19855,10 +19839,14 @@ var Visualizer = class { /** Handle global click events, filtering for clicks within our SVG */ handleGlobalClick(event) { if (!this.onClickCallback || !this.container) return; - const containerRect = this.container.getBoundingClientRect(); - if (event.clientX < containerRect.left || event.clientX > containerRect.right || event.clientY < containerRect.top || event.clientY > containerRect.bottom) { + event.stopPropagation(); + const now2 = Date.now(); + if (now2 - this.lastClickTime < this.clickThrottleMs) { + console.log("Click throttled"); return; } + this.lastClickTime = now2; + console.log("Processing click at", event.clientX, event.clientY); const svgElement = this.container.querySelector("svg"); if (!svgElement) return; const svgRect = svgElement.getBoundingClientRect(); @@ -19889,11 +19877,18 @@ var Visualizer = class { const yScale = linear2().domain([offsetY + gridH, offsetY]).range([origin.coords[1], origin.coords[1] + rows * resolution]); const worldX = xScale(transformedPoint.x); const worldY = yScale(transformedPoint.y); + console.log("Calling callback with world coords:", worldX.toFixed(2), worldY.toFixed(2)); this.onClickCallback(worldX, worldY); } /** Push a new application‑state snapshot to the visualiser */ visualizeState(state) { + const prevState = this.state; this.state = { ...state }; + const timeSinceLastClick = Date.now() - this.lastClickTime; + if (timeSinceLastClick < this.clickThrottleMs) { + console.log("Skipping render during click processing"); + return; + } this.render(); } /** React‑render the component tree */ @@ -19906,7 +19901,9 @@ var Visualizer = class { this.resizeObserver.unobserve(this.container); this.resizeObserver.disconnect(); } - document.removeEventListener("click", this.handleGlobalClick.bind(this)); + if (this.container) { + this.container.removeEventListener("click", this.handleGlobalClick, true); + } } }; @@ -19928,7 +19925,6 @@ socket.on("disconnect", () => { serverState.status = "disconnected"; }); socket.on("message", (data) => { - console.log("Received message:", data); }); function deepMerge(source, destination) { for (const key in source) { @@ -19948,11 +19944,9 @@ function decodeDrawables(encoded) { return drawables; } function state_update(state) { - console.log("Received state update:", state); if (state.draw) { state.draw = decodeDrawables(state.draw); } - console.log("Decoded state update:", state); serverState = { ...deepMerge(state, { ...serverState }) }; updateUI(); } @@ -19962,7 +19956,6 @@ function emitMessage(data) { socket.emit("message", data); } function updateUI() { - console.log("Current state:", serverState); if (serverState.draw && Object.keys(serverState.draw).length > 0) { if (reactVisualizer) { reactVisualizer.visualizeState(serverState.draw); diff --git a/pyproject.toml b/pyproject.toml index e2f39fd9dd..d9154ec8aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,9 +7,15 @@ version = "0.0.2" description = "Powering agentive generalist robotics" [tool.ruff] -# Allow lines to be as long as 120. line-length = 120 [tool.pytest.ini_options] testpaths = ["dimos/robot"] norecursedirs = ["dimos/robot/unitree/external"] +markers = [ + "vis: marks tests that run visuals and require a visual check by dev", + "benchmark: benchmark, executes something multiple times, calculates avg, prints to console", + "exclude: arbitrary exclusion from CI and default test exec", + "tool: dev tooling" +] +addopts = "-m 'not vis and not benchmark and not exclude and not tool'" diff --git a/requirements.txt b/requirements.txt index 00a25d4e16..cce69542ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,8 @@ reactivex git+https://github.com/dimensionalOS/rxpy-backpressure.git pytest-asyncio==0.26.0 asyncio==3.4.3 +-e git+https://github.com/legion1581/go2_webrtc_connect.git@fe64abb5987594e8c048427a98445799f6f6a9cc#egg=go2_webrtc_connect +-e git+https://github.com/legion1581/aioice.git@ff5755a1e37127411b5fc797c105804db8437445#egg=aioice # Web Extensions fastapi>=0.115.6 @@ -86,4 +88,4 @@ scikit-learn lvis nltk git+https://github.com/openai/CLIP.git -git+https://github.com/facebookresearch/detectron2.git@v0.6 \ No newline at end of file +git+https://github.com/facebookresearch/detectron2.git@v0.6 diff --git a/tests/run.py b/tests/run.py index 56ec5a7f90..d62c3a1103 100644 --- a/tests/run.py +++ b/tests/run.py @@ -18,7 +18,7 @@ import time from dotenv import load_dotenv from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.web.robot_web_interface import RobotWebInterface @@ -53,11 +53,11 @@ def parse_arguments(): # Initialize robot with spatial memory parameters robot = UnitreeGo2(ip=os.getenv('ROBOT_IP'), - ros_control=UnitreeROSControl(), skills=MyUnitreeSkills(), mock_connection=False, spatial_memory_dir=args.spatial_memory_dir, # Will use default if None - new_memory=args.new_memory) # Create a new memory if specified + new_memory=args.new_memory, # Create a new memory if specified + mode = "ai") # Create a subject for agent responses agent_response_subject = rx.subject.Subject() @@ -133,8 +133,8 @@ def combine_with_locations(object_detections): # stt_node = stt() # Read system query from prompt.txt file -with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'assets', 'agent', 'prompt.txt'), 'r') as f: - system_query = f.read() +# with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'prompt.txt'), 'r') as f: +# system_query = f.read() # Create a ClaudeAgent instance agent = ClaudeAgent( @@ -143,7 +143,7 @@ def combine_with_locations(object_detections): input_query_stream=web_interface.query_stream, input_data_stream=enhanced_data_stream, # Add the enhanced data stream skills=robot.get_skills(), - system_query=system_query, + system_query="What do you see", model_name="claude-3-7-sonnet-latest", thinking_budget_tokens=0 ) @@ -157,15 +157,15 @@ def combine_with_locations(object_detections): robot_skills.add(NavigateWithText) robot_skills.add(FollowHuman) robot_skills.add(GetPose) -# robot_skills.add(Speak) -robot_skills.add(NavigateToGoal) +robot_skills.add(Speak) +# robot_skills.add(NavigateToGoal) robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) robot_skills.create_instance("NavigateWithText", robot=robot) robot_skills.create_instance("FollowHuman", robot=robot) robot_skills.create_instance("GetPose", robot=robot) -robot_skills.create_instance("NavigateToGoal", robot=robot) -# robot_skills.create_instance("Speak", tts_node=tts_node) +# robot_skills.create_instance("NavigateToGoal", robot=robot) +robot_skills.create_instance("Speak", tts_node=tts_node) # Subscribe to agent responses and send them to the subject agent.get_response_observable().subscribe( diff --git a/tests/run_webrtc.py b/tests/run_webrtc.py new file mode 100644 index 0000000000..ff96bed14e --- /dev/null +++ b/tests/run_webrtc.py @@ -0,0 +1,75 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cv2 +import os +import asyncio +from dotenv import load_dotenv +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2, Color +from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream +from dimos.web.websocket_vis.server import WebsocketVis +from dimos.types.vector import Vector +import logging +import open3d as o3d +import reactivex.operators as ops +import numpy as np +import time +import threading + +# logging.basicConfig(level=logging.DEBUG) + +load_dotenv() +robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") + +websocket_vis = WebsocketVis() +websocket_vis.start() +websocket_vis.connect(robot.global_planner.vis_stream()) + + +def msg_handler(msgtype, data): + if msgtype == "click": + try: + robot.global_planner.set_goal(Vector(data["position"])) + except Exception as e: + print(f"Error setting goal: {e}") + return + + +def threaded_msg_handler(msgtype, data): + thread = threading.Thread(target=msg_handler, args=(msgtype, data)) + thread.daemon = True + thread.start() + + +websocket_vis.msg_handler = threaded_msg_handler + +print("standing up") +robot.standup() +print("robot is up") + + +def newmap(msg): + return ["costmap", robot.map.costmap.smudge()] + + +websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) +websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) + +try: + while True: + # robot.move_vel(Vector(0.1, 0.1, 0.1)) + time.sleep(0.01) + +except KeyboardInterrupt: + print("Stopping robot") + robot.liedown() diff --git a/tests/test_robot.py b/tests/test_robot.py index 3e4a968799..b452100c85 100644 --- a/tests/test_robot.py +++ b/tests/test_robot.py @@ -4,7 +4,7 @@ from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.robot.local_planner.local_planner import navigate_to_goal_local +from dimos.robot.local_planner import navigate_to_goal_local from dimos.web.robot_web_interface import RobotWebInterface from reactivex import operators as RxOps import tests.test_header